From 6325ec8afbcaa10f1143ce1f83d33150678b3db6 Mon Sep 17 00:00:00 2001 From: Frederik Ring Date: Tue, 16 Jul 2019 10:29:24 +0200 Subject: [PATCH] clean up and add missing tests --- .circleci/config.yml | 3 -- accounts/accounts/__init__.py | 2 - accounts/accounts/api.py | 2 +- accounts/accounts/models.py | 6 +-- accounts/accounts/views.py | 11 +++- accounts/authorizer/__init__.py | 3 +- accounts/requirements.txt | 2 +- accounts/scripts/hash.py | 20 +++++++ accounts/tests/test_api.py | 59 +++++++++++++++++++++ shared/http/jwt_test.go | 93 +++++++++++++++++++++++++++++++-- 10 files changed, 184 insertions(+), 17 deletions(-) create mode 100644 accounts/scripts/hash.py diff --git a/.circleci/config.yml b/.circleci/config.yml index bba5520..618e102 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -290,9 +290,6 @@ jobs: - run: name: Install dependencies command: npm install - - run: - name: Install psycopg2 dependencies - command: sudo apt-get install libpq-dev - save_cache: paths: - ~/offen/packages/node_modules diff --git a/accounts/accounts/__init__.py b/accounts/accounts/__init__.py index 8a54231..85fce24 100644 --- a/accounts/accounts/__init__.py +++ b/accounts/accounts/__init__.py @@ -13,11 +13,9 @@ from accounts.models import Account, User from accounts.views import AccountView, UserView import accounts.api -# set optional bootswatch theme app.config["FLASK_ADMIN_SWATCH"] = "flatly" admin = Admin(app, name="offen admin", template_mode="bootstrap3") -# Add administrative views here admin.add_view(AccountView(Account, db.session)) admin.add_view(UserView(User, db.session)) diff --git a/accounts/accounts/api.py b/accounts/accounts/api.py index 4a4a582..3fe5290 100644 --- a/accounts/accounts/api.py +++ b/accounts/accounts/api.py @@ -65,7 +65,7 @@ def post_login(): }, private_key.encode(), algorithm="RS256", - ).decode("utf-8") + ).decode() resp = make_response(jsonify({"user": match.serialize()})) resp.set_cookie( diff --git a/accounts/accounts/models.py b/accounts/accounts/models.py index 0bc844d..3fe5909 100644 --- a/accounts/accounts/models.py +++ b/accounts/accounts/models.py @@ -10,7 +10,7 @@ def generate_key(): class Account(db.Model): __tablename__ = "accounts" account_id = db.Column(db.String(36), primary_key=True, default=generate_key) - name = db.Column(db.String(256), nullable=False, unique=True) + name = db.Column(db.Text, nullable=False) users = db.relationship("AccountUserAssociation", back_populates="account") def __repr__(self): @@ -20,8 +20,8 @@ class Account(db.Model): class User(db.Model): __tablename__ = "users" user_id = db.Column(db.String(36), primary_key=True, default=generate_key) - email = db.Column(db.String(256), nullable=False, unique=True) - hashed_password = db.Column(db.String(256), nullable=False) + email = db.Column(db.String(128), nullable=False, unique=True) + hashed_password = db.Column(db.Text, nullable=False) accounts = db.relationship( "AccountUserAssociation", back_populates="user", lazy="joined" ) diff --git a/accounts/accounts/views.py b/accounts/accounts/views.py index 059970a..5984958 100644 --- a/accounts/accounts/views.py +++ b/accounts/accounts/views.py @@ -15,10 +15,17 @@ from accounts.models import AccountUserAssociation class RemoteServerException(Exception): status = 0 + def __str__(self): + return "Status {}: {}".format( + self.status, super(RemoteServerException, self).__str__() + ) + def create_remote_account(name, account_id): private_key = environ.get("JWT_PRIVATE_KEY", "") - expiry = datetime.utcnow() + timedelta(seconds=10) + # expires in 30 seconds as this will mean the HTTP request would have + # timed out anyways + expiry = datetime.utcnow() + timedelta(seconds=30) encoded = jwt.encode( {"ok": True, "exp": expiry, "priv": {"rpc": "1"}}, private_key.encode(), @@ -27,7 +34,7 @@ def create_remote_account(name, account_id): r = requests.post( "{}/accounts".format(environ.get("SERVER_HOST")), - json={"name": name, "account_id": account_id}, + json={"name": name, "accountId": account_id}, headers={"X-RPC-Authentication": encoded}, ) diff --git a/accounts/authorizer/__init__.py b/accounts/authorizer/__init__.py index e4911db..c7a16f8 100644 --- a/accounts/authorizer/__init__.py +++ b/accounts/authorizer/__init__.py @@ -53,8 +53,7 @@ def handler(event, context): if user != environ.get("BASIC_AUTH_USER"): return build_response(api_arn, False) - encoded_password = environ.get("HASHED_BASIC_AUTH_PASSWORD") - hashed_password = base64.standard_b64decode(encoded_password).decode() + hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD") if not bcrypt.verify(password, hashed_password): return build_response(api_arn, False) diff --git a/accounts/requirements.txt b/accounts/requirements.txt index 132ac41..351e404 100644 --- a/accounts/requirements.txt +++ b/accounts/requirements.txt @@ -7,5 +7,5 @@ pyjwt[crypto]==1.7.1 passlib==1.7.1 bcrypt==3.1.7 PyMySQL==0.9.3 -mysqlclient +mysqlclient==1.4.2.post1 requests==2.22.0 diff --git a/accounts/scripts/hash.py b/accounts/scripts/hash.py new file mode 100644 index 0000000..23e66ab --- /dev/null +++ b/accounts/scripts/hash.py @@ -0,0 +1,20 @@ +import base64 +import argparse + +from passlib.hash import bcrypt + +parser = argparse.ArgumentParser() +parser.add_argument("--password", type=str, help="The password to hash", required=True) +parser.add_argument( + "--plain", + help="Do not encode the result as base64", + default=False, + action="store_true", +) + +if __name__ == "__main__": + args = parser.parse_args() + out = bcrypt.hash(args.password) + if not args.plain: + out = base64.standard_b64encode(out.encode()).decode() + print(out) diff --git a/accounts/tests/test_api.py b/accounts/tests/test_api.py index 246b4cb..79f7b19 100644 --- a/accounts/tests/test_api.py +++ b/accounts/tests/test_api.py @@ -3,11 +3,45 @@ import json import base64 from json import loads from time import time +from datetime import datetime, timedelta from os import environ +import jwt + from accounts import app +FOREIGN_PRIVATE_KEY = """ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG +N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P +nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj +ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN +6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4 +Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i +6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl +wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN +kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn +f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761 +5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI +TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3 +llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G +rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0 +lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS +NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU +zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE +L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz +uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ +WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f +qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE +oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm +ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d +rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO +4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF +15NFZ2gU4rKnQ3sVAOzMoEw5 +-----END PRIVATE KEY----- +""" + def _pad_b64_string(s): while len(s) % 4 is not 0: s = s + "=" @@ -81,6 +115,7 @@ class TestJWT(unittest.TestCase): """ data = json.loads(rv.data) assert data["user"]["userId"] is not None + data["user"]["accounts"].sort(key=lambda a: a["name"]) self.assertListEqual( data["user"]["accounts"], [ @@ -146,3 +181,27 @@ class TestJWT(unittest.TestCase): rv = self.app.post("/api/logout") assert rv.status.startswith("204") self._assert_cookie_not_present("auth") + + def test_forged_token(self): + """ + The application needs to verify that tokens that would be theoretically + valid are not signed using an unknown key. + """ + forged_token = jwt.encode( + { + "exp": datetime.utcnow() + timedelta(hours=24), + "priv": { + "userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8", + "accounts": [ + "9b63c4d8-65c0-438c-9d30-cc4b01173393", + "78403940-ae4f-4aff-a395-1e90f145cf62", + ], + }, + }, + FOREIGN_PRIVATE_KEY, + algorithm="RS256", + ).decode() + + self.app.set_cookie("localhost", "auth", forged_token) + rv = self.app.get("/api/login") + assert rv.status.startswith("401") diff --git a/shared/http/jwt_test.go b/shared/http/jwt_test.go index 3cc40fe..395c196 100644 --- a/shared/http/jwt_test.go +++ b/shared/http/jwt_test.go @@ -3,6 +3,7 @@ package http import ( "crypto/x509" "encoding/pem" + "errors" "fmt" "net/http" "net/http/httptest" @@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) { tests := []struct { name string cookie *http.Cookie + headers *http.Header server *httptest.Server + authorizer func(r *http.Request, claims map[string]interface{}) error expectedStatusCode int }{ { "no cookie", nil, nil, + nil, + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusForbidden, }, { @@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) { Value: "irrelevantgibberish", }, nil, + nil, + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusInternalServerError, }, { @@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) { Name: "auth", Value: "irrelevantgibberish", }, + nil, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("here's some bytes 4 y'all")) })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusInternalServerError, }, { @@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) { Name: "auth", Value: "irrelevantgibberish", }, + nil, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"key":"not really a key"}`)) })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusInternalServerError, }, { @@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) { Name: "auth", Value: "irrelevantgibberish", }, + nil, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`)) })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusBadGateway, }, { @@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) { Name: "auth", Value: "irrelevantgibberish", }, + nil, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte( fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), )) })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusForbidden, }, { @@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) { return string(b) })(), }, + nil, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte( fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), )) })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusOK, }, + { + "ok token in headers", + nil, + (func() *http.Header { + token := jwt.New() + token.Set("exp", time.Now().Add(time.Hour)) + keyBytes, _ := pem.Decode([]byte(privateKey)) + privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes) + b, _ := token.Sign(jwa.RS256, privKey) + return &http.Header{ + "X-RPC-Authentication": []string{string(b)}, + } + })(), + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte( + fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), + )) + })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, + http.StatusOK, + }, + { + "bad token in headers", + nil, + (func() *http.Header { + return &http.Header{ + "X-RPC-Authentication": []string{"nilly willy"}, + } + })(), + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte( + fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), + )) + })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, + http.StatusForbidden, + }, + { + "authorizer rejects", + &http.Cookie{ + Name: "auth", + Value: (func() string { + token := jwt.New() + token.Set("exp", time.Now().Add(time.Hour)) + token.Set("priv", map[string]interface{}{"ok": false}) + keyBytes, _ := pem.Decode([]byte(privateKey)) + privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes) + b, _ := token.Sign(jwa.RS256, privKey) + return string(b) + })(), + }, + nil, + httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte( + fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), + )) + })), + func(r *http.Request, claims map[string]interface{}) error { + if claims["ok"] == true { + return nil + } + return errors.New("expected ok to be true") + }, + http.StatusForbidden, + }, { "valid key, expired token", &http.Cookie{ @@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) { return string(b) })(), }, + nil, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte( fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), )) })), + func(r *http.Request, claims map[string]interface{}) error { return nil }, http.StatusForbidden, }, } @@ -174,9 +258,7 @@ func TestJWTProtect(t *testing.T) { if test.server != nil { url = test.server.URL } - wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", func(r *http.Request, claims map[string]interface{}) error { - return nil - })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })) w := httptest.NewRecorder() @@ -184,6 +266,11 @@ func TestJWTProtect(t *testing.T) { if test.cookie != nil { r.AddCookie(test.cookie) } + if test.headers != nil { + for key, value := range *test.headers { + r.Header.Add(key, value[0]) + } + } wrappedHandler.ServeHTTP(w, r) if w.Code != test.expectedStatusCode { t.Errorf("Unexpected status code %v", w.Code)