2
0
mirror of https://github.com/offen/website.git synced 2024-11-22 17:10:29 +01:00

allow checking JWTs against multiple public keys

This commit is contained in:
Frederik Ring 2019-07-19 14:49:35 +02:00
parent 9d98cb63b6
commit f678a9aa56
11 changed files with 310 additions and 63 deletions

View File

@ -3,10 +3,13 @@ from os import environ
from flask import Flask from flask import Flask
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from flask_admin import Admin from flask_admin import Admin
from werkzeug.utils import import_string
app = Flask(__name__) app = Flask(__name__)
app.config.from_object(environ.get("CONFIG_CLASS"))
cfg = import_string(environ.get("CONFIG_CLASS"))()
app.config.from_object(cfg)
db = SQLAlchemy(app) db = SQLAlchemy(app)

View File

@ -83,10 +83,19 @@ def post_login():
@json_error @json_error
def get_login(): def get_login():
auth_cookie = request.cookies.get(COOKIE_KEY) auth_cookie = request.cookies.get(COOKIE_KEY)
public_keys = app.config["JWT_PUBLIC_KEYS"]
token = None
token_err = None
for public_key in public_keys:
try: try:
token = jwt.decode(auth_cookie, app.config["JWT_PUBLIC_KEY"]) token = jwt.decode(auth_cookie, public_key)
except jwt.exceptions.PyJWTError as unauthorized_error: break
return jsonify({"error": str(unauthorized_error), "status": 401}), 401 except Exception as decode_err:
token_err = decode_err
if not token:
return jsonify({"error": str(token_err), "status": 401}), 401
try: try:
match = User.query.get(token["priv"]["userId"]) match = User.query.get(token["priv"]["userId"])
@ -117,5 +126,4 @@ def key():
This route is not supposed to be called by client-side applications, so This route is not supposed to be called by client-side applications, so
no CORS configuration is added no CORS configuration is added
""" """
public_key = app.config["JWT_PUBLIC_KEY"].strip() return jsonify({"keys": app.config["JWT_PUBLIC_KEYS"]})
return jsonify({"key": public_key})

View File

@ -1,14 +1,68 @@
import json
from os import environ from os import environ
class BaseConfig(object): class BaseConfig(object):
SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_TRACK_MODIFICATIONS = False
FLASK_ADMIN_SWATCH = "flatly" FLASK_ADMIN_SWATCH = "flatly"
class EnvConfig(BaseConfig): class EnvConfig(BaseConfig):
SECRET_KEY = environ.get("SESSION_SECRET") SECRET_KEY = environ.get("SESSION_SECRET")
SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING") SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING")
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*") CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "") JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "")
JWT_PUBLIC_KEY = environ.get("JWT_PUBLIC_KEY", "") JWT_PUBLIC_KEYS = [environ.get("JWT_PUBLIC_KEY", "")]
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN") COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
SERVER_HOST = environ.get("SERVER_HOST") SERVER_HOST = environ.get("SERVER_HOST")
class SecretsManagerConfig(BaseConfig):
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
SERVER_HOST = environ.get("SERVER_HOST")
def __init__(self):
import boto3
session = boto3.session.Session()
self.client = session.client(
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
)
self.SECRET_KEY = self.get_secret("sessionSecret")
self.SQLALCHEMY_DATABASE_URI = self.get_secret("mysqlConnectionString")
current_version = self.get_secret("jwtKeyPair")
key_pair = json.loads(current_version)
previous_version = self.get_secret("jwtKeyPair", previous=True)
previous_key_pair = (
json.loads(previous_version)
if previous_version is not None
else {"public": None}
)
self.JWT_PRIVATE_KEY = key_pair["private"]
self.JWT_PUBLIC_KEYS = [
k for k in [key_pair["public"], previous_key_pair["public"]] if k
]
def get_secret(self, secret_name, previous=False):
import base64
from botocore.exceptions import ClientError
try:
ssm_response = self.client.get_secret_value(
SecretId="{}/accounts/{}".format(environ.get("STAGE"), secret_name),
VersionStage=("AWSPREVIOUS" if previous else "AWSCURRENT"),
)
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException" and previous:
# A secret might not have a previous version yet. It is left
# up to the caller to handle the None return in this case
return None
raise e
if "SecretString" in ssm_response:
return ssm_response["SecretString"]
return base64.b64decode(ssm_response["SecretBinary"])

View File

View File

@ -1,9 +1,20 @@
import base64 import base64
from os import environ from os import environ
import boto3
from botocore.exceptions import ClientError
from passlib.hash import bcrypt from passlib.hash import bcrypt
session = boto3.session.Session()
boto_client = session.client(
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
)
basic_auth_user = get_secret(boto_client, "basicAuthUser")
hashed_basic_auth_password = get_secret(boto_client, "hashedBasicAuthPassword")
def build_api_arn(method_arn): def build_api_arn(method_arn):
arn_chunks = method_arn.split(":") arn_chunks = method_arn.split(":")
aws_region = arn_chunks[3] aws_region = arn_chunks[3]
@ -38,6 +49,15 @@ def build_response(api_arn, allow):
} }
def get_secret(secret_name):
ssm_response = boto_client.get_secret_value(
SecretId="{}/accounts/{}".format(environ.get("STAGE"), secret_name)
)
if "SecretString" in ssm_response:
return ssm_response["SecretString"]
return base64.b64decode(ssm_response["SecretBinary"])
def handler(event, context): def handler(event, context):
api_arn = build_api_arn(event["methodArn"]) api_arn = build_api_arn(event["methodArn"])
@ -50,11 +70,10 @@ def handler(event, context):
user = credentials[0] user = credentials[0]
password = credentials[1] password = credentials[1]
if user != environ.get("BASIC_AUTH_USER"): if user != basic_auth_user:
return build_response(api_arn, False) return build_response(api_arn, False)
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD") if not bcrypt.verify(password, hashed_basic_auth_password):
if not bcrypt.verify(password, hashed_password):
return build_response(api_arn, False) return build_response(api_arn, False)
return build_response(api_arn, True) return build_response(api_arn, True)

View File

@ -0,0 +1,143 @@
import io
import boto3
import logging
import json
from os import environ
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def handler(event, context):
arn = event["SecretId"]
token = event["ClientRequestToken"]
step = event["Step"]
session = boto3.session.Session()
service_client = session.client(
service_name="secretsmanager", region_name=environ.get("AWS_REGION")
)
# Make sure the version is staged correctly
metadata = service_client.describe_secret(SecretId=arn)
if not metadata["RotationEnabled"]:
logger.error("Secret %s is not enabled for rotation" % arn)
raise ValueError("Secret %s is not enabled for rotation" % arn)
versions = metadata["VersionIdsToStages"]
if token not in versions:
logger.error(
"Secret version %s has no stage for rotation of secret %s.", token, arn
)
raise ValueError(
"Secret version %s has no stage for rotation of secret %s." % (token, arn)
)
if "AWSCURRENT" in versions[token]:
logger.info(
"Secret version %s already set as AWSCURRENT for secret %s.", token, arn
)
return
elif "AWSPENDING" not in versions[token]:
logger.error(
"Secret version %s not set as AWSPENDING for rotation of secret %s.",
token,
arn,
)
raise ValueError(
"Secret version %s not set as AWSPENDING for rotation of secret %s."
% (token, arn)
)
if step == "createSecret":
create_secret(service_client, arn, token)
elif step == "setSecret":
set_secret(service_client, arn, token)
elif step == "testSecret":
test_secret(service_client, arn, token)
elif step == "finishSecret":
finish_secret(service_client, arn, token)
else:
raise ValueError("Invalid step parameter")
def create_key_pair(**kwargs):
key = rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, **kwargs
)
public_key = key.public_key().public_bytes(
serialization.Encoding.PEM, serialization.PublicFormat.PKCS1
)
pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return {"private": pem.decode(), "public": public_key.decode()}
def create_secret(service_client, arn, token):
service_client.get_secret_value(SecretId=arn, VersionStage="AWSCURRENT")
try:
service_client.get_secret_value(
SecretId=arn, VersionId=token, VersionStage="AWSPENDING"
)
logger.info("createSecret: Successfully retrieved secret for %s." % arn)
except service_client.exceptions.ResourceNotFoundException:
secret = create_key_pair(key_size=2048)
service_client.put_secret_value(
SecretId=arn,
ClientRequestToken=token,
SecretString=json.dumps(secret).encode().decode("unicode_escape"),
VersionStages=["AWSPENDING"],
)
logger.info(
"createSecret: Successfully put secret for ARN %s and version %s."
% (arn, token)
)
def set_secret(service_client, arn, token):
pass
def test_secret(service_client, arn, token):
pass
def finish_secret(service_client, arn, token):
metadata = service_client.describe_secret(SecretId=arn)
current_version = None
for version in metadata["VersionIdsToStages"]:
if "AWSCURRENT" in metadata["VersionIdsToStages"][version]:
if version == token:
# The correct version is already marked as current, return
logger.info(
"finishSecret: Version %s already marked as AWSCURRENT for %s",
version,
arn,
)
return
current_version = version
break
# Finalize by staging the secret version current
service_client.update_secret_version_stage(
SecretId=arn,
VersionStage="AWSCURRENT",
MoveToVersionId=token,
RemoveFromVersionId=current_version,
)
logger.info(
"finishSecret: Successfully set AWSCURRENT stage to version %s for secret %s."
% (token, arn)
)

View File

@ -9,3 +9,4 @@ bcrypt==3.1.7
PyMySQL==0.9.3 PyMySQL==0.9.3
mysqlclient==1.4.2.post1 mysqlclient==1.4.2.post1
requests==2.22.0 requests==2.22.0
cryptography==2.7

View File

@ -23,9 +23,9 @@ plugins:
custom: custom:
stage: ${opt:stage, self:provider.stage} stage: ${opt:stage, self:provider.stage}
origin: origin:
production: vault.offen.dev production: https://vault.offen.dev
staging: vault-staging.offen.dev staging: https://vault-staging.offen.dev
alpha: vault-alpha.offen.dev alpha: https://vault-alpha.offen.dev
serverHost: serverHost:
production: https://server.offen.dev production: https://server.offen.dev
staging: https://server-staging.offen.dev staging: https://server-staging.offen.dev
@ -55,10 +55,13 @@ custom:
functions: functions:
authorizer: authorizer:
handler: authorizer.handler handler: lambdas.authorizer.handler
environment: environment:
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true} STAGE: ${self:custom.stage}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true} rotateKeys:
handler: lambdas.rotate_keys.handler
environment:
STAGE: ${self:custom.stage}
app: app:
handler: wsgi_handler.handler handler: wsgi_handler.handler
timeout: 30 timeout: 30
@ -84,16 +87,11 @@ functions:
path: '/{proxy+}' path: '/{proxy+}'
method: any method: any
environment: environment:
CONFIG_CLASS: accounts.config.EnvConfig CONFIG_CLASS: accounts.config.SecretsManagerConfig
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}} STAGE: ${self:custom.stage}
CORS_ORIGIN: ${self:custom.origin.${self:custom.stage}}
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}} COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}} SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
resources: resources:
Resources: Resources:

View File

@ -55,7 +55,7 @@ class TestKey(unittest.TestCase):
rv = self.app.get("/api/key") rv = self.app.get("/api/key")
assert rv.status.startswith("200") assert rv.status.startswith("200")
data = json.loads(rv.data) data = json.loads(rv.data)
assert data["key"] == environ.get("JWT_PUBLIC_KEY") assert data["keys"] == [environ.get("JWT_PUBLIC_KEY")]
class TestJWT(unittest.TestCase): class TestJWT(unittest.TestCase):

View File

@ -39,38 +39,31 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
return return
} }
keyRes, keyErr := fetchKey(keyURL) keys, keysErr := fetchKeys(keyURL)
if keyErr != nil { if keysErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error fetching key: %v", keyErr), http.StatusInternalServerError) RespondWithJSONError(w, fmt.Errorf("jwt: error fetching keys: %v", keysErr), http.StatusInternalServerError)
return return
} }
keyBytes, _ := pem.Decode([]byte(keyRes)) var token *jwt.Token
if keyBytes == nil { var tokenErr error
RespondWithJSONError(w, errors.New("jwt: no PEM block found in given key"), http.StatusInternalServerError) // the response can contain multiple keys to try as some of them
return // might have been retired with signed tokens still in use until
// their expiry
for _, key := range keys {
token, tokenErr = tryParse(key, jwtValue)
if tokenErr == nil {
break
}
} }
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes) if tokenErr != nil {
if parseErr != nil { RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token signature: %v", tokenErr), http.StatusForbidden)
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing key: %v", parseErr), http.StatusBadGateway)
return
}
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
if !pubKeyOk {
RespondWithJSONError(w, errors.New("jwt: given key is not of type RSA public key"), http.StatusInternalServerError)
return
}
token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
if jwtErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
return return
} }
if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil { if err := token.Verify(jwt.WithAcceptableSkew(0)); err != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token: %v", err), http.StatusForbidden) RespondWithJSONError(w, fmt.Errorf("jwt: error verifying token claims: %v", err), http.StatusForbidden)
return return
} }
@ -86,11 +79,34 @@ func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Req
} }
} }
type keyResponse struct { func tryParse(key []byte, tokenValue string) (*jwt.Token, error) {
Key string `json:"key"` keyBytes, _ := pem.Decode([]byte(key))
if keyBytes == nil {
return nil, errors.New("jwt: no PEM block found in given key")
}
parseResult, parseErr := x509.ParsePKIXPublicKey(keyBytes.Bytes)
if parseErr != nil {
return nil, fmt.Errorf("jwt: error parsing key: %v", parseErr)
}
pubKey, pubKeyOk := parseResult.(*rsa.PublicKey)
if !pubKeyOk {
return nil, errors.New("jwt: given key is not of type RSA public key")
}
token, jwtErr := jwt.ParseVerify(strings.NewReader(tokenValue), jwa.RS256, pubKey)
if jwtErr != nil {
return nil, fmt.Errorf("jwt: error parsing token: %v", jwtErr)
}
return token, nil
} }
func fetchKey(keyURL string) ([]byte, error) { type keyResponse struct {
Keys []string `json:"keys"`
}
func fetchKeys(keyURL string) ([][]byte, error) {
fetchRes, fetchErr := http.Get(keyURL) fetchRes, fetchErr := http.Get(keyURL)
if fetchErr != nil { if fetchErr != nil {
return nil, fetchErr return nil, fetchErr
@ -100,5 +116,10 @@ func fetchKey(keyURL string) ([]byte, error) {
if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil { if err := json.NewDecoder(fetchRes.Body).Decode(&payload); err != nil {
return nil, err return nil, err
} }
return []byte(payload.Key), nil
asBytes := [][]byte{}
for _, key := range payload.Keys {
asBytes = append(asBytes, []byte(key))
}
return asBytes, nil
} }

View File

@ -108,10 +108,10 @@ func TestJWTProtect(t *testing.T) {
}, },
nil, nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"not really a key"}`)) w.Write([]byte(`{"keys":["not really a key","me neither"]}`))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusForbidden,
}, },
{ {
"invalid key", "invalid key",
@ -121,10 +121,10 @@ func TestJWTProtect(t *testing.T) {
}, },
nil, nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`)) w.Write([]byte(`{"keys":["-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"]}`))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusBadGateway, http.StatusForbidden,
}, },
{ {
"valid key, bad auth token", "valid key, bad auth token",
@ -135,7 +135,7 @@ func TestJWTProtect(t *testing.T) {
nil, nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -157,7 +157,7 @@ func TestJWTProtect(t *testing.T) {
nil, nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -178,7 +178,7 @@ func TestJWTProtect(t *testing.T) {
})(), })(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -194,7 +194,7 @@ func TestJWTProtect(t *testing.T) {
})(), })(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },
@ -217,7 +217,7 @@ func TestJWTProtect(t *testing.T) {
nil, nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { func(r *http.Request, claims map[string]interface{}) error {
@ -244,7 +244,7 @@ func TestJWTProtect(t *testing.T) {
nil, nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"keys":["%s"]}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil }, func(r *http.Request, claims map[string]interface{}) error { return nil },