mirror of
https://github.com/offen/website.git
synced 2024-11-22 09:00:28 +01:00
enable flask app to use different configuration sources when run in lambda
This commit is contained in:
parent
561f89b9ff
commit
92ae60149f
@ -198,6 +198,7 @@ jobs:
|
|||||||
docker:
|
docker:
|
||||||
- image: circleci/python:3.6
|
- image: circleci/python:3.6
|
||||||
environment:
|
environment:
|
||||||
|
CONFIG_CLASS: accounts.config.EnvConfig
|
||||||
MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
|
MYSQL_CONNECTION_STRING: mysql://root:circle@127.0.0.1:3306/circle
|
||||||
JWT_PRIVATE_KEY: |-
|
JWT_PRIVATE_KEY: |-
|
||||||
-----BEGIN PRIVATE KEY-----
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
@ -4,18 +4,20 @@ from flask import Flask
|
|||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from flask_admin import Admin
|
from flask_admin import Admin
|
||||||
|
|
||||||
|
from accounts.config import EnvConfig
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.secret_key = environ.get("SESSION_SECRET")
|
app.config.from_object(environ.get("CONFIG_CLASS"))
|
||||||
app.config["SQLALCHEMY_DATABASE_URI"] = environ.get("MYSQL_CONNECTION_STRING")
|
|
||||||
db = SQLAlchemy(app)
|
db = SQLAlchemy(app)
|
||||||
|
|
||||||
from accounts.models import Account, User
|
from accounts.models import Account, User
|
||||||
from accounts.views import AccountView, UserView
|
from accounts.views import AccountView, UserView
|
||||||
import accounts.api
|
import accounts.api
|
||||||
|
|
||||||
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
|
admin = Admin(
|
||||||
|
app, name="offen admin", template_mode="bootstrap3", base_template="index.html"
|
||||||
admin = Admin(app, name="offen admin", template_mode="bootstrap3", base_template="index.html")
|
)
|
||||||
|
|
||||||
admin.add_view(AccountView(Account, db.session))
|
admin.add_view(AccountView(Account, db.session))
|
||||||
admin.add_view(UserView(User, db.session))
|
admin.add_view(UserView(User, db.session))
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from os import environ
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import jsonify, make_response, request
|
from flask import jsonify, make_response, request
|
||||||
@ -37,7 +36,7 @@ class UnauthorizedError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
@app.route("/api/login", methods=["POST"])
|
@app.route("/api/login", methods=["POST"])
|
||||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True)
|
||||||
@json_error
|
@json_error
|
||||||
def post_login():
|
def post_login():
|
||||||
credentials = request.get_json(force=True)
|
credentials = request.get_json(force=True)
|
||||||
@ -53,7 +52,6 @@ def post_login():
|
|||||||
resp.status_code = 401
|
resp.status_code = 401
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
|
||||||
expiry = datetime.utcnow() + timedelta(hours=24)
|
expiry = datetime.utcnow() + timedelta(hours=24)
|
||||||
encoded = jwt.encode(
|
encoded = jwt.encode(
|
||||||
{
|
{
|
||||||
@ -63,7 +61,7 @@ def post_login():
|
|||||||
"accounts": [a.account_id for a in match.accounts],
|
"accounts": [a.account_id for a in match.accounts],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
private_key.encode(),
|
app.config["JWT_PRIVATE_KEY"].encode(),
|
||||||
algorithm="RS256",
|
algorithm="RS256",
|
||||||
).decode()
|
).decode()
|
||||||
|
|
||||||
@ -74,20 +72,19 @@ def post_login():
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
expires=expiry,
|
expires=expiry,
|
||||||
path="/",
|
path="/",
|
||||||
domain=environ.get("COOKIE_DOMAIN"),
|
domain=app.config["COOKIE_DOMAIN"],
|
||||||
samesite="strict",
|
samesite="strict",
|
||||||
)
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/login", methods=["GET"])
|
@app.route("/api/login", methods=["GET"])
|
||||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True)
|
||||||
@json_error
|
@json_error
|
||||||
def get_login():
|
def get_login():
|
||||||
auth_cookie = request.cookies.get(COOKIE_KEY)
|
auth_cookie = request.cookies.get(COOKIE_KEY)
|
||||||
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
|
||||||
try:
|
try:
|
||||||
token = jwt.decode(auth_cookie, public_key)
|
token = jwt.decode(auth_cookie, app.config["JWT_PUBLIC_KEY"])
|
||||||
except jwt.exceptions.PyJWTError as unauthorized_error:
|
except jwt.exceptions.PyJWTError as unauthorized_error:
|
||||||
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
||||||
|
|
||||||
@ -104,7 +101,7 @@ def get_login():
|
|||||||
|
|
||||||
|
|
||||||
@app.route("/api/logout", methods=["POST"])
|
@app.route("/api/logout", methods=["POST"])
|
||||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
@cross_origin(origins=[app.config["CORS_ORIGIN"]], supports_credentials=True)
|
||||||
@json_error
|
@json_error
|
||||||
def post_logout():
|
def post_logout():
|
||||||
resp = make_response("")
|
resp = make_response("")
|
||||||
@ -120,5 +117,5 @@ def key():
|
|||||||
This route is not supposed to be called by client-side applications, so
|
This route is not supposed to be called by client-side applications, so
|
||||||
no CORS configuration is added
|
no CORS configuration is added
|
||||||
"""
|
"""
|
||||||
public_key = environ.get("JWT_PUBLIC_KEY", "").strip()
|
public_key = app.config["JWT_PUBLIC_KEY"].strip()
|
||||||
return jsonify({"key": public_key})
|
return jsonify({"key": public_key})
|
||||||
|
14
accounts/accounts/config.py
Normal file
14
accounts/accounts/config.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from os import environ
|
||||||
|
|
||||||
|
class BaseConfig(object):
|
||||||
|
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||||
|
FLASK_ADMIN_SWATCH = "flatly"
|
||||||
|
|
||||||
|
class EnvConfig(BaseConfig):
|
||||||
|
SECRET_KEY = environ.get("SESSION_SECRET")
|
||||||
|
SQLALCHEMY_DATABASE_URI = environ.get("MYSQL_CONNECTION_STRING")
|
||||||
|
CORS_ORIGIN = environ.get("CORS_ORIGIN", "*")
|
||||||
|
JWT_PRIVATE_KEY = environ.get("JWT_PRIVATE_KEY", "")
|
||||||
|
JWT_PUBLIC_KEY = environ.get("JWT_PUBLIC_KEY", "")
|
||||||
|
COOKIE_DOMAIN = environ.get("COOKIE_DOMAIN")
|
||||||
|
SERVER_HOST = environ.get("SERVER_HOST")
|
@ -1,5 +1,4 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from os import environ
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask_admin.contrib.sqla import ModelView
|
from flask_admin.contrib.sqla import ModelView
|
||||||
@ -8,7 +7,7 @@ from wtforms.validators import InputRequired, EqualTo
|
|||||||
from passlib.hash import bcrypt
|
from passlib.hash import bcrypt
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from accounts import db
|
from accounts import db, app
|
||||||
from accounts.models import AccountUserAssociation
|
from accounts.models import AccountUserAssociation
|
||||||
|
|
||||||
|
|
||||||
@ -22,18 +21,17 @@ class RemoteServerException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def create_remote_account(name, account_id):
|
def create_remote_account(name, account_id):
|
||||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
|
||||||
# expires in 30 seconds as this will mean the HTTP request would have
|
# expires in 30 seconds as this will mean the HTTP request would have
|
||||||
# timed out anyways
|
# timed out anyways
|
||||||
expiry = datetime.utcnow() + timedelta(seconds=30)
|
expiry = datetime.utcnow() + timedelta(seconds=30)
|
||||||
encoded = jwt.encode(
|
encoded = jwt.encode(
|
||||||
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
|
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
|
||||||
private_key.encode(),
|
app.config["JWT_PRIVATE_KEY"].encode(),
|
||||||
algorithm="RS256",
|
algorithm="RS256",
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
"{}/accounts".format(environ.get("SERVER_HOST")),
|
"{}/accounts".format(app.config["SERVER_HOST"]),
|
||||||
json={"name": name, "accountId": account_id},
|
json={"name": name, "accountId": account_id},
|
||||||
headers={"X-RPC-Authentication": encoded},
|
headers={"X-RPC-Authentication": encoded},
|
||||||
)
|
)
|
||||||
|
@ -84,6 +84,7 @@ functions:
|
|||||||
path: '/{proxy+}'
|
path: '/{proxy+}'
|
||||||
method: any
|
method: any
|
||||||
environment:
|
environment:
|
||||||
|
CONFIG_CLASS: accounts.config.EnvConfig
|
||||||
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
|
CORS_ORIGIN: https://${self:custom.origin.${self:custom.stage}}
|
||||||
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
COOKIE_DOMAIN: ${self:custom.cookieDomain.${self:custom.stage}}
|
||||||
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}
|
SERVER_HOST: ${self:custom.serverHost.${self:custom.stage}}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
from json import loads
|
|
||||||
from time import time
|
from time import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from os import environ
|
from os import environ
|
||||||
@ -55,7 +54,7 @@ class TestKey(unittest.TestCase):
|
|||||||
def test_get_key(self):
|
def test_get_key(self):
|
||||||
rv = self.app.get("/api/key")
|
rv = self.app.get("/api/key")
|
||||||
assert rv.status.startswith("200")
|
assert rv.status.startswith("200")
|
||||||
data = loads(rv.data)
|
data = json.loads(rv.data)
|
||||||
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
|
assert data["key"] == environ.get("JWT_PUBLIC_KEY")
|
||||||
|
|
||||||
|
|
||||||
@ -132,7 +131,7 @@ class TestJWT(unittest.TestCase):
|
|||||||
# PyJWT strips the padding from the base64 encoded parts which Python
|
# PyJWT strips the padding from the base64 encoded parts which Python
|
||||||
# cannot decode properly, so we need to add the padding ourselves
|
# cannot decode properly, so we need to add the padding ourselves
|
||||||
claims_part = _pad_b64_string(jwt.split(".")[1])
|
claims_part = _pad_b64_string(jwt.split(".")[1])
|
||||||
claims = loads(base64.b64decode(claims_part))
|
claims = json.loads(base64.b64decode(claims_part))
|
||||||
assert claims.get("exp") > time()
|
assert claims.get("exp") > time()
|
||||||
|
|
||||||
priv = claims.get("priv")
|
priv = claims.get("priv")
|
||||||
|
@ -117,6 +117,7 @@ services:
|
|||||||
links:
|
links:
|
||||||
- accounts_database
|
- accounts_database
|
||||||
environment:
|
environment:
|
||||||
|
CONFIG_CLASS: accounts.config.EnvConfig
|
||||||
FLASK_APP: accounts:app
|
FLASK_APP: accounts:app
|
||||||
FLASK_ENV: development
|
FLASK_ENV: development
|
||||||
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql
|
MYSQL_CONNECTION_STRING: mysql+pymysql://root:develop@accounts_database:3306/mysql
|
||||||
|
Loading…
Reference in New Issue
Block a user