2
0
mirror of https://github.com/offen/website.git synced 2024-12-23 13:30:20 +01:00

Merge pull request #59 from offen/flask-admin

Add auditorium users and db backed accounts
This commit is contained in:
Frederik Ring 2019-07-17 10:14:28 +02:00 committed by GitHub
commit d044d95b5b
21 changed files with 798 additions and 93 deletions

View File

@ -87,7 +87,9 @@ jobs:
echo Failed waiting for Postgres && exit 1
- run:
name: Run tests
command: make test-ci
command: |
cp ~/offen/bootstrap.yml .
make test-ci
shared:
docker:
@ -196,7 +198,7 @@ jobs:
docker:
- image: circleci/python:3.6
environment:
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp
MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
JWT_PRIVATE_KEY: |-
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzgU18PnRrpbVK
@ -236,7 +238,11 @@ jobs:
a3B4L0waKzP5QWcO865n1HCUTnV+s4lNcphBDZCrSwTkXnVnQWVPCL7ssoQyM0u3
HQIDAQAB
-----END PUBLIC KEY-----
- image: circleci/mysql:5.7
environment:
- MYSQL_ROOT_PASSWORD=circle
- MYSQL_DATABASE=circle
- MYSQL_HOST=127.0.0.1
working_directory: ~/offen/accounts
steps:
- checkout:
@ -254,11 +260,22 @@ jobs:
paths:
- ~/offen/accounts/venv
key: offen-accounts-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- run:
name: Waiting for MySQL to be ready
command: |
for i in `seq 1 10`;
do
nc -z localhost 3306 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for MySQL && exit 1
- run:
name: Run tests
command: |
. venv/bin/activate
make
cp ~/offen/bootstrap.yml .
make test-ci
deploy_python:
docker:

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ package-lock.json
# mkcert certificates
*.pem
venv/
bootstrap-alpha.yml

View File

@ -20,5 +20,6 @@ setup:
bootstrap:
@docker-compose run kms make bootstrap
@docker-compose run server make bootstrap
@docker-compose run accounts make bootstrap
.PHONY: setup bootstrap

View File

@ -1,7 +1,13 @@
test:
@pytest --disable-pytest-warnings
test-ci: bootstrap
@pytest --disable-pytest-warnings
fmt:
@black .
.PHONY: test fmt
bootstrap:
@python -m scripts.bootstrap
.PHONY: test fmt bootstrap

View File

@ -1,5 +1,21 @@
from os import environ
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_admin import Admin
app = Flask(__name__)
app.secret_key = environ.get("SESSION_SECRET")
app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING")
db = SQLAlchemy(app)
import accounts.views
from accounts.models import Account, User
from accounts.views import AccountView, UserView
import accounts.api
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
admin = Admin(app, name="offen admin", template_mode="bootstrap3")
admin.add_view(AccountView(Account, db.session))
admin.add_view(UserView(User, db.session))

124
accounts/accounts/api.py Normal file
View File

@ -0,0 +1,124 @@
from datetime import datetime, timedelta
from os import environ
from functools import wraps
from flask import jsonify, make_response, request
from flask_cors import cross_origin
from passlib.hash import bcrypt
import jwt
from accounts import app
from accounts.models import User
COOKIE_KEY = "auth"
def json_error(handler):
@wraps(handler)
def wrapped_handler(*args, **kwargs):
try:
return handler(*args, **kwargs)
except Exception as server_error:
return (
jsonify(
{
"error": "Internal server error: {}".format(str(server_error)),
"status": 500,
}
),
500,
)
return wrapped_handler
class UnauthorizedError(Exception):
pass
@app.route("/api/login", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
def post_login():
credentials = request.get_json(force=True)
try:
match = User.query.filter_by(email=credentials["username"]).first()
if not match:
raise UnauthorizedError("bad username")
if not bcrypt.verify(credentials["password"], match.hashed_password):
raise UnauthorizedError("bad password")
except UnauthorizedError as unauthorized_error:
resp = make_response(jsonify({"error": str(unauthorized_error), "status": 401}))
resp.set_cookie(COOKIE_KEY, "", expires=0)
resp.status_code = 401
return resp
private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(hours=24)
encoded = jwt.encode(
{
"exp": expiry,
"priv": {
"userId": match.user_id,
"accounts": [a.account_id for a in match.accounts],
},
},
private_key.encode(),
algorithm="RS256",
).decode()
resp = make_response(jsonify({"user": match.serialize()}))
resp.set_cookie(
COOKIE_KEY,
encoded,
httponly=True,
expires=expiry,
path="/",
domain=environ.get("COOKIE_DOMAIN"),
samesite="strict",
)
return resp
@app.route("/api/login", methods=["GET"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
def get_login():
auth_cookie = request.cookies.get(COOKIE_KEY)
public_key = environ.get("JWT_PUBLIC_KEY", "")
try:
token = jwt.decode(auth_cookie, public_key)
except jwt.exceptions.PyJWTError as unauthorized_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
try:
match = User.query.get(token["priv"]["userId"])
except KeyError as key_err:
return (
jsonify(
{"error": "malformed JWT claims: {}".format(key_err), "status": 401}
),
401,
)
return jsonify({"user": match.serialize()})
@app.route("/api/logout", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
def post_logout():
resp = make_response("")
resp.set_cookie(COOKIE_KEY, "", expires=0)
resp.status_code = 204
return resp
@app.route("/api/key", methods=["GET"])
@json_error
def key():
"""
This route is not supposed to be called by client-side applications, so
no CORS configuration is added
"""
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
return jsonify({"key": public_key})

View File

@ -0,0 +1,49 @@
from uuid import uuid4
from accounts import db
def generate_key():
return str(uuid4())
class Account(db.Model):
__tablename__ = "accounts"
account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
name = db.Column(db.Text, nullable=False)
users = db.relationship("AccountUserAssociation", back_populates="account")
def __repr__(self):
return self.name
class User(db.Model):
__tablename__ = "users"
user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
email = db.Column(db.String(128), nullable=False, unique=True)
hashed_password = db.Column(db.Text, nullable=False)
accounts = db.relationship(
"AccountUserAssociation", back_populates="user", lazy="joined"
)
def serialize(self):
associated_accounts = [a.account_id for a in self.accounts]
records = [
{"name": a.name, "accountId": a.account_id}
for a in Account.query.filter(Account.account_id.in_(associated_accounts))
]
return {"userId": self.user_id, "email": self.email, "accounts": records}
class AccountUserAssociation(db.Model):
__tablename__ = "account_to_user"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.String(36), db.ForeignKey("users.user_id"), nullable=False)
account_id = db.Column(
db.String(36), db.ForeignKey("accounts.account_id"), nullable=False
)
user = db.relationship("User", back_populates="accounts")
account = db.relationship("Account", back_populates="users")

View File

@ -1,69 +1,104 @@
from datetime import datetime, timedelta
from os import environ
import base64
from flask import jsonify, render_template, make_response, request
from flask_cors import cross_origin
import requests
from flask_admin.contrib.sqla import ModelView
from wtforms import PasswordField, StringField, Form
from wtforms.validators import InputRequired, EqualTo
from passlib.hash import bcrypt
import jwt
from accounts import app
@app.route("/")
def home():
return render_template("index.html")
from accounts import db
from accounts.models import AccountUserAssociation
@app.route("/api/login", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
def post_login():
credentials = request.get_json(force=True)
class RemoteServerException(Exception):
status = 0
if credentials["username"] != environ.get("USER", "offen"):
return jsonify({"error": "bad username", "status": 401}), 401
def __str__(self):
return "Status {}: {}".format(
self.status, super(RemoteServerException, self).__str__()
)
hashed_password = base64.standard_b64decode(environ.get("HASHED_PASSWORD", ""))
if not bcrypt.verify(credentials["password"], hashed_password):
return jsonify({"error": "bad password", "status": 401}), 401
def create_remote_account(name, account_id):
private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(hours=24)
try:
encoded = jwt.encode(
{"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256"
).decode("utf-8")
except jwt.exceptions.PyJWTError as encode_error:
return jsonify({"error": str(encode_error), "status": 500}), 500
# expires in 30 seconds as this will mean the HTTP request would have
# timed out anyways
expiry = datetime.utcnow() + timedelta(seconds=30)
encoded = jwt.encode(
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
private_key.encode(),
algorithm="RS256",
).decode("utf-8")
resp = make_response(jsonify({"ok": True}))
resp.set_cookie(
"auth",
encoded,
httponly=True,
expires=expiry,
path="/",
domain=environ.get("COOKIE_DOMAIN"),
samesite="strict"
r = requests.post(
"{}/accounts".format(environ.get("SERVER_HOST")),
json={"name": name, "accountId": account_id},
headers={"X-RPC-Authentication": encoded},
)
return resp
if r.status_code > 299:
err = r.json()
remote_err = RemoteServerException(err["error"])
remote_err.status = err["status"]
raise remote_err
@app.route("/api/login", methods=["GET"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
def get_login():
auth_cookie = request.cookies.get("auth")
public_key = environ.get("JWT_PUBLIC_KEY", "")
try:
jwt.decode(auth_cookie, public_key)
except jwt.exceptions.PyJWTError as unauthorized_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
return jsonify({"ok": True})
class AccountForm(Form):
name = StringField(
"Account Name",
validators=[InputRequired()],
description="This is the account name visible to users",
)
# This route is not supposed to be called by client-side applications, so
# no CORS configuration is added
@app.route("/api/key", methods=["GET"])
def key():
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
return jsonify({"key": public_key})
class AccountView(ModelView):
form = AccountForm
column_display_all_relations = True
column_list = ("name", "account_id")
def after_model_change(self, form, model, is_created):
if is_created:
try:
create_remote_account(model.name, model.account_id)
except RemoteServerException as server_error:
db.session.delete(model)
db.session.commit()
raise server_error
class UserView(ModelView):
inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
column_auto_select_related = True
column_display_all_relations = True
column_list = ("email", "user_id")
form_columns = ("email", "accounts")
form_create_rules = ("email", "password", "confirm", "accounts")
form_edit_rules = ("email", "password", "confirm", "accounts")
def on_model_change(self, form, model, is_created):
if form.password.data:
model.hashed_password = bcrypt.hash(form.password.data)
def get_create_form(self):
form = super(UserView, self).get_create_form()
form.password = PasswordField(
"Password",
validators=[
InputRequired(),
EqualTo("confirm", message="Passwords must match"),
],
)
form.confirm = PasswordField("Repeat Password", validators=[InputRequired()])
return form
def get_edit_form(self):
form = super(UserView, self).get_edit_form()
form.password = PasswordField(
"Password",
description="When left blank, the password will remain unchanged on update",
validators=[EqualTo("confirm", message="Passwords must match")],
)
form.confirm = PasswordField("Repeat Password", validators=[])
return form

View File

@ -0,0 +1,60 @@
import base64
from os import environ
from passlib.hash import bcrypt
def build_api_arn(method_arn):
arn_chunks = method_arn.split(":")
aws_region = arn_chunks[3]
aws_account_id = arn_chunks[4]
gateway_arn_chunks = arn_chunks[5].split("/")
rest_api_id = gateway_arn_chunks[0]
stage = gateway_arn_chunks[1]
return "arn:aws:execute-api:{}:{}:{}/{}/*/*".format(
aws_region, aws_account_id, rest_api_id, stage
)
def build_response(api_arn, allow):
effect = "Deny"
if allow:
effect = "Allow"
return {
"principalId": "offen",
"policyDocument": {
"Version": "2012-10-17",
"Statement": [
{
"Action": ["execute-api:Invoke"],
"Effect": effect,
"Resource": [api_arn],
}
],
},
}
def handler(event, context):
api_arn = build_api_arn(event["methodArn"])
encoded_auth = event["authorizationToken"].lstrip("Basic ")
auth_string = base64.standard_b64decode(encoded_auth).decode()
if not auth_string:
return build_response(api_arn, False)
credentials = auth_string.split(":")
user = credentials[0]
password = credentials[1]
if user != environ.get("BASIC_AUTH_USER"):
return build_response(api_arn, False)
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
if not bcrypt.verify(password, hashed_password):
return build_response(api_arn, False)
return build_response(api_arn, True)

0
accounts/bootstrap.yml Executable file
View File

View File

@ -1,2 +1,3 @@
pytest
black
pyyaml

View File

@ -1,6 +1,11 @@
Flask==1.0.2
Flask-Admin==1.5.3
Flask-Cors==3.0.8
Flask-SQLAlchemy==2.4.0
werkzeug==0.15.4
pyjwt[crypto]==1.7.1
passlib==1.7.1
bcrypt==3.1.7
PyMySQL==0.9.3
mysqlclient==1.4.2.post1
requests==2.22.0

View File

32
accounts/scripts/bootstrap.py Executable file
View File

@ -0,0 +1,32 @@
import yaml
from passlib.hash import bcrypt
from accounts import db
from accounts.models import Account, User, AccountUserAssociation
if __name__ == "__main__":
db.drop_all()
db.create_all()
with open("./bootstrap.yml", "r") as stream:
data = yaml.safe_load(stream)
for account in data["accounts"]:
record = Account(
name=account["name"],
account_id=account["id"],
)
db.session.add(record)
for user in data["users"]:
record = User(
email=user["email"],
hashed_password=bcrypt.hash(user["password"]),
)
for account_id in user["accounts"]:
record.accounts.append(AccountUserAssociation(account_id=account_id))
db.session.add(record)
db.session.commit()
print("Successfully bootstrapped accounts database")

20
accounts/scripts/hash.py Normal file
View File

@ -0,0 +1,20 @@
import base64
import argparse
from passlib.hash import bcrypt
parser = argparse.ArgumentParser()
parser.add_argument("--password", type=str, help="The password to hash", required=True)
parser.add_argument(
"--plain",
help="Do not encode the result as base64",
default=False,
action="store_true",
)
if __name__ == "__main__":
args = parser.parse_args()
out = bcrypt.hash(args.password)
if not args.plain:
out = base64.standard_b64encode(out.encode()).decode()
print(out)

View File

@ -26,6 +26,10 @@ custom:
production: vault.offen.dev
staging: vault-staging.offen.dev
alpha: vault-alpha.offen.dev
serverHost:
production: server.offen.dev
staging: server-staging.offen.dev
alpha: server-alpha.offen.dev
domain:
production: accounts.offen.dev
staging: accounts-staging.offen.dev
@ -50,19 +54,53 @@ custom:
fileName: requirements.txt
functions:
authorizer:
handler: authorizer.handler
environment:
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
app:
handler: wsgi_handler.handler
events:
- http:
path: /admin/
method: any
authorizer:
name: authorizer
resultTtlInSeconds: 0
identitySource: method.request.header.Authorization
- http:
path: /admin/{proxy+}
method: any
authorizer:
name: authorizer
resultTtlInSeconds: 0
identitySource: method.request.header.Authorization
- http:
path: '/'
method: any
- http:
path: '{proxy+}'
path: '/{proxy+}'
method: any
environment:
USER: offen
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
SERVER_URL: ${self:custom.serverHost.${self:custom.stage}}
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
HASHED_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
resources:
Resources:
GatewayResponse:
Type: 'AWS::ApiGateway::GatewayResponse'
Properties:
ResponseParameters:
gatewayresponse.header.WWW-Authenticate: "'Basic'"
ResponseType: UNAUTHORIZED
RestApiId:
Ref: 'ApiGatewayRestApi'
StatusCode: '401'

207
accounts/tests/test_api.py Normal file
View File

@ -0,0 +1,207 @@
import unittest
import json
import base64
from json import loads
from time import time
from datetime import datetime, timedelta
from os import environ
import jwt
from accounts import app
FOREIGN_PRIVATE_KEY = """
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG
N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P
nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj
ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN
6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4
Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i
6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl
wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN
kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn
f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761
5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI
TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3
llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G
rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0
lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS
NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU
zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE
L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz
uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ
WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f
qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE
oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm
ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d
rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO
4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF
15NFZ2gU4rKnQ3sVAOzMoEw5
-----END PRIVATE KEY-----
"""
def _pad_b64_string(s):
while len(s) % 4 is not 0:
s = s + "="
return s
class TestKey(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def test_get_key(self):
rv = self.app.get("/api/key")
assert rv.status.startswith("200")
data = loads(rv.data)
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
class TestJWT(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def _assert_cookie_present(self, name):
for cookie in self.app.cookie_jar:
if cookie.name == name:
return cookie.value
raise AssertionError("Cookie named {} not found".format(name))
def _assert_cookie_not_present(self, name):
for cookie in self.app.cookie_jar:
assert cookie.name != name
def test_jwt_flow(self):
"""
First, try login attempts that are supposed to fail:
1. checking login status without any prior interaction
2. try logging in with an unknown user
3. try logging in with a known user and bad password
"""
rv = self.app.get("/api/login")
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
rv = self.app.post(
"/api/login",
data=json.dumps(
{"username": "does@not.exist", "password": "somethingsomething"}
),
)
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
rv = self.app.post(
"/api/login",
data=json.dumps({"username": "develop@offen.dev", "password": "developp"}),
)
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
"""
Next, perform a successful login
"""
rv = self.app.post(
"/api/login",
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
)
assert rv.status.startswith("200")
"""
The response should contain information about the
user and full information (i.e. a name) about the associated accounts
"""
data = json.loads(rv.data)
assert data["user"]["userId"] is not None
data["user"]["accounts"].sort(key=lambda a: a["name"])
self.assertListEqual(
data["user"]["accounts"],
[
{"name": "One", "accountId": "9b63c4d8-65c0-438c-9d30-cc4b01173393"},
{"name": "Two", "accountId": "78403940-ae4f-4aff-a395-1e90f145cf62"},
],
)
"""
The claims part of the JWT is expected to contain a valid expiry,
information about the user and the associated account ids.
"""
jwt = self._assert_cookie_present("auth")
# PyJWT strips the padding from the base64 encoded parts which Python
# cannot decode properly, so we need to add the padding ourselves
claims_part = _pad_b64_string(jwt.split(".")[1])
claims = loads(base64.b64decode(claims_part))
assert claims.get("exp") > time()
priv = claims.get("priv")
assert priv is not None
assert priv.get("userId") is not None
self.assertListEqual(
priv["accounts"],
[
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
"78403940-ae4f-4aff-a395-1e90f145cf62",
],
)
"""
Checking the login status when re-using the cookie should yield
a successful response
"""
rv = self.app.get("/api/login")
assert rv.status.startswith("200")
jwt2 = self._assert_cookie_present("auth")
assert jwt2 == jwt
"""
Performing a bad login attempt when sending a valid auth cookie
is expected to destroy the cookie and leave the user logged out again
"""
rv = self.app.post(
"/api/login",
data=json.dumps(
{"username": "evil@session.takeover", "password": "develop"}
),
)
assert rv.status.startswith("401")
self._assert_cookie_not_present("auth")
"""
Explicitly logging out leaves the user without cookies
"""
rv = self.app.post(
"/api/login",
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
)
assert rv.status.startswith("200")
rv = self.app.post("/api/logout")
assert rv.status.startswith("204")
self._assert_cookie_not_present("auth")
def test_forged_token(self):
"""
The application needs to verify that tokens that would be theoretically
valid are not signed using an unknown key.
"""
forged_token = jwt.encode(
{
"exp": datetime.utcnow() + timedelta(hours=24),
"priv": {
"userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8",
"accounts": [
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
"78403940-ae4f-4aff-a395-1e90f145cf62",
],
},
},
FOREIGN_PRIVATE_KEY,
algorithm="RS256",
).decode()
self.app.set_cookie("localhost", "auth", forged_token)
rv = self.app.get("/api/login")
assert rv.status.startswith("401")

View File

@ -1,21 +0,0 @@
import unittest
import json
from accounts import app
class TestJWT(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
def test_jwt_flow(self):
rv = self.app.get("/api/login")
assert rv.status.startswith("401")
rv = self.app.post(
"/api/login", data=json.dumps({"username": "offen", "password": "develop"})
)
assert rv.status.startswith("200")
rv = self.app.get("/api/login")
assert rv.status.startswith("200")

View File

@ -24,6 +24,14 @@ services:
environment:
POSTGRES_PASSWORD: develop
accounts_database:
image: mysql:5.7
ports:
- "3306:3306"
environment:
MYSQL_DATABASE: mysql
MYSQL_ROOT_PASSWORD: develop
server:
build:
context: '.'
@ -31,6 +39,7 @@ services:
working_dir: /offen/server
volumes:
- .:/offen
- ./bootstrap.yml:/offen/server/bootstrap.yml
- serverdeps:/go/pkg/mod
environment:
POSTGRES_CONNECTION_STRING: postgres://postgres:develop@server_database:5432/postgres?sslmode=disable
@ -100,16 +109,20 @@ services:
working_dir: /offen/accounts
volumes:
- .:/offen
- ./bootstrap.yml:/offen/accounts/bootstrap.yml
- accountdeps:/root/.local
command: flask run --host 0.0.0.0
ports:
- 5000:5000
links:
- accounts_database
environment:
FLASK_APP: accounts
FLASK_APP: accounts:app
FLASK_ENV: development
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql
CORS_ORIGIN: http://localhost:9977
# local password is `develop`
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp
SERVER_HOST: http://server:8080
SESSION_SECRET: vndJRFJTiyjfgtTF
JWT_PRIVATE_KEY: |-
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCa6AEl0RUW43YS

View File

@ -23,12 +23,19 @@ const ClaimsContextKey contextKey = "claims"
// JWTProtect uses the public key located at the given URL to check if the
// cookie value is signed properly. In case yes, the JWT claims will be added
// to the request context
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authCookie, err := r.Cookie(cookieName)
if err != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error reading cookie: %s", err), http.StatusForbidden)
var jwtValue string
if authCookie, err := r.Cookie(cookieName); err == nil {
jwtValue = authCookie.Value
} else {
if header := r.Header.Get(headerName); header != "" {
jwtValue = header
}
}
if jwtValue == "" {
RespondWithJSONError(w, errors.New("jwt: could not infer JWT value from cookie or header"), http.StatusForbidden)
return
}
@ -56,7 +63,7 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
return
}
token, jwtErr := jwt.ParseVerify(strings.NewReader(authCookie.Value), jwa.RS256, pubKey)
token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
if jwtErr != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
return
@ -67,9 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
return
}
privateClaims, _ := token.Get("priv")
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims))
privKey, _ := token.Get("priv")
claims, _ := privKey.(map[string]interface{})
if err := authorizer(r, claims); err != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
return
}
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
next.ServeHTTP(w, r)
})
}

View File

@ -3,6 +3,7 @@ package http
import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) {
tests := []struct {
name string
cookie *http.Cookie
headers *http.Header
server *httptest.Server
authorizer func(r *http.Request, claims map[string]interface{}) error
expectedStatusCode int
}{
{
"no cookie",
nil,
nil,
nil,
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden,
},
{
@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) {
Value: "irrelevantgibberish",
},
nil,
nil,
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError,
},
{
@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth",
Value: "irrelevantgibberish",
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("here's some bytes 4 y'all"))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError,
},
{
@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth",
Value: "irrelevantgibberish",
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"not really a key"}`))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError,
},
{
@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth",
Value: "irrelevantgibberish",
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusBadGateway,
},
{
@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) {
Name: "auth",
Value: "irrelevantgibberish",
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden,
},
{
@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) {
return string(b)
})(),
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusOK,
},
{
"ok token in headers",
nil,
(func() *http.Header {
token := jwt.New()
token.Set("exp", time.Now().Add(time.Hour))
keyBytes, _ := pem.Decode([]byte(privateKey))
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
b, _ := token.Sign(jwa.RS256, privKey)
return &http.Header{
"X-RPC-Authentication": []string{string(b)},
}
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusOK,
},
{
"bad token in headers",
nil,
(func() *http.Header {
return &http.Header{
"X-RPC-Authentication": []string{"nilly willy"},
}
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden,
},
{
"authorizer rejects",
&http.Cookie{
Name: "auth",
Value: (func() string {
token := jwt.New()
token.Set("exp", time.Now().Add(time.Hour))
token.Set("priv", map[string]interface{}{"ok": false})
keyBytes, _ := pem.Decode([]byte(privateKey))
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
b, _ := token.Sign(jwa.RS256, privKey)
return string(b)
})(),
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error {
if claims["ok"] == true {
return nil
}
return errors.New("expected ok to be true")
},
http.StatusForbidden,
},
{
"valid key, expired token",
&http.Cookie{
@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) {
return string(b)
})(),
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden,
},
}
@ -174,7 +258,7 @@ func TestJWTProtect(t *testing.T) {
if test.server != nil {
url = test.server.URL
}
wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
w := httptest.NewRecorder()
@ -182,6 +266,11 @@ func TestJWTProtect(t *testing.T) {
if test.cookie != nil {
r.AddCookie(test.cookie)
}
if test.headers != nil {
for key, value := range *test.headers {
r.Header.Add(key, value[0])
}
}
wrappedHandler.ServeHTTP(w, r)
if w.Code != test.expectedStatusCode {
t.Errorf("Unexpected status code %v", w.Code)