2
0
mirror of https://github.com/offen/website.git synced 2024-11-25 10:10:28 +01:00

clean up and add missing tests

This commit is contained in:
Frederik Ring 2019-07-16 10:29:24 +02:00
parent 1d07355112
commit 6325ec8afb
10 changed files with 184 additions and 17 deletions

View File

@ -290,9 +290,6 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: npm install command: npm install
- run:
name: Install psycopg2 dependencies
command: sudo apt-get install libpq-dev
- save_cache: - save_cache:
paths: paths:
- ~/offen/packages/node_modules - ~/offen/packages/node_modules

View File

@ -13,11 +13,9 @@ from accounts.models import Account, User
from accounts.views import AccountView, UserView from accounts.views import AccountView, UserView
import accounts.api import accounts.api
# set optional bootswatch theme
app.config["FLASK_ADMIN_SWATCH"] = "flatly" app.config["FLASK_ADMIN_SWATCH"] = "flatly"
admin = Admin(app, name="offen admin", template_mode="bootstrap3") admin = Admin(app, name="offen admin", template_mode="bootstrap3")
# Add administrative views here
admin.add_view(AccountView(Account, db.session)) admin.add_view(AccountView(Account, db.session))
admin.add_view(UserView(User, db.session)) admin.add_view(UserView(User, db.session))

View File

@ -65,7 +65,7 @@ def post_login():
}, },
private_key.encode(), private_key.encode(),
algorithm="RS256", algorithm="RS256",
).decode("utf-8") ).decode()
resp = make_response(jsonify({"user": match.serialize()})) resp = make_response(jsonify({"user": match.serialize()}))
resp.set_cookie( resp.set_cookie(

View File

@ -10,7 +10,7 @@ def generate_key():
class Account(db.Model): class Account(db.Model):
__tablename__ = "accounts" __tablename__ = "accounts"
account_id = db.Column(db.String(36), primary_key=True, default=generate_key) account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
name = db.Column(db.String(256), nullable=False, unique=True) name = db.Column(db.Text, nullable=False)
users = db.relationship("AccountUserAssociation", back_populates="account") users = db.relationship("AccountUserAssociation", back_populates="account")
def __repr__(self): def __repr__(self):
@ -20,8 +20,8 @@ class Account(db.Model):
class User(db.Model): class User(db.Model):
__tablename__ = "users" __tablename__ = "users"
user_id = db.Column(db.String(36), primary_key=True, default=generate_key) user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
email = db.Column(db.String(256), nullable=False, unique=True) email = db.Column(db.String(128), nullable=False, unique=True)
hashed_password = db.Column(db.String(256), nullable=False) hashed_password = db.Column(db.Text, nullable=False)
accounts = db.relationship( accounts = db.relationship(
"AccountUserAssociation", back_populates="user", lazy="joined" "AccountUserAssociation", back_populates="user", lazy="joined"
) )

View File

@ -15,10 +15,17 @@ from accounts.models import AccountUserAssociation
class RemoteServerException(Exception): class RemoteServerException(Exception):
status = 0 status = 0
def __str__(self):
return "Status {}: {}".format(
self.status, super(RemoteServerException, self).__str__()
)
def create_remote_account(name, account_id): def create_remote_account(name, account_id):
private_key = environ.get("JWT_PRIVATE_KEY", "") private_key = environ.get("JWT_PRIVATE_KEY", "")
expiry = datetime.utcnow() + timedelta(seconds=10) # expires in 30 seconds as this will mean the HTTP request would have
# timed out anyways
expiry = datetime.utcnow() + timedelta(seconds=30)
encoded = jwt.encode( encoded = jwt.encode(
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}}, {"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
private_key.encode(), private_key.encode(),
@ -27,7 +34,7 @@ def create_remote_account(name, account_id):
r = requests.post( r = requests.post(
"{}/accounts".format(environ.get("SERVER_HOST")), "{}/accounts".format(environ.get("SERVER_HOST")),
json={"name": name, "account_id": account_id}, json={"name": name, "accountId": account_id},
headers={"X-RPC-Authentication": encoded}, headers={"X-RPC-Authentication": encoded},
) )

View File

@ -53,8 +53,7 @@ def handler(event, context):
if user != environ.get("BASIC_AUTH_USER"): if user != environ.get("BASIC_AUTH_USER"):
return build_response(api_arn, False) return build_response(api_arn, False)
encoded_password = environ.get("HASHED_BASIC_AUTH_PASSWORD") hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
hashed_password = base64.standard_b64decode(encoded_password).decode()
if not bcrypt.verify(password, hashed_password): if not bcrypt.verify(password, hashed_password):
return build_response(api_arn, False) return build_response(api_arn, False)

View File

@ -7,5 +7,5 @@ pyjwt[crypto]==1.7.1
passlib==1.7.1 passlib==1.7.1
bcrypt==3.1.7 bcrypt==3.1.7
PyMySQL==0.9.3 PyMySQL==0.9.3
mysqlclient mysqlclient==1.4.2.post1
requests==2.22.0 requests==2.22.0

20
accounts/scripts/hash.py Normal file
View File

@ -0,0 +1,20 @@
import base64
import argparse
from passlib.hash import bcrypt
parser = argparse.ArgumentParser()
parser.add_argument("--password", type=str, help="The password to hash", required=True)
parser.add_argument(
"--plain",
help="Do not encode the result as base64",
default=False,
action="store_true",
)
if __name__ == "__main__":
args = parser.parse_args()
out = bcrypt.hash(args.password)
if not args.plain:
out = base64.standard_b64encode(out.encode()).decode()
print(out)

View File

@ -3,11 +3,45 @@ import json
import base64 import base64
from json import loads from json import loads
from time import time from time import time
from datetime import datetime, timedelta
from os import environ from os import environ
import jwt
from accounts import app from accounts import app
FOREIGN_PRIVATE_KEY = """
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG
N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P
nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj
ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN
6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4
Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i
6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl
wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN
kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn
f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761
5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI
TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3
llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G
rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0
lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS
NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU
zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE
L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz
uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ
WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f
qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE
oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm
ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d
rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO
4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF
15NFZ2gU4rKnQ3sVAOzMoEw5
-----END PRIVATE KEY-----
"""
def _pad_b64_string(s): def _pad_b64_string(s):
while len(s) % 4 is not 0: while len(s) % 4 is not 0:
s = s + "=" s = s + "="
@ -81,6 +115,7 @@ class TestJWT(unittest.TestCase):
""" """
data = json.loads(rv.data) data = json.loads(rv.data)
assert data["user"]["userId"] is not None assert data["user"]["userId"] is not None
data["user"]["accounts"].sort(key=lambda a: a["name"])
self.assertListEqual( self.assertListEqual(
data["user"]["accounts"], data["user"]["accounts"],
[ [
@ -146,3 +181,27 @@ class TestJWT(unittest.TestCase):
rv = self.app.post("/api/logout") rv = self.app.post("/api/logout")
assert rv.status.startswith("204") assert rv.status.startswith("204")
self._assert_cookie_not_present("auth") self._assert_cookie_not_present("auth")
def test_forged_token(self):
"""
The application needs to verify that tokens that would be theoretically
valid are not signed using an unknown key.
"""
forged_token = jwt.encode(
{
"exp": datetime.utcnow() + timedelta(hours=24),
"priv": {
"userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8",
"accounts": [
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
"78403940-ae4f-4aff-a395-1e90f145cf62",
],
},
},
FOREIGN_PRIVATE_KEY,
algorithm="RS256",
).decode()
self.app.set_cookie("localhost", "auth", forged_token)
rv = self.app.get("/api/login")
assert rv.status.startswith("401")

View File

@ -3,6 +3,7 @@ package http
import ( import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cookie *http.Cookie cookie *http.Cookie
headers *http.Header
server *httptest.Server server *httptest.Server
authorizer func(r *http.Request, claims map[string]interface{}) error
expectedStatusCode int expectedStatusCode int
}{ }{
{ {
"no cookie", "no cookie",
nil, nil,
nil, nil,
nil,
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden, http.StatusForbidden,
}, },
{ {
@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) {
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil, nil,
nil,
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusInternalServerError,
}, },
{ {
@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("here's some bytes 4 y'all")) w.Write([]byte("here's some bytes 4 y'all"))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusInternalServerError,
}, },
{ {
@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"not really a key"}`)) w.Write([]byte(`{"key":"not really a key"}`))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusInternalServerError, http.StatusInternalServerError,
}, },
{ {
@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`)) w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusBadGateway, http.StatusBadGateway,
}, },
{ {
@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) {
Name: "auth", Name: "auth",
Value: "irrelevantgibberish", Value: "irrelevantgibberish",
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden, http.StatusForbidden,
}, },
{ {
@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) {
return string(b) return string(b)
})(), })(),
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusOK, http.StatusOK,
}, },
{
"ok token in headers",
nil,
(func() *http.Header {
token := jwt.New()
token.Set("exp", time.Now().Add(time.Hour))
keyBytes, _ := pem.Decode([]byte(privateKey))
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
b, _ := token.Sign(jwa.RS256, privKey)
return &http.Header{
"X-RPC-Authentication": []string{string(b)},
}
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusOK,
},
{
"bad token in headers",
nil,
(func() *http.Header {
return &http.Header{
"X-RPC-Authentication": []string{"nilly willy"},
}
})(),
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden,
},
{
"authorizer rejects",
&http.Cookie{
Name: "auth",
Value: (func() string {
token := jwt.New()
token.Set("exp", time.Now().Add(time.Hour))
token.Set("priv", map[string]interface{}{"ok": false})
keyBytes, _ := pem.Decode([]byte(privateKey))
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
b, _ := token.Sign(jwa.RS256, privKey)
return string(b)
})(),
},
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
))
})),
func(r *http.Request, claims map[string]interface{}) error {
if claims["ok"] == true {
return nil
}
return errors.New("expected ok to be true")
},
http.StatusForbidden,
},
{ {
"valid key, expired token", "valid key, expired token",
&http.Cookie{ &http.Cookie{
@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) {
return string(b) return string(b)
})(), })(),
}, },
nil,
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte( w.Write([]byte(
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)), fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
)) ))
})), })),
func(r *http.Request, claims map[string]interface{}) error { return nil },
http.StatusForbidden, http.StatusForbidden,
}, },
} }
@ -174,9 +258,7 @@ func TestJWTProtect(t *testing.T) {
if test.server != nil { if test.server != nil {
url = test.server.URL url = test.server.URL
} }
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", func(r *http.Request, claims map[string]interface{}) error { wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return nil
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK")) w.Write([]byte("OK"))
})) }))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -184,6 +266,11 @@ func TestJWTProtect(t *testing.T) {
if test.cookie != nil { if test.cookie != nil {
r.AddCookie(test.cookie) r.AddCookie(test.cookie)
} }
if test.headers != nil {
for key, value := range *test.headers {
r.Header.Add(key, value[0])
}
}
wrappedHandler.ServeHTTP(w, r) wrappedHandler.ServeHTTP(w, r)
if w.Code != test.expectedStatusCode { if w.Code != test.expectedStatusCode {
t.Errorf("Unexpected status code %v", w.Code) t.Errorf("Unexpected status code %v", w.Code)