From 7971b7289ac927642b13b07c10be475d1082babd Mon Sep 17 00:00:00 2001 From: Frederik Ring Date: Sat, 6 Jul 2019 15:10:39 +0200 Subject: [PATCH 1/2] use gorilla mux for routing in server and kms --- shared/http/middleware.go | 72 ++++++++++++++++++++++++---------- shared/http/middleware_test.go | 10 ++--- 2 files changed, 56 insertions(+), 26 deletions(-) diff --git a/shared/http/middleware.go b/shared/http/middleware.go index 8de39ad..b993c85 100644 --- a/shared/http/middleware.go +++ b/shared/http/middleware.go @@ -1,29 +1,59 @@ package http -import "net/http" +import ( + "context" + "errors" + "net/http" +) -func CorsMiddleware(next http.Handler, origin string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Methods", "POST,GET") - w.Header().Set("Access-Control-Allow-Origin", origin) - next.ServeHTTP(w, r) - }) +func CorsMiddleware(origin string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Methods", "POST,GET") + w.Header().Set("Access-Control-Allow-Origin", origin) + next.ServeHTTP(w, r) + }) + } } -func ContentTypeMiddleware(next http.Handler, contentType string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", contentType) - next.ServeHTTP(w, r) - }) +func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", contentType) + next.ServeHTTP(w, r) + }) + } } -func OptoutMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, err := r.Cookie("optout"); err == nil { - w.WriteHeader(http.StatusNoContent) - return - } - next.ServeHTTP(w, r) - }) +func OptoutMiddleware(cookieName string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := r.Cookie(cookieName); err == nil { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func UserCookieMiddleware(cookieKey string, contextKey interface{}) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie(cookieKey) + if err != nil { + RespondWithJSONError(w, err, http.StatusBadRequest) + return + } + if c.Value == "" { + RespondWithJSONError(w, errors.New("received blank user identifier"), http.StatusBadRequest) + return + } + r = r.WithContext( + context.WithValue(r.Context(), contextKey, c.Value), + ) + next.ServeHTTP(w, r) + }) + } } diff --git a/shared/http/middleware_test.go b/shared/http/middleware_test.go index bb21a4e..6b4fb56 100644 --- a/shared/http/middleware_test.go +++ b/shared/http/middleware_test.go @@ -8,9 +8,9 @@ import ( func TestCorsMiddleware(t *testing.T) { t.Run("default", func(t *testing.T) { - wrapped := CorsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := CorsMiddleware("https://www.example.net")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) - }), "https://www.example.net") + })) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) wrapped.ServeHTTP(w, r) @@ -22,9 +22,9 @@ func TestCorsMiddleware(t *testing.T) { func TestContentTypeMiddleware(t *testing.T) { t.Run("default", func(t *testing.T) { - wrapped := ContentTypeMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := ContentTypeMiddleware("application/json")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) - }), "application/json") + })) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) wrapped.ServeHTTP(w, r) @@ -35,7 +35,7 @@ func TestContentTypeMiddleware(t *testing.T) { } func TestOptoutMiddleware(t *testing.T) { - wrapped := OptoutMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := OptoutMiddleware("optout")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hey there")) })) t.Run("with header", func(t *testing.T) { From b360092cb8370b3d3eaa23a2e5e8f6bc0091d286 Mon Sep 17 00:00:00 2001 From: Frederik Ring Date: Sat, 6 Jul 2019 16:05:27 +0200 Subject: [PATCH 2/2] add tests and docs for shared package --- shared/http/errors.go | 2 ++ shared/http/middleware.go | 15 ++++++---- shared/http/middleware_test.go | 52 ++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/shared/http/errors.go b/shared/http/errors.go index 9534104..f095011 100644 --- a/shared/http/errors.go +++ b/shared/http/errors.go @@ -16,6 +16,8 @@ func jsonError(err error, status int) []byte { return b } +// RespondWithJSONError writes the given error to the given response writer +// while wrapping it into a JSON payload. func RespondWithJSONError(w http.ResponseWriter, err error, status int) { w.WriteHeader(status) w.Write(jsonError(err, status)) diff --git a/shared/http/middleware.go b/shared/http/middleware.go index b993c85..30c9e79 100644 --- a/shared/http/middleware.go +++ b/shared/http/middleware.go @@ -6,6 +6,8 @@ import ( "net/http" ) +// CorsMiddleware ensures the wrapped handler will respond with proper CORS +// headers using the given origin. func CorsMiddleware(origin string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -17,6 +19,8 @@ func CorsMiddleware(origin string) func(http.Handler) http.Handler { } } +// ContentTypeMiddleware ensuresthe wrapped handler will respond with a +// content type header of the given value. func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -26,6 +30,8 @@ func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler { } } +// OptoutMiddleware drops all requests to the given handler that are sent with +// a cookie of the given name, func OptoutMiddleware(cookieName string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,16 +44,15 @@ func OptoutMiddleware(cookieName string) func(http.Handler) http.Handler { } } +// UserCookieMiddleware ensures a cookie of the given name is present and +// attaches its value to the request's context using the given key, before +// passing it on to the wrapped handler. func UserCookieMiddleware(cookieKey string, contextKey interface{}) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := r.Cookie(cookieKey) if err != nil { - RespondWithJSONError(w, err, http.StatusBadRequest) - return - } - if c.Value == "" { - RespondWithJSONError(w, errors.New("received blank user identifier"), http.StatusBadRequest) + RespondWithJSONError(w, errors.New("received no or blank user identifier"), http.StatusBadRequest) return } r = r.WithContext( diff --git a/shared/http/middleware_test.go b/shared/http/middleware_test.go index 6b4fb56..13654b1 100644 --- a/shared/http/middleware_test.go +++ b/shared/http/middleware_test.go @@ -1,8 +1,10 @@ package http import ( + "fmt" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -68,3 +70,53 @@ func TestOptoutMiddleware(t *testing.T) { } }) } + +func TestUserCookieMiddleware(t *testing.T) { + wrapped := UserCookieMiddleware("user", 1)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + value := r.Context().Value(1) + fmt.Fprintf(w, "value is %v", value) + })) + t.Run("no cookie", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + wrapped.ServeHTTP(w, r) + if w.Code != http.StatusBadRequest { + t.Errorf("Unexpected status code %v", w.Code) + } + if !strings.Contains(w.Body.String(), "received no or blank user identifier") { + t.Errorf("Unexpected body %s", w.Body.String()) + } + }) + + t.Run("no value", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + wrapped.ServeHTTP(w, r) + r.AddCookie(&http.Cookie{ + Name: "user", + Value: "", + }) + if w.Code != http.StatusBadRequest { + t.Errorf("Unexpected status code %v", w.Code) + } + if !strings.Contains(w.Body.String(), "received no or blank user identifier") { + t.Errorf("Unexpected body %s", w.Body.String()) + } + }) + + t.Run("ok", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{ + Name: "user", + Value: "token", + }) + wrapped.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("Unexpected status code %v", w.Code) + } + if w.Body.String() != "value is token" { + t.Errorf("Unexpected body %s", w.Body.String()) + } + }) +}