mirror of
https://github.com/offen/website.git
synced 2024-11-22 09:00:28 +01:00
serialize account claims into JWT and check in middleware
This commit is contained in:
parent
d5be3feb7e
commit
3929b88fec
@ -87,7 +87,9 @@ jobs:
|
|||||||
echo Failed waiting for Postgres && exit 1
|
echo Failed waiting for Postgres && exit 1
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: make test-ci
|
command: |
|
||||||
|
cp ~/offen/bootstrap.yml .
|
||||||
|
make test-ci
|
||||||
|
|
||||||
shared:
|
shared:
|
||||||
docker:
|
docker:
|
||||||
@ -272,7 +274,8 @@ jobs:
|
|||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
. venv/bin/activate
|
. venv/bin/activate
|
||||||
make
|
cp ~/offen/bootstrap.yml .
|
||||||
|
make test-ci
|
||||||
|
|
||||||
deploy_python:
|
deploy_python:
|
||||||
docker:
|
docker:
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
test:
|
test:
|
||||||
@pytest --disable-pytest-warnings
|
@pytest --disable-pytest-warnings
|
||||||
|
|
||||||
|
test-ci: bootstrap
|
||||||
|
@pytest --disable-pytest-warnings
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
@black .
|
@black .
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ def json_error(handler):
|
|||||||
class UnauthorizedError(Exception):
|
class UnauthorizedError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/login", methods=["POST"])
|
@app.route("/api/login", methods=["POST"])
|
||||||
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
@cross_origin(origins=[environ.get("CORS_ORIGIN", "*")], supports_credentials=True)
|
||||||
@json_error
|
@json_error
|
||||||
@ -54,10 +55,19 @@ def post_login():
|
|||||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
||||||
expiry = datetime.utcnow() + timedelta(hours=24)
|
expiry = datetime.utcnow() + timedelta(hours=24)
|
||||||
encoded = jwt.encode(
|
encoded = jwt.encode(
|
||||||
{"ok": True, "exp": expiry}, private_key.encode(), algorithm="RS256"
|
{
|
||||||
|
"ok": True,
|
||||||
|
"exp": expiry,
|
||||||
|
"priv": {
|
||||||
|
"userId": match.user_id,
|
||||||
|
"accounts": [a.account_id for a in match.accounts],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
private_key.encode(),
|
||||||
|
algorithm="RS256",
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
|
||||||
resp = make_response(jsonify({"match": match}))
|
resp = make_response(jsonify({"user": match.serialize()}))
|
||||||
resp.set_cookie(
|
resp.set_cookie(
|
||||||
"auth",
|
"auth",
|
||||||
encoded,
|
encoded,
|
||||||
@ -77,11 +87,20 @@ def get_login():
|
|||||||
auth_cookie = request.cookies.get("auth")
|
auth_cookie = request.cookies.get("auth")
|
||||||
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
public_key = environ.get("JWT_PUBLIC_KEY", "")
|
||||||
try:
|
try:
|
||||||
jwt.decode(auth_cookie, public_key)
|
token = jwt.decode(auth_cookie, public_key)
|
||||||
except jwt.exceptions.PyJWTError as unauthorized_error:
|
except jwt.exceptions.PyJWTError as unauthorized_error:
|
||||||
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
return jsonify({"error": str(unauthorized_error), "status": 401}), 401
|
||||||
|
|
||||||
return jsonify({"ok": True})
|
try:
|
||||||
|
match = User.query.get(token["priv"]["userId"])
|
||||||
|
except KeyError as key_err:
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{"error": "malformed JWT claims: {}".format(key_err), "status": 401}
|
||||||
|
),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
return jsonify({"user": match.serialize()})
|
||||||
|
|
||||||
|
|
||||||
# This route is not supposed to be called by client-side applications, so
|
# This route is not supposed to be called by client-side applications, so
|
||||||
|
@ -22,7 +22,17 @@ class User(db.Model):
|
|||||||
user_id = db.Column(db.String, primary_key=True, default=generate_key)
|
user_id = db.Column(db.String, primary_key=True, default=generate_key)
|
||||||
email = db.Column(db.String, nullable=False, unique=True)
|
email = db.Column(db.String, nullable=False, unique=True)
|
||||||
hashed_password = db.Column(db.String, nullable=False)
|
hashed_password = db.Column(db.String, nullable=False)
|
||||||
accounts = db.relationship("AccountUserAssociation", back_populates="user")
|
accounts = db.relationship(
|
||||||
|
"AccountUserAssociation", back_populates="user", lazy="joined"
|
||||||
|
)
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
associated_accounts = [a.account_id for a in self.accounts]
|
||||||
|
records = [
|
||||||
|
{"name": a.name, "accountId": a.account_id}
|
||||||
|
for a in Account.query.filter(Account.account_id.in_(associated_accounts))
|
||||||
|
]
|
||||||
|
return {"userId": self.user_id, "email": self.email, "accounts": records}
|
||||||
|
|
||||||
|
|
||||||
class AccountUserAssociation(db.Model):
|
class AccountUserAssociation(db.Model):
|
||||||
|
@ -49,7 +49,7 @@ class AccountForm(Form):
|
|||||||
class AccountView(ModelView):
|
class AccountView(ModelView):
|
||||||
form = AccountForm
|
form = AccountForm
|
||||||
column_display_all_relations = True
|
column_display_all_relations = True
|
||||||
column_list = ("account_id", "name")
|
column_list = ("name", "account_id")
|
||||||
|
|
||||||
def after_model_change(self, form, model, is_created):
|
def after_model_change(self, form, model, is_created):
|
||||||
if is_created:
|
if is_created:
|
||||||
@ -65,7 +65,7 @@ class UserView(ModelView):
|
|||||||
inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
|
inline_models = [(AccountUserAssociation, dict(form_columns=["id", "account"]))]
|
||||||
column_auto_select_related = True
|
column_auto_select_related = True
|
||||||
column_display_all_relations = True
|
column_display_all_relations = True
|
||||||
column_list = ("user_id", "email")
|
column_list = ("email", "user_id")
|
||||||
form_columns = ("email", "accounts")
|
form_columns = ("email", "accounts")
|
||||||
|
|
||||||
def on_model_change(self, form, model, is_created):
|
def on_model_change(self, form, model, is_created):
|
||||||
|
@ -2,7 +2,7 @@ import yaml
|
|||||||
from passlib.hash import bcrypt
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
from accounts import db
|
from accounts import db
|
||||||
from accounts.models import Account
|
from accounts.models import Account, User, AccountUserAssociation
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
db.drop_all()
|
db.drop_all()
|
||||||
@ -18,6 +18,15 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
db.session.add(record)
|
db.session.add(record)
|
||||||
|
|
||||||
|
for user in data["users"]:
|
||||||
|
record = User(
|
||||||
|
email=user["email"],
|
||||||
|
hashed_password=bcrypt.hash(user["password"]),
|
||||||
|
)
|
||||||
|
for account_id in user["accounts"]:
|
||||||
|
record.accounts.append(AccountUserAssociation(account_id=account_id))
|
||||||
|
db.session.add(record)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
print("Successfully bootstrapped accounts database")
|
print("Successfully bootstrapped accounts database")
|
||||||
|
@ -13,7 +13,7 @@ class TestJWT(unittest.TestCase):
|
|||||||
assert rv.status.startswith("401")
|
assert rv.status.startswith("401")
|
||||||
|
|
||||||
rv = self.app.post(
|
rv = self.app.post(
|
||||||
"/api/login", data=json.dumps({"username": "offen", "password": "develop"})
|
"/api/login", data=json.dumps({"username": "develop@offen.dev", "password": "develop"})
|
||||||
)
|
)
|
||||||
assert rv.status.startswith("200")
|
assert rv.status.startswith("200")
|
||||||
|
|
||||||
|
@ -23,17 +23,15 @@ const ClaimsContextKey contextKey = "claims"
|
|||||||
// JWTProtect uses the public key located at the given URL to check if the
|
// JWTProtect uses the public key located at the given URL to check if the
|
||||||
// cookie value is signed properly. In case yes, the JWT claims will be added
|
// cookie value is signed properly. In case yes, the JWT claims will be added
|
||||||
// to the request context
|
// to the request context
|
||||||
func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
func JWTProtect(keyURL, cookieName, headerName string, authorizer func(*http.Request, map[string]interface{}) error) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
var jwtValue string
|
var jwtValue string
|
||||||
var isRPC bool
|
|
||||||
if authCookie, err := r.Cookie(cookieName); err == nil {
|
if authCookie, err := r.Cookie(cookieName); err == nil {
|
||||||
jwtValue = authCookie.Value
|
jwtValue = authCookie.Value
|
||||||
} else {
|
} else {
|
||||||
if header := r.Header.Get("X-RPC-Authentication"); header != "" {
|
if header := r.Header.Get(headerName); header != "" {
|
||||||
jwtValue = header
|
jwtValue = header
|
||||||
isRPC = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if jwtValue == "" {
|
if jwtValue == "" {
|
||||||
@ -76,20 +74,13 @@ func JWTProtect(keyURL, cookieName string) func(http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
privateClaims, _ := token.Get("priv")
|
privKey, _ := token.Get("priv")
|
||||||
if isRPC {
|
claims, _ := privKey.(map[string]interface{})
|
||||||
cast, ok := privateClaims.(map[string]interface{})
|
if err := authorizer(r, claims); err != nil {
|
||||||
if !ok {
|
RespondWithJSONError(w, fmt.Errorf("jwt: token claims do not allow the requested operation: %v", err), http.StatusForbidden)
|
||||||
RespondWithJSONError(w, fmt.Errorf("jwt: malformed private claims section in token: %v", privateClaims), http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if cast["rpc"] != "1" {
|
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, claims))
|
||||||
RespondWithJSONError(w, errors.New("jwt: token claims do not allow the requested operation"), http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r = r.WithContext(context.WithValue(r.Context(), ClaimsContextKey, privateClaims))
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -174,7 +174,9 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
if test.server != nil {
|
if test.server != nil {
|
||||||
url = test.server.URL
|
url = test.server.URL
|
||||||
}
|
}
|
||||||
wrappedHandler := JWTProtect(url, "auth")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", func(r *http.Request, claims map[string]interface{}) error {
|
||||||
|
return nil
|
||||||
|
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("OK"))
|
w.Write([]byte("OK"))
|
||||||
}))
|
}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
Loading…
Reference in New Issue
Block a user