2019-07-07 13:21:20 +02:00
|
|
|
package http
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"crypto/rsa"
|
|
|
|
"crypto/x509"
|
|
|
|
"encoding/json"
|
|
|
|
"encoding/pem"
|
|
|
|
"errors"
|
2019-07-09 14:50:31 +02:00
|
|
|
"fmt"
|
2019-07-07 13:21:20 +02:00
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/lestrrat-go/jwx/jwa"
|
|
|
|
"github.com/lestrrat-go/jwx/jwt"
|
|
|
|
)
|
|
|
|
|
|
|
|
type contextKey string
|
|
|
|
|
|
|
|
// ClaimsContextKey will be used to attach a JWT claim to a request context
|
|
|
|
const ClaimsContextKey contextKey = "claims"
|
|
|
|
|
|
|
|
// JWTProtect uses the public key located at the given URL to check if the
|
|
|
|
// cookie value is signed properly. In case yes, the JWT claims will be added
|
|
|
|
// to the request context
|
|
|
|
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
authCookie, err := r.Cookie(cookieName)
|
|
|
|
if err != nil {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error reading cookie: %s", err), http.StatusForbidden)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
|
|
|
keyRes, keyErr := fetchKey(keyURL)
|
2019-07-07 13:21:20 +02:00
|
|
|
if keyErr != nil {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching key: %v", keyErr), http.StatusInternalServerError)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
|
|
|
keyBytes, _ := pem.Decode([]byte(keyRes))
|
2019-07-07 13:21:20 +02:00
|
|
|
if keyBytes == nil {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, errors.New("jwt: no PEM block found in given key"), http.StatusInternalServerError)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
2019-07-07 13:21:20 +02:00
|
|
|
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes)
|
|
|
|
if parseErr != nil {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing key: %v", parseErr), http.StatusBadGateway)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
2019-07-07 13:21:20 +02:00
|
|
|
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
|
|
|
|
if !pubKeyOk {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, errors.New("jwt: given key is not of type RSA public key"), http.StatusInternalServerError)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
token, jwtErr := jwt.ParseVerify(strings.NewReader(authCookie.Value), jwa.RS256, pubKey)
|
|
|
|
if jwtErr != nil {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil {
|
2019-07-09 14:50:31 +02:00
|
|
|
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token: %v", err), http.StatusForbidden)
|
2019-07-07 13:21:20 +02:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
privateClaims, _ := token.Get("priv")
|
|
|
|
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims))
|
|
|
|
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2019-07-08 15:11:06 +02:00
|
|
|
|
|
|
|
type keyResponse struct {
|
|
|
|
Key string `json:"key"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func fetchKey(keyURL string) ([]byte, error) {
|
|
|
|
fetchRes, fetchErr := http.Get(keyURL)
|
|
|
|
if fetchErr != nil {
|
|
|
|
return nil, fetchErr
|
|
|
|
}
|
|
|
|
defer fetchRes.Body.Close()
|
|
|
|
payload := keyResponse{}
|
|
|
|
if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return []byte(payload.Key), nil
|
|
|
|
}
|