mirror of
https://github.com/offen/website.git
synced 2024-11-22 17:10:29 +01:00
Merge pull request #59 from offen/flask-admin
Add auditorium users and db backed accounts
This commit is contained in:
commit
d044d95b5b
@ -87,7 +87,9 @@ jobs:
|
|||||||
echo Failed waiting for Postgres && exit 1
|
echo Failed waiting for Postgres && exit 1
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: make test-ci
|
command: |
|
||||||
|
cp ~/offen/bootstrap.yml .
|
||||||
|
make test-ci
|
||||||
|
|
||||||
shared:
|
shared:
|
||||||
docker:
|
docker:
|
||||||
@ -196,7 +198,7 @@ jobs:
|
|||||||
docker:
|
docker:
|
||||||
- image: circleci/python:3.6
|
- image: circleci/python:3.6
|
||||||
environment:
|
environment:
|
||||||
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp
|
MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
|
||||||
JWT_PRIVATE_KEY: |-
|
JWT_PRIVATE_KEY: |-
|
||||||
-----BEGIN PRIVATE KEY-----
|
-----BEGIN PRIVATE KEY-----
|
||||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzgU18PnRrpbVK
|
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzgU18PnRrpbVK
|
||||||
@ -236,7 +238,11 @@ jobs:
|
|||||||
a3B4L0waKzP5QWcO865n1HCUTnV+s4lNcphBDZCrSwTkXnVnQWVPCL7ssoQyM0u3
|
a3B4L0waKzP5QWcO865n1HCUTnV+s4lNcphBDZCrSwTkXnVnQWVPCL7ssoQyM0u3
|
||||||
HQIDAQAB
|
HQIDAQAB
|
||||||
-----END PUBLIC KEY-----
|
-----END PUBLIC KEY-----
|
||||||
|
- image: circleci/mysql:5.7
|
||||||
|
environment:
|
||||||
|
- MYSQL_ROOT_PASSWORD=circle
|
||||||
|
- MYSQL_DATABASE=circle
|
||||||
|
- MYSQL_HOST=127.0.0.1
|
||||||
working_directory: ~/offen/accounts
|
working_directory: ~/offen/accounts
|
||||||
steps:
|
steps:
|
||||||
- checkout:
|
- checkout:
|
||||||
@ -254,11 +260,22 @@ jobs:
|
|||||||
paths:
|
paths:
|
||||||
- ~/offen/accounts/venv
|
- ~/offen/accounts/venv
|
||||||
key: offen-accounts-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
|
key: offen-accounts-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
|
||||||
|
- run:
|
||||||
|
name: Waiting for MySQL to be ready
|
||||||
|
command: |
|
||||||
|
for i in `seq 1 10`;
|
||||||
|
do
|
||||||
|
nc -z localhost 3306 && echo Success && exit 0
|
||||||
|
echo -n .
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo Failed waiting for MySQL && exit 1
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
. venv/bin/activate
|
. venv/bin/activate
|
||||||
make
|
cp ~/offen/bootstrap.yml .
|
||||||
|
make test-ci
|
||||||
|
|
||||||
deploy_python:
|
deploy_python:
|
||||||
docker:
|
docker:
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -5,3 +5,5 @@ package-lock.json
|
|||||||
# mkcert certificates
|
# mkcert certificates
|
||||||
*.pem
|
*.pem
|
||||||
venv/
|
venv/
|
||||||
|
|
||||||
|
bootstrap-alpha.yml
|
||||||
|
1
Makefile
1
Makefile
@ -20,5 +20,6 @@ setup:
|
|||||||
bootstrap:
|
bootstrap:
|
||||||
@docker-compose run kms make bootstrap
|
@docker-compose run kms make bootstrap
|
||||||
@docker-compose run server make bootstrap
|
@docker-compose run server make bootstrap
|
||||||
|
@docker-compose run accounts make bootstrap
|
||||||
|
|
||||||
.PHONY: setup bootstrap
|
.PHONY: setup bootstrap
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
test:
|
test:
|
||||||
@pytest --disable-pytest-warnings
|
@pytest --disable-pytest-warnings
|
||||||
|
|
||||||
|
test-ci: bootstrap
|
||||||
|
@pytest --disable-pytest-warnings
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
@black .
|
@black .
|
||||||
|
|
||||||
.PHONY: test fmt
|
bootstrap:
|
||||||
|
@python -m scripts.bootstrap
|
||||||
|
|
||||||
|
.PHONY: test fmt bootstrap
|
||||||
|
@ -1,5 +1,21 @@
|
|||||||
|
from os import environ
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
|
from flask_admin import Admin
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
app.secret_key = environ.get("SESSION_SECRET")
|
||||||
|
app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING")
|
||||||
|
db = SQLAlchemy(app)
|
||||||
|
|
||||||
import accounts.views
|
from accounts.models import Account, User
|
||||||
|
from accounts.views import AccountView, UserView
|
||||||
|
import accounts.api
|
||||||
|
|
||||||
|
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
|
||||||
|
|
||||||
|
admin = Admin(app, name="offen admin", template_mode="bootstrap3")
|
||||||
|
|
||||||
|
admin.add_view(AccountView(Account, db.session))
|
||||||
|
admin.add_view(UserView(User, db.session))
|
||||||
|
124
accounts/accounts/api.py
Normal file
124
accounts/accounts/api.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from os import environ
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask import jsonify, make_response, request
|
||||||
|
from flask_cors import cross_origin
|
||||||
|
from passlib.hash import bcrypt
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from accounts import app
|
||||||
|
from accounts.models import User
|
||||||
|
|
||||||
|
COOKIE_KEY = "auth"
|
||||||
|
|
||||||
|
|
||||||
|
def json_error(handler):
|
||||||
|
@wraps(handler)
|
||||||
|
def wrapped_handler(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return handler(*args, **kwargs)
|
||||||
|
except Exception as server_error:
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"error": "Internal server error: {}".format(str(server_error)),
|
||||||
|
"status": 500,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapped_handler
|
||||||
|
|
||||||
|
|
||||||
|
class UnauthorizedError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/login", methods=["POST"])
|
||||||
|
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||||
|
@json_error
|
||||||
|
def post_login():
|
||||||
|
credentials = request.get_json(force=True)
|
||||||
|
try:
|
||||||
|
match = User.query.filter_by(email=credentials["username"]).first()
|
||||||
|
if not match:
|
||||||
|
raise UnauthorizedError("bad username")
|
||||||
|
if not bcrypt.verify(credentials["password"], match.hashed_password):
|
||||||
|
raise UnauthorizedError("bad password")
|
||||||
|
except UnauthorizedError as unauthorized_error:
|
||||||
|
resp = make_response(jsonify({"error": str(unauthorized_error), "status": 401}))
|
||||||
|
resp.set_cookie(COOKIE_KEY, "", expires=0)
|
||||||
|
resp.status_code = 401
|
||||||
|
return resp
|
||||||
|
|
||||||
|
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
||||||
|
expiry = datetime.utcnow() + timedelta(hours=24)
|
||||||
|
encoded = jwt.encode(
|
||||||
|
{
|
||||||
|
"exp": expiry,
|
||||||
|
"priv": {
|
||||||
|
"userId": match.user_id,
|
||||||
|
"accounts": [a.account_id for a in match.accounts],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
private_key.encode(),
|
||||||
|
algorithm="RS256",
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
resp = make_response(jsonify({"user": match.serialize()}))
|
||||||
|
resp.set_cookie(
|
||||||
|
COOKIE_KEY,
|
||||||
|
encoded,
|
||||||
|
httponly=True,
|
||||||
|
expires=expiry,
|
||||||
|
path="/",
|
||||||
|
domain=environ.get("COOKIE_DOMAIN"),
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/login", methods=["GET"])
|
||||||
|
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||||
|
@json_error
|
||||||
|
def get_login():
|
||||||
|
auth_cookie = request.cookies.get(COOKIE_KEY)
|
||||||
|
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
||||||
|
try:
|
||||||
|
token = jwt.decode(auth_cookie, public_key)
|
||||||
|
except jwt.exceptions.PyJWTError as unauthorized_error:
|
||||||
|
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
||||||
|
|
||||||
|
try:
|
||||||
|
match = User.query.get(token["priv"]["userId"])
|
||||||
|
except KeyError as key_err:
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{"error": "malformed JWT claims: {}".format(key_err), "status": 401}
|
||||||
|
),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
return jsonify({"user": match.serialize()})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/logout", methods=["POST"])
|
||||||
|
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||||
|
@json_error
|
||||||
|
def post_logout():
|
||||||
|
resp = make_response("")
|
||||||
|
resp.set_cookie(COOKIE_KEY, "", expires=0)
|
||||||
|
resp.status_code = 204
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/key", methods=["GET"])
|
||||||
|
@json_error
|
||||||
|
def key():
|
||||||
|
"""
|
||||||
|
This route is not supposed to be called by client-side applications, so
|
||||||
|
no CORS configuration is added
|
||||||
|
"""
|
||||||
|
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
|
||||||
|
return jsonify({"key": public_key})
|
49
accounts/accounts/models.py
Normal file
49
accounts/accounts/models.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from accounts import db
|
||||||
|
|
||||||
|
|
||||||
|
def generate_key():
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class Account(db.Model):
|
||||||
|
__tablename__ = "accounts"
|
||||||
|
account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
||||||
|
name = db.Column(db.Text, nullable=False)
|
||||||
|
users = db.relationship("AccountUserAssociation", back_populates="account")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class User(db.Model):
|
||||||
|
__tablename__ = "users"
|
||||||
|
user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
||||||
|
email = db.Column(db.String(128), nullable=False, unique=True)
|
||||||
|
hashed_password = db.Column(db.Text, nullable=False)
|
||||||
|
accounts = db.relationship(
|
||||||
|
"AccountUserAssociation", back_populates="user", lazy="joined"
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
associated_accounts = [a.account_id for a in self.accounts]
|
||||||
|
records = [
|
||||||
|
{"name": a.name, "accountId": a.account_id}
|
||||||
|
for a in Account.query.filter(Account.account_id.in_(associated_accounts))
|
||||||
|
]
|
||||||
|
return {"userId": self.user_id, "email": self.email, "accounts": records}
|
||||||
|
|
||||||
|
|
||||||
|
class AccountUserAssociation(db.Model):
|
||||||
|
__tablename__ = "account_to_user"
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
|
||||||
|
user_id = db.Column(db.String(36), db.ForeignKey("users.user_id"), nullable=False)
|
||||||
|
account_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("accounts.account_id"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
user = db.relationship("User", back_populates="accounts")
|
||||||
|
account = db.relationship("Account", back_populates="users")
|
@ -1,69 +1,104 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from os import environ
|
from os import environ
|
||||||
import base64
|
|
||||||
|
|
||||||
from flask import jsonify, render_template, make_response, request
|
import requests
|
||||||
from flask_cors import cross_origin
|
from flask_admin.contrib.sqla import ModelView
|
||||||
|
from wtforms import PasswordField, StringField, Form
|
||||||
|
from wtforms.validators import InputRequired, EqualTo
|
||||||
from passlib.hash import bcrypt
|
from passlib.hash import bcrypt
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from accounts import app
|
from accounts import db
|
||||||
|
from accounts.models import AccountUserAssociation
|
||||||
@app.route("/")
|
|
||||||
def home():
|
|
||||||
return render_template("index.html")
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/login", methods=["POST"])
|
class RemoteServerException(Exception):
|
||||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
status = 0
|
||||||
def post_login():
|
|
||||||
credentials = request.get_json(force=True)
|
|
||||||
|
|
||||||
if credentials["username"] != environ.get("USER", "offen"):
|
def __str__(self):
|
||||||
return jsonify({"error": "bad username", "status": 401}), 401
|
return "Status {}: {}".format(
|
||||||
|
self.status, super(RemoteServerException, self).__str__()
|
||||||
hashed_password = base64.standard_b64decode(environ.get("HASHED_PASSWORD", ""))
|
|
||||||
if not bcrypt.verify(credentials["password"], hashed_password):
|
|
||||||
return jsonify({"error": "bad password", "status": 401}), 401
|
|
||||||
|
|
||||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
|
||||||
expiry = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
try:
|
|
||||||
encoded = jwt.encode(
|
|
||||||
{"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256"
|
|
||||||
).decode("utf-8")
|
|
||||||
except jwt.exceptions.PyJWTError as encode_error:
|
|
||||||
return jsonify({"error": str(encode_error), "status": 500}), 500
|
|
||||||
|
|
||||||
resp = make_response(jsonify({"ok": True}))
|
|
||||||
resp.set_cookie(
|
|
||||||
"auth",
|
|
||||||
encoded,
|
|
||||||
httponly=True,
|
|
||||||
expires=expiry,
|
|
||||||
path="/",
|
|
||||||
domain=environ.get("COOKIE_DOMAIN"),
|
|
||||||
samesite="strict"
|
|
||||||
)
|
)
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/login", methods=["GET"])
|
def create_remote_account(name, account_id):
|
||||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
||||||
def get_login():
|
# expires in 30 seconds as this will mean the HTTP request would have
|
||||||
auth_cookie = request.cookies.get("auth")
|
# timed out anyways
|
||||||
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
expiry = datetime.utcnow() + timedelta(seconds=30)
|
||||||
|
encoded = jwt.encode(
|
||||||
|
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
|
||||||
|
private_key.encode(),
|
||||||
|
algorithm="RS256",
|
||||||
|
).decode("utf-8")
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
"{}/accounts".format(environ.get("SERVER_HOST")),
|
||||||
|
json={"name": name, "accountId": account_id},
|
||||||
|
headers={"X-RPC-Authentication": encoded},
|
||||||
|
)
|
||||||
|
|
||||||
|
if r.status_code > 299:
|
||||||
|
err = r.json()
|
||||||
|
remote_err = RemoteServerException(err["error"])
|
||||||
|
remote_err.status = err["status"]
|
||||||
|
raise remote_err
|
||||||
|
|
||||||
|
|
||||||
|
class AccountForm(Form):
|
||||||
|
name = StringField(
|
||||||
|
"Account Name",
|
||||||
|
validators=[InputRequired()],
|
||||||
|
description="This is the account name visible to users",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AccountView(ModelView):
|
||||||
|
form = AccountForm
|
||||||
|
column_display_all_relations = True
|
||||||
|
column_list = ("name", "account_id")
|
||||||
|
|
||||||
|
def after_model_change(self, form, model, is_created):
|
||||||
|
if is_created:
|
||||||
try:
|
try:
|
||||||
jwt.decode(auth_cookie, public_key)
|
create_remote_account(model.name, model.account_id)
|
||||||
except jwt.exceptions.PyJWTError as unauthorized_error:
|
except RemoteServerException as server_error:
|
||||||
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
return jsonify({"ok": True})
|
raise server_error
|
||||||
|
|
||||||
|
|
||||||
# This route is not supposed to be called by client-side applications, so
|
class UserView(ModelView):
|
||||||
# no CORS configuration is added
|
inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
|
||||||
@app.route("/api/key", methods=["GET"])
|
column_auto_select_related = True
|
||||||
def key():
|
column_display_all_relations = True
|
||||||
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
|
column_list = ("email", "user_id")
|
||||||
return jsonify({"key": public_key})
|
form_columns = ("email", "accounts")
|
||||||
|
form_create_rules = ("email", "password", "confirm", "accounts")
|
||||||
|
form_edit_rules = ("email", "password", "confirm", "accounts")
|
||||||
|
|
||||||
|
def on_model_change(self, form, model, is_created):
|
||||||
|
if form.password.data:
|
||||||
|
model.hashed_password = bcrypt.hash(form.password.data)
|
||||||
|
|
||||||
|
def get_create_form(self):
|
||||||
|
form = super(UserView, self).get_create_form()
|
||||||
|
form.password = PasswordField(
|
||||||
|
"Password",
|
||||||
|
validators=[
|
||||||
|
InputRequired(),
|
||||||
|
EqualTo("confirm", message="Passwords must match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
form.confirm = PasswordField("Repeat Password", validators=[InputRequired()])
|
||||||
|
return form
|
||||||
|
|
||||||
|
def get_edit_form(self):
|
||||||
|
form = super(UserView, self).get_edit_form()
|
||||||
|
form.password = PasswordField(
|
||||||
|
"Password",
|
||||||
|
description="When left blank, the password will remain unchanged on update",
|
||||||
|
validators=[EqualTo("confirm", message="Passwords must match")],
|
||||||
|
)
|
||||||
|
form.confirm = PasswordField("Repeat Password", validators=[])
|
||||||
|
return form
|
||||||
|
60
accounts/authorizer/__init__.py
Normal file
60
accounts/authorizer/__init__.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import base64
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
|
|
||||||
|
def build_api_arn(method_arn):
|
||||||
|
arn_chunks = method_arn.split(":")
|
||||||
|
aws_region = arn_chunks[3]
|
||||||
|
aws_account_id = arn_chunks[4]
|
||||||
|
|
||||||
|
gateway_arn_chunks = arn_chunks[5].split("/")
|
||||||
|
rest_api_id = gateway_arn_chunks[0]
|
||||||
|
stage = gateway_arn_chunks[1]
|
||||||
|
|
||||||
|
return "arn:aws:execute-api:{}:{}:{}/{}/*/*".format(
|
||||||
|
aws_region, aws_account_id, rest_api_id, stage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_response(api_arn, allow):
|
||||||
|
effect = "Deny"
|
||||||
|
if allow:
|
||||||
|
effect = "Allow"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"principalId": "offen",
|
||||||
|
"policyDocument": {
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Action": ["execute-api:Invoke"],
|
||||||
|
"Effect": effect,
|
||||||
|
"Resource": [api_arn],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def handler(event, context):
|
||||||
|
api_arn = build_api_arn(event["methodArn"])
|
||||||
|
|
||||||
|
encoded_auth = event["authorizationToken"].lstrip("Basic ")
|
||||||
|
auth_string = base64.standard_b64decode(encoded_auth).decode()
|
||||||
|
if not auth_string:
|
||||||
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
|
credentials = auth_string.split(":")
|
||||||
|
user = credentials[0]
|
||||||
|
password = credentials[1]
|
||||||
|
|
||||||
|
if user != environ.get("BASIC_AUTH_USER"):
|
||||||
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
|
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
|
||||||
|
if not bcrypt.verify(password, hashed_password):
|
||||||
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
|
return build_response(api_arn, True)
|
0
accounts/bootstrap.yml
Executable file
0
accounts/bootstrap.yml
Executable file
@ -1,2 +1,3 @@
|
|||||||
pytest
|
pytest
|
||||||
black
|
black
|
||||||
|
pyyaml
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
Flask==1.0.2
|
Flask==1.0.2
|
||||||
|
Flask-Admin==1.5.3
|
||||||
Flask-Cors==3.0.8
|
Flask-Cors==3.0.8
|
||||||
|
Flask-SQLAlchemy==2.4.0
|
||||||
werkzeug==0.15.4
|
werkzeug==0.15.4
|
||||||
pyjwt[crypto]==1.7.1
|
pyjwt[crypto]==1.7.1
|
||||||
passlib==1.7.1
|
passlib==1.7.1
|
||||||
bcrypt==3.1.7
|
bcrypt==3.1.7
|
||||||
|
PyMySQL==0.9.3
|
||||||
|
mysqlclient==1.4.2.post1
|
||||||
|
requests==2.22.0
|
||||||
|
0
accounts/scripts/__init__.py
Normal file
0
accounts/scripts/__init__.py
Normal file
32
accounts/scripts/bootstrap.py
Executable file
32
accounts/scripts/bootstrap.py
Executable file
@ -0,0 +1,32 @@
|
|||||||
|
import yaml
|
||||||
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
|
from accounts import db
|
||||||
|
from accounts.models import Account, User, AccountUserAssociation
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
db.drop_all()
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
with open("./bootstrap.yml", "r") as stream:
|
||||||
|
data = yaml.safe_load(stream)
|
||||||
|
|
||||||
|
for account in data["accounts"]:
|
||||||
|
record = Account(
|
||||||
|
name=account["name"],
|
||||||
|
account_id=account["id"],
|
||||||
|
)
|
||||||
|
db.session.add(record)
|
||||||
|
|
||||||
|
for user in data["users"]:
|
||||||
|
record = User(
|
||||||
|
email=user["email"],
|
||||||
|
hashed_password=bcrypt.hash(user["password"]),
|
||||||
|
)
|
||||||
|
for account_id in user["accounts"]:
|
||||||
|
record.accounts.append(AccountUserAssociation(account_id=account_id))
|
||||||
|
db.session.add(record)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
print("Successfully bootstrapped accounts database")
|
20
accounts/scripts/hash.py
Normal file
20
accounts/scripts/hash.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import base64
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--password", type=str, help="The password to hash", required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plain",
|
||||||
|
help="Do not encode the result as base64",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
out = bcrypt.hash(args.password)
|
||||||
|
if not args.plain:
|
||||||
|
out = base64.standard_b64encode(out.encode()).decode()
|
||||||
|
print(out)
|
@ -26,6 +26,10 @@ custom:
|
|||||||
production: vault.offen.dev
|
production: vault.offen.dev
|
||||||
staging: vault-staging.offen.dev
|
staging: vault-staging.offen.dev
|
||||||
alpha: vault-alpha.offen.dev
|
alpha: vault-alpha.offen.dev
|
||||||
|
serverHost:
|
||||||
|
production: server.offen.dev
|
||||||
|
staging: server-staging.offen.dev
|
||||||
|
alpha: server-alpha.offen.dev
|
||||||
domain:
|
domain:
|
||||||
production: accounts.offen.dev
|
production: accounts.offen.dev
|
||||||
staging: accounts-staging.offen.dev
|
staging: accounts-staging.offen.dev
|
||||||
@ -50,19 +54,53 @@ custom:
|
|||||||
fileName: requirements.txt
|
fileName: requirements.txt
|
||||||
|
|
||||||
functions:
|
functions:
|
||||||
|
authorizer:
|
||||||
|
handler: authorizer.handler
|
||||||
|
environment:
|
||||||
|
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
|
||||||
|
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
||||||
app:
|
app:
|
||||||
handler: wsgi_handler.handler
|
handler: wsgi_handler.handler
|
||||||
events:
|
events:
|
||||||
|
- http:
|
||||||
|
path: /admin/
|
||||||
|
method: any
|
||||||
|
authorizer:
|
||||||
|
name: authorizer
|
||||||
|
resultTtlInSeconds: 0
|
||||||
|
identitySource: method.request.header.Authorization
|
||||||
|
- http:
|
||||||
|
path: /admin/{proxy+}
|
||||||
|
method: any
|
||||||
|
authorizer:
|
||||||
|
name: authorizer
|
||||||
|
resultTtlInSeconds: 0
|
||||||
|
identitySource: method.request.header.Authorization
|
||||||
- http:
|
- http:
|
||||||
path: '/'
|
path: '/'
|
||||||
method: any
|
method: any
|
||||||
- http:
|
- http:
|
||||||
path: '{proxy+}'
|
path: '/{proxy+}'
|
||||||
method: any
|
method: any
|
||||||
environment:
|
environment:
|
||||||
USER: offen
|
|
||||||
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
|
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
|
||||||
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
||||||
|
SERVER_URL: ${self:custom.serverHost.${self:custom.stage}}
|
||||||
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
|
JWT_PRIVATE_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPrivateKey~true}'
|
||||||
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
|
JWT_PUBLIC_KEY: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/jwtPublicKey~true}'
|
||||||
HASHED_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
BASIC_AUTH_USER: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/basicAuthUser~true}
|
||||||
|
HASHED_BASIC_AUTH_PASSWORD: ${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/hashedBasicAuthPassword~true}
|
||||||
|
SESSION_SECRET: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/sessionSecret~true}'
|
||||||
|
MYSQL_CONNECTION_STRING: '${ssm:/aws/reference/secretsmanager/${self:custom.stage}/accounts/mysqlConnectionString~true}'
|
||||||
|
|
||||||
|
resources:
|
||||||
|
Resources:
|
||||||
|
GatewayResponse:
|
||||||
|
Type: 'AWS::ApiGateway::GatewayResponse'
|
||||||
|
Properties:
|
||||||
|
ResponseParameters:
|
||||||
|
gatewayresponse.header.WWW-Authenticate: "'Basic'"
|
||||||
|
ResponseType: UNAUTHORIZED
|
||||||
|
RestApiId:
|
||||||
|
Ref: 'ApiGatewayRestApi'
|
||||||
|
StatusCode: '401'
|
||||||
|
207
accounts/tests/test_api.py
Normal file
207
accounts/tests/test_api.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from json import loads
|
||||||
|
from time import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from accounts import app
|
||||||
|
|
||||||
|
|
||||||
|
FOREIGN_PRIVATE_KEY = """
|
||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG
|
||||||
|
N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P
|
||||||
|
nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj
|
||||||
|
ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN
|
||||||
|
6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4
|
||||||
|
Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i
|
||||||
|
6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl
|
||||||
|
wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN
|
||||||
|
kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn
|
||||||
|
f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761
|
||||||
|
5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI
|
||||||
|
TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3
|
||||||
|
llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G
|
||||||
|
rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0
|
||||||
|
lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS
|
||||||
|
NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU
|
||||||
|
zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE
|
||||||
|
L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz
|
||||||
|
uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ
|
||||||
|
WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f
|
||||||
|
qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE
|
||||||
|
oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm
|
||||||
|
ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d
|
||||||
|
rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO
|
||||||
|
4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF
|
||||||
|
15NFZ2gU4rKnQ3sVAOzMoEw5
|
||||||
|
-----END PRIVATE KEY-----
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _pad_b64_string(s):
|
||||||
|
while len(s) % 4 is not 0:
|
||||||
|
s = s + "="
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class TestKey(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.app = app.test_client()
|
||||||
|
|
||||||
|
def test_get_key(self):
|
||||||
|
rv = self.app.get("/api/key")
|
||||||
|
assert rv.status.startswith("200")
|
||||||
|
data = loads(rv.data)
|
||||||
|
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWT(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.app = app.test_client()
|
||||||
|
|
||||||
|
def _assert_cookie_present(self, name):
|
||||||
|
for cookie in self.app.cookie_jar:
|
||||||
|
if cookie.name == name:
|
||||||
|
return cookie.value
|
||||||
|
raise AssertionError("Cookie named {} not found".format(name))
|
||||||
|
|
||||||
|
def _assert_cookie_not_present(self, name):
|
||||||
|
for cookie in self.app.cookie_jar:
|
||||||
|
assert cookie.name != name
|
||||||
|
|
||||||
|
def test_jwt_flow(self):
|
||||||
|
"""
|
||||||
|
First, try login attempts that are supposed to fail:
|
||||||
|
1. checking login status without any prior interaction
|
||||||
|
2. try logging in with an unknown user
|
||||||
|
3. try logging in with a known user and bad password
|
||||||
|
"""
|
||||||
|
rv = self.app.get("/api/login")
|
||||||
|
assert rv.status.startswith("401")
|
||||||
|
self._assert_cookie_not_present("auth")
|
||||||
|
|
||||||
|
rv = self.app.post(
|
||||||
|
"/api/login",
|
||||||
|
data=json.dumps(
|
||||||
|
{"username": "does@not.exist", "password": "somethingsomething"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert rv.status.startswith("401")
|
||||||
|
self._assert_cookie_not_present("auth")
|
||||||
|
|
||||||
|
rv = self.app.post(
|
||||||
|
"/api/login",
|
||||||
|
data=json.dumps({"username": "develop@offen.dev", "password": "developp"}),
|
||||||
|
)
|
||||||
|
assert rv.status.startswith("401")
|
||||||
|
self._assert_cookie_not_present("auth")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Next, perform a successful login
|
||||||
|
"""
|
||||||
|
rv = self.app.post(
|
||||||
|
"/api/login",
|
||||||
|
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
|
||||||
|
)
|
||||||
|
assert rv.status.startswith("200")
|
||||||
|
|
||||||
|
"""
|
||||||
|
The response should contain information about the
|
||||||
|
user and full information (i.e. a name) about the associated accounts
|
||||||
|
"""
|
||||||
|
data = json.loads(rv.data)
|
||||||
|
assert data["user"]["userId"] is not None
|
||||||
|
data["user"]["accounts"].sort(key=lambda a: a["name"])
|
||||||
|
self.assertListEqual(
|
||||||
|
data["user"]["accounts"],
|
||||||
|
[
|
||||||
|
{"name": "One", "accountId": "9b63c4d8-65c0-438c-9d30-cc4b01173393"},
|
||||||
|
{"name": "Two", "accountId": "78403940-ae4f-4aff-a395-1e90f145cf62"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
The claims part of the JWT is expected to contain a valid expiry,
|
||||||
|
information about the user and the associated account ids.
|
||||||
|
"""
|
||||||
|
jwt = self._assert_cookie_present("auth")
|
||||||
|
# PyJWT strips the padding from the base64 encoded parts which Python
|
||||||
|
# cannot decode properly, so we need to add the padding ourselves
|
||||||
|
claims_part = _pad_b64_string(jwt.split(".")[1])
|
||||||
|
claims = loads(base64.b64decode(claims_part))
|
||||||
|
assert claims.get("exp") > time()
|
||||||
|
|
||||||
|
priv = claims.get("priv")
|
||||||
|
assert priv is not None
|
||||||
|
|
||||||
|
assert priv.get("userId") is not None
|
||||||
|
self.assertListEqual(
|
||||||
|
priv["accounts"],
|
||||||
|
[
|
||||||
|
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
|
||||||
|
"78403940-ae4f-4aff-a395-1e90f145cf62",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Checking the login status when re-using the cookie should yield
|
||||||
|
a successful response
|
||||||
|
"""
|
||||||
|
rv = self.app.get("/api/login")
|
||||||
|
assert rv.status.startswith("200")
|
||||||
|
jwt2 = self._assert_cookie_present("auth")
|
||||||
|
assert jwt2 == jwt
|
||||||
|
|
||||||
|
"""
|
||||||
|
Performing a bad login attempt when sending a valid auth cookie
|
||||||
|
is expected to destroy the cookie and leave the user logged out again
|
||||||
|
"""
|
||||||
|
rv = self.app.post(
|
||||||
|
"/api/login",
|
||||||
|
data=json.dumps(
|
||||||
|
{"username": "evil@session.takeover", "password": "develop"}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert rv.status.startswith("401")
|
||||||
|
self._assert_cookie_not_present("auth")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Explicitly logging out leaves the user without cookies
|
||||||
|
"""
|
||||||
|
rv = self.app.post(
|
||||||
|
"/api/login",
|
||||||
|
data=json.dumps({"username": "develop@offen.dev", "password": "develop"}),
|
||||||
|
)
|
||||||
|
assert rv.status.startswith("200")
|
||||||
|
|
||||||
|
rv = self.app.post("/api/logout")
|
||||||
|
assert rv.status.startswith("204")
|
||||||
|
self._assert_cookie_not_present("auth")
|
||||||
|
|
||||||
|
def test_forged_token(self):
|
||||||
|
"""
|
||||||
|
The application needs to verify that tokens that would be theoretically
|
||||||
|
valid are not signed using an unknown key.
|
||||||
|
"""
|
||||||
|
forged_token = jwt.encode(
|
||||||
|
{
|
||||||
|
"exp": datetime.utcnow() + timedelta(hours=24),
|
||||||
|
"priv": {
|
||||||
|
"userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8",
|
||||||
|
"accounts": [
|
||||||
|
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
|
||||||
|
"78403940-ae4f-4aff-a395-1e90f145cf62",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FOREIGN_PRIVATE_KEY,
|
||||||
|
algorithm="RS256",
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
self.app.set_cookie("localhost", "auth", forged_token)
|
||||||
|
rv = self.app.get("/api/login")
|
||||||
|
assert rv.status.startswith("401")
|
@ -1,21 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import json
|
|
||||||
|
|
||||||
from accounts import app
|
|
||||||
|
|
||||||
|
|
||||||
class TestJWT(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.app = app.test_client()
|
|
||||||
|
|
||||||
def test_jwt_flow(self):
|
|
||||||
rv = self.app.get("/api/login")
|
|
||||||
assert rv.status.startswith("401")
|
|
||||||
|
|
||||||
rv = self.app.post(
|
|
||||||
"/api/login", data=json.dumps({"username": "offen", "password": "develop"})
|
|
||||||
)
|
|
||||||
assert rv.status.startswith("200")
|
|
||||||
|
|
||||||
rv = self.app.get("/api/login")
|
|
||||||
assert rv.status.startswith("200")
|
|
@ -24,6 +24,14 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
POSTGRES_PASSWORD: develop
|
POSTGRES_PASSWORD: develop
|
||||||
|
|
||||||
|
accounts_database:
|
||||||
|
image: mysql:5.7
|
||||||
|
ports:
|
||||||
|
- "3306:3306"
|
||||||
|
environment:
|
||||||
|
MYSQL_DATABASE: mysql
|
||||||
|
MYSQL_ROOT_PASSWORD: develop
|
||||||
|
|
||||||
server:
|
server:
|
||||||
build:
|
build:
|
||||||
context: '.'
|
context: '.'
|
||||||
@ -31,6 +39,7 @@ services:
|
|||||||
working_dir: /offen/server
|
working_dir: /offen/server
|
||||||
volumes:
|
volumes:
|
||||||
- .:/offen
|
- .:/offen
|
||||||
|
- ./bootstrap.yml:/offen/server/bootstrap.yml
|
||||||
- serverdeps:/go/pkg/mod
|
- serverdeps:/go/pkg/mod
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_CONNECTION_STRING: postgres://postgres:develop@server_database:5432/postgres?sslmode=disable
|
POSTGRES_CONNECTION_STRING: postgres://postgres:develop@server_database:5432/postgres?sslmode=disable
|
||||||
@ -100,16 +109,20 @@ services:
|
|||||||
working_dir: /offen/accounts
|
working_dir: /offen/accounts
|
||||||
volumes:
|
volumes:
|
||||||
- .:/offen
|
- .:/offen
|
||||||
|
- ./bootstrap.yml:/offen/accounts/bootstrap.yml
|
||||||
- accountdeps:/root/.local
|
- accountdeps:/root/.local
|
||||||
command: flask run --host 0.0.0.0
|
command: flask run --host 0.0.0.0
|
||||||
ports:
|
ports:
|
||||||
- 5000:5000
|
- 5000:5000
|
||||||
|
links:
|
||||||
|
- accounts_database
|
||||||
environment:
|
environment:
|
||||||
FLASK_APP: accounts
|
FLASK_APP: accounts:app
|
||||||
FLASK_ENV: development
|
FLASK_ENV: development
|
||||||
|
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql
|
||||||
CORS_ORIGIN: http://localhost:9977
|
CORS_ORIGIN: http://localhost:9977
|
||||||
# local password is `develop`
|
SERVER_HOST: http://server:8080
|
||||||
HASHED_PASSWORD: JDJhJDEwJGpFRXJMOVVSQndZQlFQNjkxallkZi53aGp1cDMvRW5maGUvakZleG1pWFlnWEVXcU93ODBp
|
SESSION_SECRET: vndJRFJTiyjfgtTF
|
||||||
JWT_PRIVATE_KEY: |-
|
JWT_PRIVATE_KEY: |-
|
||||||
-----BEGIN PRIVATE KEY-----
|
-----BEGIN PRIVATE KEY-----
|
||||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCa6AEl0RUW43YS
|
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCa6AEl0RUW43YS
|
||||||
|
@ -23,12 +23,19 @@ const ClaimsContextKey contextKey = "claims"
|
|||||||
// JWTProtect uses the public key located at the given URL to check if the
|
// JWTProtect uses the public key located at the given URL to check if the
|
||||||
// cookie value is signed properly. In case yes, the JWT claims will be added
|
// cookie value is signed properly. In case yes, the JWT claims will be added
|
||||||
// to the request context
|
// to the request context
|
||||||
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
authCookie, err := r.Cookie(cookieName)
|
var jwtValue string
|
||||||
if err != nil {
|
if authCookie, err := r.Cookie(cookieName); err == nil {
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error reading cookie: %s", err), http.StatusForbidden)
|
jwtValue = authCookie.Value
|
||||||
|
} else {
|
||||||
|
if header := r.Header.Get(headerName); header != "" {
|
||||||
|
jwtValue = header
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if jwtValue == "" {
|
||||||
|
RespondWithJSONError(w, errors.New("jwt: could not infer JWT value from cookie or header"), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,7 +63,7 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, jwtErr := jwt.ParseVerify(strings.NewReader(authCookie.Value), jwa.RS256, pubKey)
|
token, jwtErr := jwt.ParseVerify(strings.NewReader(jwtValue), jwa.RS256, pubKey)
|
||||||
if jwtErr != nil {
|
if jwtErr != nil {
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
|
RespondWithJSONError(w, fmt.Errorf("jwt: error parsing token: %v", jwtErr), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
@ -67,9 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
privateClaims, _ := token.Get("priv")
|
privKey, _ := token.Get("priv")
|
||||||
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims))
|
claims, _ := privKey.(map[string]interface{})
|
||||||
|
if err := authorizer(r, claims); err != nil {
|
||||||
|
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cookie *http.Cookie
|
cookie *http.Cookie
|
||||||
|
headers *http.Header
|
||||||
server *httptest.Server
|
server *httptest.Server
|
||||||
|
authorizer func(r *http.Request, claims map[string]interface{}) error
|
||||||
expectedStatusCode int
|
expectedStatusCode int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"no cookie",
|
"no cookie",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusForbidden,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("here's some bytes 4 y'all"))
|
w.Write([]byte("here's some bytes 4 y'all"))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"key":"not really a key"}`))
|
w.Write([]byte(`{"key":"not really a key"}`))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
|
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusBadGateway,
|
http.StatusBadGateway,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusForbidden,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
return string(b)
|
return string(b)
|
||||||
})(),
|
})(),
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"ok token in headers",
|
||||||
|
nil,
|
||||||
|
(func() *http.Header {
|
||||||
|
token := jwt.New()
|
||||||
|
token.Set("exp", time.Now().Add(time.Hour))
|
||||||
|
keyBytes, _ := pem.Decode([]byte(privateKey))
|
||||||
|
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
|
||||||
|
b, _ := token.Sign(jwa.RS256, privKey)
|
||||||
|
return &http.Header{
|
||||||
|
"X-RPC-Authentication": []string{string(b)},
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"bad token in headers",
|
||||||
|
nil,
|
||||||
|
(func() *http.Header {
|
||||||
|
return &http.Header{
|
||||||
|
"X-RPC-Authentication": []string{"nilly willy"},
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"authorizer rejects",
|
||||||
|
&http.Cookie{
|
||||||
|
Name: "auth",
|
||||||
|
Value: (func() string {
|
||||||
|
token := jwt.New()
|
||||||
|
token.Set("exp", time.Now().Add(time.Hour))
|
||||||
|
token.Set("priv", map[string]interface{}{"ok": false})
|
||||||
|
keyBytes, _ := pem.Decode([]byte(privateKey))
|
||||||
|
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
|
||||||
|
b, _ := token.Sign(jwa.RS256, privKey)
|
||||||
|
return string(b)
|
||||||
|
})(),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error {
|
||||||
|
if claims["ok"] == true {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("expected ok to be true")
|
||||||
|
},
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"valid key, expired token",
|
"valid key, expired token",
|
||||||
&http.Cookie{
|
&http.Cookie{
|
||||||
@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
return string(b)
|
return string(b)
|
||||||
})(),
|
})(),
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusForbidden,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -174,7 +258,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
if test.server != nil {
|
if test.server != nil {
|
||||||
url = test.server.URL
|
url = test.server.URL
|
||||||
}
|
}
|
||||||
wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("OK"))
|
w.Write([]byte("OK"))
|
||||||
}))
|
}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -182,6 +266,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
if test.cookie != nil {
|
if test.cookie != nil {
|
||||||
r.AddCookie(test.cookie)
|
r.AddCookie(test.cookie)
|
||||||
}
|
}
|
||||||
|
if test.headers != nil {
|
||||||
|
for key, value := range *test.headers {
|
||||||
|
r.Header.Add(key, value[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
wrappedHandler.ServeHTTP(w, r)
|
wrappedHandler.ServeHTTP(w, r)
|
||||||
if w.Code != test.expectedStatusCode {
|
if w.Code != test.expectedStatusCode {
|
||||||
t.Errorf("Unexpected status code %v", w.Code)
|
t.Errorf("Unexpected status code %v", w.Code)
|
||||||
|
Loading…
Reference in New Issue
Block a user