2
0
mirror of https://github.com/offen/website.git synced 2024-10-18 12:10:25 +02:00

enable flask app to use different configuration sources when run in lambda

This commit is contained in:
Frederik Ring 2019-07-19 09:59:24 +02:00
parent 561f89b9ff
commit 92ae60149f
8 changed files with 36 additions and 23 deletions

View File

@ -198,6 +198,7 @@ jobs:
docker:
- image: circleci/python:3.6
environment:
CONFIG_CLASS: accounts.config.EnvConfig
MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
JWT_PRIVATE_KEY: |-
-----BEGIN PRIVATE KEY-----

View File

@ -4,18 +4,20 @@ from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_admin import Admin
from accounts.config import EnvConfig
app = Flask(__name__)
app.secret_key = environ.get("SESSION_SECRET")
app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING")
app.config.from_object(environ.get("CONFIG_CLASS"))
db = SQLAlchemy(app)
from accounts.models import Account, User
from accounts.views import AccountView, UserView
import accounts.api
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
admin = Admin(app, name="offen admin", template_mode="bootstrap3", base_template="index.html")
admin = Admin(
app, name="offen admin", template_mode="bootstrap3", base_template="index.html"
)
admin.add_view(AccountView(Account, db.session))
admin.add_view(UserView(User, db.session))

View File

@ -1,5 +1,4 @@
from datetime import datetime, timedelta
from os import environ
from functools import wraps
from flask import jsonify, make_response, request
@ -37,7 +36,7 @@ class UnauthorizedError(Exception):
@app.route("/api/login", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True)
@json_error
def post_login():
credentials = request.get_json(force=True)
@ -53,7 +52,6 @@ def post_login():
resp.status_code = 401
return resp
private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(hours=24)
encoded = jwt.encode(
{
@ -63,7 +61,7 @@ def post_login():
"accounts": [a.account_id for a in match.accounts],
},
},
private_key.encode(),
app.config["JWT_PRIVATE_KEY"].encode(),
algorithm="RS256",
).decode()
@ -74,20 +72,19 @@ def post_login():
httponly=True,
expires=expiry,
path="/",
domain=environ.get("COOKIE_DOMAIN"),
domain=app.config["COOKIE_DOMAIN"],
samesite="strict",
)
return resp
@app.route("/api/login", methods=["GET"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True)
@json_error
def get_login():
auth_cookie = request.cookies.get(COOKIE_KEY)
public_key = environ.get("JWT_PUBLIC_KEY", "")
try:
token = jwt.decode(auth_cookie, public_key)
token = jwt.decode(auth_cookie, app.config["JWT_PUBLIC_KEY"])
except jwt.exceptions.PyJWTError as unauthorized_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
@ -104,7 +101,7 @@ def get_login():
@app.route("/api/logout", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True)
@json_error
def post_logout():
resp = make_response("")
@ -120,5 +117,5 @@ def key():
This route is not supposed to be called by client-side applications, so
no CORS configuration is added
"""
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
public_key = app.config["JWT_PUBLIC_KEY"].strip()
return jsonify({"key": public_key})

View File

@ -0,0 +1,14 @@
from os import environ
class BaseConfig(object):
SQLALCHEMY_TRACK_MODIFICATIONS = False
FLASK_ADMIN_SWATCH = "flatly"
class EnvConfig(BaseConfig):
SECRET_KEY = environ.get("SESSION_SECRET")
SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING")
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "")
JWT_PUBLIC_KEY = environ.get("JWT_PUBLIC_KEY", "")
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
SERVER_HOST = environ.get("SERVER_HOST")

View File

@ -1,5 +1,4 @@
from datetime import datetime, timedelta
from os import environ
import requests
from flask_admin.contrib.sqla import ModelView
@ -8,7 +7,7 @@ from wtforms.validators import InputRequired, EqualTo
from passlib.hash import bcrypt
import jwt
from accounts import db
from accounts import db, app
from accounts.models import AccountUserAssociation
@ -22,18 +21,17 @@ class RemoteServerException(Exception):
def create_remote_account(name, account_id):
private_key = environ.get("JWT_PRIVATE_KEY", "")
# expires in 30 seconds as this will mean the HTTP request would have
# timed out anyways
expiry = datetime.utcnow() + timedelta(seconds=30)
encoded = jwt.encode(
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
private_key.encode(),
app.config["JWT_PRIVATE_KEY"].encode(),
algorithm="RS256",
).decode("utf-8")
r = requests.post(
"{}/accounts".format(environ.get("SERVER_HOST")),
"{}/accounts".format(app.config["SERVER_HOST"]),
json={"name": name, "accountId": account_id},
headers={"X-RPC-Authentication": encoded},
)

View File

@ -84,6 +84,7 @@ functions:
path: '/{proxy+}'
method: any
environment:
CONFIG_CLASS: accounts.config.EnvConfig
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}

View File

@ -1,7 +1,6 @@
import unittest
import json
import base64
from json import loads
from time import time
from datetime import datetime, timedelta
from os import environ
@ -55,7 +54,7 @@ class TestKey(unittest.TestCase):
def test_get_key(self):
rv = self.app.get("/api/key")
assert rv.status.startswith("200")
data = loads(rv.data)
data = json.loads(rv.data)
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
@ -132,7 +131,7 @@ class TestJWT(unittest.TestCase):
# PyJWT strips the padding from the base64 encoded parts which Python
# cannot decode properly, so we need to add the padding ourselves
claims_part = _pad_b64_string(jwt.split(".")[1])
claims = loads(base64.b64decode(claims_part))
claims = json.loads(base64.b64decode(claims_part))
assert claims.get("exp") > time()
priv = claims.get("priv")

View File

@ -117,6 +117,7 @@ services:
links:
- accounts_database
environment:
CONFIG_CLASS: accounts.config.EnvConfig
FLASK_APP: accounts:app
FLASK_ENV: development
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql