diff --git a/accounts/accounts/__init__.py b/accounts/accounts/__init__.py index dee5a29..ba3e558 100644 --- a/accounts/accounts/__init__.py +++ b/accounts/accounts/__init__.py @@ -44,6 +44,7 @@ def post_login(): expires=expiry, path="/", domain=environ.get("COOKIE_DOMAIN"), + samesite="strict" ) return resp diff --git a/accounts/serverless.yml b/accounts/serverless.yml index 87839e6..aaedaf7 100644 --- a/accounts/serverless.yml +++ b/accounts/serverless.yml @@ -33,6 +33,10 @@ custom: production: accounts.offen.dev staging: accounts-staging.offen.dev alpha: accounts-alpha.offen.dev + cookieDomain: + production: .offen.dev + staging: .offen.dev + alpha: .offen.dev customDomain: basePath: '' certificateName: '*.offen.dev' @@ -64,6 +68,7 @@ functions: environment: USER: offen CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}} + COOKIE_DOMAIN: ${self:custom.origin.${self:custom.stage}} JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}' JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}' HASHED_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedPassword~true} diff --git a/shared/http/jwt.go b/shared/http/jwt.go index 45efbf8..ee74170 100644 --- a/shared/http/jwt.go +++ b/shared/http/jwt.go @@ -14,10 +14,6 @@ import ( "github.com/lestrrat-go/jwx/jwt" ) -type keyResponse struct { - Key string `json:"key"` -} - type contextKey string // ClaimsContextKey will be used to attach a JWT claim to a request context @@ -34,27 +30,25 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler { RespondWithJSONError(w, err, http.StatusForbidden) return } - keyRes, keyErr := http.Get(keyURL) + + keyRes, keyErr := fetchKey(keyURL) if keyErr != nil { RespondWithJSONError(w, keyErr, http.StatusInternalServerError) return } - defer keyRes.Body.Close() - payload := keyResponse{} - if err := json.NewDecoder(keyRes.Body).Decode(&payload); err != nil { - RespondWithJSONError(w, keyErr, http.StatusBadGateway) - return - } - keyBytes, _ := pem.Decode([]byte(payload.Key)) + + keyBytes, _ := pem.Decode([]byte(keyRes)) if keyBytes == nil { RespondWithJSONError(w, errors.New("no pem block found"), http.StatusInternalServerError) return } + parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes) if parseErr != nil { RespondWithJSONError(w, parseErr, http.StatusBadGateway) return } + pubKey, pubKeyOk := parseResult.(*rsa.PublicKey) if !pubKeyOk { RespondWithJSONError(w, errors.New("unable to use given key"), http.StatusInternalServerError) @@ -79,3 +73,20 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler { }) } } + +type keyResponse struct { + Key string `json:"key"` +} + +func fetchKey(keyURL string) ([]byte, error) { + fetchRes, fetchErr := http.Get(keyURL) + if fetchErr != nil { + return nil, fetchErr + } + defer fetchRes.Body.Close() + payload := keyResponse{} + if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil { + return nil, err + } + return []byte(payload.Key), nil +} diff --git a/shared/http/jwt_test.go b/shared/http/jwt_test.go new file mode 100644 index 0000000..80f0918 --- /dev/null +++ b/shared/http/jwt_test.go @@ -0,0 +1,49 @@ +package http + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestJWTProtect(t *testing.T) { + tests := []struct { + name string + cookie *http.Cookie + keyURL string + expectedStatusCode int + }{ + { + "no cookie", + nil, + "http://localhost:9999", + http.StatusForbidden, + }, + { + "bad url", + &http.Cookie{ + Name: "auth", + Value: "irrelevantgibberish", + }, + "http://localhost:9999", + http.StatusInternalServerError, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + wrappedHandler := JWTProtect("http://localhost:9999", "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + if test.cookie != nil { + r.AddCookie(test.cookie) + } + wrappedHandler.ServeHTTP(w, r) + if w.Code != test.expectedStatusCode { + t.Errorf("Unexpected status code %v", w.Code) + } + }) + } +}