2
0
mirror of https://github.com/offen/website.git synced 2025-01-22 09:10:24 +01:00

serialize account claims into JWT and check in middleware

This commit is contained in:
Frederik Ring 2019-07-13 22:04:12 +02:00
parent d5be3feb7e
commit 3929b88fec
9 changed files with 66 additions and 29 deletions

View File

@ -87,7 +87,9 @@ jobs:
echo Failed waiting for Postgres && exit 1
- run:
name: Run tests
command: make test-ci
command: |
cp ~/offen/bootstrap.yml .
make test-ci
shared:
docker:
@ -272,7 +274,8 @@ jobs:
name: Run tests
command: |
. venv/bin/activate
make
cp ~/offen/bootstrap.yml .
make test-ci
deploy_python:
docker:

View File

@ -1,6 +1,9 @@
test:
@pytest --disable-pytest-warnings
test-ci: bootstrap
@pytest --disable-pytest-warnings
fmt:
@black .

View File

@ -34,6 +34,7 @@ def json_error(handler):
class UnauthorizedError(Exception):
pass
@app.route("/api/login", methods=["POST"])
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
@json_error
@ -54,10 +55,19 @@ def post_login():
private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(hours=24)
encoded = jwt.encode(
{"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256"
{
"ok": True,
"exp": expiry,
"priv": {
"userId": match.user_id,
"accounts": [a.account_id for a in match.accounts],
},
},
private_key.encode(),
algorithm="RS256",
).decode("utf-8")
resp = make_response(jsonify({"match": match}))
resp = make_response(jsonify({"user": match.serialize()}))
resp.set_cookie(
"auth",
encoded,
@ -77,11 +87,20 @@ def get_login():
auth_cookie = request.cookies.get("auth")
public_key = environ.get("JWT_PUBLIC_KEY", "")
try:
jwt.decode(auth_cookie, public_key)
token = jwt.decode(auth_cookie, public_key)
except jwt.exceptions.PyJWTError as unauthorized_error:
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
return jsonify({"ok": True})
try:
match = User.query.get(token["priv"]["userId"])
except KeyError as key_err:
return (
jsonify(
{"error": "malformed JWT claims: {}".format(key_err), "status": 401}
),
401,
)
return jsonify({"user": match.serialize()})
# This route is not supposed to be called by client-side applications, so

View File

@ -22,7 +22,17 @@ class User(db.Model):
user_id = db.Column(db.String, primary_key=True, default=generate_key)
email = db.Column(db.String, nullable=False, unique=True)
hashed_password = db.Column(db.String, nullable=False)
accounts = db.relationship("AccountUserAssociation", back_populates="user")
accounts = db.relationship(
"AccountUserAssociation", back_populates="user", lazy="joined"
)
def serialize(self):
associated_accounts = [a.account_id for a in self.accounts]
records = [
{"name": a.name, "accountId": a.account_id}
for a in Account.query.filter(Account.account_id.in_(associated_accounts))
]
return {"userId": self.user_id, "email": self.email, "accounts": records}
class AccountUserAssociation(db.Model):

View File

@ -49,7 +49,7 @@ class AccountForm(Form):
class AccountView(ModelView):
form = AccountForm
column_display_all_relations = True
column_list = ("account_id", "name")
column_list = ("name", "account_id")
def after_model_change(self, form, model, is_created):
if is_created:
@ -65,7 +65,7 @@ class UserView(ModelView):
inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
column_auto_select_related = True
column_display_all_relations = True
column_list = ("user_id", "email")
column_list = ("email", "user_id")
form_columns = ("email", "accounts")
def on_model_change(self, form, model, is_created):

View File

@ -2,7 +2,7 @@ import yaml
from passlib.hash import bcrypt
from accounts import db
from accounts.models import Account
from accounts.models import Account, User, AccountUserAssociation
if __name__ == "__main__":
db.drop_all()
@ -18,6 +18,15 @@ if __name__ == "__main__":
)
db.session.add(record)
for user in data["users"]:
record = User(
email=user["email"],
hashed_password=bcrypt.hash(user["password"]),
)
for account_id in user["accounts"]:
record.accounts.append(AccountUserAssociation(account_id=account_id))
db.session.add(record)
db.session.commit()
print("Successfully bootstrapped accounts database")

View File

@ -13,7 +13,7 @@ class TestJWT(unittest.TestCase):
assert rv.status.startswith("401")
rv = self.app.post(
"/api/login", data=json.dumps({"username": "offen", "password": "develop"})
"/api/login", data=json.dumps({"username": "develop@offen.dev", "password": "develop"})
)
assert rv.status.startswith("200")

View File

@ -23,17 +23,15 @@ const ClaimsContextKey contextKey = "claims"
// JWTProtect uses the public key located at the given URL to check if the
// cookie value is signed properly. In case yes, the JWT claims will be added
// to the request context
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var jwtValue string
var isRPC bool
if authCookie, err := r.Cookie(cookieName); err == nil {
jwtValue = authCookie.Value
} else {
if header := r.Header.Get("X-RPC-Authentication"); header != "" {
if header := r.Header.Get(headerName); header != "" {
jwtValue = header
isRPC = true
}
}
if jwtValue == "" {
@ -76,20 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
return
}
privateClaims, _ := token.Get("priv")
if isRPC {
cast, ok := privateClaims.(map[string]interface{})
if !ok {
RespondWithJSONError(w, fmt.Errorf("jwt: malformed private claims section in token: %v", privateClaims), http.StatusBadRequest)
return
}
if cast["rpc"] != "1" {
RespondWithJSONError(w, errors.New("jwt: token claims do not allow the requested operation"), http.StatusForbidden)
return
}
privKey, _ := token.Get("priv")
claims, _ := privKey.(map[string]interface{})
if err := authorizer(r, claims); err != nil {
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
return
}
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims))
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
next.ServeHTTP(w, r)
})
}

View File

@ -174,7 +174,9 @@ func TestJWTProtect(t *testing.T) {
if test.server != nil {
url = test.server.URL
}
wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", func(r *http.Request, claims map[string]interface{}) error {
return nil
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
w := httptest.NewRecorder()