2019-07-07 13:21:20 +02:00
|
|
|
package http
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"crypto/x509"
|
|
|
|
"encoding/json"
|
|
|
|
"encoding/pem"
|
|
|
|
"errors"
|
2019-07-09 14:50:31 +02:00
|
|
|
"fmt"
|
2019-07-07 13:21:20 +02:00
|
|
|
"net/http"
|
|
|
|
"strings"
|
2019-08-16 09:28:14 +02:00
|
|
|
"time"
|
2019-07-07 13:21:20 +02:00
|
|
|
|
|
|
|
"github.com/lestrrat-go/jwx/jwa"
|
|
|
|
"github.com/lestrrat-go/jwx/jwt"
|
|
|
|
)
|
|
|
|
|
|
|
|
type contextKey string
|
|
|
|
|
|
|
|
// ClaimsContextKey will be used to attach a JWT claim to a request context
|
|
|
|
const ClaimsContextKey contextKey = "claims"
|
|
|
|
|
|
|
|
// JWTProtect uses the public key located at the given URL to check if the
|
|
|
|
// cookie value is signed properly. In case yes, the JWT claims will be added
|
|
|
|
// to the request context
|
2019-08-16 09:28:14 +02:00
|
|
|
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error, cache Cache) func(http.Handler) http.Handler {
|
2019-07-07 13:21:20 +02:00
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
2019-07-10 17:17:50 +02:00
|
|
|
var jwtValue string
|
|
|
|
if authCookie, err := r.Cookie(cookieName); err == nil {
|
|
|
|
jwtValue = authCookie.Value
|
|
|
|
} else {
|
2019-07-13 22:04:12 +02:00
|
|
|
if header := r.Header.Get(headerName); header != "" {
|
2019-07-10 17:17:50 +02:00
|
|
|
jwtValue = header
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if jwtValue == "" {
|
|
|
|
RespondWithJSONError(w, errors.New("jwt: could not infer JWT value from cookie or header"), http.StatusForbidden)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
2019-08-16 09:28:14 +02:00
|
|
|
var keys [][]byte
|
|
|
|
if cache != nil {
|
|
|
|
lookup, lookupErr := cache.Get()
|
|
|
|
if lookupErr == nil {
|
|
|
|
keys = lookup
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if keys == nil {
|
|
|
|
var keysErr error
|
|
|
|
keys, keysErr = fetchKeys(keyURL)
|
|
|
|
if keysErr != nil {
|
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if cache != nil {
|
|
|
|
cache.Set(keys)
|
|
|
|
}
|
2019-07-07 13:21:20 +02:00
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
2019-07-19 14:49:35 +02:00
|
|
|
var token *jwt.Token
|
|
|
|
var tokenErr error
|
|
|
|
// the response can contain multiple keys to try as some of them
|
|
|
|
// might have been retired with signed tokens still in use until
|
|
|
|
// their expiry
|
|
|
|
for _, key := range keys {
|
|
|
|
token, tokenErr = tryParse(key, jwtValue)
|
|
|
|
if tokenErr == nil {
|
|
|
|
break
|
|
|
|
}
|
2019-07-07 13:21:20 +02:00
|
|
|
}
|
|
|
|
|
2019-07-19 14:49:35 +02:00
|
|
|
if tokenErr != nil {
|
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token signature: %v", tokenErr), http.StatusForbidden)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil {
|
2019-07-19 14:49:35 +02:00
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token claims: %v", err), http.StatusForbidden)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2019-07-13 22:04:12 +02:00
|
|
|
privKey, _ := token.Get("priv")
|
|
|
|
claims, _ := privKey.(map[string]interface{})
|
|
|
|
if err := authorizer(r, claims); err != nil {
|
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
|
|
|
|
return
|
2019-07-10 17:17:50 +02:00
|
|
|
}
|
2019-07-13 22:04:12 +02:00
|
|
|
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
|
2019-07-07 13:21:20 +02:00
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
2019-07-19 14:49:35 +02:00
|
|
|
func tryParse(key []byte, tokenValue string) (*jwt.Token, error) {
|
|
|
|
keyBytes, _ := pem.Decode([]byte(key))
|
|
|
|
if keyBytes == nil {
|
|
|
|
return nil, errors.New("jwt: no PEM block found in given key")
|
|
|
|
}
|
|
|
|
|
2019-07-20 16:27:56 +02:00
|
|
|
pubKey, parseErr := x509.ParsePKCS1PublicKey(keyBytes.Bytes)
|
2019-07-19 14:49:35 +02:00
|
|
|
if parseErr != nil {
|
|
|
|
return nil, fmt.Errorf("jwt: error parsing key: %v", parseErr)
|
|
|
|
}
|
|
|
|
|
|
|
|
token, jwtErr := jwt.ParseVerify(strings.NewReader(tokenValue), jwa.RS256, pubKey)
|
|
|
|
if jwtErr != nil {
|
|
|
|
return nil, fmt.Errorf("jwt: error parsing token: %v", jwtErr)
|
|
|
|
}
|
|
|
|
return token, nil
|
|
|
|
}
|
|
|
|
|
2019-07-08 15:11:06 +02:00
|
|
|
type keyResponse struct {
|
2019-07-19 14:49:35 +02:00
|
|
|
Keys []string `json:"keys"`
|
2019-07-08 15:11:06 +02:00
|
|
|
}
|
|
|
|
|
2019-07-19 14:49:35 +02:00
|
|
|
func fetchKeys(keyURL string) ([][]byte, error) {
|
2019-07-08 15:11:06 +02:00
|
|
|
fetchRes, fetchErr := http.Get(keyURL)
|
|
|
|
if fetchErr != nil {
|
|
|
|
return nil, fetchErr
|
|
|
|
}
|
|
|
|
defer fetchRes.Body.Close()
|
|
|
|
payload := keyResponse{}
|
|
|
|
if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2019-07-19 14:49:35 +02:00
|
|
|
|
|
|
|
asBytes := [][]byte{}
|
|
|
|
for _, key := range payload.Keys {
|
|
|
|
asBytes = append(asBytes, []byte(key))
|
|
|
|
}
|
|
|
|
return asBytes, nil
|
2019-07-08 15:11:06 +02:00
|
|
|
}
|
2019-08-16 09:28:14 +02:00
|
|
|
|
|
|
|
// Cache can be implemented by consumers in order to define how requests
|
|
|
|
// for public keys are being cached. For most use cases, the default cache
|
|
|
|
// supplied by this package will suffice.
|
|
|
|
type Cache interface {
|
|
|
|
Get() ([][]byte, error)
|
|
|
|
Set([][]byte)
|
|
|
|
}
|
|
|
|
|
|
|
|
type defaultCache struct {
|
|
|
|
value *[][]byte
|
|
|
|
expires time.Duration
|
|
|
|
deadline time.Time
|
|
|
|
}
|
|
|
|
|
|
|
|
// DefaultCacheExpiry should be used by cache instantiations without
|
|
|
|
// any particular requirements.
|
|
|
|
const DefaultCacheExpiry = time.Minute * 15
|
|
|
|
|
|
|
|
// ErrNoCache is returned on a cache lookup that did not yield a result
|
|
|
|
var ErrNoCache = errors.New("nothing found in cache")
|
|
|
|
|
|
|
|
func (c *defaultCache) Get() ([][]byte, error) {
|
|
|
|
if c.value != nil && time.Now().Before(c.deadline) {
|
|
|
|
return *c.value, nil
|
|
|
|
}
|
|
|
|
return nil, ErrNoCache
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *defaultCache) Set(value [][]byte) {
|
|
|
|
c.deadline = time.Now().Add(c.expires)
|
|
|
|
c.value = &value
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewDefaultKeyCache creates a simple cache that will hold a single
|
|
|
|
// value for the given expiration time
|
|
|
|
func NewDefaultKeyCache(expires time.Duration) Cache {
|
|
|
|
return &defaultCache{
|
|
|
|
expires: expires,
|
|
|
|
}
|
|
|
|
}
|