diff --git a/shared/http/jwt.go b/shared/http/jwt.go index 986fb6c..bb2dfdd 100644 --- a/shared/http/jwt.go +++ b/shared/http/jwt.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" @@ -22,7 +23,7 @@ const ClaimsContextKey contextKey = "claims" // JWTProtect uses the public key located at the given URL to check if the // cookie value is signed properly. In case yes, the JWT claims will be added // to the request context -func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler { +func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error, cache Cache) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var jwtValue string @@ -38,10 +39,23 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req return } - keys, keysErr := fetchKeys(keyURL) - if keysErr != nil { - RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError) - return + var keys [][]byte + if cache != nil { + lookup, lookupErr := cache.Get() + if lookupErr == nil { + keys = lookup + } + } + if keys == nil { + var keysErr error + keys, keysErr = fetchKeys(keyURL) + if keysErr != nil { + RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError) + return + } + if cache != nil { + cache.Set(keys) + } } var token *jwt.Token @@ -117,3 +131,44 @@ func fetchKeys(keyURL string) ([][]byte, error) { } return asBytes, nil } + +// Cache can be implemented by consumers in order to define how requests +// for public keys are being cached. For most use cases, the default cache +// supplied by this package will suffice. +type Cache interface { + Get() ([][]byte, error) + Set([][]byte) +} + +type defaultCache struct { + value *[][]byte + expires time.Duration + deadline time.Time +} + +// DefaultCacheExpiry should be used by cache instantiations without +// any particular requirements. +const DefaultCacheExpiry = time.Minute * 15 + +// ErrNoCache is returned on a cache lookup that did not yield a result +var ErrNoCache = errors.New("nothing found in cache") + +func (c *defaultCache) Get() ([][]byte, error) { + if c.value != nil && time.Now().Before(c.deadline) { + return *c.value, nil + } + return nil, ErrNoCache +} + +func (c *defaultCache) Set(value [][]byte) { + c.deadline = time.Now().Add(c.expires) + c.value = &value +} + +// NewDefaultKeyCache creates a simple cache that will hold a single +// value for the given expiration time +func NewDefaultKeyCache(expires time.Duration) Cache { + return &defaultCache{ + expires: expires, + } +} diff --git a/shared/http/jwt_test.go b/shared/http/jwt_test.go index fe62b4e..4826efd 100644 --- a/shared/http/jwt_test.go +++ b/shared/http/jwt_test.go @@ -257,7 +257,7 @@ func TestJWTProtect(t *testing.T) { if test.server != nil { url = test.server.URL } - wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })) w := httptest.NewRecorder()