diff --git a/.circleci/config.yml b/.circleci/config.yml index 618e102..3c9bf24 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -198,6 +198,7 @@ jobs: docker: - image: circleci/python:3.6 environment: + CONFIG_CLASS: accounts.config.EnvConfig MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle JWT_PRIVATE_KEY: |- -----BEGIN PRIVATE KEY----- diff --git a/accounts/accounts/__init__.py b/accounts/accounts/__init__.py index ebe3981..dc44382 100644 --- a/accounts/accounts/__init__.py +++ b/accounts/accounts/__init__.py @@ -4,18 +4,20 @@ from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_admin import Admin +from accounts.config import EnvConfig + app = Flask(__name__) -app.secret_key = environ.get("SESSION_SECRET") -app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING") +app.config.from_object(environ.get("CONFIG_CLASS")) + db = SQLAlchemy(app) from accounts.models import Account, User from accounts.views import AccountView, UserView import accounts.api -app.config["FLASK_ADMIN_SWATCH"] = "flatly" - -admin = Admin(app, name="offen admin", template_mode="bootstrap3", base_template="index.html") +admin = Admin( + app, name="offen admin", template_mode="bootstrap3", base_template="index.html" +) admin.add_view(AccountView(Account, db.session)) admin.add_view(UserView(User, db.session)) diff --git a/accounts/accounts/api.py b/accounts/accounts/api.py index 3fe5290..ad6b97f 100644 --- a/accounts/accounts/api.py +++ b/accounts/accounts/api.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from os import environ from functools import wraps from flask import jsonify, make_response, request @@ -37,7 +36,7 @@ class UnauthorizedError(Exception): @app.route("/api/login", methods=["POST"]) -@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True) +@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True) @json_error def post_login(): credentials = request.get_json(force=True) @@ -53,7 +52,6 @@ def post_login(): resp.status_code = 401 return resp - private_key = environ.get("JWT_PRIVATE_KEY", "") expiry = datetime.utcnow() + timedelta(hours=24) encoded = jwt.encode( { @@ -63,7 +61,7 @@ def post_login(): "accounts": [a.account_id for a in match.accounts], }, }, - private_key.encode(), + app.config["JWT_PRIVATE_KEY"].encode(), algorithm="RS256", ).decode() @@ -74,20 +72,19 @@ def post_login(): httponly=True, expires=expiry, path="/", - domain=environ.get("COOKIE_DOMAIN"), + domain=app.config["COOKIE_DOMAIN"], samesite="strict", ) return resp @app.route("/api/login", methods=["GET"]) -@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True) +@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True) @json_error def get_login(): auth_cookie = request.cookies.get(COOKIE_KEY) - public_key = environ.get("JWT_PUBLIC_KEY", "") try: - token = jwt.decode(auth_cookie, public_key) + token = jwt.decode(auth_cookie, app.config["JWT_PUBLIC_KEY"]) except jwt.exceptions.PyJWTError as unauthorized_error: return jsonify({"error": str(unauthorized_error), "status": 401}), 401 @@ -104,7 +101,7 @@ def get_login(): @app.route("/api/logout", methods=["POST"]) -@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True) +@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True) @json_error def post_logout(): resp = make_response("") @@ -120,5 +117,5 @@ def key(): This route is not supposed to be called by client-side applications, so no CORS configuration is added """ - public_key = environ.get("JWT_PUBLIC_KEY", "").strip() + public_key = app.config["JWT_PUBLIC_KEY"].strip() return jsonify({"key": public_key}) diff --git a/accounts/accounts/config.py b/accounts/accounts/config.py new file mode 100644 index 0000000..3255e21 --- /dev/null +++ b/accounts/accounts/config.py @@ -0,0 +1,14 @@ +from os import environ + +class BaseConfig(object): + SQLALCHEMY_TRACK_MODIFICATIONS = False + FLASK_ADMIN_SWATCH = "flatly" + +class EnvConfig(BaseConfig): + SECRET_KEY = environ.get("SESSION_SECRET") + SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING") + CORS_ORIGIN = environ.get("CORS_ORIGIN", "*") + JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "") + JWT_PUBLIC_KEY = environ.get("JWT_PUBLIC_KEY", "") + COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN") + SERVER_HOST = environ.get("SERVER_HOST") diff --git a/accounts/accounts/views.py b/accounts/accounts/views.py index 5984958..dfd930a 100644 --- a/accounts/accounts/views.py +++ b/accounts/accounts/views.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from os import environ import requests from flask_admin.contrib.sqla import ModelView @@ -8,7 +7,7 @@ from wtforms.validators import InputRequired, EqualTo from passlib.hash import bcrypt import jwt -from accounts import db +from accounts import db, app from accounts.models import AccountUserAssociation @@ -22,18 +21,17 @@ class RemoteServerException(Exception): def create_remote_account(name, account_id): - private_key = environ.get("JWT_PRIVATE_KEY", "") # expires in 30 seconds as this will mean the HTTP request would have # timed out anyways expiry = datetime.utcnow() + timedelta(seconds=30) encoded = jwt.encode( {"ok": True, "exp": expiry, "priv": {"rpc": "1"}}, - private_key.encode(), + app.config["JWT_PRIVATE_KEY"].encode(), algorithm="RS256", ).decode("utf-8") r = requests.post( - "{}/accounts".format(environ.get("SERVER_HOST")), + "{}/accounts".format(app.config["SERVER_HOST"]), json={"name": name, "accountId": account_id}, headers={"X-RPC-Authentication": encoded}, ) diff --git a/accounts/serverless.yml b/accounts/serverless.yml index 81a977f..16026ed 100644 --- a/accounts/serverless.yml +++ b/accounts/serverless.yml @@ -84,6 +84,7 @@ functions: path: '/{proxy+}' method: any environment: + CONFIG_CLASS: accounts.config.EnvConfig CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}} COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}} SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}} diff --git a/accounts/tests/test_api.py b/accounts/tests/test_api.py index 79f7b19..14ab44b 100644 --- a/accounts/tests/test_api.py +++ b/accounts/tests/test_api.py @@ -1,7 +1,6 @@ import unittest import json import base64 -from json import loads from time import time from datetime import datetime, timedelta from os import environ @@ -55,7 +54,7 @@ class TestKey(unittest.TestCase): def test_get_key(self): rv = self.app.get("/api/key") assert rv.status.startswith("200") - data = loads(rv.data) + data = json.loads(rv.data) assert data["key"] == environ.get("JWT_PUBLIC_KEY") @@ -132,7 +131,7 @@ class TestJWT(unittest.TestCase): # PyJWT strips the padding from the base64 encoded parts which Python # cannot decode properly, so we need to add the padding ourselves claims_part = _pad_b64_string(jwt.split(".")[1]) - claims = loads(base64.b64decode(claims_part)) + claims = json.loads(base64.b64decode(claims_part)) assert claims.get("exp") > time() priv = claims.get("priv") diff --git a/docker-compose.yml b/docker-compose.yml index 9da979a..5198a04 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -117,6 +117,7 @@ services: links: - accounts_database environment: + CONFIG_CLASS: accounts.config.EnvConfig FLASK_APP: accounts:app FLASK_ENV: development MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql