mirror of
https://github.com/offen/website.git
synced 2024-11-22 17:10:29 +01:00
allow checking JWTs against multiple public keys
This commit is contained in:
parent
9d98cb63b6
commit
f678a9aa56
@ -3,10 +3,13 @@ from os import environ
|
|||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from flask_admin import Admin
|
from flask_admin import Admin
|
||||||
|
from werkzeug.utils import import_string
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config.from_object(environ.get("CONFIG_CLASS"))
|
|
||||||
|
cfg = import_string(environ.get("CONFIG_CLASS"))()
|
||||||
|
app.config.from_object(cfg)
|
||||||
|
|
||||||
db = SQLAlchemy(app)
|
db = SQLAlchemy(app)
|
||||||
|
|
||||||
|
@ -83,10 +83,19 @@ def post_login():
|
|||||||
@json_error
|
@json_error
|
||||||
def get_login():
|
def get_login():
|
||||||
auth_cookie = request.cookies.get(COOKIE_KEY)
|
auth_cookie = request.cookies.get(COOKIE_KEY)
|
||||||
|
public_keys = app.config["JWT_PUBLIC_KEYS"]
|
||||||
|
|
||||||
|
token = None
|
||||||
|
token_err = None
|
||||||
|
for public_key in public_keys:
|
||||||
try:
|
try:
|
||||||
token = jwt.decode(auth_cookie, app.config["JWT_PUBLIC_KEY"])
|
token = jwt.decode(auth_cookie, public_key)
|
||||||
except jwt.exceptions.PyJWTError as unauthorized_error:
|
break
|
||||||
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
except Exception as decode_err:
|
||||||
|
token_err = decode_err
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return jsonify({"error": str(token_err), "status": 401}), 401
|
||||||
|
|
||||||
try:
|
try:
|
||||||
match = User.query.get(token["priv"]["userId"])
|
match = User.query.get(token["priv"]["userId"])
|
||||||
@ -117,5 +126,4 @@ def key():
|
|||||||
This route is not supposed to be called by client-side applications, so
|
This route is not supposed to be called by client-side applications, so
|
||||||
no CORS configuration is added
|
no CORS configuration is added
|
||||||
"""
|
"""
|
||||||
public_key = app.config["JWT_PUBLIC_KEY"].strip()
|
return jsonify({"keys": app.config["JWT_PUBLIC_KEYS"]})
|
||||||
return jsonify({"key": public_key})
|
|
||||||
|
@ -1,14 +1,68 @@
|
|||||||
|
import json
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(object):
|
class BaseConfig(object):
|
||||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||||
FLASK_ADMIN_SWATCH = "flatly"
|
FLASK_ADMIN_SWATCH = "flatly"
|
||||||
|
|
||||||
|
|
||||||
class EnvConfig(BaseConfig):
|
class EnvConfig(BaseConfig):
|
||||||
SECRET_KEY = environ.get("SESSION_SECRET")
|
SECRET_KEY = environ.get("SESSION_SECRET")
|
||||||
SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING")
|
SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING")
|
||||||
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
|
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
|
||||||
JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "")
|
JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "")
|
||||||
JWT_PUBLIC_KEY = environ.get("JWT_PUBLIC_KEY", "")
|
JWT_PUBLIC_KEYS = [environ.get("JWT_PUBLIC_KEY", "")]
|
||||||
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
|
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
|
||||||
SERVER_HOST = environ.get("SERVER_HOST")
|
SERVER_HOST = environ.get("SERVER_HOST")
|
||||||
|
|
||||||
|
|
||||||
|
class SecretsManagerConfig(BaseConfig):
|
||||||
|
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
|
||||||
|
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
|
||||||
|
SERVER_HOST = environ.get("SERVER_HOST")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
session = boto3.session.Session()
|
||||||
|
self.client = session.client(
|
||||||
|
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.SECRET_KEY = self.get_secret("sessionSecret")
|
||||||
|
self.SQLALCHEMY_DATABASE_URI = self.get_secret("mysqlConnectionString")
|
||||||
|
|
||||||
|
current_version = self.get_secret("jwtKeyPair")
|
||||||
|
key_pair = json.loads(current_version)
|
||||||
|
previous_version = self.get_secret("jwtKeyPair", previous=True)
|
||||||
|
previous_key_pair = (
|
||||||
|
json.loads(previous_version)
|
||||||
|
if previous_version is not None
|
||||||
|
else {"public": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.JWT_PRIVATE_KEY = key_pair["private"]
|
||||||
|
self.JWT_PUBLIC_KEYS = [
|
||||||
|
k for k in [key_pair["public"], previous_key_pair["public"]] if k
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_secret(self, secret_name, previous=False):
|
||||||
|
import base64
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
try:
|
||||||
|
ssm_response = self.client.get_secret_value(
|
||||||
|
SecretId="{}/accounts/{}".format(environ.get("STAGE"), secret_name),
|
||||||
|
VersionStage=("AWSPREVIOUS" if previous else "AWSCURRENT"),
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response["Error"]["Code"] == "ResourceNotFoundException" and previous:
|
||||||
|
# A secret might not have a previous version yet. It is left
|
||||||
|
# up to the caller to handle the None return in this case
|
||||||
|
return None
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if "SecretString" in ssm_response:
|
||||||
|
return ssm_response["SecretString"]
|
||||||
|
return base64.b64decode(ssm_response["SecretBinary"])
|
||||||
|
0
accounts/lambdas/__init__.py
Normal file
0
accounts/lambdas/__init__.py
Normal file
@ -1,9 +1,20 @@
|
|||||||
import base64
|
import base64
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
from passlib.hash import bcrypt
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
|
|
||||||
|
session = boto3.session.Session()
|
||||||
|
boto_client = session.client(
|
||||||
|
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
|
||||||
|
)
|
||||||
|
|
||||||
|
basic_auth_user = get_secret(boto_client, "basicAuthUser")
|
||||||
|
hashed_basic_auth_password = get_secret(boto_client, "hashedBasicAuthPassword")
|
||||||
|
|
||||||
|
|
||||||
def build_api_arn(method_arn):
|
def build_api_arn(method_arn):
|
||||||
arn_chunks = method_arn.split(":")
|
arn_chunks = method_arn.split(":")
|
||||||
aws_region = arn_chunks[3]
|
aws_region = arn_chunks[3]
|
||||||
@ -38,6 +49,15 @@ def build_response(api_arn, allow):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_secret(secret_name):
|
||||||
|
ssm_response = boto_client.get_secret_value(
|
||||||
|
SecretId="{}/accounts/{}".format(environ.get("STAGE"), secret_name)
|
||||||
|
)
|
||||||
|
if "SecretString" in ssm_response:
|
||||||
|
return ssm_response["SecretString"]
|
||||||
|
return base64.b64decode(ssm_response["SecretBinary"])
|
||||||
|
|
||||||
|
|
||||||
def handler(event, context):
|
def handler(event, context):
|
||||||
api_arn = build_api_arn(event["methodArn"])
|
api_arn = build_api_arn(event["methodArn"])
|
||||||
|
|
||||||
@ -50,11 +70,10 @@ def handler(event, context):
|
|||||||
user = credentials[0]
|
user = credentials[0]
|
||||||
password = credentials[1]
|
password = credentials[1]
|
||||||
|
|
||||||
if user != environ.get("BASIC_AUTH_USER"):
|
if user != basic_auth_user:
|
||||||
return build_response(api_arn, False)
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
|
if not bcrypt.verify(password, hashed_basic_auth_password):
|
||||||
if not bcrypt.verify(password, hashed_password):
|
|
||||||
return build_response(api_arn, False)
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
return build_response(api_arn, True)
|
return build_response(api_arn, True)
|
143
accounts/lambdas/rotate_keys.py
Normal file
143
accounts/lambdas/rotate_keys.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import io
|
||||||
|
import boto3
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def handler(event, context):
|
||||||
|
arn = event["SecretId"]
|
||||||
|
token = event["ClientRequestToken"]
|
||||||
|
step = event["Step"]
|
||||||
|
|
||||||
|
session = boto3.session.Session()
|
||||||
|
service_client = session.client(
|
||||||
|
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure the version is staged correctly
|
||||||
|
metadata = service_client.describe_secret(SecretId=arn)
|
||||||
|
if not metadata["RotationEnabled"]:
|
||||||
|
logger.error("Secret %s is not enabled for rotation" % arn)
|
||||||
|
raise ValueError("Secret %s is not enabled for rotation" % arn)
|
||||||
|
versions = metadata["VersionIdsToStages"]
|
||||||
|
if token not in versions:
|
||||||
|
logger.error(
|
||||||
|
"Secret version %s has no stage for rotation of secret %s.", token, arn
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Secret version %s has no stage for rotation of secret %s." % (token, arn)
|
||||||
|
)
|
||||||
|
if "AWSCURRENT" in versions[token]:
|
||||||
|
logger.info(
|
||||||
|
"Secret version %s already set as AWSCURRENT for secret %s.", token, arn
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif "AWSPENDING" not in versions[token]:
|
||||||
|
logger.error(
|
||||||
|
"Secret version %s not set as AWSPENDING for rotation of secret %s.",
|
||||||
|
token,
|
||||||
|
arn,
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Secret version %s not set as AWSPENDING for rotation of secret %s."
|
||||||
|
% (token, arn)
|
||||||
|
)
|
||||||
|
|
||||||
|
if step == "createSecret":
|
||||||
|
create_secret(service_client, arn, token)
|
||||||
|
|
||||||
|
elif step == "setSecret":
|
||||||
|
set_secret(service_client, arn, token)
|
||||||
|
|
||||||
|
elif step == "testSecret":
|
||||||
|
test_secret(service_client, arn, token)
|
||||||
|
|
||||||
|
elif step == "finishSecret":
|
||||||
|
finish_secret(service_client, arn, token)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid step parameter")
|
||||||
|
|
||||||
|
|
||||||
|
def create_key_pair(**kwargs):
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
backend=default_backend(), public_exponent=65537, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
public_key = key.public_key().public_bytes(
|
||||||
|
serialization.Encoding.PEM, serialization.PublicFormat.PKCS1
|
||||||
|
)
|
||||||
|
|
||||||
|
pem = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"private": pem.decode(), "public": public_key.decode()}
|
||||||
|
|
||||||
|
|
||||||
|
def create_secret(service_client, arn, token):
|
||||||
|
service_client.get_secret_value(SecretId=arn, VersionStage="AWSCURRENT")
|
||||||
|
try:
|
||||||
|
service_client.get_secret_value(
|
||||||
|
SecretId=arn, VersionId=token, VersionStage="AWSPENDING"
|
||||||
|
)
|
||||||
|
logger.info("createSecret: Successfully retrieved secret for %s." % arn)
|
||||||
|
except service_client.exceptions.ResourceNotFoundException:
|
||||||
|
secret = create_key_pair(key_size=2048)
|
||||||
|
service_client.put_secret_value(
|
||||||
|
SecretId=arn,
|
||||||
|
ClientRequestToken=token,
|
||||||
|
SecretString=json.dumps(secret).encode().decode("unicode_escape"),
|
||||||
|
VersionStages=["AWSPENDING"],
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"createSecret: Successfully put secret for ARN %s and version %s."
|
||||||
|
% (arn, token)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_secret(service_client, arn, token):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_secret(service_client, arn, token):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def finish_secret(service_client, arn, token):
|
||||||
|
metadata = service_client.describe_secret(SecretId=arn)
|
||||||
|
current_version = None
|
||||||
|
for version in metadata["VersionIdsToStages"]:
|
||||||
|
if "AWSCURRENT" in metadata["VersionIdsToStages"][version]:
|
||||||
|
if version == token:
|
||||||
|
# The correct version is already marked as current, return
|
||||||
|
logger.info(
|
||||||
|
"finishSecret: Version %s already marked as AWSCURRENT for %s",
|
||||||
|
version,
|
||||||
|
arn,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
current_version = version
|
||||||
|
break
|
||||||
|
|
||||||
|
# Finalize by staging the secret version current
|
||||||
|
service_client.update_secret_version_stage(
|
||||||
|
SecretId=arn,
|
||||||
|
VersionStage="AWSCURRENT",
|
||||||
|
MoveToVersionId=token,
|
||||||
|
RemoveFromVersionId=current_version,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"finishSecret: Successfully set AWSCURRENT stage to version %s for secret %s."
|
||||||
|
% (token, arn)
|
||||||
|
)
|
@ -9,3 +9,4 @@ bcrypt==3.1.7
|
|||||||
PyMySQL==0.9.3
|
PyMySQL==0.9.3
|
||||||
mysqlclient==1.4.2.post1
|
mysqlclient==1.4.2.post1
|
||||||
requests==2.22.0
|
requests==2.22.0
|
||||||
|
cryptography==2.7
|
||||||
|
@ -23,9 +23,9 @@ plugins:
|
|||||||
custom:
|
custom:
|
||||||
stage: ${opt:stage, self:provider.stage}
|
stage: ${opt:stage, self:provider.stage}
|
||||||
origin:
|
origin:
|
||||||
production: vault.offen.dev
|
production: https://vault.offen.dev
|
||||||
staging: vault-staging.offen.dev
|
staging: https://vault-staging.offen.dev
|
||||||
alpha: vault-alpha.offen.dev
|
alpha: https://vault-alpha.offen.dev
|
||||||
serverHost:
|
serverHost:
|
||||||
production: https://server.offen.dev
|
production: https://server.offen.dev
|
||||||
staging: https://server-staging.offen.dev
|
staging: https://server-staging.offen.dev
|
||||||
@ -55,10 +55,13 @@ custom:
|
|||||||
|
|
||||||
functions:
|
functions:
|
||||||
authorizer:
|
authorizer:
|
||||||
handler: authorizer.handler
|
handler: lambdas.authorizer.handler
|
||||||
environment:
|
environment:
|
||||||
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
|
STAGE: ${self:custom.stage}
|
||||||
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
rotateKeys:
|
||||||
|
handler: lambdas.rotate_keys.handler
|
||||||
|
environment:
|
||||||
|
STAGE: ${self:custom.stage}
|
||||||
app:
|
app:
|
||||||
handler: wsgi_handler.handler
|
handler: wsgi_handler.handler
|
||||||
timeout: 30
|
timeout: 30
|
||||||
@ -84,16 +87,11 @@ functions:
|
|||||||
path: '/{proxy+}'
|
path: '/{proxy+}'
|
||||||
method: any
|
method: any
|
||||||
environment:
|
environment:
|
||||||
CONFIG_CLASS: accounts.config.EnvConfig
|
CONFIG_CLASS: accounts.config.SecretsManagerConfig
|
||||||
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
|
STAGE: ${self:custom.stage}
|
||||||
|
CORS_ORIGIN: ${self:custom.origin.${self:custom.stage}}
|
||||||
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
||||||
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}
|
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}
|
||||||
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
|
|
||||||
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
|
|
||||||
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
|
|
||||||
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
|
||||||
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
|
|
||||||
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
|
|
||||||
|
|
||||||
resources:
|
resources:
|
||||||
Resources:
|
Resources:
|
||||||
|
@ -55,7 +55,7 @@ class TestKey(unittest.TestCase):
|
|||||||
rv = self.app.get("/api/key")
|
rv = self.app.get("/api/key")
|
||||||
assert rv.status.startswith("200")
|
assert rv.status.startswith("200")
|
||||||
data = json.loads(rv.data)
|
data = json.loads(rv.data)
|
||||||
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
|
assert data["keys"] == [environ.get("JWT_PUBLIC_KEY")]
|
||||||
|
|
||||||
|
|
||||||
class TestJWT(unittest.TestCase):
|
class TestJWT(unittest.TestCase):
|
||||||
|
@ -39,38 +39,31 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keyRes, keyErr := fetchKey(keyURL)
|
keys, keysErr := fetchKeys(keyURL)
|
||||||
if keyErr != nil {
|
if keysErr != nil {
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching key: %v", keyErr), http.StatusInternalServerError)
|
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
keyBytes, _ := pem.Decode([]byte(keyRes))
|
var token *jwt.Token
|
||||||
if keyBytes == nil {
|
var tokenErr error
|
||||||
RespondWithJSONError(w, errors.New("jwt: no PEM block found in given key"), http.StatusInternalServerError)
|
// the response can contain multiple keys to try as some of them
|
||||||
return
|
// might have been retired with signed tokens still in use until
|
||||||
|
// their expiry
|
||||||
|
for _, key := range keys {
|
||||||
|
token, tokenErr = tryParse(key, jwtValue)
|
||||||
|
if tokenErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes)
|
if tokenErr != nil {
|
||||||
if parseErr != nil {
|
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token signature: %v", tokenErr), http.StatusForbidden)
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing key: %v", parseErr), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
|
|
||||||
if !pubKeyOk {
|
|
||||||
RespondWithJSONError(w, errors.New("jwt: given key is not of type RSA public key"), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
|
|
||||||
if jwtErr != nil {
|
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil {
|
if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil {
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token: %v", err), http.StatusForbidden)
|
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token claims: %v", err), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,11 +79,34 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type keyResponse struct {
|
func tryParse(key []byte, tokenValue string) (*jwt.Token, error) {
|
||||||
Key string `json:"key"`
|
keyBytes, _ := pem.Decode([]byte(key))
|
||||||
|
if keyBytes == nil {
|
||||||
|
return nil, errors.New("jwt: no PEM block found in given key")
|
||||||
|
}
|
||||||
|
|
||||||
|
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, fmt.Errorf("jwt: error parsing key: %v", parseErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
|
||||||
|
if !pubKeyOk {
|
||||||
|
return nil, errors.New("jwt: given key is not of type RSA public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, jwtErr := jwt.ParseVerify(strings.NewReader(tokenValue), jwa.RS256, pubKey)
|
||||||
|
if jwtErr != nil {
|
||||||
|
return nil, fmt.Errorf("jwt: error parsing token: %v", jwtErr)
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchKey(keyURL string) ([]byte, error) {
|
type keyResponse struct {
|
||||||
|
Keys []string `json:"keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchKeys(keyURL string) ([][]byte, error) {
|
||||||
fetchRes, fetchErr := http.Get(keyURL)
|
fetchRes, fetchErr := http.Get(keyURL)
|
||||||
if fetchErr != nil {
|
if fetchErr != nil {
|
||||||
return nil, fetchErr
|
return nil, fetchErr
|
||||||
@ -100,5 +116,10 @@ func fetchKey(keyURL string) ([]byte, error) {
|
|||||||
if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil {
|
if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return []byte(payload.Key), nil
|
|
||||||
|
asBytes := [][]byte{}
|
||||||
|
for _, key := range payload.Keys {
|
||||||
|
asBytes = append(asBytes, []byte(key))
|
||||||
|
}
|
||||||
|
return asBytes, nil
|
||||||
}
|
}
|
||||||
|
@ -108,10 +108,10 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"key":"not really a key"}`))
|
w.Write([]byte(`{"keys":["not really a key","me neither"]}`))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"invalid key",
|
"invalid key",
|
||||||
@ -121,10 +121,10 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
|
w.Write([]byte(`{"keys":["-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"]}`))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusBadGateway,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"valid key, bad auth token",
|
"valid key, bad auth token",
|
||||||
@ -135,7 +135,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
@ -157,7 +157,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
@ -178,7 +178,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
})(),
|
})(),
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
@ -194,7 +194,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
})(),
|
})(),
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
@ -217,7 +217,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error {
|
func(r *http.Request, claims map[string]interface{}) error {
|
||||||
@ -244,7 +244,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
|
Loading…
Reference in New Issue
Block a user