mirror of
https://github.com/offen/website.git
synced 2024-12-23 13:30:20 +01:00
Merge pull request #59 from offen/flask-admin
Add auditorium users and db backed accounts
This commit is contained in:
commit
d044d95b5b
@ -87,7 +87,9 @@ jobs:
|
||||
echo Failed waiting for Postgres && exit 1
|
||||
- run:
|
||||
name: Run tests
|
||||
command: make test-ci
|
||||
command: |
|
||||
cp ~/offen/bootstrap.yml .
|
||||
make test-ci
|
||||
|
||||
shared:
|
||||
docker:
|
||||
@ -196,7 +198,7 @@ jobs:
|
||||
docker:
|
||||
- image: circleci/python:3.6
|
||||
environment:
|
||||
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp
|
||||
MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
|
||||
JWT_PRIVATE_KEY: |-
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzgU18PnRrpbVK
|
||||
@ -236,7 +238,11 @@ jobs:
|
||||
a3B4L0waKzP5QWcO865n1HCUTnV+s4lNcphBDZCrSwTkXnVnQWVPCL7ssoQyM0u3
|
||||
HQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
|
||||
- image: circleci/mysql:5.7
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=circle
|
||||
- MYSQL_DATABASE=circle
|
||||
- MYSQL_HOST=127.0.0.1
|
||||
working_directory: ~/offen/accounts
|
||||
steps:
|
||||
- checkout:
|
||||
@ -254,11 +260,22 @@ jobs:
|
||||
paths:
|
||||
- ~/offen/accounts/venv
|
||||
key: offen-accounts-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
|
||||
- run:
|
||||
name: Waiting for MySQL to be ready
|
||||
command: |
|
||||
for i in `seq 1 10`;
|
||||
do
|
||||
nc -z localhost 3306 && echo Success && exit 0
|
||||
echo -n .
|
||||
sleep 1
|
||||
done
|
||||
echo Failed waiting for MySQL && exit 1
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
. venv/bin/activate
|
||||
make
|
||||
cp ~/offen/bootstrap.yml .
|
||||
make test-ci
|
||||
|
||||
deploy_python:
|
||||
docker:
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -5,3 +5,5 @@ package-lock.json
|
||||
# mkcert certificates
|
||||
*.pem
|
||||
venv/
|
||||
|
||||
bootstrap-alpha.yml
|
||||
|
1
Makefile
1
Makefile
@ -20,5 +20,6 @@ setup:
|
||||
bootstrap:
|
||||
@docker-compose run kms make bootstrap
|
||||
@docker-compose run server make bootstrap
|
||||
@docker-compose run accounts make bootstrap
|
||||
|
||||
.PHONY: setup bootstrap
|
||||
|
@ -1,7 +1,13 @@
|
||||
test:
|
||||
@pytest --disable-pytest-warnings
|
||||
|
||||
test-ci: bootstrap
|
||||
@pytest --disable-pytest-warnings
|
||||
|
||||
fmt:
|
||||
@black .
|
||||
|
||||
.PHONY: test fmt
|
||||
bootstrap:
|
||||
@python -m scripts.bootstrap
|
||||
|
||||
.PHONY: test fmt bootstrap
|
||||
|
@ -1,5 +1,21 @@
|
||||
from os import environ
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_admin import Admin
|
||||
|
||||
app = Flask(__name__)
|
||||
app.secret_key = environ.get("SESSION_SECRET")
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING")
|
||||
db = SQLAlchemy(app)
|
||||
|
||||
import accounts.views
|
||||
from accounts.models import Account, User
|
||||
from accounts.views import AccountView, UserView
|
||||
import accounts.api
|
||||
|
||||
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
|
||||
|
||||
admin = Admin(app, name="offen admin", template_mode="bootstrap3")
|
||||
|
||||
admin.add_view(AccountView(Account, db.session))
|
||||
admin.add_view(UserView(User, db.session))
|
||||
|
124
accounts/accounts/api.py
Normal file
124
accounts/accounts/api.py
Normal file
@ -0,0 +1,124 @@
|
||||
from datetime import datetime, timedelta
|
||||
from os import environ
|
||||
from functools import wraps
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
from flask_cors import cross_origin
|
||||
from passlib.hash import bcrypt
|
||||
import jwt
|
||||
|
||||
from accounts import app
|
||||
from accounts.models import User
|
||||
|
||||
COOKIE_KEY = "auth"
|
||||
|
||||
|
||||
def json_error(handler):
|
||||
@wraps(handler)
|
||||
def wrapped_handler(*args, **kwargs):
|
||||
try:
|
||||
return handler(*args, **kwargs)
|
||||
except Exception as server_error:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": "Internal server error: {}".format(str(server_error)),
|
||||
"status": 500,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
return wrapped_handler
|
||||
|
||||
|
||||
class UnauthorizedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@app.route("/api/login", methods=["POST"])
|
||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||
@json_error
|
||||
def post_login():
|
||||
credentials = request.get_json(force=True)
|
||||
try:
|
||||
match = User.query.filter_by(email=credentials["username"]).first()
|
||||
if not match:
|
||||
raise UnauthorizedError("bad username")
|
||||
if not bcrypt.verify(credentials["password"], match.hashed_password):
|
||||
raise UnauthorizedError("bad password")
|
||||
except UnauthorizedError as unauthorized_error:
|
||||
resp = make_response(jsonify({"error": str(unauthorized_error), "status": 401}))
|
||||
resp.set_cookie(COOKIE_KEY, "", expires=0)
|
||||
resp.status_code = 401
|
||||
return resp
|
||||
|
||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
||||
expiry = datetime.utcnow() + timedelta(hours=24)
|
||||
encoded = jwt.encode(
|
||||
{
|
||||
"exp": expiry,
|
||||
"priv": {
|
||||
"userId": match.user_id,
|
||||
"accounts": [a.account_id for a in match.accounts],
|
||||
},
|
||||
},
|
||||
private_key.encode(),
|
||||
algorithm="RS256",
|
||||
).decode()
|
||||
|
||||
resp = make_response(jsonify({"user": match.serialize()}))
|
||||
resp.set_cookie(
|
||||
COOKIE_KEY,
|
||||
encoded,
|
||||
httponly=True,
|
||||
expires=expiry,
|
||||
path="/",
|
||||
domain=environ.get("COOKIE_DOMAIN"),
|
||||
samesite="strict",
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@app.route("/api/login", methods=["GET"])
|
||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||
@json_error
|
||||
def get_login():
|
||||
auth_cookie = request.cookies.get(COOKIE_KEY)
|
||||
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
||||
try:
|
||||
token = jwt.decode(auth_cookie, public_key)
|
||||
except jwt.exceptions.PyJWTError as unauthorized_error:
|
||||
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
||||
|
||||
try:
|
||||
match = User.query.get(token["priv"]["userId"])
|
||||
except KeyError as key_err:
|
||||
return (
|
||||
jsonify(
|
||||
{"error": "malformed JWT claims: {}".format(key_err), "status": 401}
|
||||
),
|
||||
401,
|
||||
)
|
||||
return jsonify({"user": match.serialize()})
|
||||
|
||||
|
||||
@app.route("/api/logout", methods=["POST"])
|
||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||
@json_error
|
||||
def post_logout():
|
||||
resp = make_response("")
|
||||
resp.set_cookie(COOKIE_KEY, "", expires=0)
|
||||
resp.status_code = 204
|
||||
return resp
|
||||
|
||||
|
||||
@app.route("/api/key", methods=["GET"])
|
||||
@json_error
|
||||
def key():
|
||||
"""
|
||||
This route is not supposed to be called by client-side applications, so
|
||||
no CORS configuration is added
|
||||
"""
|
||||
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
|
||||
return jsonify({"key": public_key})
|
49
accounts/accounts/models.py
Normal file
49
accounts/accounts/models.py
Normal file
@ -0,0 +1,49 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from accounts import db
|
||||
|
||||
|
||||
def generate_key():
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
class Account(db.Model):
|
||||
__tablename__ = "accounts"
|
||||
account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
||||
name = db.Column(db.Text, nullable=False)
|
||||
users = db.relationship("AccountUserAssociation", back_populates="account")
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
__tablename__ = "users"
|
||||
user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
||||
email = db.Column(db.String(128), nullable=False, unique=True)
|
||||
hashed_password = db.Column(db.Text, nullable=False)
|
||||
accounts = db.relationship(
|
||||
"AccountUserAssociation", back_populates="user", lazy="joined"
|
||||
)
|
||||
|
||||
def serialize(self):
|
||||
associated_accounts = [a.account_id for a in self.accounts]
|
||||
records = [
|
||||
{"name": a.name, "accountId": a.account_id}
|
||||
for a in Account.query.filter(Account.account_id.in_(associated_accounts))
|
||||
]
|
||||
return {"userId": self.user_id, "email": self.email, "accounts": records}
|
||||
|
||||
|
||||
class AccountUserAssociation(db.Model):
|
||||
__tablename__ = "account_to_user"
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.user_id"), nullable=False)
|
||||
account_id = db.Column(
|
||||
db.String(36), db.ForeignKey("accounts.account_id"), nullable=False
|
||||
)
|
||||
|
||||
user = db.relationship("User", back_populates="accounts")
|
||||
account = db.relationship("Account", back_populates="users")
|
@ -1,69 +1,104 @@
|
||||
from datetime import datetime, timedelta
|
||||
from os import environ
|
||||
import base64
|
||||
|
||||
from flask import jsonify, render_template, make_response, request
|
||||
from flask_cors import cross_origin
|
||||
import requests
|
||||
from flask_admin.contrib.sqla import ModelView
|
||||
from wtforms import PasswordField, StringField, Form
|
||||
from wtforms.validators import InputRequired, EqualTo
|
||||
from passlib.hash import bcrypt
|
||||
import jwt
|
||||
|
||||
from accounts import app
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
return render_template("index.html")
|
||||
from accounts import db
|
||||
from accounts.models import AccountUserAssociation
|
||||
|
||||
|
||||
@app.route("/api/login", methods=["POST"])
|
||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||
def post_login():
|
||||
credentials = request.get_json(force=True)
|
||||
class RemoteServerException(Exception):
|
||||
status = 0
|
||||
|
||||
if credentials["username"] != environ.get("USER", "offen"):
|
||||
return jsonify({"error": "bad username", "status": 401}), 401
|
||||
def __str__(self):
|
||||
return "Status {}: {}".format(
|
||||
self.status, super(RemoteServerException, self).__str__()
|
||||
)
|
||||
|
||||
hashed_password = base64.standard_b64decode(environ.get("HASHED_PASSWORD", ""))
|
||||
if not bcrypt.verify(credentials["password"], hashed_password):
|
||||
return jsonify({"error": "bad password", "status": 401}), 401
|
||||
|
||||
def create_remote_account(name, account_id):
|
||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
||||
expiry = datetime.utcnow() + timedelta(hours=24)
|
||||
try:
|
||||
encoded = jwt.encode(
|
||||
{"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256"
|
||||
).decode("utf-8")
|
||||
except jwt.exceptions.PyJWTError as encode_error:
|
||||
return jsonify({"error": str(encode_error), "status": 500}), 500
|
||||
# expires in 30 seconds as this will mean the HTTP request would have
|
||||
# timed out anyways
|
||||
expiry = datetime.utcnow() + timedelta(seconds=30)
|
||||
encoded = jwt.encode(
|
||||
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
|
||||
private_key.encode(),
|
||||
algorithm="RS256",
|
||||
).decode("utf-8")
|
||||
|
||||
resp = make_response(jsonify({"ok": True}))
|
||||
resp.set_cookie(
|
||||
"auth",
|
||||
encoded,
|
||||
httponly=True,
|
||||
expires=expiry,
|
||||
path="/",
|
||||
domain=environ.get("COOKIE_DOMAIN"),
|
||||
samesite="strict"
|
||||
r = requests.post(
|
||||
"{}/accounts".format(environ.get("SERVER_HOST")),
|
||||
json={"name": name, "accountId": account_id},
|
||||
headers={"X-RPC-Authentication": encoded},
|
||||
)
|
||||
return resp
|
||||
|
||||
if r.status_code > 299:
|
||||
err = r.json()
|
||||
remote_err = RemoteServerException(err["error"])
|
||||
remote_err.status = err["status"]
|
||||
raise remote_err
|
||||
|
||||
|
||||
@app.route("/api/login", methods=["GET"])
|
||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||
def get_login():
|
||||
auth_cookie = request.cookies.get("auth")
|
||||
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
||||
try:
|
||||
jwt.decode(auth_cookie, public_key)
|
||||
except jwt.exceptions.PyJWTError as unauthorized_error:
|
||||
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
||||
|
||||
return jsonify({"ok": True})
|
||||
class AccountForm(Form):
|
||||
name = StringField(
|
||||
"Account Name",
|
||||
validators=[InputRequired()],
|
||||
description="This is the account name visible to users",
|
||||
)
|
||||
|
||||
|
||||
# This route is not supposed to be called by client-side applications, so
|
||||
# no CORS configuration is added
|
||||
@app.route("/api/key", methods=["GET"])
|
||||
def key():
|
||||
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
|
||||
return jsonify({"key": public_key})
|
||||
class AccountView(ModelView):
|
||||
form = AccountForm
|
||||
column_display_all_relations = True
|
||||
column_list = ("name", "account_id")
|
||||
|
||||
def after_model_change(self, form, model, is_created):
|
||||
if is_created:
|
||||
try:
|
||||
create_remote_account(model.name, model.account_id)
|
||||
except RemoteServerException as server_error:
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
raise server_error
|
||||
|
||||
|
||||
class UserView(ModelView):
|
||||
inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
|
||||
column_auto_select_related = True
|
||||
column_display_all_relations = True
|
||||
column_list = ("email", "user_id")
|
||||
form_columns = ("email", "accounts")
|
||||
form_create_rules = ("email", "password", "confirm", "accounts")
|
||||
form_edit_rules = ("email", "password", "confirm", "accounts")
|
||||
|
||||
def on_model_change(self, form, model, is_created):
|
||||
if form.password.data:
|
||||
model.hashed_password = bcrypt.hash(form.password.data)
|
||||
|
||||
def get_create_form(self):
|
||||
form = super(UserView, self).get_create_form()
|
||||
form.password = PasswordField(
|
||||
"Password",
|
||||
validators=[
|
||||
InputRequired(),
|
||||
EqualTo("confirm", message="Passwords must match"),
|
||||
],
|
||||
)
|
||||
form.confirm = PasswordField("Repeat Password", validators=[InputRequired()])
|
||||
return form
|
||||
|
||||
def get_edit_form(self):
|
||||
form = super(UserView, self).get_edit_form()
|
||||
form.password = PasswordField(
|
||||
"Password",
|
||||
description="When left blank, the password will remain unchanged on update",
|
||||
validators=[EqualTo("confirm", message="Passwords must match")],
|
||||
)
|
||||
form.confirm = PasswordField("Repeat Password", validators=[])
|
||||
return form
|
||||
|
60
accounts/authorizer/__init__.py
Normal file
60
accounts/authorizer/__init__.py
Normal file
@ -0,0 +1,60 @@
|
||||
import base64
|
||||
from os import environ
|
||||
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
|
||||
def build_api_arn(method_arn):
|
||||
arn_chunks = method_arn.split(":")
|
||||
aws_region = arn_chunks[3]
|
||||
aws_account_id = arn_chunks[4]
|
||||
|
||||
gateway_arn_chunks = arn_chunks[5].split("/")
|
||||
rest_api_id = gateway_arn_chunks[0]
|
||||
stage = gateway_arn_chunks[1]
|
||||
|
||||
return "arn:aws:execute-api:{}:{}:{}/{}/*/*".format(
|
||||
aws_region, aws_account_id, rest_api_id, stage
|
||||
)
|
||||
|
||||
|
||||
def build_response(api_arn, allow):
|
||||
effect = "Deny"
|
||||
if allow:
|
||||
effect = "Allow"
|
||||
|
||||
return {
|
||||
"principalId": "offen",
|
||||
"policyDocument": {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Action": ["execute-api:Invoke"],
|
||||
"Effect": effect,
|
||||
"Resource": [api_arn],
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def handler(event, context):
|
||||
api_arn = build_api_arn(event["methodArn"])
|
||||
|
||||
encoded_auth = event["authorizationToken"].lstrip("Basic ")
|
||||
auth_string = base64.standard_b64decode(encoded_auth).decode()
|
||||
if not auth_string:
|
||||
return build_response(api_arn, False)
|
||||
|
||||
credentials = auth_string.split(":")
|
||||
user = credentials[0]
|
||||
password = credentials[1]
|
||||
|
||||
if user != environ.get("BASIC_AUTH_USER"):
|
||||
return build_response(api_arn, False)
|
||||
|
||||
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
|
||||
if not bcrypt.verify(password, hashed_password):
|
||||
return build_response(api_arn, False)
|
||||
|
||||
return build_response(api_arn, True)
|
0
accounts/bootstrap.yml
Executable file
0
accounts/bootstrap.yml
Executable file
@ -1,2 +1,3 @@
|
||||
pytest
|
||||
black
|
||||
pyyaml
|
||||
|
@ -1,6 +1,11 @@
|
||||
Flask==1.0.2
|
||||
Flask-Admin==1.5.3
|
||||
Flask-Cors==3.0.8
|
||||
Flask-SQLAlchemy==2.4.0
|
||||
werkzeug==0.15.4
|
||||
pyjwt[crypto]==1.7.1
|
||||
passlib==1.7.1
|
||||
bcrypt==3.1.7
|
||||
PyMySQL==0.9.3
|
||||
mysqlclient==1.4.2.post1
|
||||
requests==2.22.0
|
||||
|
0
accounts/scripts/__init__.py
Normal file
0
accounts/scripts/__init__.py
Normal file
32
accounts/scripts/bootstrap.py
Executable file
32
accounts/scripts/bootstrap.py
Executable file
@ -0,0 +1,32 @@
|
||||
import yaml
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
from accounts import db
|
||||
from accounts.models import Account, User, AccountUserAssociation
|
||||
|
||||
if __name__ == "__main__":
|
||||
db.drop_all()
|
||||
db.create_all()
|
||||
|
||||
with open("./bootstrap.yml", "r") as stream:
|
||||
data = yaml.safe_load(stream)
|
||||
|
||||
for account in data["accounts"]:
|
||||
record = Account(
|
||||
name=account["name"],
|
||||
account_id=account["id"],
|
||||
)
|
||||
db.session.add(record)
|
||||
|
||||
for user in data["users"]:
|
||||
record = User(
|
||||
email=user["email"],
|
||||
hashed_password=bcrypt.hash(user["password"]),
|
||||
)
|
||||
for account_id in user["accounts"]:
|
||||
record.accounts.append(AccountUserAssociation(account_id=account_id))
|
||||
db.session.add(record)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
print("Successfully bootstrapped accounts database")
|
20
accounts/scripts/hash.py
Normal file
20
accounts/scripts/hash.py
Normal file
@ -0,0 +1,20 @@
|
||||
import base64
|
||||
import argparse
|
||||
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--password", type=str, help="The password to hash", required=True)
|
||||
parser.add_argument(
|
||||
"--plain",
|
||||
help="Do not encode the result as base64",
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
out = bcrypt.hash(args.password)
|
||||
if not args.plain:
|
||||
out = base64.standard_b64encode(out.encode()).decode()
|
||||
print(out)
|
@ -26,6 +26,10 @@ custom:
|
||||
production: vault.offen.dev
|
||||
staging: vault-staging.offen.dev
|
||||
alpha: vault-alpha.offen.dev
|
||||
serverHost:
|
||||
production: server.offen.dev
|
||||
staging: server-staging.offen.dev
|
||||
alpha: server-alpha.offen.dev
|
||||
domain:
|
||||
production: accounts.offen.dev
|
||||
staging: accounts-staging.offen.dev
|
||||
@ -50,19 +54,53 @@ custom:
|
||||
fileName: requirements.txt
|
||||
|
||||
functions:
|
||||
authorizer:
|
||||
handler: authorizer.handler
|
||||
environment:
|
||||
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
|
||||
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
||||
app:
|
||||
handler: wsgi_handler.handler
|
||||
events:
|
||||
- http:
|
||||
path: /admin/
|
||||
method: any
|
||||
authorizer:
|
||||
name: authorizer
|
||||
resultTtlInSeconds: 0
|
||||
identitySource: method.request.header.Authorization
|
||||
- http:
|
||||
path: /admin/{proxy+}
|
||||
method: any
|
||||
authorizer:
|
||||
name: authorizer
|
||||
resultTtlInSeconds: 0
|
||||
identitySource: method.request.header.Authorization
|
||||
- http:
|
||||
path: '/'
|
||||
method: any
|
||||
- http:
|
||||
path: '{proxy+}'
|
||||
path: '/{proxy+}'
|
||||
method: any
|
||||
environment:
|
||||
USER: offen
|
||||
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
|
||||
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
||||
SERVER_URL: ${self:custom.serverHost.${self:custom.stage}}
|
||||
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
|
||||
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
|
||||
HASHED_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
||||
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
|
||||
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
||||
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
|
||||
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
|
||||
|
||||
resources:
|
||||
Resources:
|
||||
GatewayResponse:
|
||||
Type: 'AWS::ApiGateway::GatewayResponse'
|
||||
Properties:
|
||||
ResponseParameters:
|
||||
gatewayresponse.header.WWW-Authenticate: "'Basic'"
|
||||
ResponseType: UNAUTHORIZED
|
||||
RestApiId:
|
||||
Ref: 'ApiGatewayRestApi'
|
||||
StatusCode: '401'
|
||||
|
207
accounts/tests/test_api.py
Normal file
207
accounts/tests/test_api.py
Normal file
@ -0,0 +1,207 @@
|
||||
import unittest
|
||||
import json
|
||||
import base64
|
||||
from json import loads
|
||||
from time import time
|
||||
from datetime import datetime, timedelta
|
||||
from os import environ
|
||||
|
||||
import jwt
|
||||
|
||||
from accounts import app
|
||||
|
||||
|
||||
FOREIGN_PRIVATE_KEY = """
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG
|
||||
N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P
|
||||
nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj
|
||||
ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN
|
||||
6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4
|
||||
Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i
|
||||
6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl
|
||||
wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN
|
||||
kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn
|
||||
f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761
|
||||
5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI
|
||||
TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3
|
||||
llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G
|
||||
rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0
|
||||
lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS
|
||||
NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU
|
||||
zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE
|
||||
L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz
|
||||
uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ
|
||||
WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f
|
||||
qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE
|
||||
oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm
|
||||
ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d
|
||||
rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO
|
||||
4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF
|
||||
15NFZ2gU4rKnQ3sVAOzMoEw5
|
||||
-----END PRIVATE KEY-----
|
||||
"""
|
||||
|
||||
def _pad_b64_string(s):
|
||||
while len(s) % 4 is not 0:
|
||||
s = s + "="
|
||||
return s
|
||||
|
||||
|
||||
class TestKey(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.app = app.test_client()
|
||||
|
||||
def test_get_key(self):
|
||||
rv = self.app.get("/api/key")
|
||||
assert rv.status.startswith("200")
|
||||
data = loads(rv.data)
|
||||
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
|
||||
|
||||
|
||||
class TestJWT(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.app = app.test_client()
|
||||
|
||||
def _assert_cookie_present(self, name):
|
||||
for cookie in self.app.cookie_jar:
|
||||
if cookie.name == name:
|
||||
return cookie.value
|
||||
raise AssertionError("Cookie named {} not found".format(name))
|
||||
|
||||
def _assert_cookie_not_present(self, name):
|
||||
for cookie in self.app.cookie_jar:
|
||||
assert cookie.name != name
|
||||
|
||||
def test_jwt_flow(self):
|
||||
"""
|
||||
First, try login attempts that are supposed to fail:
|
||||
1. checking login status without any prior interaction
|
||||
2. try logging in with an unknown user
|
||||
3. try logging in with a known user and bad password
|
||||
"""
|
||||
rv = self.app.get("/api/login")
|
||||
assert rv.status.startswith("401")
|
||||
self._assert_cookie_not_present("auth")
|
||||
|
||||
rv = self.app.post(
|
||||
"/api/login",
|
||||
data=json.dumps(
|
||||
{"username": "does@not.exist", "password": "somethingsomething"}
|
||||
),
|
||||
)
|
||||
assert rv.status.startswith("401")
|
||||
self._assert_cookie_not_present("auth")
|
||||
|
||||
rv = self.app.post(
|
||||
"/api/login",
|
||||
data=json.dumps({"username": "develop@offen.dev", "password": "developp"}),
|
||||
)
|
||||
assert rv.status.startswith("401")
|
||||
self._assert_cookie_not_present("auth")
|
||||
|
||||
"""
|
||||
Next, perform a successful login
|
||||
"""
|
||||
rv = self.app.post(
|
||||
"/api/login",
|
||||
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
|
||||
)
|
||||
assert rv.status.startswith("200")
|
||||
|
||||
"""
|
||||
The response should contain information about the
|
||||
user and full information (i.e. a name) about the associated accounts
|
||||
"""
|
||||
data = json.loads(rv.data)
|
||||
assert data["user"]["userId"] is not None
|
||||
data["user"]["accounts"].sort(key=lambda a: a["name"])
|
||||
self.assertListEqual(
|
||||
data["user"]["accounts"],
|
||||
[
|
||||
{"name": "One", "accountId": "9b63c4d8-65c0-438c-9d30-cc4b01173393"},
|
||||
{"name": "Two", "accountId": "78403940-ae4f-4aff-a395-1e90f145cf62"},
|
||||
],
|
||||
)
|
||||
|
||||
"""
|
||||
The claims part of the JWT is expected to contain a valid expiry,
|
||||
information about the user and the associated account ids.
|
||||
"""
|
||||
jwt = self._assert_cookie_present("auth")
|
||||
# PyJWT strips the padding from the base64 encoded parts which Python
|
||||
# cannot decode properly, so we need to add the padding ourselves
|
||||
claims_part = _pad_b64_string(jwt.split(".")[1])
|
||||
claims = loads(base64.b64decode(claims_part))
|
||||
assert claims.get("exp") > time()
|
||||
|
||||
priv = claims.get("priv")
|
||||
assert priv is not None
|
||||
|
||||
assert priv.get("userId") is not None
|
||||
self.assertListEqual(
|
||||
priv["accounts"],
|
||||
[
|
||||
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
|
||||
"78403940-ae4f-4aff-a395-1e90f145cf62",
|
||||
],
|
||||
)
|
||||
|
||||
"""
|
||||
Checking the login status when re-using the cookie should yield
|
||||
a successful response
|
||||
"""
|
||||
rv = self.app.get("/api/login")
|
||||
assert rv.status.startswith("200")
|
||||
jwt2 = self._assert_cookie_present("auth")
|
||||
assert jwt2 == jwt
|
||||
|
||||
"""
|
||||
Performing a bad login attempt when sending a valid auth cookie
|
||||
is expected to destroy the cookie and leave the user logged out again
|
||||
"""
|
||||
rv = self.app.post(
|
||||
"/api/login",
|
||||
data=json.dumps(
|
||||
{"username": "evil@session.takeover", "password": "develop"}
|
||||
),
|
||||
)
|
||||
assert rv.status.startswith("401")
|
||||
self._assert_cookie_not_present("auth")
|
||||
|
||||
"""
|
||||
Explicitly logging out leaves the user without cookies
|
||||
"""
|
||||
rv = self.app.post(
|
||||
"/api/login",
|
||||
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
|
||||
)
|
||||
assert rv.status.startswith("200")
|
||||
|
||||
rv = self.app.post("/api/logout")
|
||||
assert rv.status.startswith("204")
|
||||
self._assert_cookie_not_present("auth")
|
||||
|
||||
def test_forged_token(self):
|
||||
"""
|
||||
The application needs to verify that tokens that would be theoretically
|
||||
valid are not signed using an unknown key.
|
||||
"""
|
||||
forged_token = jwt.encode(
|
||||
{
|
||||
"exp": datetime.utcnow() + timedelta(hours=24),
|
||||
"priv": {
|
||||
"userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8",
|
||||
"accounts": [
|
||||
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
|
||||
"78403940-ae4f-4aff-a395-1e90f145cf62",
|
||||
],
|
||||
},
|
||||
},
|
||||
FOREIGN_PRIVATE_KEY,
|
||||
algorithm="RS256",
|
||||
).decode()
|
||||
|
||||
self.app.set_cookie("localhost", "auth", forged_token)
|
||||
rv = self.app.get("/api/login")
|
||||
assert rv.status.startswith("401")
|
@ -1,21 +0,0 @@
|
||||
import unittest
|
||||
import json
|
||||
|
||||
from accounts import app
|
||||
|
||||
|
||||
class TestJWT(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.app = app.test_client()
|
||||
|
||||
def test_jwt_flow(self):
|
||||
rv = self.app.get("/api/login")
|
||||
assert rv.status.startswith("401")
|
||||
|
||||
rv = self.app.post(
|
||||
"/api/login", data=json.dumps({"username": "offen", "password": "develop"})
|
||||
)
|
||||
assert rv.status.startswith("200")
|
||||
|
||||
rv = self.app.get("/api/login")
|
||||
assert rv.status.startswith("200")
|
@ -24,6 +24,14 @@ services:
|
||||
environment:
|
||||
POSTGRES_PASSWORD: develop
|
||||
|
||||
accounts_database:
|
||||
image: mysql:5.7
|
||||
ports:
|
||||
- "3306:3306"
|
||||
environment:
|
||||
MYSQL_DATABASE: mysql
|
||||
MYSQL_ROOT_PASSWORD: develop
|
||||
|
||||
server:
|
||||
build:
|
||||
context: '.'
|
||||
@ -31,6 +39,7 @@ services:
|
||||
working_dir: /offen/server
|
||||
volumes:
|
||||
- .:/offen
|
||||
- ./bootstrap.yml:/offen/server/bootstrap.yml
|
||||
- serverdeps:/go/pkg/mod
|
||||
environment:
|
||||
POSTGRES_CONNECTION_STRING: postgres://postgres:develop@server_database:5432/postgres?sslmode=disable
|
||||
@ -100,16 +109,20 @@ services:
|
||||
working_dir: /offen/accounts
|
||||
volumes:
|
||||
- .:/offen
|
||||
- ./bootstrap.yml:/offen/accounts/bootstrap.yml
|
||||
- accountdeps:/root/.local
|
||||
command: flask run --host 0.0.0.0
|
||||
ports:
|
||||
- 5000:5000
|
||||
links:
|
||||
- accounts_database
|
||||
environment:
|
||||
FLASK_APP: accounts
|
||||
FLASK_APP: accounts:app
|
||||
FLASK_ENV: development
|
||||
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql
|
||||
CORS_ORIGIN: http://localhost:9977
|
||||
# local password is `develop`
|
||||
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp
|
||||
SERVER_HOST: http://server:8080
|
||||
SESSION_SECRET: vndJRFJTiyjfgtTF
|
||||
JWT_PRIVATE_KEY: |-
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCa6AEl0RUW43YS
|
||||
|
@ -23,12 +23,19 @@ const ClaimsContextKey contextKey = "claims"
|
||||
// JWTProtect uses the public key located at the given URL to check if the
|
||||
// cookie value is signed properly. In case yes, the JWT claims will be added
|
||||
// to the request context
|
||||
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
||||
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authCookie, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
RespondWithJSONError(w, fmt.Errorf("jwt: error reading cookie: %s", err), http.StatusForbidden)
|
||||
var jwtValue string
|
||||
if authCookie, err := r.Cookie(cookieName); err == nil {
|
||||
jwtValue = authCookie.Value
|
||||
} else {
|
||||
if header := r.Header.Get(headerName); header != "" {
|
||||
jwtValue = header
|
||||
}
|
||||
}
|
||||
if jwtValue == "" {
|
||||
RespondWithJSONError(w, errors.New("jwt: could not infer JWT value from cookie or header"), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@ -56,7 +63,7 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
token, jwtErr := jwt.ParseVerify(strings.NewReader(authCookie.Value), jwa.RS256, pubKey)
|
||||
token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
|
||||
if jwtErr != nil {
|
||||
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
|
||||
return
|
||||
@ -67,9 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
privateClaims, _ := token.Get("priv")
|
||||
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims))
|
||||
|
||||
privKey, _ := token.Get("priv")
|
||||
claims, _ := privKey.(map[string]interface{})
|
||||
if err := authorizer(r, claims); err != nil {
|
||||
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package http
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cookie *http.Cookie
|
||||
headers *http.Header
|
||||
server *httptest.Server
|
||||
authorizer func(r *http.Request, claims map[string]interface{}) error
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
"no cookie",
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) {
|
||||
Value: "irrelevantgibberish",
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) {
|
||||
Name: "auth",
|
||||
Value: "irrelevantgibberish",
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("here's some bytes 4 y'all"))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) {
|
||||
Name: "auth",
|
||||
Value: "irrelevantgibberish",
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"key":"not really a key"}`))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) {
|
||||
Name: "auth",
|
||||
Value: "irrelevantgibberish",
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) {
|
||||
Name: "auth",
|
||||
Value: "irrelevantgibberish",
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||
))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) {
|
||||
return string(b)
|
||||
})(),
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||
))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"ok token in headers",
|
||||
nil,
|
||||
(func() *http.Header {
|
||||
token := jwt.New()
|
||||
token.Set("exp", time.Now().Add(time.Hour))
|
||||
keyBytes, _ := pem.Decode([]byte(privateKey))
|
||||
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
|
||||
b, _ := token.Sign(jwa.RS256, privKey)
|
||||
return &http.Header{
|
||||
"X-RPC-Authentication": []string{string(b)},
|
||||
}
|
||||
})(),
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||
))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"bad token in headers",
|
||||
nil,
|
||||
(func() *http.Header {
|
||||
return &http.Header{
|
||||
"X-RPC-Authentication": []string{"nilly willy"},
|
||||
}
|
||||
})(),
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||
))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"authorizer rejects",
|
||||
&http.Cookie{
|
||||
Name: "auth",
|
||||
Value: (func() string {
|
||||
token := jwt.New()
|
||||
token.Set("exp", time.Now().Add(time.Hour))
|
||||
token.Set("priv", map[string]interface{}{"ok": false})
|
||||
keyBytes, _ := pem.Decode([]byte(privateKey))
|
||||
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
|
||||
b, _ := token.Sign(jwa.RS256, privKey)
|
||||
return string(b)
|
||||
})(),
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||
))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error {
|
||||
if claims["ok"] == true {
|
||||
return nil
|
||||
}
|
||||
return errors.New("expected ok to be true")
|
||||
},
|
||||
http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
"valid key, expired token",
|
||||
&http.Cookie{
|
||||
@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) {
|
||||
return string(b)
|
||||
})(),
|
||||
},
|
||||
nil,
|
||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||
))
|
||||
})),
|
||||
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||
http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
@ -174,7 +258,7 @@ func TestJWTProtect(t *testing.T) {
|
||||
if test.server != nil {
|
||||
url = test.server.URL
|
||||
}
|
||||
wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
w := httptest.NewRecorder()
|
||||
@ -182,6 +266,11 @@ func TestJWTProtect(t *testing.T) {
|
||||
if test.cookie != nil {
|
||||
r.AddCookie(test.cookie)
|
||||
}
|
||||
if test.headers != nil {
|
||||
for key, value := range *test.headers {
|
||||
r.Header.Add(key, value[0])
|
||||
}
|
||||
}
|
||||
wrappedHandler.ServeHTTP(w, r)
|
||||
if w.Code != test.expectedStatusCode {
|
||||
t.Errorf("Unexpected status code %v", w.Code)
|
||||
|
Loading…
Reference in New Issue
Block a user