2
0
mirror of https://github.com/offen/website.git synced 2024-11-22 09:00:28 +01:00

allow checking JWTs against multiple public keys

This commit is contained in:
Frederik Ring 2019-07-19 14:49:35 +02:00
parent 9d98cb63b6
commit f678a9aa56
11 changed files with 310 additions and 63 deletions

View File

@ -3,10 +3,13 @@ from os import environ
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_admin import Admin
from werkzeug.utils import import_string
app = Flask(__name__)
app.config.from_object(environ.get("CONFIG_CLASS"))
cfg = import_string(environ.get("CONFIG_CLASS"))()
app.config.from_object(cfg)
db = SQLAlchemy(app)

View File

@ -83,10 +83,19 @@ def post_login():
@json_error
def get_login():
auth_cookie = request.cookies.get(COOKIE_KEY)
try:
token = jwt.decode(auth_cookie, app.config["JWT_PUBLIC_KEY"])
except jwt.exceptions.PyJWTError as unauthorized_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
public_keys = app.config["JWT_PUBLIC_KEYS"]
token = None
token_err = None
for public_key in public_keys:
try:
token = jwt.decode(auth_cookie, public_key)
break
except Exception as decode_err:
token_err = decode_err
if not token:
return jsonify({"error": str(token_err), "status": 401}), 401
try:
match = User.query.get(token["priv"]["userId"])
@ -117,5 +126,4 @@ def key():
This route is not supposed to be called by client-side applications, so
no CORS configuration is added
"""
public_key = app.config["JWT_PUBLIC_KEY"].strip()
return jsonify({"key": public_key})
return jsonify({"keys": app.config["JWT_PUBLIC_KEYS"]})

View File

@ -1,14 +1,68 @@
import json
from os import environ
class BaseConfig(object):
SQLALCHEMY_TRACK_MODIFICATIONS = False
FLASK_ADMIN_SWATCH = "flatly"
class EnvConfig(BaseConfig):
SECRET_KEY = environ.get("SESSION_SECRET")
SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING")
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "")
JWT_PUBLIC_KEY = environ.get("JWT_PUBLIC_KEY", "")
JWT_PUBLIC_KEYS = [environ.get("JWT_PUBLIC_KEY", "")]
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
SERVER_HOST = environ.get("SERVER_HOST")
class SecretsManagerConfig(BaseConfig):
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
SERVER_HOST = environ.get("SERVER_HOST")
def __init__(self):
import boto3
session = boto3.session.Session()
self.client = session.client(
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
)
self.SECRET_KEY = self.get_secret("sessionSecret")
self.SQLALCHEMY_DATABASE_URI = self.get_secret("mysqlConnectionString")
current_version = self.get_secret("jwtKeyPair")
key_pair = json.loads(current_version)
previous_version = self.get_secret("jwtKeyPair", previous=True)
previous_key_pair = (
json.loads(previous_version)
if previous_version is not None
else {"public": None}
)
self.JWT_PRIVATE_KEY = key_pair["private"]
self.JWT_PUBLIC_KEYS = [
k for k in [key_pair["public"], previous_key_pair["public"]] if k
]
def get_secret(self, secret_name, previous=False):
import base64
from botocore.exceptions import ClientError
try:
ssm_response = self.client.get_secret_value(
SecretId="{}/accounts/{}".format(environ.get("STAGE"), secret_name),
VersionStage=("AWSPREVIOUS" if previous else "AWSCURRENT"),
)
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException" and previous:
# A secret might not have a previous version yet. It is left
# up to the caller to handle the None return in this case
return None
raise e
if "SecretString" in ssm_response:
return ssm_response["SecretString"]
return base64.b64decode(ssm_response["SecretBinary"])

View File

View File

@ -1,9 +1,20 @@
import base64
from os import environ
import boto3
from botocore.exceptions import ClientError
from passlib.hash import bcrypt
session = boto3.session.Session()
boto_client = session.client(
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
)
basic_auth_user = get_secret(boto_client, "basicAuthUser")
hashed_basic_auth_password = get_secret(boto_client, "hashedBasicAuthPassword")
def build_api_arn(method_arn):
arn_chunks = method_arn.split(":")
aws_region = arn_chunks[3]
@ -38,6 +49,15 @@ def build_response(api_arn, allow):
}
def get_secret(secret_name):
ssm_response = boto_client.get_secret_value(
SecretId="{}/accounts/{}".format(environ.get("STAGE"), secret_name)
)
if "SecretString" in ssm_response:
return ssm_response["SecretString"]
return base64.b64decode(ssm_response["SecretBinary"])
def handler(event, context):
api_arn = build_api_arn(event["methodArn"])
@ -50,11 +70,10 @@ def handler(event, context):
user = credentials[0]
password = credentials[1]
if user != environ.get("BASIC_AUTH_USER"):
if user != basic_auth_user:
return build_response(api_arn, False)
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
if not bcrypt.verify(password, hashed_password):
if not bcrypt.verify(password, hashed_basic_auth_password):
return build_response(api_arn, False)
return build_response(api_arn, True)

View File

@ -0,0 +1,143 @@
import io
import boto3
import logging
import json
from os import environ
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def handler(event, context):
arn = event["SecretId"]
token = event["ClientRequestToken"]
step = event["Step"]
session = boto3.session.Session()
service_client = session.client(
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
)
# Make sure the version is staged correctly
metadata = service_client.describe_secret(SecretId=arn)
if not metadata["RotationEnabled"]:
logger.error("Secret %s is not enabled for rotation" % arn)
raise ValueError("Secret %s is not enabled for rotation" % arn)
versions = metadata["VersionIdsToStages"]
if token not in versions:
logger.error(
"Secret version %s has no stage for rotation of secret %s.", token, arn
)
raise ValueError(
"Secret version %s has no stage for rotation of secret %s." % (token, arn)
)
if "AWSCURRENT" in versions[token]:
logger.info(
"Secret version %s already set as AWSCURRENT for secret %s.", token, arn
)
return
elif "AWSPENDING" not in versions[token]:
logger.error(
"Secret version %s not set as AWSPENDING for rotation of secret %s.",
token,
arn,
)
raise ValueError(
"Secret version %s not set as AWSPENDING for rotation of secret %s."
% (token, arn)
)
if step == "createSecret":
create_secret(service_client, arn, token)
elif step == "setSecret":
set_secret(service_client, arn, token)
elif step == "testSecret":
test_secret(service_client, arn, token)
elif step == "finishSecret":
finish_secret(service_client, arn, token)
else:
raise ValueError("Invalid step parameter")
def create_key_pair(**kwargs):
key = rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, **kwargs
)
public_key = key.public_key().public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.PKCS1
)
pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return {"private": pem.decode(), "public": public_key.decode()}
def create_secret(service_client, arn, token):
service_client.get_secret_value(SecretId=arn, VersionStage="AWSCURRENT")
try:
service_client.get_secret_value(
SecretId=arn, VersionId=token, VersionStage="AWSPENDING"
)
logger.info("createSecret: Successfully retrieved secret for %s." % arn)
except service_client.exceptions.ResourceNotFoundException:
secret = create_key_pair(key_size=2048)
service_client.put_secret_value(
SecretId=arn,
ClientRequestToken=token,
SecretString=json.dumps(secret).encode().decode("unicode_escape"),
VersionStages=["AWSPENDING"],
)
logger.info(
"createSecret: Successfully put secret for ARN %s and version %s."
% (arn, token)
)
def set_secret(service_client, arn, token):
pass
def test_secret(service_client, arn, token):
pass
def finish_secret(service_client, arn, token):
metadata = service_client.describe_secret(SecretId=arn)
current_version = None
for version in metadata["VersionIdsToStages"]:
if "AWSCURRENT" in metadata["VersionIdsToStages"][version]:
if version == token:
# The correct version is already marked as current, return
logger.info(
"finishSecret: Version %s already marked as AWSCURRENT for %s",
version,
arn,
)
return
current_version = version
break
# Finalize by staging the secret version current
service_client.update_secret_version_stage(
SecretId=arn,
VersionStage="AWSCURRENT",
MoveToVersionId=token,
RemoveFromVersionId=current_version,
)
logger.info(
"finishSecret: Successfully set AWSCURRENT stage to version %s for secret %s."
% (token, arn)
)

View File

@ -9,3 +9,4 @@ bcrypt==3.1.7
PyMySQL==0.9.3
mysqlclient==1.4.2.post1
requests==2.22.0
cryptography==2.7

View File

@ -23,9 +23,9 @@ plugins:
custom:
stage: ${opt:stage, self:provider.stage}
origin:
production: vault.offen.dev
staging: vault-staging.offen.dev
alpha: vault-alpha.offen.dev
production: https://vault.offen.dev
staging: https://vault-staging.offen.dev
alpha: https://vault-alpha.offen.dev
serverHost:
production: https://server.offen.dev
staging: https://server-staging.offen.dev
@ -55,10 +55,13 @@ custom:
functions:
authorizer:
handler: authorizer.handler
handler: lambdas.authorizer.handler
environment:
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
STAGE: ${self:custom.stage}
rotateKeys:
handler: lambdas.rotate_keys.handler
environment:
STAGE: ${self:custom.stage}
app:
handler: wsgi_handler.handler
timeout: 30
@ -84,16 +87,11 @@ functions:
path: '/{proxy+}'
method: any
environment:
CONFIG_CLASS: accounts.config.EnvConfig
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
CONFIG_CLASS: accounts.config.SecretsManagerConfig
STAGE: ${self:custom.stage}
CORS_ORIGIN: ${self:custom.origin.${self:custom.stage}}
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
resources:
Resources:

View File

@ -55,7 +55,7 @@ class TestKey(unittest.TestCase):
rv = self.app.get("/api/key")
assert rv.status.startswith("200")
data = json.loads(rv.data)
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
assert data["keys"] == [environ.get("JWT_PUBLIC_KEY")]
class TestJWT(unittest.TestCase):

View File

@ -39,38 +39,31 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
return
}
keyRes, keyErr := fetchKey(keyURL)
if keyErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching key: %v", keyErr), http.StatusInternalServerError)
keys, keysErr := fetchKeys(keyURL)
if keysErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
return
}
keyBytes, _ := pem.Decode([]byte(keyRes))
if keyBytes == nil {
RespondWithJSONError(w, errors.New("jwt: no PEM block found in given key"), http.StatusInternalServerError)
return
var token *jwt.Token
var tokenErr error
// the response can contain multiple keys to try as some of them
// might have been retired with signed tokens still in use until
// their expiry
for _, key := range keys {
token, tokenErr = tryParse(key, jwtValue)
if tokenErr == nil {
break
}
}
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes)
if parseErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing key: %v", parseErr), http.StatusBadGateway)
return
}
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
if !pubKeyOk {
RespondWithJSONError(w, errors.New("jwt: given key is not of type RSA public key"), http.StatusInternalServerError)
return
}
token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
if jwtErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
if tokenErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token signature: %v", tokenErr), http.StatusForbidden)
return
}
if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token: %v", err), http.StatusForbidden)
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token claims: %v", err), http.StatusForbidden)
return
}
@ -86,11 +79,34 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
}
}
type keyResponse struct {
Key string `json:"key"`
func tryParse(key []byte, tokenValue string) (*jwt.Token, error) {
keyBytes, _ := pem.Decode([]byte(key))
if keyBytes == nil {
return nil, errors.New("jwt: no PEM block found in given key")
}
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes)
if parseErr != nil {
return nil, fmt.Errorf("jwt: error parsing key: %v", parseErr)
}
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
if !pubKeyOk {
return nil, errors.New("jwt: given key is not of type RSA public key")
}
token, jwtErr := jwt.ParseVerify(strings.NewReader(tokenValue), jwa.RS256, pubKey)
if jwtErr != nil {
return nil, fmt.Errorf("jwt: error parsing token: %v", jwtErr)
}
return token, nil
}
func fetchKey(keyURL string) ([]byte, error) {
type keyResponse struct {
Keys []string `json:"keys"`
}
func fetchKeys(keyURL string) ([][]byte, error) {
fetchRes, fetchErr := http.Get(keyURL)
if fetchErr != nil {
return nil, fetchErr
@ -100,5 +116,10 @@ func fetchKey(keyURL string) ([]byte, error) {
if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil {
return nil, err
}
return []byte(payload.Key), nil
asBytes := [][]byte{}
for _, key := range payload.Keys {
asBytes = append(asBytes, []byte(key))
}
return asBytes, nil
}

View File

@ -108,10 +108,10 @@ func TestJWTProtect(t *testing.T) {
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"not really a key"}`))
w.Write([]byte(`{"keys":["not really a key","me neither"]}`))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError,
http.StatusForbidden,
},
{
"invalid key",
@ -121,10 +121,10 @@ func TestJWTProtect(t *testing.T) {
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
w.Write([]byte(`{"keys":["-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"]}`))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusBadGateway,
http.StatusForbidden,
},
{
"valid key, bad auth token",
@ -135,7 +135,7 @@ func TestJWTProtect(t *testing.T) {
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -157,7 +157,7 @@ func TestJWTProtect(t *testing.T) {
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -178,7 +178,7 @@ func TestJWTProtect(t *testing.T) {
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -194,7 +194,7 @@ func TestJWTProtect(t *testing.T) {
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -217,7 +217,7 @@ func TestJWTProtect(t *testing.T) {
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error {
@ -244,7 +244,7 @@ func TestJWTProtect(t *testing.T) {
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },