diff --git a/shared/http/jwt.go b/shared/http/jwt.go index ee74170..686d682 100644 --- a/shared/http/jwt.go +++ b/shared/http/jwt.go @@ -7,6 +7,7 @@ import ( "encoding/json" "encoding/pem" "errors" + "fmt" "net/http" "strings" @@ -27,42 +28,42 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authCookie, err := r.Cookie(cookieName) if err != nil { - RespondWithJSONError(w, err, http.StatusForbidden) + RespondWithJSONError(w, fmt.Errorf("jwt: error reading cookie: %s", err), http.StatusForbidden) return } keyRes, keyErr := fetchKey(keyURL) if keyErr != nil { - RespondWithJSONError(w, keyErr, http.StatusInternalServerError) + RespondWithJSONError(w, fmt.Errorf("jwt: error fetching key: %v", keyErr), http.StatusInternalServerError) return } keyBytes, _ := pem.Decode([]byte(keyRes)) if keyBytes == nil { - RespondWithJSONError(w, errors.New("no pem block found"), http.StatusInternalServerError) + RespondWithJSONError(w, errors.New("jwt: no PEM block found in given key"), http.StatusInternalServerError) return } parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes) if parseErr != nil { - RespondWithJSONError(w, parseErr, http.StatusBadGateway) + RespondWithJSONError(w, fmt.Errorf("jwt: error parsing key: %v", parseErr), http.StatusBadGateway) return } pubKey, pubKeyOk := parseResult.(*rsa.PublicKey) if !pubKeyOk { - RespondWithJSONError(w, errors.New("unable to use given key"), http.StatusInternalServerError) + RespondWithJSONError(w, errors.New("jwt: given key is not of type RSA public key"), http.StatusInternalServerError) return } token, jwtErr := jwt.ParseVerify(strings.NewReader(authCookie.Value), jwa.RS256, pubKey) if jwtErr != nil { - RespondWithJSONError(w, jwtErr, http.StatusForbidden) + RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden) return } if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil { - RespondWithJSONError(w, err, http.StatusForbidden) + RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token: %v", err), http.StatusForbidden) return } diff --git a/shared/http/jwt_test.go b/shared/http/jwt_test.go index 80f0918..4f3c2f7 100644 --- a/shared/http/jwt_test.go +++ b/shared/http/jwt_test.go @@ -1,38 +1,180 @@ package http import ( + "crypto/x509" + "encoding/pem" + "fmt" "net/http" "net/http/httptest" + "strings" "testing" + "time" + + "github.com/lestrrat-go/jwx/jwa" + + "github.com/lestrrat-go/jwx/jwt" ) +const publicKey = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi +Hk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5 +7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6 +QQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo +N18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F +42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ +mQIDAQAB +-----END PUBLIC KEY----- +` + +const privateKey = ` +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDbJR8cfpJFgq8F +N6at5+IeTgviqRxziySXhAJoc6R+ADhVwYobadfKxi5krsDM05wdBWL57gHA34F8 +hm3ARTns6By1vl6DySzHEuNLt5JF9Ai88Z8BsnXZH06g5Zy578J1TL9gx5zhH+jj +61njlLpBBYgyEupiWPD6zVG6d1rcrxDyU50njIRlyRdpmvXBo0mKQLugalPsKUCB +1u9Yhug3XwbU6W5Gf7dY11w3pX68iPxz10az3ZuBM9fhyfdGYH/zQEN/N4WI8xHD +nFqrroXjbNfOWWROq6RA+TbCY7mb2kXLRyJG975rB+BtO1lFk4SlNJYttuk593hF +DlioNYmZAgMBAAECggEADvr6pXgBh77nN/QV8M1pJ6kuJtBooX1hgvoDMCC3neVl +9HbGehlCJxplEXzgsR/GDDXSDkO22vhsYZbO6dXRn+A+Fi5tR5T4+qLP5t0loqKL +9l6OAA+y/qSlO1p23D8Hi/0zF+qNTtZflTUBcA06rjcymDmyzAZIctyWOajvDSbK +Df0ZvKYPnwG5gjF01hPS2VJicv/O0HXLN7elq/jio1dwvLa2JjPyXhWBkHqnJBcq +ncWP9IEJQmhQ8ijNEg78uLtiNZQ4+GcXNBlwhM7JER6X/AxSxEZ/7fjZog685yUH +3iF820SnStOJQQci/RMMPOsK6cM7BiJxGp2W12EOAQKBgQD85UdCDro6zpblpAw7 +Gw82SkWGksJXuGlTX+nj3/3iIiEb4ATCvZufYXALGNtiG0tPHDMBQCKLYrbLE1pt +9uIU/IbDFPeQk8rR/b7IHu0gv3463p6r7WVhzY2/JCororKYQk4zbuk3cNYtlV76 +ojnNY1EFDLK/1nGT6QDxDA7Z5wKBgQDd1chB2qlbljRzYFwKrWXZ0COtbnEGPnUz +rLvSlAvYlZSKuB/vXkHGepxdlAjDGgX6xkKSl1TKb8UWQ9JSv0MPGBcMPukuwCAL +BOobyvd1mln6f/C7FrATkRbrG+r8RAQTwR+eknwYYOPAS/PpXm8gZvVntiahihFd +NqQtud8QfwKBgQDGV+xzWqmkxbKDmQ4erTJZGhc9XI0fz3qL8YW3O04btTjSa/hP +4/XSItGFYpFteIqwGSXHrU1qlJlY3GzoIeFfJE9tYVxpAADqgWDIA7lnHcka0s8P +eLky48xwRSTt5ES+NgKvRCWVXeIdDjHX0LQU6ff5ReRLoRyjLPOYGiTrsQKBgAmq +z1dPWCINoauFf31XoSCk2Wktbu9+uUzPMkAzA3Ek05xX+cxMp0EnBrltQhR+hdQv +36bTwXYw+L3HptrESv/VZOu7sh2/caYJSMp9RdtyJomsGamNi47Ou9jzFoJ31FWo +DOC0MYQ+dK5koPSCkQUwd3FVlsljYu5U+0Ki3v2xAoGASIMhNHOvz+Ay2otovVFN +gfRGTnepw8znHbkr10IG97BWd4VbFnHRdpYbtk8fH0UOyUVMrcY0B2/d73Rzqze3 +iZ//FXIDTtmKnVS/ZhC2w0AH8Piziy3NW3G6jRZN6+9NpOf/BIc4pfzgUJ3RqHz/ +IeONX+52k6gz1SCjPgSUlTs= +-----END PRIVATE KEY----- +` + func TestJWTProtect(t *testing.T) { tests := []struct { name string cookie *http.Cookie - keyURL string + server *httptest.Server expectedStatusCode int }{ { "no cookie", nil, - "http://localhost:9999", + nil, http.StatusForbidden, }, { - "bad url", + "bad server", &http.Cookie{ Name: "auth", Value: "irrelevantgibberish", }, - "http://localhost:9999", + nil, http.StatusInternalServerError, }, + { + "non-json response", + &http.Cookie{ + Name: "auth", + Value: "irrelevantgibberish", + }, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("here's some bytes 4 y'all")) + })), + http.StatusInternalServerError, + }, + { + "bad key value", + &http.Cookie{ + Name: "auth", + Value: "irrelevantgibberish", + }, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"key":"not really a key"}`)) + })), + http.StatusInternalServerError, + }, + { + "invalid key", + &http.Cookie{ + Name: "auth", + Value: "irrelevantgibberish", + }, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`)) + })), + http.StatusBadGateway, + }, + { + "valid key, bad auth token", + &http.Cookie{ + Name: "auth", + Value: "irrelevantgibberish", + }, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte( + fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), + )) + })), + http.StatusForbidden, + }, + { + "valid key, valid token", + &http.Cookie{ + Name: "auth", + Value: (func() string { + token := jwt.New() + token.Set("exp", time.Now().Add(time.Hour)) + keyBytes, _ := pem.Decode([]byte(privateKey)) + privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes) + b, _ := token.Sign(jwa.RS256, privKey) + return string(b) + })(), + }, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte( + fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), + )) + })), + http.StatusOK, + }, + { + "valid key, expired token", + &http.Cookie{ + Name: "auth", + Value: (func() string { + token := jwt.New() + token.Set("exp", time.Now().Add(-time.Hour)) + keyBytes, _ := pem.Decode([]byte(privateKey)) + privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes) + b, _ := token.Sign(jwa.RS256, privKey) + return string(b) + })(), + }, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte( + fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), + )) + })), + http.StatusForbidden, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - wrappedHandler := JWTProtect("http://localhost:9999", "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var url string + if test.server != nil { + url = test.server.URL + } + wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })) w := httptest.NewRecorder() diff --git a/shared/http/middleware.go b/shared/http/middleware.go index 30c9e79..d1dc3ee 100644 --- a/shared/http/middleware.go +++ b/shared/http/middleware.go @@ -52,7 +52,7 @@ func UserCookieMiddleware(cookieKey string, contextKey interface{}) func(http.Ha return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := r.Cookie(cookieKey) if err != nil { - RespondWithJSONError(w, errors.New("received no or blank user identifier"), http.StatusBadRequest) + RespondWithJSONError(w, errors.New("user cookie: received no or blank identifier"), http.StatusBadRequest) return } r = r.WithContext( diff --git a/shared/http/middleware_test.go b/shared/http/middleware_test.go index 13654b1..633d44b 100644 --- a/shared/http/middleware_test.go +++ b/shared/http/middleware_test.go @@ -83,7 +83,7 @@ func TestUserCookieMiddleware(t *testing.T) { if w.Code != http.StatusBadRequest { t.Errorf("Unexpected status code %v", w.Code) } - if !strings.Contains(w.Body.String(), "received no or blank user identifier") { + if !strings.Contains(w.Body.String(), "received no or blank identifier") { t.Errorf("Unexpected body %s", w.Body.String()) } }) @@ -99,7 +99,7 @@ func TestUserCookieMiddleware(t *testing.T) { if w.Code != http.StatusBadRequest { t.Errorf("Unexpected status code %v", w.Code) } - if !strings.Contains(w.Body.String(), "received no or blank user identifier") { + if !strings.Contains(w.Body.String(), "received no or blank identifier") { t.Errorf("Unexpected body %s", w.Body.String()) } })