From 3929b88fec1fa407305c89eff606d851561eeb6a Mon Sep 17 00:00:00 2001 From: Frederik Ring Date: Sat, 13 Jul 2019 22:04:12 +0200 Subject: [PATCH] serialize account claims into JWT and check in middleware --- .circleci/config.yml | 7 +++++-- accounts/Makefile | 3 +++ accounts/accounts/api.py | 27 +++++++++++++++++++++++---- accounts/accounts/models.py | 12 +++++++++++- accounts/accounts/views.py | 4 ++-- accounts/scripts/bootstrap.py | 11 ++++++++++- accounts/tests/test_jwt.py | 2 +- shared/http/jwt.go | 25 ++++++++----------------- shared/http/jwt_test.go | 4 +++- 9 files changed, 66 insertions(+), 29 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index dc6ad49..a5e558c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -87,7 +87,9 @@ jobs: echo Failed waiting for Postgres && exit 1 - run: name: Run tests - command: make test-ci + command: | + cp ~/offen/bootstrap.yml . + make test-ci shared: docker: @@ -272,7 +274,8 @@ jobs: name: Run tests command: | . venv/bin/activate - make + cp ~/offen/bootstrap.yml . + make test-ci deploy_python: docker: diff --git a/accounts/Makefile b/accounts/Makefile index f9efb2f..2c9a929 100644 --- a/accounts/Makefile +++ b/accounts/Makefile @@ -1,6 +1,9 @@ test: @pytest --disable-pytest-warnings +test-ci: bootstrap + @pytest --disable-pytest-warnings + fmt: @black . diff --git a/accounts/accounts/api.py b/accounts/accounts/api.py index 5591a64..d3120ce 100644 --- a/accounts/accounts/api.py +++ b/accounts/accounts/api.py @@ -34,6 +34,7 @@ def json_error(handler): class UnauthorizedError(Exception): pass + @app.route("/api/login", methods=["POST"]) @cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True) @json_error @@ -54,10 +55,19 @@ def post_login(): private_key = environ.get("JWT_PRIVATE_KEY", "") expiry = datetime.utcnow() + timedelta(hours=24) encoded = jwt.encode( - {"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256" + { + "ok": True, + "exp": expiry, + "priv": { + "userId": match.user_id, + "accounts": [a.account_id for a in match.accounts], + }, + }, + private_key.encode(), + algorithm="RS256", ).decode("utf-8") - resp = make_response(jsonify({"match": match})) + resp = make_response(jsonify({"user": match.serialize()})) resp.set_cookie( "auth", encoded, @@ -77,11 +87,20 @@ def get_login(): auth_cookie = request.cookies.get("auth") public_key = environ.get("JWT_PUBLIC_KEY", "") try: - jwt.decode(auth_cookie, public_key) + token = jwt.decode(auth_cookie, public_key) except jwt.exceptions.PyJWTError as unauthorized_error: return jsonify({"error": str(unauthorized_error), "status": 401}), 401 - return jsonify({"ok": True}) + try: + match = User.query.get(token["priv"]["userId"]) + except KeyError as key_err: + return ( + jsonify( + {"error": "malformed JWT claims: {}".format(key_err), "status": 401} + ), + 401, + ) + return jsonify({"user": match.serialize()}) # This route is not supposed to be called by client-side applications, so diff --git a/accounts/accounts/models.py b/accounts/accounts/models.py index b4fd6c3..942bcba 100644 --- a/accounts/accounts/models.py +++ b/accounts/accounts/models.py @@ -22,7 +22,17 @@ class User(db.Model): user_id = db.Column(db.String, primary_key=True, default=generate_key) email = db.Column(db.String, nullable=False, unique=True) hashed_password = db.Column(db.String, nullable=False) - accounts = db.relationship("AccountUserAssociation", back_populates="user") + accounts = db.relationship( + "AccountUserAssociation", back_populates="user", lazy="joined" + ) + + def serialize(self): + associated_accounts = [a.account_id for a in self.accounts] + records = [ + {"name": a.name, "accountId": a.account_id} + for a in Account.query.filter(Account.account_id.in_(associated_accounts)) + ] + return {"userId": self.user_id, "email": self.email, "accounts": records} class AccountUserAssociation(db.Model): diff --git a/accounts/accounts/views.py b/accounts/accounts/views.py index 3071006..612f9bc 100644 --- a/accounts/accounts/views.py +++ b/accounts/accounts/views.py @@ -49,7 +49,7 @@ class AccountForm(Form): class AccountView(ModelView): form = AccountForm column_display_all_relations = True - column_list = ("account_id", "name") + column_list = ("name", "account_id") def after_model_change(self, form, model, is_created): if is_created: @@ -65,7 +65,7 @@ class UserView(ModelView): inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))] column_auto_select_related = True column_display_all_relations = True - column_list = ("user_id", "email") + column_list = ("email", "user_id") form_columns = ("email", "accounts") def on_model_change(self, form, model, is_created): diff --git a/accounts/scripts/bootstrap.py b/accounts/scripts/bootstrap.py index 2fd87e7..2ab926e 100755 --- a/accounts/scripts/bootstrap.py +++ b/accounts/scripts/bootstrap.py @@ -2,7 +2,7 @@ import yaml from passlib.hash import bcrypt from accounts import db -from accounts.models import Account +from accounts.models import Account, User, AccountUserAssociation if __name__ == "__main__": db.drop_all() @@ -18,6 +18,15 @@ if __name__ == "__main__": ) db.session.add(record) + for user in data["users"]: + record = User( + email=user["email"], + hashed_password=bcrypt.hash(user["password"]), + ) + for account_id in user["accounts"]: + record.accounts.append(AccountUserAssociation(account_id=account_id)) + db.session.add(record) + db.session.commit() print("Successfully bootstrapped accounts database") diff --git a/accounts/tests/test_jwt.py b/accounts/tests/test_jwt.py index 2b0d423..956ee5f 100644 --- a/accounts/tests/test_jwt.py +++ b/accounts/tests/test_jwt.py @@ -13,7 +13,7 @@ class TestJWT(unittest.TestCase): assert rv.status.startswith("401") rv = self.app.post( - "/api/login", data=json.dumps({"username": "offen", "password": "develop"}) + "/api/login", data=json.dumps({"username": "develop@offen.dev", "password": "develop"}) ) assert rv.status.startswith("200") diff --git a/shared/http/jwt.go b/shared/http/jwt.go index 633a343..c3ad3fb 100644 --- a/shared/http/jwt.go +++ b/shared/http/jwt.go @@ -23,17 +23,15 @@ const ClaimsContextKey contextKey = "claims" // JWTProtect uses the public key located at the given URL to check if the // cookie value is signed properly. In case yes, the JWT claims will be added // to the request context -func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler { +func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var jwtValue string - var isRPC bool if authCookie, err := r.Cookie(cookieName); err == nil { jwtValue = authCookie.Value } else { - if header := r.Header.Get("X-RPC-Authentication"); header != "" { + if header := r.Header.Get(headerName); header != "" { jwtValue = header - isRPC = true } } if jwtValue == "" { @@ -76,20 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler { return } - privateClaims, _ := token.Get("priv") - if isRPC { - cast, ok := privateClaims.(map[string]interface{}) - if !ok { - RespondWithJSONError(w, fmt.Errorf("jwt: malformed private claims section in token: %v", privateClaims), http.StatusBadRequest) - return - } - if cast["rpc"] != "1" { - RespondWithJSONError(w, errors.New("jwt: token claims do not allow the requested operation"), http.StatusForbidden) - return - } + privKey, _ := token.Get("priv") + claims, _ := privKey.(map[string]interface{}) + if err := authorizer(r, claims); err != nil { + RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden) + return } - - r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims)) + r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims)) next.ServeHTTP(w, r) }) } diff --git a/shared/http/jwt_test.go b/shared/http/jwt_test.go index 4f3c2f7..3cc40fe 100644 --- a/shared/http/jwt_test.go +++ b/shared/http/jwt_test.go @@ -174,7 +174,9 @@ func TestJWTProtect(t *testing.T) { if test.server != nil { url = test.server.URL } - wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", func(r *http.Request, claims map[string]interface{}) error { + return nil + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })) w := httptest.NewRecorder()