mirror of
https://github.com/offen/website.git
synced 2024-11-25 10:10:28 +01:00
clean up and add missing tests
This commit is contained in:
parent
1d07355112
commit
6325ec8afb
@ -290,9 +290,6 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: npm install
|
command: npm install
|
||||||
- run:
|
|
||||||
name: Install psycopg2 dependencies
|
|
||||||
command: sudo apt-get install libpq-dev
|
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ~/offen/packages/node_modules
|
- ~/offen/packages/node_modules
|
||||||
|
@ -13,11 +13,9 @@ from accounts.models import Account, User
|
|||||||
from accounts.views import AccountView, UserView
|
from accounts.views import AccountView, UserView
|
||||||
import accounts.api
|
import accounts.api
|
||||||
|
|
||||||
# set optional bootswatch theme
|
|
||||||
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
|
app.config["FLASK_ADMIN_SWATCH"] = "flatly"
|
||||||
|
|
||||||
admin = Admin(app, name="offen admin", template_mode="bootstrap3")
|
admin = Admin(app, name="offen admin", template_mode="bootstrap3")
|
||||||
# Add administrative views here
|
|
||||||
|
|
||||||
admin.add_view(AccountView(Account, db.session))
|
admin.add_view(AccountView(Account, db.session))
|
||||||
admin.add_view(UserView(User, db.session))
|
admin.add_view(UserView(User, db.session))
|
||||||
|
@ -65,7 +65,7 @@ def post_login():
|
|||||||
},
|
},
|
||||||
private_key.encode(),
|
private_key.encode(),
|
||||||
algorithm="RS256",
|
algorithm="RS256",
|
||||||
).decode("utf-8")
|
).decode()
|
||||||
|
|
||||||
resp = make_response(jsonify({"user": match.serialize()}))
|
resp = make_response(jsonify({"user": match.serialize()}))
|
||||||
resp.set_cookie(
|
resp.set_cookie(
|
||||||
|
@ -10,7 +10,7 @@ def generate_key():
|
|||||||
class Account(db.Model):
|
class Account(db.Model):
|
||||||
__tablename__ = "accounts"
|
__tablename__ = "accounts"
|
||||||
account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
account_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
||||||
name = db.Column(db.String(256), nullable=False, unique=True)
|
name = db.Column(db.Text, nullable=False)
|
||||||
users = db.relationship("AccountUserAssociation", back_populates="account")
|
users = db.relationship("AccountUserAssociation", back_populates="account")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -20,8 +20,8 @@ class Account(db.Model):
|
|||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
user_id = db.Column(db.String(36), primary_key=True, default=generate_key)
|
||||||
email = db.Column(db.String(256), nullable=False, unique=True)
|
email = db.Column(db.String(128), nullable=False, unique=True)
|
||||||
hashed_password = db.Column(db.String(256), nullable=False)
|
hashed_password = db.Column(db.Text, nullable=False)
|
||||||
accounts = db.relationship(
|
accounts = db.relationship(
|
||||||
"AccountUserAssociation", back_populates="user", lazy="joined"
|
"AccountUserAssociation", back_populates="user", lazy="joined"
|
||||||
)
|
)
|
||||||
|
@ -15,10 +15,17 @@ from accounts.models import AccountUserAssociation
|
|||||||
class RemoteServerException(Exception):
|
class RemoteServerException(Exception):
|
||||||
status = 0
|
status = 0
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Status {}: {}".format(
|
||||||
|
self.status, super(RemoteServerException, self).__str__()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_remote_account(name, account_id):
|
def create_remote_account(name, account_id):
|
||||||
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
private_key = environ.get("JWT_PRIVATE_KEY", "")
|
||||||
expiry = datetime.utcnow() + timedelta(seconds=10)
|
# expires in 30 seconds as this will mean the HTTP request would have
|
||||||
|
# timed out anyways
|
||||||
|
expiry = datetime.utcnow() + timedelta(seconds=30)
|
||||||
encoded = jwt.encode(
|
encoded = jwt.encode(
|
||||||
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
|
{"ok": True, "exp": expiry, "priv": {"rpc": "1"}},
|
||||||
private_key.encode(),
|
private_key.encode(),
|
||||||
@ -27,7 +34,7 @@ def create_remote_account(name, account_id):
|
|||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
"{}/accounts".format(environ.get("SERVER_HOST")),
|
"{}/accounts".format(environ.get("SERVER_HOST")),
|
||||||
json={"name": name, "account_id": account_id},
|
json={"name": name, "accountId": account_id},
|
||||||
headers={"X-RPC-Authentication": encoded},
|
headers={"X-RPC-Authentication": encoded},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,8 +53,7 @@ def handler(event, context):
|
|||||||
if user != environ.get("BASIC_AUTH_USER"):
|
if user != environ.get("BASIC_AUTH_USER"):
|
||||||
return build_response(api_arn, False)
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
encoded_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
|
hashed_password = environ.get("HASHED_BASIC_AUTH_PASSWORD")
|
||||||
hashed_password = base64.standard_b64decode(encoded_password).decode()
|
|
||||||
if not bcrypt.verify(password, hashed_password):
|
if not bcrypt.verify(password, hashed_password):
|
||||||
return build_response(api_arn, False)
|
return build_response(api_arn, False)
|
||||||
|
|
||||||
|
@ -7,5 +7,5 @@ pyjwt[crypto]==1.7.1
|
|||||||
passlib==1.7.1
|
passlib==1.7.1
|
||||||
bcrypt==3.1.7
|
bcrypt==3.1.7
|
||||||
PyMySQL==0.9.3
|
PyMySQL==0.9.3
|
||||||
mysqlclient
|
mysqlclient==1.4.2.post1
|
||||||
requests==2.22.0
|
requests==2.22.0
|
||||||
|
20
accounts/scripts/hash.py
Normal file
20
accounts/scripts/hash.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import base64
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from passlib.hash import bcrypt
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--password", type=str, help="The password to hash", required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plain",
|
||||||
|
help="Do not encode the result as base64",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
out = bcrypt.hash(args.password)
|
||||||
|
if not args.plain:
|
||||||
|
out = base64.standard_b64encode(out.encode()).decode()
|
||||||
|
print(out)
|
@ -3,11 +3,45 @@ import json
|
|||||||
import base64
|
import base64
|
||||||
from json import loads
|
from json import loads
|
||||||
from time import time
|
from time import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
from accounts import app
|
from accounts import app
|
||||||
|
|
||||||
|
|
||||||
|
FOREIGN_PRIVATE_KEY = """
|
||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwAPFiTSLKlVvG
|
||||||
|
N97TIyDWIxPp4Ji8hAmtlMn0gdGclC2DGKA2v7orXdNkngFon0PPe08acKI5NL9P
|
||||||
|
nkVSrjWxrn8H7LeNQadwPxjYVmri4SLhBJUcAe+SoqrIZtrci+2y64mLPrl6wxBj
|
||||||
|
ZKDl8o1Qm8iZSMgJ+wRG2FrItZUBWLZ79KSB2lQkO5OWorPX3T0SPxQXqq9hc4xN
|
||||||
|
6I+qtfmv5jZTJOviMCehOs48ZlObgr/W+Kak4q/jrrqXvG3XQqVVTN/z95+2XuN4
|
||||||
|
Btj7fv24PIRE/BddDAzC/yzISYb9QqLChaxx1fqY+aSA6ou2wh1PjUiyXNnAmP2i
|
||||||
|
6UWwikILAgMBAAECggEBAJuYmc1/x+w00qeQKQubmKH27NnsVtsCF9Q/H7NrOTYl
|
||||||
|
wX6OPMVqBlnkXsgq76/gbQB2UN5dCO1t9lua3kpT/OASFfeZjEPy8OXIwlwvOdtN
|
||||||
|
kZpAhNn31CZcbIMyevZTNlbg5/4T+8HNxSU5hw0Cu2+x6UuqDj7UjVlcWBXsgchn
|
||||||
|
f8kguLHr6Q7rndC10Vv5a4Rz9fzuS2K4jEnhlJjgD22XB2SCH5kLrAikH10AW761
|
||||||
|
5g7HSiMxKSUyXc51PX3n/FkxjzT0Vm1ENeZou263VEQhke49IWLIcbLD7ShOyNaI
|
||||||
|
TuYPAyRY4o70/d/YTydRCEp/H8stB6UaVK9hlzzfoMECgYEA1e9UgW4vBueSoZv3
|
||||||
|
llc7dXlAnk6nJeCaujjPBAd0Sc3XcpMik1kb8oDgI4fwNxTYqlHu3Wp7ZLR14C4G
|
||||||
|
rlry+4rRUdxnWNcKtyOtA6km0b33V3ja4GsLViENBSQZDUe7EljER2VSRynMTog0
|
||||||
|
lfmUr+ORzWDpanEO+Ke25zhU2DsCgYEA0pxM2UjmmAepSWBAcXABjIFE09MxXVTS
|
||||||
|
NwRhdYjHJsKmGnPD8DEDJbRSHNAEN2mTD2kJW5pFThKVWtQ8WpjSXuRSkS7HzXrU
|
||||||
|
zMNZnzTDdTZl6nnui3RJtIYntSXR7ommC6ldY7nlnHnzkIEcDLwN6E/JNOB5gtTE
|
||||||
|
L4ztUpKncHECgYBO3qHX6agasorjW52mZlh8UYxaEIMcurYwSzs+sATWJLX1/npz
|
||||||
|
uhlMiOiZEMelduD9waD/Lf95u/HtCOrbopoL1DyhIlFTdkv0AooJXHX8Qz2JmPuQ
|
||||||
|
WsZeJWcoawt1UumLtP//lkIEDEvO8/X3CIEhaxNYlQ7Yd//d+e67RZA5+wKBgD6f
|
||||||
|
qR4m1iI4jPa7fw377wn3Wh7eOlx1Hziqvcv0CruUv004RPfDqxrn/k6A7/AGHWtE
|
||||||
|
oTqyqY7oaa6jUvrhXBRJMd/nmBOaRXJJV/nF96R/s1hAP1UKE+xww5fSkhSqq0vm
|
||||||
|
ZVWE7ihT/r9mFJAYzs3YA40MfjUPzPISpnKaFt2RAoGBANCtswMqztcuPDF5rL3d
|
||||||
|
rqB6jwFrXKvwrx4HxOmF/MgGPyp6MWLBEnpZDvLJo9uSafq6Q6IwOQMWWF5GO7JO
|
||||||
|
4EG9ldVugR/CtmL3+XTHE4MGPXmqHg/q/o7rItc7g11iXJTndcUZtWGwkHwl4zBF
|
||||||
|
15NFZ2gU4rKnQ3sVAOzMoEw5
|
||||||
|
-----END PRIVATE KEY-----
|
||||||
|
"""
|
||||||
|
|
||||||
def _pad_b64_string(s):
|
def _pad_b64_string(s):
|
||||||
while len(s) % 4 is not 0:
|
while len(s) % 4 is not 0:
|
||||||
s = s + "="
|
s = s + "="
|
||||||
@ -81,6 +115,7 @@ class TestJWT(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
data = json.loads(rv.data)
|
data = json.loads(rv.data)
|
||||||
assert data["user"]["userId"] is not None
|
assert data["user"]["userId"] is not None
|
||||||
|
data["user"]["accounts"].sort(key=lambda a: a["name"])
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
data["user"]["accounts"],
|
data["user"]["accounts"],
|
||||||
[
|
[
|
||||||
@ -146,3 +181,27 @@ class TestJWT(unittest.TestCase):
|
|||||||
rv = self.app.post("/api/logout")
|
rv = self.app.post("/api/logout")
|
||||||
assert rv.status.startswith("204")
|
assert rv.status.startswith("204")
|
||||||
self._assert_cookie_not_present("auth")
|
self._assert_cookie_not_present("auth")
|
||||||
|
|
||||||
|
def test_forged_token(self):
|
||||||
|
"""
|
||||||
|
The application needs to verify that tokens that would be theoretically
|
||||||
|
valid are not signed using an unknown key.
|
||||||
|
"""
|
||||||
|
forged_token = jwt.encode(
|
||||||
|
{
|
||||||
|
"exp": datetime.utcnow() + timedelta(hours=24),
|
||||||
|
"priv": {
|
||||||
|
"userId": "8bc8db1b-f32d-4376-a1cf-724bf6a597b8",
|
||||||
|
"accounts": [
|
||||||
|
"9b63c4d8-65c0-438c-9d30-cc4b01173393",
|
||||||
|
"78403940-ae4f-4aff-a395-1e90f145cf62",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FOREIGN_PRIVATE_KEY,
|
||||||
|
algorithm="RS256",
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
self.app.set_cookie("localhost", "auth", forged_token)
|
||||||
|
rv = self.app.get("/api/login")
|
||||||
|
assert rv.status.startswith("401")
|
||||||
|
@ -3,6 +3,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -62,13 +63,17 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cookie *http.Cookie
|
cookie *http.Cookie
|
||||||
|
headers *http.Header
|
||||||
server *httptest.Server
|
server *httptest.Server
|
||||||
|
authorizer func(r *http.Request, claims map[string]interface{}) error
|
||||||
expectedStatusCode int
|
expectedStatusCode int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"no cookie",
|
"no cookie",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusForbidden,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -78,6 +83,8 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -86,9 +93,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("here's some bytes 4 y'all"))
|
w.Write([]byte("here's some bytes 4 y'all"))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,9 +106,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"key":"not really a key"}`))
|
w.Write([]byte(`{"key":"not really a key"}`))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -108,9 +119,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
|
w.Write([]byte(`{"key":"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCATZAMIIBCgKCAQEA2yUfHH6SRYKvBTemrefi\nHk4L4qkcc4skl4QCaHOkfgA4VcGKG2nXysYuZK7AzNOcHQVi+e4BwN+BfIZtwEU5\n7Ogctb5eg8ksxxLjS7eSRfQIvPGfAbJ12R9OoOWcue/CdUy/YMec4R/o4+tZ45S6\nQQWIMhLqYljw+s1Runda3K8Q8lOdJ4yEZckXaZr1waNJikC7oGpT7ClAgdbvWIbo\nN18G1OluRn+3WNdcN6V+vIj8c9dGs92bgTPX4cn3RmB/80BDfzeFiPMRw5xaq66F\n42zXzllkTqukQPk2wmO5m9pFy0ciRve+awfgbTtZRZOEpTSWLbbpOfd4RQ5YqDWJ\nmQIDAQAB\n-----END PUBLIC KEY-----"}`))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusBadGateway,
|
http.StatusBadGateway,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -119,11 +132,13 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
Name: "auth",
|
Name: "auth",
|
||||||
Value: "irrelevantgibberish",
|
Value: "irrelevantgibberish",
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusForbidden,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -139,13 +154,80 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
return string(b)
|
return string(b)
|
||||||
})(),
|
})(),
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"ok token in headers",
|
||||||
|
nil,
|
||||||
|
(func() *http.Header {
|
||||||
|
token := jwt.New()
|
||||||
|
token.Set("exp", time.Now().Add(time.Hour))
|
||||||
|
keyBytes, _ := pem.Decode([]byte(privateKey))
|
||||||
|
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
|
||||||
|
b, _ := token.Sign(jwa.RS256, privKey)
|
||||||
|
return &http.Header{
|
||||||
|
"X-RPC-Authentication": []string{string(b)},
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
|
http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"bad token in headers",
|
||||||
|
nil,
|
||||||
|
(func() *http.Header {
|
||||||
|
return &http.Header{
|
||||||
|
"X-RPC-Authentication": []string{"nilly willy"},
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"authorizer rejects",
|
||||||
|
&http.Cookie{
|
||||||
|
Name: "auth",
|
||||||
|
Value: (func() string {
|
||||||
|
token := jwt.New()
|
||||||
|
token.Set("exp", time.Now().Add(time.Hour))
|
||||||
|
token.Set("priv", map[string]interface{}{"ok": false})
|
||||||
|
keyBytes, _ := pem.Decode([]byte(privateKey))
|
||||||
|
privKey, _ := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
|
||||||
|
b, _ := token.Sign(jwa.RS256, privKey)
|
||||||
|
return string(b)
|
||||||
|
})(),
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(
|
||||||
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error {
|
||||||
|
if claims["ok"] == true {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("expected ok to be true")
|
||||||
|
},
|
||||||
|
http.StatusForbidden,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"valid key, expired token",
|
"valid key, expired token",
|
||||||
&http.Cookie{
|
&http.Cookie{
|
||||||
@ -159,11 +241,13 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
return string(b)
|
return string(b)
|
||||||
})(),
|
})(),
|
||||||
},
|
},
|
||||||
|
nil,
|
||||||
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(
|
w.Write([]byte(
|
||||||
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
fmt.Sprintf(`{"key":"%s"}`, strings.ReplaceAll(publicKey, "\n", `\n`)),
|
||||||
))
|
))
|
||||||
})),
|
})),
|
||||||
|
func(r *http.Request, claims map[string]interface{}) error { return nil },
|
||||||
http.StatusForbidden,
|
http.StatusForbidden,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -174,9 +258,7 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
if test.server != nil {
|
if test.server != nil {
|
||||||
url = test.server.URL
|
url = test.server.URL
|
||||||
}
|
}
|
||||||
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", func(r *http.Request, claims map[string]interface{}) error {
|
wrappedHandler := JWTProtect(url, "auth", "X-RPC-Authentication", test.authorizer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
return nil
|
|
||||||
})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte("OK"))
|
w.Write([]byte("OK"))
|
||||||
}))
|
}))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@ -184,6 +266,11 @@ func TestJWTProtect(t *testing.T) {
|
|||||||
if test.cookie != nil {
|
if test.cookie != nil {
|
||||||
r.AddCookie(test.cookie)
|
r.AddCookie(test.cookie)
|
||||||
}
|
}
|
||||||
|
if test.headers != nil {
|
||||||
|
for key, value := range *test.headers {
|
||||||
|
r.Header.Add(key, value[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
wrappedHandler.ServeHTTP(w, r)
|
wrappedHandler.ServeHTTP(w, r)
|
||||||
if w.Code != test.expectedStatusCode {
|
if w.Code != test.expectedStatusCode {
|
||||||
t.Errorf("Unexpected status code %v", w.Code)
|
t.Errorf("Unexpected status code %v", w.Code)
|
||||||
|
Loading…
Reference in New Issue
Block a user