mirror of
https://github.com/offen/website.git
synced 2024-11-22 17:10:29 +01:00
add optional key cache to JWT middleware
This commit is contained in:
parent
8c04de556c
commit
eeba0b8b58
@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/jwa"
|
"github.com/lestrrat-go/jwx/jwa"
|
||||||
"github.com/lestrrat-go/jwx/jwt"
|
"github.com/lestrrat-go/jwx/jwt"
|
||||||
@ -22,7 +23,7 @@ const ClaimsContextKey contextKey = "claims"
|
|||||||
// JWTProtect uses the public key located at the given URL to check if the
|
// JWTProtect uses the public key located at the given URL to check if the
|
||||||
// cookie value is signed properly. In case yes, the JWT claims will be added
|
// cookie value is signed properly. In case yes, the JWT claims will be added
|
||||||
// to the request context
|
// to the request context
|
||||||
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
|
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error, cache Cache) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var jwtValue string
|
var jwtValue string
|
||||||
@ -38,10 +39,23 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, keysErr := fetchKeys(keyURL)
|
var keys [][]byte
|
||||||
if keysErr != nil {
|
if cache != nil {
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
|
lookup, lookupErr := cache.Get()
|
||||||
return
|
if lookupErr == nil {
|
||||||
|
keys = lookup
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keys == nil {
|
||||||
|
var keysErr error
|
||||||
|
keys, keysErr = fetchKeys(keyURL)
|
||||||
|
if keysErr != nil {
|
||||||
|
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cache != nil {
|
||||||
|
cache.Set(keys)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var token *jwt.Token
|
var token *jwt.Token
|
||||||
@ -117,3 +131,44 @@ func fetchKeys(keyURL string) ([][]byte, error) {
|
|||||||
}
|
}
|
||||||
return asBytes, nil
|
return asBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache can be implemented by consumers in order to define how requests
|
||||||
|
// for public keys are being cached. For most use cases, the default cache
|
||||||
|
// supplied by this package will suffice.
|
||||||
|
type Cache interface {
|
||||||
|
Get() ([][]byte, error)
|
||||||
|
Set([][]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultCache struct {
|
||||||
|
value *[][]byte
|
||||||
|
expires time.Duration
|
||||||
|
deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCacheExpiry should be used by cache instantiations without
|
||||||
|
// any particular requirements.
|
||||||
|
const DefaultCacheExpiry = time.Minute * 15
|
||||||
|
|
||||||
|
// ErrNoCache is returned on a cache lookup that did not yield a result
|
||||||
|
var ErrNoCache = errors.New("nothing found in cache")
|
||||||
|
|
||||||
|
func (c *defaultCache) Get() ([][]byte, error) {
|
||||||
|
if c.value != nil && time.Now().Before(c.deadline) {
|
||||||
|
return *c.value, nil
|
||||||
|
}
|
||||||
|
return nil, ErrNoCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *defaultCache) Set(value [][]byte) {
|
||||||
|
c.deadline = time.Now().Add(c.expires)
|
||||||
|
c.value = &value
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultKeyCache creates a simple cache that will hold a single
|
||||||
|
// value for the given expiration time
|
||||||
|
func NewDefaultKeyCache(expires time.Duration) Cache {
|
||||||
|
return &defaultCache{
|
||||||
|
expires: expires,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -257,7 +257,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
if test.server != nil {
|
if test.server != nil {
|
||||||
url = test.server.URL
|
url = test.server.URL
|
||||||
}
|
}
|
||||||
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer, nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("OK"))
|
w.Write([]byte("OK"))
|
||||||
}))
|
}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
Loading…
Reference in New Issue
Block a user