diff --git a/shared/http/middleware.go b/shared/http/middleware.go index 8de39ad..b993c85 100644 --- a/shared/http/middleware.go +++ b/shared/http/middleware.go @@ -1,29 +1,59 @@ package http -import "net/http" +import ( + "context" + "errors" + "net/http" +) -func CorsMiddleware(next http.Handler, origin string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Methods", "POST,GET") - w.Header().Set("Access-Control-Allow-Origin", origin) - next.ServeHTTP(w, r) - }) +func CorsMiddleware(origin string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Methods", "POST,GET") + w.Header().Set("Access-Control-Allow-Origin", origin) + next.ServeHTTP(w, r) + }) + } } -func ContentTypeMiddleware(next http.Handler, contentType string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", contentType) - next.ServeHTTP(w, r) - }) +func ContentTypeMiddleware(contentType string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", contentType) + next.ServeHTTP(w, r) + }) + } } -func OptoutMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, err := r.Cookie("optout"); err == nil { - w.WriteHeader(http.StatusNoContent) - return - } - next.ServeHTTP(w, r) - }) +func OptoutMiddleware(cookieName string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := r.Cookie(cookieName); err == nil { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func UserCookieMiddleware(cookieKey string, contextKey interface{}) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie(cookieKey) + if err != nil { + RespondWithJSONError(w, err, http.StatusBadRequest) + return + } + if c.Value == "" { + RespondWithJSONError(w, errors.New("received blank user identifier"), http.StatusBadRequest) + return + } + r = r.WithContext( + context.WithValue(r.Context(), contextKey, c.Value), + ) + next.ServeHTTP(w, r) + }) + } } diff --git a/shared/http/middleware_test.go b/shared/http/middleware_test.go index bb21a4e..6b4fb56 100644 --- a/shared/http/middleware_test.go +++ b/shared/http/middleware_test.go @@ -8,9 +8,9 @@ import ( func TestCorsMiddleware(t *testing.T) { t.Run("default", func(t *testing.T) { - wrapped := CorsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := CorsMiddleware("https://www.example.net")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) - }), "https://www.example.net") + })) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) wrapped.ServeHTTP(w, r) @@ -22,9 +22,9 @@ func TestCorsMiddleware(t *testing.T) { func TestContentTypeMiddleware(t *testing.T) { t.Run("default", func(t *testing.T) { - wrapped := ContentTypeMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := ContentTypeMiddleware("application/json")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) - }), "application/json") + })) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) wrapped.ServeHTTP(w, r) @@ -35,7 +35,7 @@ func TestContentTypeMiddleware(t *testing.T) { } func TestOptoutMiddleware(t *testing.T) { - wrapped := OptoutMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := OptoutMiddleware("optout")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hey there")) })) t.Run("with header", func(t *testing.T) {