diff --git a/shared/http/errors.go b/shared/http/errors.go index 9534104..f095011 100644 --- a/shared/http/errors.go +++ b/shared/http/errors.go @@ -16,6 +16,8 @@ func jsonError(err error, status int) []byte { return b } +// RespondWithJSONError writes the given error to the given response writer +// while wrapping it into a JSON payload. func RespondWithJSONError(w http.ResponseWriter, err error, status int) { w.WriteHeader(status) w.Write(jsonError(err, status)) diff --git a/shared/http/middleware.go b/shared/http/middleware.go index b993c85..30c9e79 100644 --- a/shared/http/middleware.go +++ b/shared/http/middleware.go @@ -6,6 +6,8 @@ import ( "net/http" ) +// CorsMiddleware ensures the wrapped handler will respond with proper CORS +// headers using the given origin. func CorsMiddleware(origin string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -17,6 +19,8 @@ func CorsMiddleware(origin string) func(http.Handler) http.Handler { } } +// ContentTypeMiddleware ensuresthe wrapped handler will respond with a +// content type header of the given value. func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -26,6 +30,8 @@ func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler { } } +// OptoutMiddleware drops all requests to the given handler that are sent with +// a cookie of the given name, func OptoutMiddleware(cookieName string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,16 +44,15 @@ func OptoutMiddleware(cookieName string) func(http.Handler) http.Handler { } } +// UserCookieMiddleware ensures a cookie of the given name is present and +// attaches its value to the request's context using the given key, before +// passing it on to the wrapped handler. func UserCookieMiddleware(cookieKey string, contextKey interface{}) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := r.Cookie(cookieKey) if err != nil { - RespondWithJSONError(w, err, http.StatusBadRequest) - return - } - if c.Value == "" { - RespondWithJSONError(w, errors.New("received blank user identifier"), http.StatusBadRequest) + RespondWithJSONError(w, errors.New("received no or blank user identifier"), http.StatusBadRequest) return } r = r.WithContext( diff --git a/shared/http/middleware_test.go b/shared/http/middleware_test.go index 6b4fb56..13654b1 100644 --- a/shared/http/middleware_test.go +++ b/shared/http/middleware_test.go @@ -1,8 +1,10 @@ package http import ( + "fmt" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -68,3 +70,53 @@ func TestOptoutMiddleware(t *testing.T) { } }) } + +func TestUserCookieMiddleware(t *testing.T) { + wrapped := UserCookieMiddleware("user", 1)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + value := r.Context().Value(1) + fmt.Fprintf(w, "value is %v", value) + })) + t.Run("no cookie", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + wrapped.ServeHTTP(w, r) + if w.Code != http.StatusBadRequest { + t.Errorf("Unexpected status code %v", w.Code) + } + if !strings.Contains(w.Body.String(), "received no or blank user identifier") { + t.Errorf("Unexpected body %s", w.Body.String()) + } + }) + + t.Run("no value", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + wrapped.ServeHTTP(w, r) + r.AddCookie(&http.Cookie{ + Name: "user", + Value: "", + }) + if w.Code != http.StatusBadRequest { + t.Errorf("Unexpected status code %v", w.Code) + } + if !strings.Contains(w.Body.String(), "received no or blank user identifier") { + t.Errorf("Unexpected body %s", w.Body.String()) + } + }) + + t.Run("ok", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{ + Name: "user", + Value: "token", + }) + wrapped.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("Unexpected status code %v", w.Code) + } + if w.Body.String() != "value is token" { + t.Errorf("Unexpected body %s", w.Body.String()) + } + }) +}