2
0
mirror of https://github.com/offen/website.git synced 2024-11-22 17:10:29 +01:00

Merge pull request #59 from offen/flask-admin

Add auditorium users and db backed accounts
This commit is contained in:
Frederik Ring 2019-07-17 10:14:28 +02:00 committed by GitHub
commit d044d95b5b
21 changed files with 798 additions and 93 deletions

View File

@ -87,7 +87,9 @@ jobs:
echo Failed waiting for Postgres && exit 1 echo Failed waiting for Postgres && exit 1
- run: - run:
name: Run tests name: Run tests
command: make test-ci command: |
cp ~/offen/bootstrap.yml .
make test-ci
shared: shared:
docker: docker:
@ -196,7 +198,7 @@ jobs:
docker: docker:
- image: circleci/python:3.6 - image: circleci/python:3.6
environment: environment:
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
JWT_PRIVATE_KEY: |- JWT_PRIVATE_KEY: |-
-----BEGIN PRIVATE KEY----- -----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzgU18PnRrpbVK MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzgU18PnRrpbVK
@ -236,7 +238,11 @@ jobs:
a3B4L0waKzP5QWcO865n1HCUTnV+s4lNcphBDZCrSwTkXnVnQWVPCL7ssoQyM0u3 a3B4L0waKzP5QWcO865n1HCUTnV+s4lNcphBDZCrSwTkXnVnQWVPCL7ssoQyM0u3
HQIDAQAB HQIDAQAB
-----END PUBLIC KEY----- -----END PUBLIC KEY-----
- image: circleci/mysql:5.7
environment:
- MYSQL_ROOT_PASSWORD=circle
- MYSQL_DATABASE=circle
- MYSQL_HOST=127.0.0.1
working_directory: ~/offen/accounts working_directory: ~/offen/accounts
steps: steps:
- checkout: - checkout:
@ -254,11 +260,22 @@ jobs:
paths: paths:
- ~/offen/accounts/venv - ~/offen/accounts/venv
key: offen-accounts-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} key: offen-accounts-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- run:
name: Waiting for MySQL to be ready
command: |
for i in `seq 1 10`;
do
nc -z localhost 3306 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run: - run:
name: Run tests name: Run tests
command: | command: |
. venv/bin/activate . venv/bin/activate
make cp ~/offen/bootstrap.yml .
make test-ci
deploy_python: deploy_python:
docker: docker:

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ package-lock.json
# mkcert certificates # mkcert certificates
*.pem *.pem
venv/ venv/
bootstrap-alpha.yml

View File

@ -20,5 +20,6 @@ setup:
bootstrap: bootstrap:
@docker-compose run kms make bootstrap @docker-compose run kms make bootstrap
@docker-compose run server make bootstrap @docker-compose run server make bootstrap
@docker-compose run accounts make bootstrap
.PHONY: setup bootstrap .PHONY: setup bootstrap

View File

@ -1,7 +1,13 @@
test: test:
@pytest --disable-pytest-warnings @pytest --disable-pytest-warnings
test-ci: bootstrap
@pytest --disable-pytest-warnings
fmt: fmt:
@black . @black .
.PHONY: test fmt bootstrap:
@python -m scripts.bootstrap
.PHONY: test fmt bootstrap

View File

@ -1,5 +1,21 @@
from os import environ
from flask import Flask from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_admin import Admin
app = Flask(__name__) app = Flask(__name__)
app.secret_key = environ.get("SESSION_SECRET")
app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING")
db = SQLAlchemy(app)
import accounts.views from accounts.models import Account, User
from accounts.views import AccountView, UserView
import accounts.api
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
admin = Admin(app, name="offen admin", template_mode="bootstrap3")
admin.add_view(AccountView(Account, db.session))
admin.add_view(UserView(User, db.session))

124
accounts/accounts/api.py Normal file
View File

@ -0,0 +1,124 @@
from datetime import datetime, timedelta
from os import environ
from functools import wraps
from flask import jsonify, make_response, request
from flask_cors import cross_origin
from passlib.hash import bcrypt
import jwt
from accounts import app
from accounts.models import User
COOKIE_KEY = "auth"
def json_error(handler):
@wraps(handler)
def wrapped_handler(*args, **kwargs):
try:
return handler(*args, **kwargs)
except Exception as server_error:
return (
jsonify(
{
"error": "Internal server error: {}".format(str(server_error)),
"status": 500,
}
),
500,
)
return wrapped_handler
class UnauthorizedError(Exception):
pass
@app.route("/api/login", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
def post_login():
credentials = request.get_json(force=True)
try:
match = User.query.filter_by(email=credentials["username"]).first()
if not match:
raise UnauthorizedError("bad username")
if not bcrypt.verify(credentials["password"], match.hashed_password):
raise UnauthorizedError("bad password")
except UnauthorizedError as unauthorized_error:
resp = make_response(jsonify({"error": str(unauthorized_error), "status": 401}))
resp.set_cookie(COOKIE_KEY, "", expires=0)
resp.status_code = 401
return resp
private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(hours=24)
encoded = jwt.encode(
{
"exp": expiry,
"priv": {
"userId": match.user_id,
"accounts": [a.account_id for a in match.accounts],
},
},
private_key.encode(),
algorithm="RS256",
).decode()
resp = make_response(jsonify({"user": match.serialize()}))
resp.set_cookie(
COOKIE_KEY,
encoded,
httponly=True,
expires=expiry,
path="/",
domain=environ.get("COOKIE_DOMAIN"),
samesite="strict",
)
return resp
@app.route("/api/login", methods=["GET"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
def get_login():
auth_cookie = request.cookies.get(COOKIE_KEY)
public_key = environ.get("JWT_PUBLIC_KEY", "")
try:
token = jwt.decode(auth_cookie, public_key)
except jwt.exceptions.PyJWTError as unauthorized_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
try:
match = User.query.get(token["priv"]["userId"])
except KeyError as key_err:
return (
jsonify(
{"error": "malformed JWT claims: {}".format(key_err), "status": 401}
),
401,
)
return jsonify({"user": match.serialize()})
@app.route("/api/logout", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
def post_logout():
resp = make_response("")
resp.set_cookie(COOKIE_KEY, "", expires=0)
resp.status_code = 204
return resp
@app.route("/api/key", methods=["GET"])
@json_error
def key():
"""
This route is not supposed to be called by client-side applications, so
no CORS configuration is added
"""
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
return jsonify({"key": public_key})

View File

@ -0,0 +1,49 @@
from uuid import uuid4
from accounts import db
def generate_key():
return str(uuid4())
class Account(db.Model):
__tablename__ = "accounts"
account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
name = db.Column(db.Text, nullable=False)
users = db.relationship("AccountUserAssociation", back_populates="account")
def __repr__(self):
return self.name
class User(db.Model):
__tablename__ = "users"
user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
email = db.Column(db.String(128), nullable=False, unique=True)
hashed_password = db.Column(db.Text, nullable=False)
accounts = db.relationship(
"AccountUserAssociation", back_populates="user", lazy="joined"
)
def serialize(self):
associated_accounts = [a.account_id for a in self.accounts]
records = [
{"name": a.name, "accountId": a.account_id}
for a in Account.query.filter(Account.account_id.in_(associated_accounts))
]
return {"userId": self.user_id, "email": self.email, "accounts": records}
class AccountUserAssociation(db.Model):
__tablename__ = "account_to_user"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.String(36), db.ForeignKey("users.user_id"), nullable=False)
account_id = db.Column(
db.String(36), db.ForeignKey("accounts.account_id"), nullable=False
)
user = db.relationship("User", back_populates="accounts")
account = db.relationship("Account", back_populates="users")

View File

@ -1,69 +1,104 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from os import environ from os import environ
import base64
from flask import jsonify, render_template, make_response, request import requests
from flask_cors import cross_origin from flask_admin.contrib.sqla import ModelView
from wtforms import PasswordField, StringField, Form
from wtforms.validators import InputRequired, EqualTo
from passlib.hash import bcrypt from passlib.hash import bcrypt
import jwt import jwt
from accounts import app from accounts import db
from accounts.models import AccountUserAssociation
@app.route("/")
def home():
return render_template("index.html")
@app.route("/api/login", methods=["POST"]) class RemoteServerException(Exception):
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True) status = 0
def post_login():
credentials = request.get_json(force=True)
if credentials["username"] != environ.get("USER", "offen"): def __str__(self):
return jsonify({"error": "bad username", "status": 401}), 401 return "Status {}: {}".format(
self.status, super(RemoteServerException, self).__str__()
hashed_password = base64.standard_b64decode(environ.get("HASHED_PASSWORD", ""))
if not bcrypt.verify(credentials["password"], hashed_password):
return jsonify({"error": "bad password", "status": 401}), 401
private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(hours=24)
try:
encoded = jwt.encode(
{"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256"
).decode("utf-8")
except jwt.exceptions.PyJWTError as encode_error:
return jsonify({"error": str(encode_error), "status": 500}), 500
resp = make_response(jsonify({"ok": True}))
resp.set_cookie(
"auth",
encoded,
httponly=True,
expires=expiry,
path="/",
domain=environ.get("COOKIE_DOMAIN"),
samesite="strict"
) )
return resp
@app.route("/api/login", methods=["GET"]) def create_remote_account(name, account_id):
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True) private_key = environ.get("JWT_PRIVATE_KEY", "")
def get_login(): # expires in 30 seconds as this will mean the HTTP request would have
auth_cookie = request.cookies.get("auth") # timed out anyways
public_key = environ.get("JWT_PUBLIC_KEY", "") expiry = datetime.utcnow() + timedelta(seconds=30)
encoded = jwt.encode(
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
private_key.encode(),
algorithm="RS256",
).decode("utf-8")
r = requests.post(
"{}/accounts".format(environ.get("SERVER_HOST")),
json={"name": name, "accountId": account_id},
headers={"X-RPC-Authentication": encoded},
)
if r.status_code > 299:
err = r.json()
remote_err = RemoteServerException(err["error"])
remote_err.status = err["status"]
raise remote_err
class AccountForm(Form):
name = StringField(
"Account Name",
validators=[InputRequired()],
description="This is the account name visible to users",
)
class AccountView(ModelView):
form = AccountForm
column_display_all_relations = True
column_list = ("name", "account_id")
def after_model_change(self, form, model, is_created):
if is_created:
try: try:
jwt.decode(auth_cookie, public_key) create_remote_account(model.name, model.account_id)
except jwt.exceptions.PyJWTError as unauthorized_error: except RemoteServerException as server_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401 db.session.delete(model)
db.session.commit()
return jsonify({"ok": True}) raise server_error
# This route is not supposed to be called by client-side applications, so class UserView(ModelView):
# no CORS configuration is added inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
@app.route("/api/key", methods=["GET"]) column_auto_select_related = True
def key(): column_display_all_relations = True
public_key = environ.get("JWT_PUBLIC_KEY", "").strip() column_list = ("email", "user_id")
return jsonify({"key": public_key}) form_columns = ("email", "accounts")
form_create_rules = ("email", "password", "confirm", "accounts")
form_edit_rules = ("email", "password", "confirm", "accounts")
def on_model_change(self, form, model, is_created):
if form.password.data:
model.hashed_password = bcrypt.hash(form.password.data)
def get_create_form(self):
form = super(UserView, self).get_create_form()
form.password = PasswordField(
"Password",
validators=[
InputRequired(),
EqualTo("confirm", message="Passwords must match"),
],
)
form.confirm = PasswordField("Repeat Password", validators=[InputRequired()])
return form
def get_edit_form(self):
form = super(UserView, self).get_edit_form()
form.password = PasswordField(
"Password",
description="When left blank, the password will remain unchanged on update",
validators=[EqualTo("confirm", message="Passwords must match")],
)
form.confirm = PasswordField("Repeat Password", validators=[])
return form

View File

@ -0,0 +1,60 @@
import base64
from os import environ
from passlib.hash import bcrypt
def build_api_arn(method_arn):
arn_chunks = method_arn.split(":")
aws_region = arn_chunks[3]
aws_account_id = arn_chunks[4]
gateway_arn_chunks = arn_chunks[5].split("/")
rest_api_id = gateway_arn_chunks[0]
stage = gateway_arn_chunks[1]
return "arn:aws:execute-api:{}:{}:{}/{}/*/*".format(
aws_region, aws_account_id, rest_api_id, stage
)
def build_response(api_arn, allow):
effect = "Deny"
if allow:
effect = "Allow"
return {
"principalId": "offen",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": ["execute-api:Invoke"],
"Effect": effect,
"Resource": [api_arn],
}
],
},
}
def handler(event, context):
api_arn = build_api_arn(event["methodArn"])
encoded_auth = event["authorizationToken"].lstrip("Basic ")
auth_string = base64.standard_b64decode(encoded_auth).decode()
if not auth_string:
return build_response(api_arn, False)
credentials = auth_string.split(":")
user = credentials[0]
password = credentials[1]
if user != environ.get("BASIC_AUTH_USER"):
return build_response(api_arn, False)
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
if not bcrypt.verify(password, hashed_password):
return build_response(api_arn, False)
return build_response(api_arn, True)

0
accounts/bootstrap.yml Executable file
View File

View File

@ -1,2 +1,3 @@
pytest pytest
black black
pyyaml

View File

@ -1,6 +1,11 @@
Flask==1.0.2 Flask==1.0.2
Flask-Admin==1.5.3
Flask-Cors==3.0.8 Flask-Cors==3.0.8
Flask-SQLAlchemy==2.4.0
werkzeug==0.15.4 werkzeug==0.15.4
pyjwt[crypto]==1.7.1 pyjwt[crypto]==1.7.1
passlib==1.7.1 passlib==1.7.1
bcrypt==3.1.7 bcrypt==3.1.7
PyMySQL==0.9.3
mysqlclient==1.4.2.post1
requests==2.22.0

View File

32
accounts/scripts/bootstrap.py Executable file
View File

@ -0,0 +1,32 @@
import yaml
from passlib.hash import bcrypt
from accounts import db
from accounts.models import Account, User, AccountUserAssociation
if __name__ == "__main__":
db.drop_all()
db.create_all()
with open("./bootstrap.yml", "r") as stream:
data = yaml.safe_load(stream)
for account in data["accounts"]:
record = Account(
name=account["name"],
account_id=account["id"],
)
db.session.add(record)
for user in data["users"]:
record = User(
email=user["email"],
hashed_password=bcrypt.hash(user["password"]),
)
for account_id in user["accounts"]:
record.accounts.append(AccountUserAssociation(account_id=account_id))
db.session.add(record)
db.session.commit()
print("Successfully bootstrapped accounts database")

20
accounts/scripts/hash.py Normal file
View File

@ -0,0 +1,20 @@
import base64
import argparse
from passlib.hash import bcrypt
parser = argparse.ArgumentParser()
parser.add_argument("--password", type=str, help="The password to hash", required=True)
parser.add_argument(
"--plain",
help="Do not encode the result as base64",
default=False,
action="store_true",
)
if __name__ == "__main__":
args = parser.parse_args()
out = bcrypt.hash(args.password)
if not args.plain:
out = base64.standard_b64encode(out.encode()).decode()
print(out)

View File

@ -26,6 +26,10 @@ custom:
production: vault.offen.dev production: vault.offen.dev
staging: vault-staging.offen.dev staging: vault-staging.offen.dev
alpha: vault-alpha.offen.dev alpha: vault-alpha.offen.dev
serverHost:
production: server.offen.dev
staging: server-staging.offen.dev
alpha: server-alpha.offen.dev
domain: domain:
production: accounts.offen.dev production: accounts.offen.dev
staging: accounts-staging.offen.dev staging: accounts-staging.offen.dev
@ -50,19 +54,53 @@ custom:
fileName: requirements.txt fileName: requirements.txt
functions: functions:
authorizer:
handler: authorizer.handler
environment:
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
app: app:
handler: wsgi_handler.handler handler: wsgi_handler.handler
events: events:
- http:
path: /admin/
method: any
authorizer:
name: authorizer
resultTtlInSeconds: 0
identitySource: method.request.header.Authorization
- http:
path: /admin/{proxy+}
method: any
authorizer:
name: authorizer
resultTtlInSeconds: 0
identitySource: method.request.header.Authorization
- http: - http:
path: '/' path: '/'
method: any method: any
- http: - http:
path: '{proxy+}' path: '/{proxy+}'
method: any method: any
environment: environment:
USER: offen
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}} CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}} COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
SERVER_URL: ${self:custom.serverHost.${self:custom.stage}}
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}' JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}' JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
HASHED_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true} BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
resources:
Resources:
GatewayResponse:
Type: 'AWS::ApiGateway::GatewayResponse'
Properties:
ResponseParameters:
gatewayresponse.header.WWW-Authenticate: "'Basic'"
ResponseType: UNAUTHORIZED
RestApiId:
Ref: 'ApiGatewayRestApi'
StatusCode: '401'

207
accounts/tests/test_api.py Normal file
View File

@ -0,0 +1,207 @@
import unittest
import json
import base64
from json import loads
from time import time
from datetime import datetime, timedelta
from os import environ
import jwt
from accounts import app
FOREIGN_PRIVATE_KEY = """
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG
N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P
nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj
ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN
6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4
Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i
6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl
wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN
kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn
f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761
5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI
TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3
llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G
rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0
lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS
NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU
zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE
L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz
uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ
WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f
qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE
oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm
ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d
rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO
4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF
15NFZ2gU4rKnQ3sVAOzMoEw5
-----END PRIVATE KEY-----
"""
def _pad_b64_string(s):
while len(s) % 4 is not 0:
s = s + "="
return s
class TestKey(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def test_get_key(self):
rv = self.app.get("/api/key")
assert rv.status.startswith("200")
data = loads(rv.data)
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
class TestJWT(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def _assert_cookie_present(self, name):
for cookie in self.app.cookie_jar:
if cookie.name == name:
return cookie.value
raise AssertionError("Cookie named {} not found".format(name))
def _assert_cookie_not_present(self, name):
for cookie in self.app.cookie_jar:
assert cookie.name != name
def test_jwt_flow(self):
"""
First, try login attempts that are supposed to fail:
1. checking login status without any prior interaction
2. try logging in with an unknown user
3. try logging in with a known user and bad password
"""
rv = self.app.get("/api/login")
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
rv = self.app.post(
"/api/login",
data=json.dumps(
{"username": "does@not.exist", "password": "somethingsomething"}
),
)
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
rv = self.app.post(
"/api/login",
data=json.dumps({"username": "develop@offen.dev", "password": "developp"}),
)
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
"""
Next, perform a successful login
"""
rv = self.app.post(
"/api/login",
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
)
assert rv.status.startswith("200")
"""
The response should contain information about the
user and full information (i.e. a name) about the associated accounts
"""
data = json.loads(rv.data)
assert data["user"]["userId"] is not None
data["user"]["accounts"].sort(key=lambda a: a["name"])
self.assertListEqual(
data["user"]["accounts"],
[
{"name": "One", "accountId": "9b63c4d8-65c0-438c-9d30-cc4b01173393"},
{"name": "Two", "accountId": "78403940-ae4f-4aff-a395-1e90f145cf62"},
],
)
"""
The claims part of the JWT is expected to contain a valid expiry,
information about the user and the associated account ids.
"""
jwt = self._assert_cookie_present("auth")
# PyJWT strips the padding from the base64 encoded parts which Python
# cannot decode properly, so we need to add the padding ourselves
claims_part = _pad_b64_string(jwt.split(".")[1])
claims = loads(base64.b64decode(claims_part))
assert claims.get("exp") > time()
priv = claims.get("priv")
assert priv is not None
assert priv.get("userId") is not None
self.assertListEqual(
priv["accounts"],
[
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
"78403940-ae4f-4aff-a395-1e90f145cf62",
],
)
"""
Checking the login status when re-using the cookie should yield
a successful response
"""
rv = self.app.get("/api/login")
assert rv.status.startswith("200")
jwt2 = self._assert_cookie_present("auth")
assert jwt2 == jwt
"""
Performing a bad login attempt when sending a valid auth cookie
is expected to destroy the cookie and leave the user logged out again
"""
rv = self.app.post(
"/api/login",
data=json.dumps(
{"username": "evil@session.takeover", "password": "develop"}
),
)
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
"""
Explicitly logging out leaves the user without cookies
"""
rv = self.app.post(
"/api/login",
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
)
assert rv.status.startswith("200")
rv = self.app.post("/api/logout")
assert rv.status.startswith("204")
self._assert_cookie_not_present("auth")
def test_forged_token(self):
"""
The application needs to verify that tokens that would be theoretically
valid are not signed using an unknown key.
"""
forged_token = jwt.encode(
{
"exp": datetime.utcnow() + timedelta(hours=24),
"priv": {
"userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8",
"accounts": [
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
"78403940-ae4f-4aff-a395-1e90f145cf62",
],
},
},
FOREIGN_PRIVATE_KEY,
algorithm="RS256",
).decode()
self.app.set_cookie("localhost", "auth", forged_token)
rv = self.app.get("/api/login")
assert rv.status.startswith("401")

View File

@ -1,21 +0,0 @@
import unittest
import json
from accounts import app
class TestJWT(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def test_jwt_flow(self):
rv = self.app.get("/api/login")
assert rv.status.startswith("401")
rv = self.app.post(
"/api/login", data=json.dumps({"username": "offen", "password": "develop"})
)
assert rv.status.startswith("200")
rv = self.app.get("/api/login")
assert rv.status.startswith("200")

View File

@ -24,6 +24,14 @@ services:
environment: environment:
POSTGRES_PASSWORD: develop POSTGRES_PASSWORD: develop
accounts_database:
image: mysql:5.7
ports:
- "3306:3306"
environment:
MYSQL_DATABASE: mysql
MYSQL_ROOT_PASSWORD: develop
server: server:
build: build:
context: '.' context: '.'
@ -31,6 +39,7 @@ services:
working_dir: /offen/server working_dir: /offen/server
volumes: volumes:
- .:/offen - .:/offen
- ./bootstrap.yml:/offen/server/bootstrap.yml
- serverdeps:/go/pkg/mod - serverdeps:/go/pkg/mod
environment: environment:
POSTGRES_CONNECTION_STRING: postgres://postgres:develop@server_database:5432/postgres?sslmode=disable POSTGRES_CONNECTION_STRING: postgres://postgres:develop@server_database:5432/postgres?sslmode=disable
@ -100,16 +109,20 @@ services:
working_dir: /offen/accounts working_dir: /offen/accounts
volumes: volumes:
- .:/offen - .:/offen
- ./bootstrap.yml:/offen/accounts/bootstrap.yml
- accountdeps:/root/.local - accountdeps:/root/.local
command: flask run --host 0.0.0.0 command: flask run --host 0.0.0.0
ports: ports:
- 5000:5000 - 5000:5000
links:
- accounts_database
environment: environment:
FLASK_APP: accounts FLASK_APP: accounts:app
FLASK_ENV: development FLASK_ENV: development
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql
CORS_ORIGIN: http://localhost:9977 CORS_ORIGIN: http://localhost:9977
# local password is `develop` SERVER_HOST: http://server:8080
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp SESSION_SECRET: vndJRFJTiyjfgtTF
JWT_PRIVATE_KEY: |- JWT_PRIVATE_KEY: |-
-----BEGIN PRIVATE KEY----- -----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCa6AEl0RUW43YS MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCa6AEl0RUW43YS

View File

@ -23,12 +23,19 @@ const ClaimsContextKey contextKey = "claims"
// JWTProtect uses the public key located at the given URL to check if the // JWTProtect uses the public key located at the given URL to check if the
// cookie value is signed properly. In case yes, the JWT claims will be added // cookie value is signed properly. In case yes, the JWT claims will be added
// to the request context // to the request context
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler { func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authCookie, err := r.Cookie(cookieName) var jwtValue string
if err != nil { if authCookie, err := r.Cookie(cookieName); err == nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error reading cookie: %s", err), http.StatusForbidden) jwtValue = authCookie.Value
} else {
if header := r.Header.Get(headerName); header != "" {
jwtValue = header
}
}
if jwtValue == "" {
RespondWithJSONError(w, errors.New("jwt: could not infer JWT value from cookie or header"), http.StatusForbidden)
return return
} }
@ -56,7 +63,7 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
return return
} }
token, jwtErr := jwt.ParseVerify(strings.NewReader(authCookie.Value), jwa.RS256, pubKey) token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
if jwtErr != nil { if jwtErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden) RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
return return
@ -67,9 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
return return
} }
privateClaims, _ := token.Get("priv") privKey, _ := token.Get("priv")
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims)) claims, _ := privKey.(map[string]interface{})
if err := authorizer(r, claims); err != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
return
}
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }

View File

@ -3,6 +3,7 @@ package http
import ( import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cookie *http.Cookie cookie *http.Cookie
headers *http.Header
server *httptest.Server server *httptest.Server
authorizer func(r *http.Request, claims map[string]interface{}) error
expectedStatusCode int expectedStatusCode int
}{ }{
{ {
"no cookie", "no cookie",
nil, nil,
nil, nil,
nil,
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden, http.StatusForbidden,
}, },
{ {
@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) {
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil, nil,
nil,
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusInternalServerError,
}, },
{ {
@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("here's some bytes 4 y'all")) w.Write([]byte("here's some bytes 4 y'all"))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusInternalServerError,
}, },
{ {
@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"not really a key"}`)) w.Write([]byte(`{"key":"not really a key"}`))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusInternalServerError,
}, },
{ {
@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`)) w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusBadGateway, http.StatusBadGateway,
}, },
{ {
@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden, http.StatusForbidden,
}, },
{ {
@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) {
return string(b) return string(b)
})(), })(),
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusOK, http.StatusOK,
}, },
{
"ok token in headers",
nil,
(func() *http.Header {
token := jwt.New()
token.Set("exp", time.Now().Add(time.Hour))
keyBytes, _ := pem.Decode([]byte(privateKey))
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
b, _ := token.Sign(jwa.RS256, privKey)
return &http.Header{
"X-RPC-Authentication": []string{string(b)},
}
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusOK,
},
{
"bad token in headers",
nil,
(func() *http.Header {
return &http.Header{
"X-RPC-Authentication": []string{"nilly willy"},
}
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden,
},
{
"authorizer rejects",
&http.Cookie{
Name: "auth",
Value: (func() string {
token := jwt.New()
token.Set("exp", time.Now().Add(time.Hour))
token.Set("priv", map[string]interface{}{"ok": false})
keyBytes, _ := pem.Decode([]byte(privateKey))
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
b, _ := token.Sign(jwa.RS256, privKey)
return string(b)
})(),
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error {
if claims["ok"] == true {
return nil
}
return errors.New("expected ok to be true")
},
http.StatusForbidden,
},
{ {
"valid key, expired token", "valid key, expired token",
&http.Cookie{ &http.Cookie{
@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) {
return string(b) return string(b)
})(), })(),
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden, http.StatusForbidden,
}, },
} }
@ -174,7 +258,7 @@ func TestJWTProtect(t *testing.T) {
if test.server != nil { if test.server != nil {
url = test.server.URL url = test.server.URL
} }
wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK")) w.Write([]byte("OK"))
})) }))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -182,6 +266,11 @@ func TestJWTProtect(t *testing.T) {
if test.cookie != nil { if test.cookie != nil {
r.AddCookie(test.cookie) r.AddCookie(test.cookie)
} }
if test.headers != nil {
for key, value := range *test.headers {
r.Header.Add(key, value[0])
}
}
wrappedHandler.ServeHTTP(w, r) wrappedHandler.ServeHTTP(w, r)
if w.Code != test.expectedStatusCode { if w.Code != test.expectedStatusCode {
t.Errorf("Unexpected status code %v", w.Code) t.Errorf("Unexpected status code %v", w.Code)