diff --git a/CHANGELOG.md b/CHANGELOG.md index 6975cf73..b78a44b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,14 @@ ## Important Notes +- [#632](https://github.com/oauth2-proxy/oauth2-proxy/pull/632) There is backwards compatibility to sessions from v5 + - Any unencrypted sessions from before v5 that only contained a Username & Email will trigger a reauthentication + ## Breaking Changes ## Changes since v6.0.0 +- [#632](https://github.com/oauth2-proxy/oauth2-proxy/pull/632) Reduce session size by encoding with MessagePack and using LZ4 compression (@NickMeves) - [#675](https://github.com/oauth2-proxy/oauth2-proxy/pull/675) Fix required ruby version and deprecated option for building docs (@mkontani) - [#669](https://github.com/oauth2-proxy/oauth2-proxy/pull/669) Reduce docker context to improve build times (@JoelSpeed) - [#668](https://github.com/oauth2-proxy/oauth2-proxy/pull/668) Use req.Host in --force-https when req.URL.Host is empty (@zucaritask) diff --git a/go.mod b/go.mod index cacf7ab8..cd1d061a 100644 --- a/go.mod +++ b/go.mod @@ -9,23 +9,24 @@ require ( github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/coreos/go-oidc v2.2.1+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/frankban/quicktest v1.10.0 // indirect github.com/fsnotify/fsnotify v1.4.9 github.com/go-redis/redis/v7 v7.2.0 github.com/justinas/alice v1.2.0 - github.com/kr/pretty v0.2.0 // indirect github.com/mbland/hmacauth v0.0.0-20170912233209-44256dfd4bfa github.com/mitchellh/mapstructure v1.1.2 github.com/onsi/ginkgo v1.12.0 github.com/onsi/gomega v1.9.0 + github.com/pierrec/lz4 v2.5.2+incompatible github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/spf13/pflag v1.0.3 github.com/spf13/viper v1.6.3 github.com/stretchr/testify v1.5.1 + github.com/vmihailenco/msgpack/v4 v4.3.11 github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997 - golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 - golang.org/x/net v0.0.0-20200226121028-0de0cce0169b + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 + golang.org/x/net v0.0.0-20200301022130-244492dfa37a golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d - golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect google.golang.org/api v0.20.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/square/go-jose.v2 v2.4.1 diff --git a/go.sum b/go.sum index ba0d342b..619014de 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/frankban/quicktest v1.10.0 h1:Gfh+GAJZOAoKZsIZeZbdn2JF10kN1XHNvjsvQK8gVkE= +github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= @@ -69,6 +71,8 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.4 h1:87PNWwrRvUSnqS4dlcBU/ftvOIBep4sYuBLlh6rX2wk= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/gomodule/redigo v1.7.1-0.20190322064113-39e2c31b7ca3 h1:6amM4HsNPOvMLVc2ZnyqrjeQ92YAVWn7T4WBKK87inY= github.com/gomodule/redigo v1.7.1-0.20190322064113-39e2c31b7ca3/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/gomodule/redigo v1.8.1 h1:Abmo0bI7Xf0IhdIPc7HZQzZcShdnmxeoVuDDtIQp8N8= @@ -78,6 +82,8 @@ github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -142,6 +148,8 @@ github.com/onsi/gomega v1.9.0 h1:R1uwffexN6Pr340GtYRIdZmAiN4J+iw6WG4wog1DUXg= github.com/onsi/gomega v1.9.0/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pierrec/lz4 v2.5.2+incompatible h1:WCjObylUIOlKy/+7Abdn34TLIkXiA4UWUMhxq9m9ZXI= +github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -184,6 +192,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/vmihailenco/msgpack/v4 v4.3.11 h1:Q47CePddpNGNhk4GCnAx9DDtASi2rasatE0cd26cZoE= +github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= +github.com/vmihailenco/tagparser v0.1.1 h1:quXMXlA39OCbd2wAdTsGDlK9RkOk6Wuw+x37wVyIuWY= +github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997 h1:1+FQ4Ns+UZtUiQ4lP0sTCyKSQ0EXoiwAdHZB0Pd5t9Q= @@ -201,8 +213,6 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -217,14 +227,14 @@ golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c h1:uOCk1iQW6Vc18bnC13MfzScl+wdKBmM9Y9kU7Z83/lw= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0= @@ -243,7 +253,6 @@ golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b h1:ag/x1USPSsqHud38I9BAC88qdNLDHHtQ4mlgQIZPPNA= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -276,6 +285,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 0510a252..5bafd048 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -23,7 +23,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" + sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" "github.com/oauth2-proxy/oauth2-proxy/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/stretchr/testify/assert" @@ -36,11 +36,12 @@ const ( // encoded version of this. rawCookieSecret = "secretthirtytwobytes+abcdefghijk" base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams" + clientID = "3984n253984d7348dm8234yf982t" + clientSecret = "gv3498mfc9t23y23974dm2394dm9" ) func init() { logger.SetFlags(logger.Lshortfile) - } type WebSocketOrRestHandler struct { @@ -61,24 +62,30 @@ func TestWebSocketProxy(t *testing.T) { restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) + _, err := w.Write([]byte(hostname)) + if err != nil { + t.Fatal(err) + } }), wsHandler: websocket.Handler(func(ws *websocket.Conn) { - defer ws.Close() + defer func(t *testing.T) { + if err := ws.Close(); err != nil { + t.Fatal(err) + } + }(t) var data []byte err := websocket.Message.Receive(ws, &data) if err != nil { - t.Fatalf("err %s", err) - return + t.Fatal(err) } err = websocket.Message.Send(ws, data) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } }), } backend := httptest.NewServer(&handler) - defer backend.Close() + t.Cleanup(backend.Close) backendURL, _ := url.Parse(backend.URL) @@ -87,24 +94,24 @@ func TestWebSocketProxy(t *testing.T) { opts.PassHostHeader = true proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth) frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() + t.Cleanup(frontend.Close) frontendURL, _ := url.Parse(frontend.URL) frontendWSURL := "ws://" + frontendURL.Host + "/" ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } request := []byte("hello, world!") err = websocket.Message.Send(ws, request) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } var response = make([]byte, 1024) - websocket.Message.Receive(ws, &response) + err = websocket.Message.Receive(ws, &response) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } if g, e := string(request), string(response); g != e { t.Errorf("got body %q; expected %q", g, e) @@ -123,9 +130,12 @@ func TestNewReverseProxy(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) hostname, _, _ := net.SplitHostPort(r.Host) - w.Write([]byte(hostname)) + _, err := w.Write([]byte(hostname)) + if err != nil { + t.Fatal(err) + } })) - defer backend.Close() + t.Cleanup(backend.Close) backendURL, _ := url.Parse(backend.URL) backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) @@ -135,7 +145,7 @@ func TestNewReverseProxy(t *testing.T) { proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second}) setProxyUpstreamHostHeader(proxyHandler, proxyURL) frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() + t.Cleanup(frontend.Close) getReq, _ := http.NewRequest("GET", frontend.URL, nil) res, _ := http.DefaultClient.Do(getReq) @@ -151,20 +161,20 @@ func TestEncodedSlashes(t *testing.T) { w.WriteHeader(200) seen = r.RequestURI })) - defer backend.Close() + t.Cleanup(backend.Close) b, _ := url.Parse(backend.URL) proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second}) setProxyDirector(proxyHandler) frontend := httptest.NewServer(proxyHandler) - defer frontend.Close() + t.Cleanup(frontend.Close) f, _ := url.Parse(frontend.URL) encodedPath := "/a%2Fb/?c=1" getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} _, err := http.DefaultClient.Do(getReq) if err != nil { - t.Fatalf("err %s", err) + t.Fatal(err) } if seen != encodedPath { t.Errorf("got bad request %q expected %q", seen, encodedPath) @@ -173,13 +183,13 @@ func TestEncodedSlashes(t *testing.T) { func TestRobotsTxt(t *testing.T) { opts := baseTestOptions() - opts.ClientID = "asdlkjx" - opts.ClientSecret = "alkgks" - opts.Cookie.Secret = rawCookieSecret - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) @@ -189,9 +199,6 @@ func TestRobotsTxt(t *testing.T) { func TestIsValidRedirect(t *testing.T) { opts := baseTestOptions() - opts.ClientID = "skdlfj" - opts.ClientSecret = "fgkdsgj" - opts.Cookie.Secret = base64CookieSecret // Should match domains that are exactly foo.bar and any subdomain of bar.foo opts.WhitelistDomains = []string{ "foo.bar", @@ -201,10 +208,13 @@ func TestIsValidRedirect(t *testing.T) { "anyport.bar:*", ".sub.anyport.bar:*", } - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } testCases := []struct { Desc, Redirect string @@ -439,11 +449,7 @@ func TestIsValidRedirect(t *testing.T) { } func TestOpenRedirects(t *testing.T) { - opts := options.NewOptions() - opts.ClientID = "skdlfj" - opts.ClientSecret = "fgkdsgj" - opts.Cookie.Secret = rawCookieSecret - opts.EmailDomains = []string{"*"} + opts := baseTestOptions() // Should match domains that are exactly foo.bar and any subdomain of bar.foo opts.WhitelistDomains = []string{ "foo.bar", @@ -458,13 +464,19 @@ func TestOpenRedirects(t *testing.T) { assert.NoError(t, err) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } file, err := os.Open("./test/openredirects.txt") if err != nil { t.Fatal(err) } - defer file.Close() + defer func(t *testing.T) { + if err := file.Close(); err != nil { + t.Fatal(err) + } + }(t) scanner := bufio.NewScanner(file) for scanner.Scan() { @@ -544,22 +556,21 @@ func TestBasicAuthPassword(t *testing.T) { } } w.WriteHeader(200) - w.Write([]byte(payload)) + _, err := w.Write([]byte(payload)) + if err != nil { + t.Fatal(err) + } })) opts := baseTestOptions() opts.Upstreams = append(opts.Upstreams, providerServer.URL) - // The CookieSecret must be 32 bytes in order to create the AES - // cipher. - opts.Cookie.Secret = "xyzzyplughxyzzyplughxyzzyplughxp" - opts.ClientID = "dlgkj" - opts.ClientSecret = "alkgret" opts.Cookie.Secure = false opts.PassBasicAuth = true opts.SetBasicAuth = true opts.PassUserHeaders = true opts.PreferEmailToUser = true opts.BasicAuthPassword = "This is a secure password" - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) providerURL, _ := url.Parse(providerServer.URL) const emailAddress = "john.doe@example.com" @@ -568,7 +579,9 @@ func TestBasicAuthPassword(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) @@ -618,8 +631,8 @@ func TestBasicAuthWithEmail(t *testing.T) { opts.PassUserHeaders = false opts.PreferEmailToUser = false opts.BasicAuthPassword = "This is a secure password" - opts.Cookie.Secret = rawCookieSecret - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) const emailAddress = "john.doe@example.com" const userName = "9fcab5c9b889a557" @@ -641,7 +654,9 @@ func TestBasicAuthWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, expectedUserHeader, req.Header["Authorization"][0]) assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) @@ -655,7 +670,9 @@ func TestBasicAuthWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, expectedEmailHeader, req.Header["Authorization"][0]) assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) @@ -664,11 +681,8 @@ func TestBasicAuthWithEmail(t *testing.T) { func TestPassUserHeadersWithEmail(t *testing.T) { opts := baseTestOptions() - opts.PassBasicAuth = false - opts.PassUserHeaders = true - opts.PreferEmailToUser = false - opts.Cookie.Secret = base64CookieSecret - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) const emailAddress = "john.doe@example.com" const userName = "9fcab5c9b889a557" @@ -686,7 +700,9 @@ func TestPassUserHeadersWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) } @@ -699,7 +715,9 @@ func TestPassUserHeadersWithEmail(t *testing.T) { proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) } @@ -716,10 +734,10 @@ type PassAccessTokenTestOptions struct { ProxyUpstream string } -func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { - t := &PassAccessTokenTest{} +func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { + patt := &PassAccessTokenTest{} - t.providerServer = httptest.NewServer( + patt.providerServer = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var payload string switch r.URL.Path { @@ -732,35 +750,35 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes } } w.WriteHeader(200) - w.Write([]byte(payload)) + _, err := w.Write([]byte(payload)) + if err != nil { + panic(err) + } })) - t.opts = baseTestOptions() - t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) + patt.opts = baseTestOptions() + patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL) if opts.ProxyUpstream != "" { - t.opts.Upstreams = append(t.opts.Upstreams, opts.ProxyUpstream) + patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream) + } + patt.opts.Cookie.Secure = false + patt.opts.PassAccessToken = opts.PassAccessToken + err := validation.Validate(patt.opts) + if err != nil { + return nil, err } - // The CookieSecret must be 32 bytes in order to create the AES - // cipher. - t.opts.Cookie.Secret = "xyzzyplughxyzzyplughxyzzyplughxp" - t.opts.ClientID = "slgkj" - t.opts.ClientSecret = "gfjgojl" - t.opts.Cookie.Secure = false - t.opts.PassAccessToken = opts.PassAccessToken - validation.Validate(t.opts) - providerURL, _ := url.Parse(t.providerServer.URL) + providerURL, _ := url.Parse(patt.providerServer.URL) const emailAddress = "michael.bland@gsa.gov" - t.opts.SetProvider(NewTestProvider(providerURL, emailAddress)) - var err error - t.proxy, err = NewOAuthProxy(t.opts, func(email string) bool { + patt.opts.SetProvider(NewTestProvider(providerURL, emailAddress)) + patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { return email == emailAddress }) if err != nil { - panic(err) + return nil, err } - return t + return patt, nil } func (patTest *PassAccessTokenTest) Close() { @@ -817,17 +835,20 @@ func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoin } func TestForwardAccessTokenUpstream(t *testing.T) { - patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, }) - defer patTest.Close() + if err != nil { + t.Fatal(err) + } + t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } - assert.NotEqual(t, nil, cookie) + assert.NotNil(t, cookie) // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is @@ -840,12 +861,14 @@ func TestForwardAccessTokenUpstream(t *testing.T) { } func TestStaticProxyUpstream(t *testing.T) { - patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, ProxyUpstream: "static://200/static-proxy", }) - - defer patTest.Close() + if err != nil { + t.Fatal(err) + } + t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() @@ -864,10 +887,13 @@ func TestStaticProxyUpstream(t *testing.T) { } func TestDoNotForwardAccessTokenUpstream(t *testing.T) { - patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false, }) - defer patTest.Close() + if err != nil { + t.Fatal(err) + } + t.Cleanup(patTest.Close) // A successful validation will redirect and set the auth cookie. code, cookie := patTest.getCallbackEndpoint() @@ -895,27 +921,26 @@ type SignInPageTest struct { const signInRedirectPattern = `` const signInSkipProvider = `>Found<` -func NewSignInPageTest(skipProvider bool) *SignInPageTest { +func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) { var sipTest SignInPageTest sipTest.opts = baseTestOptions() - sipTest.opts.Cookie.Secret = rawCookieSecret - sipTest.opts.ClientID = "lkdgj" - sipTest.opts.ClientSecret = "sgiufgoi" sipTest.opts.SkipProviderButton = skipProvider - validation.Validate(sipTest.opts) + err := validation.Validate(sipTest.opts) + if err != nil { + return nil, err + } - var err error sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { return true }) if err != nil { - panic(err) + return nil, err } sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) - return &sipTest + return &sipTest, nil } func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { @@ -926,7 +951,10 @@ func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { } func TestSignInPageIncludesTargetRedirect(t *testing.T) { - sipTest := NewSignInPageTest(false) + sipTest, err := NewSignInPageTest(false) + if err != nil { + t.Fatal(err) + } const endpoint = "/some/random/endpoint" code, body := sipTest.GetEndpoint(endpoint) @@ -944,7 +972,10 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) { } func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { - sipTest := NewSignInPageTest(false) + sipTest, err := NewSignInPageTest(false) + if err != nil { + t.Fatal(err) + } code, body := sipTest.GetEndpoint("/oauth2/sign_in") assert.Equal(t, 200, code) @@ -959,8 +990,12 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { } func TestSignInPageSkipProvider(t *testing.T) { - sipTest := NewSignInPageTest(true) - const endpoint = "/some/random/endpoint" + sipTest, err := NewSignInPageTest(true) + if err != nil { + t.Fatal(err) + } + + endpoint := "/some/random/endpoint" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) @@ -973,8 +1008,12 @@ func TestSignInPageSkipProvider(t *testing.T) { } func TestSignInPageSkipProviderDirect(t *testing.T) { - sipTest := NewSignInPageTest(true) - const endpoint = "/sign_in" + sipTest, err := NewSignInPageTest(true) + if err != nil { + t.Fatal(err) + } + + endpoint := "/sign_in" code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) @@ -1000,27 +1039,26 @@ type ProcessCookieTestOpts struct { type OptionsModifier func(*options.Options) -func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { +func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() for _, modifier := range modifiers { modifier(pcTest.opts) } - pcTest.opts.ClientID = "asdfljk" - pcTest.opts.ClientSecret = "lkjfdsig" - pcTest.opts.Cookie.Secret = "0123456789abcdef0123456789abcdef" // First, set the CookieRefresh option so proxy.AesCipher is created, // needed to encrypt the access_token. pcTest.opts.Cookie.Refresh = time.Hour - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + if err != nil { + return nil, err + } - var err error pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + return nil, err } pcTest.proxy.provider = &TestProvider{ ValidToken: opts.providerValidateCookieResponse, @@ -1032,16 +1070,16 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi pcTest.rw = httptest.NewRecorder() pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) pcTest.validateUser = true - return &pcTest + return &pcTest, nil } -func NewProcessCookieTestWithDefaults() *ProcessCookieTest { +func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) { return NewProcessCookieTest(ProcessCookieTestOpts{ providerValidateCookieResponse: true, }) } -func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { +func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { return NewProcessCookieTest(ProcessCookieTestOpts{ providerValidateCookieResponse: true, }, modifiers...) @@ -1063,37 +1101,51 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) } func TestLoadCookiedSession(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() + pcTest, err := NewProcessCookieTestWithDefaults() + if err != nil { + t.Fatal(err) + } created := time.Now() startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() - assert.Equal(t, nil, err) + if err != nil { + t.Fatal(err) + } assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "", session.User) assert.Equal(t, startSession.AccessToken, session.AccessToken) } func TestProcessCookieNoCookieError(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() + pcTest, err := NewProcessCookieTestWithDefaults() + if err != nil { + t.Fatal(err) + } session, err := pcTest.LoadCookiedSession() - assert.Equal(t, "cookie \"_oauth2_proxy\" not present", err.Error()) + assert.Error(t, err, "cookie \"_oauth2_proxy\" not present") if session != nil { t.Errorf("expected nil session. got %#v", session) } } func TestProcessCookieRefreshNotSet(t *testing.T) { - pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(23) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(-2) * time.Hour) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) @@ -1104,12 +1156,17 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { } func TestProcessCookieFailIfCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) @@ -1119,12 +1176,17 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.CookieRefresh = time.Hour session, err := pcTest.LoadCookiedSession() @@ -1134,18 +1196,26 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { } } -func NewUserInfoEndpointTest() *ProcessCookieTest { - pcTest := NewProcessCookieTestWithDefaults() +func NewUserInfoEndpointTest() (*ProcessCookieTest, error) { + pcTest, err := NewProcessCookieTestWithDefaults() + if err != nil { + return nil, err + } pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/userinfo", nil) - return pcTest + return pcTest, nil } func TestUserInfoEndpointAccepted(t *testing.T) { - test := NewUserInfoEndpointTest() + test, err := NewUserInfoEndpointTest() + if err != nil { + t.Fatal(err) + } + startSession := &sessions.SessionState{ Email: "john.doe@example.com", AccessToken: "my_access_token"} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusOK, test.rw.Code) @@ -1154,25 +1224,36 @@ func TestUserInfoEndpointAccepted(t *testing.T) { } func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { - test := NewUserInfoEndpointTest() + test, err := NewUserInfoEndpointTest() + if err != nil { + t.Fatal(err) + } test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } -func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { - pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) +func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { + pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...) + if err != nil { + return nil, err + } pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) - return pcTest + return pcTest, nil } func TestAuthOnlyEndpointAccepted(t *testing.T) { - test := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest() + if err != nil { + t.Fatal(err) + } + created := time.Now() startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusAccepted, test.rw.Code) @@ -1181,7 +1262,10 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { - test := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest() + if err != nil { + t.Fatal(err) + } test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -1190,13 +1274,18 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { - test := NewAuthOnlyEndpointTest(func(opts *options.Options) { + test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) + if err != nil { + t.Fatal(err) + } + reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -1205,11 +1294,16 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { - test := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest() + if err != nil { + t.Fatal(err) + } + created := time.Now() startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} - test.SaveSession(startSession) + err = test.SaveSession(startSession) + assert.NoError(t, err) test.validateUser = false test.proxy.ServeHTTP(test.rw, test.req) @@ -1224,15 +1318,13 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true err := validation.Validate(pcTest.opts) - if err != nil { - panic(err) - } + assert.NoError(t, err) pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ValidToken: true, @@ -1247,7 +1339,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1261,14 +1354,14 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true pcTest.opts.SetBasicAuth = true - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + assert.NoError(t, err) - var err error pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ValidToken: true, @@ -1283,7 +1376,8 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1299,14 +1393,14 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true pcTest.opts.SetBasicAuth = false - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + assert.NoError(t, err) - var err error pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { - panic(err) + t.Fatal(err) } pcTest.proxy.provider = &TestProvider{ ValidToken: true, @@ -1321,7 +1415,8 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { created := time.Now() startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} - pcTest.SaveSession(startSession) + err = pcTest.SaveSession(startSession) + assert.NoError(t, err) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1333,20 +1428,26 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { func TestAuthSkippedForPreflightRequests(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) - w.Write([]byte("response")) + _, err := w.Write([]byte("response")) + if err != nil { + t.Fatal(err) + } })) - defer upstream.Close() + t.Cleanup(upstream.Close) opts := baseTestOptions() opts.Upstreams = append(opts.Upstreams, upstream.URL) opts.SkipAuthPreflight = true - validation.Validate(opts) + err := validation.Validate(opts) + assert.NoError(t, err) upstreamURL, _ := url.Parse(upstream.URL) opts.SetProvider(NewTestProvider(upstreamURL, "")) proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } rw := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) proxy.ServeHTTP(rw, req) @@ -1361,16 +1462,25 @@ type SignatureAuthenticator struct { func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) { result, headerSig, computedSig := v.auth.AuthenticateRequest(r) - if result == hmacauth.ResultNoSignature { - w.Write([]byte("no signature received")) - } else if result == hmacauth.ResultMatch { - w.Write([]byte("signatures match")) - } else if result == hmacauth.ResultMismatch { - w.Write([]byte("signatures do not match:" + - "\n received: " + headerSig + - "\n computed: " + computedSig)) - } else { - panic("Unknown result value: " + result.String()) + + var msg string + switch result { + case hmacauth.ResultNoSignature: + msg = "no signature received" + case hmacauth.ResultMatch: + msg = "signatures match" + case hmacauth.ResultMismatch: + msg = fmt.Sprintf( + "signatures do not match:\n received: %s\n computed: %s", + headerSig, + computedSig) + default: + panic("unknown result value: " + result.String()) + } + + _, err := w.Write([]byte(msg)) + if err != nil { + panic(err) } } @@ -1384,24 +1494,30 @@ type SignatureTest struct { authenticator *SignatureAuthenticator } -func NewSignatureTest() *SignatureTest { +func NewSignatureTest() (*SignatureTest, error) { opts := baseTestOptions() - opts.Cookie.Secret = rawCookieSecret - opts.ClientID = "client ID" - opts.ClientSecret = "client secret" opts.EmailDomains = []string{"acm.org"} authenticator := &SignatureAuthenticator{} upstream := httptest.NewServer( http.HandlerFunc(authenticator.Authenticate)) - upstreamURL, _ := url.Parse(upstream.URL) + upstreamURL, err := url.Parse(upstream.URL) + if err != nil { + return nil, err + } opts.Upstreams = append(opts.Upstreams, upstream.URL) providerHandler := func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(`{"access_token": "my_auth_token"}`)) + _, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) + if err != nil { + panic(err) + } } provider := httptest.NewServer(http.HandlerFunc(providerHandler)) - providerURL, _ := url.Parse(provider.URL) + providerURL, err := url.Parse(provider.URL) + if err != nil { + return nil, err + } opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) return &SignatureTest{ @@ -1412,7 +1528,7 @@ func NewSignatureTest() *SignatureTest { make(http.Header), httptest.NewRecorder(), authenticator, - } + }, nil } func (st *SignatureTest) Close() { @@ -1436,14 +1552,14 @@ func (fnc *fakeNetConn) Read(p []byte) (n int, err error) { return 0, io.EOF } -func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { +func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error { err := validation.Validate(st.opts) if err != nil { - panic(err) + return err } proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) if err != nil { - panic(err) + return err } var bodyBuf io.ReadCloser @@ -1457,7 +1573,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { Email: "mbland@acm.org", AccessToken: "my_access_token"} err = proxy.SaveSession(st.rw, req, state) if err != nil { - panic(err) + return err } for _, c := range st.rw.Result().Cookies() { req.AddCookie(c) @@ -1466,33 +1582,52 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { st.authenticator.auth = hmacauth.NewHmacAuth( crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) proxy.ServeHTTP(st.rw, req) + + return nil } -func TestNoRequestSignature(t *testing.T) { - st := NewSignatureTest() - defer st.Close() - st.MakeRequestWithExpectedKey("GET", "", "") - assert.Equal(t, 200, st.rw.Code) - assert.Equal(t, st.rw.Body.String(), "no signature received") -} - -func TestRequestSignatureGetRequest(t *testing.T) { - st := NewSignatureTest() - defer st.Close() - st.opts.SignatureKey = "sha1:7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d" - st.MakeRequestWithExpectedKey("GET", "", "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d") - assert.Equal(t, 200, st.rw.Code) - assert.Equal(t, st.rw.Body.String(), "signatures match") -} - -func TestRequestSignaturePostRequest(t *testing.T) { - st := NewSignatureTest() - defer st.Close() - st.opts.SignatureKey = "sha1:d90df39e2d19282840252612dd7c81421a372f61" - payload := `{ "hello": "world!" }` - st.MakeRequestWithExpectedKey("POST", payload, "d90df39e2d19282840252612dd7c81421a372f61") - assert.Equal(t, 200, st.rw.Code) - assert.Equal(t, st.rw.Body.String(), "signatures match") +func TestRequestSignature(t *testing.T) { + testCases := map[string]struct { + method string + body string + key string + resp string + }{ + "No request signature": { + method: "GET", + body: "", + key: "", + resp: "no signature received", + }, + "Get request": { + method: "GET", + body: "", + key: "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d", + resp: "signatures match", + }, + "Post request": { + method: "POST", + body: `{ "hello": "world!" }`, + key: "d90df39e2d19282840252612dd7c81421a372f61", + resp: "signatures match", + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + st, err := NewSignatureTest() + if err != nil { + t.Fatal(err) + } + t.Cleanup(st.Close) + if tc.key != "" { + st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key) + } + err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key) + assert.NoError(t, err) + assert.Equal(t, 200, st.rw.Code) + assert.Equal(t, tc.resp, st.rw.Body.String()) + }) + } } func TestGetRedirect(t *testing.T) { @@ -1501,7 +1636,9 @@ func TestGetRedirect(t *testing.T) { assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) proxy, err := NewOAuthProxy(opts, func(s string) bool { return false }) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } tests := []struct { name string @@ -1535,22 +1672,21 @@ type ajaxRequestTest struct { proxy *OAuthProxy } -func newAjaxRequestTest() *ajaxRequestTest { +func newAjaxRequestTest() (*ajaxRequestTest, error) { test := &ajaxRequestTest{} test.opts = baseTestOptions() - test.opts.Cookie.Secret = base64CookieSecret - test.opts.ClientID = "gkljfdl" - test.opts.ClientSecret = "sdflkjs" - validation.Validate(test.opts) + err := validation.Validate(test.opts) + if err != nil { + return nil, err + } - var err error test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { return true }) if err != nil { - panic(err) + return nil, err } - return test + return test, nil } func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { @@ -1565,7 +1701,10 @@ func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (i } func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { - test := newAjaxRequestTest() + test, err := newAjaxRequestTest() + if err != nil { + t.Fatal(err) + } endpoint := "/test" code, rh, err := test.getEndpoint(endpoint, header) @@ -1589,7 +1728,10 @@ func TestAjaxUnauthorizedRequest2(t *testing.T) { } func TestAjaxForbiddendRequest(t *testing.T) { - test := newAjaxRequestTest() + test, err := newAjaxRequestTest() + if err != nil { + t.Fatal(err) + } endpoint := "/test" header := make(http.Header) code, rh, err := test.getEndpoint(endpoint, header) @@ -1604,8 +1746,14 @@ func TestClearSplitCookie(t *testing.T) { opts.Cookie.Secret = base64CookieSecret opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} - store, err := cookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) - assert.Equal(t, nil, err) + err := validation.Validate(opts) + assert.NoError(t, err) + + store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) + if err != nil { + t.Fatal(err) + } + p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1623,7 +1771,8 @@ func TestClearSplitCookie(t *testing.T) { Value: "oauth2_1", }) - p.ClearSessionCookie(rw, req) + err = p.ClearSessionCookie(rw, req) + assert.NoError(t, err) header := rw.Header() assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries") @@ -1633,8 +1782,11 @@ func TestClearSingleCookie(t *testing.T) { opts := baseTestOptions() opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} - store, err := cookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) - assert.Equal(t, nil, err) + store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) + if err != nil { + t.Fatal(err) + } + p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1648,7 +1800,8 @@ func TestClearSingleCookie(t *testing.T) { Value: "oauth2", }) - p.ClearSessionCookie(rw, req) + err = p.ClearSessionCookie(rw, req) + assert.NoError(t, err) header := rw.Header() assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries") @@ -1686,13 +1839,16 @@ func TestGetJwtSession(t *testing.T) { verifier := oidc.NewVerifier("https://issuer.example.com", keyset, &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) - test := NewAuthOnlyEndpointTest(func(opts *options.Options) { + test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { opts.PassAuthorization = true opts.SetAuthorization = true opts.SetXAuthRequest = true opts.SkipJwtBearerTokens = true opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier)) }) + if err != nil { + t.Fatal(err) + } tp, _ := test.proxy.provider.(*TestProvider) tp.GroupValidator = func(s string) bool { return true @@ -1705,7 +1861,8 @@ func TestGetJwtSession(t *testing.T) { // Bearer expires := time.Unix(1912151821, 0) - session, _ := test.proxy.GetJwtSession(test.req) + session, err := test.proxy.GetJwtSession(test.req) + assert.NoError(t, err) assert.Equal(t, session.User, "1234567890") assert.Equal(t, session.Email, "john@example.com") assert.Equal(t, session.ExpiresOn, &expires) @@ -1739,22 +1896,26 @@ func TestFindJwtBearerToken(t *testing.T) { "Authorization": {fmt.Sprintf("Bearer %s", validToken)}, } - token, _ = p.findBearerToken(getReq) + token, err := p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) // Basic - no password getReq.SetBasicAuth(token, "") - token, _ = p.findBearerToken(getReq) + token, err = p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) // Basic - sentinel password getReq.SetBasicAuth(token, "x-oauth-basic") - token, _ = p.findBearerToken(getReq) + token, err = p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) // Basic - any username, password matching jwt pattern getReq.SetBasicAuth("any-username-you-could-wish-for", token) - token, _ = p.findBearerToken(getReq) + token, err = p.findBearerToken(getReq) + assert.NoError(t, err) assert.Equal(t, validToken, token) failures := []string{ @@ -1785,8 +1946,6 @@ func TestFindJwtBearerToken(t *testing.T) { _, err := p.findBearerToken(getReq) assert.Error(t, err) } - - fmt.Printf("%s", token) } func Test_prepareNoCache(t *testing.T) { @@ -1807,18 +1966,22 @@ func Test_prepareNoCache(t *testing.T) { func Test_noCacheHeaders(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("upstream")) + _, err := w.Write([]byte("upstream")) + if err != nil { + t.Error(err) + } })) t.Cleanup(upstream.Close) opts := baseTestOptions() opts.Upstreams = []string{upstream.URL} opts.SkipAuthRegex = []string{".*"} - _ = validation.Validate(opts) - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return true - }) + err := validation.Validate(opts) assert.NoError(t, err) + proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + if err != nil { + t.Fatal(err) + } t.Run("not exist in response from upstream", func(t *testing.T) { rec := httptest.NewRecorder() @@ -1887,8 +2050,8 @@ func Test_noCacheHeaders(t *testing.T) { func baseTestOptions() *options.Options { opts := options.NewOptions() opts.Cookie.Secret = rawCookieSecret - opts.ClientID = "cliend-id" - opts.ClientSecret = "client-secret" + opts.ClientID = clientID + opts.ClientSecret = clientSecret opts.EmailDomains = []string{"*"} return opts } diff --git a/pkg/apis/sessions/legacy_v5_tester.go b/pkg/apis/sessions/legacy_v5_tester.go new file mode 100644 index 00000000..44cf4e73 --- /dev/null +++ b/pkg/apis/sessions/legacy_v5_tester.go @@ -0,0 +1,87 @@ +package sessions + +import ( + "fmt" + "testing" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" + "github.com/stretchr/testify/assert" +) + +// LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code +type LegacyV5TestCase struct { + Input string + Error bool + Output *SessionState +} + +// CreateLegacyV5TestCases makes various V5 JSON sessions as test cases +// +// Used for `apis/sessions/session_state_test.go` & `sessions/redis/redis_store_test.go` +// +// TODO: Remove when this is deprecated (likely V7) +func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) { + const secret = "0123456789abcdefghijklmnopqrstuv" + + created := time.Now() + createdJSON, err := created.MarshalJSON() + assert.NoError(t, err) + createdString := string(createdJSON) + e := time.Now().Add(time.Duration(1) * time.Hour) + eJSON, err := e.MarshalJSON() + assert.NoError(t, err) + eString := string(eJSON) + + cfbCipher, err := encryption.NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + legacyCipher := encryption.NewBase64Cipher(cfbCipher) + + testCases := map[string]LegacyV5TestCase{ + "User & email unencrypted": { + Input: `{"Email":"user@domain.com","User":"just-user"}`, + Error: true, + }, + "Only email unencrypted": { + Input: `{"Email":"user@domain.com"}`, + Error: true, + }, + "Just user unencrypted": { + Input: `{"User":"just-user"}`, + Error: true, + }, + "User and Email unencrypted while rest is encrypted": { + Input: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), + Error: true, + }, + "Full session with cipher": { + Input: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), + Output: &SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + CreatedAt: &created, + ExpiresOn: &e, + RefreshToken: "refresh4321", + }, + }, + "Minimal session encrypted with cipher": { + Input: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`, + Output: &SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + }, + "Unencrypted User, Email and AccessToken": { + Input: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, + Error: true, + }, + "Unencrypted User, Email and IDToken": { + Input: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, + Error: true, + }, + } + + return testCases, cfbCipher, legacyCipher +} diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 44b91bd2..2015df8c 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -1,25 +1,30 @@ package sessions import ( + "bytes" "encoding/json" "errors" "fmt" + "io" + "io/ioutil" "time" "unicode/utf8" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" + "github.com/pierrec/lz4" + "github.com/vmihailenco/msgpack/v4" ) // SessionState is used to store information about the currently authenticated user session type SessionState struct { - AccessToken string `json:",omitempty"` - IDToken string `json:",omitempty"` - CreatedAt *time.Time `json:",omitempty"` - ExpiresOn *time.Time `json:",omitempty"` - RefreshToken string `json:",omitempty"` - Email string `json:",omitempty"` - User string `json:",omitempty"` - PreferredUsername string `json:",omitempty"` + AccessToken string `json:",omitempty" msgpack:"at,omitempty"` + IDToken string `json:",omitempty" msgpack:"it,omitempty"` + CreatedAt *time.Time `json:",omitempty" msgpack:"ca,omitempty"` + ExpiresOn *time.Time `json:",omitempty" msgpack:"eo,omitempty"` + RefreshToken string `json:",omitempty" msgpack:"rt,omitempty"` + Email string `json:",omitempty" msgpack:"e,omitempty"` + User string `json:",omitempty" msgpack:"u,omitempty"` + PreferredUsername string `json:",omitempty" msgpack:"pu,omitempty"` } // IsExpired checks whether the session has expired @@ -59,78 +64,79 @@ func (s *SessionState) String() string { return o + "}" } -// EncodeSessionState returns string representation of the current session -func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) { - var ss SessionState - if c == nil { - // Store only Email and User when cipher is unavailable - ss.Email = s.Email - ss.User = s.User - ss.PreferredUsername = s.PreferredUsername - } else { - ss = *s - for _, s := range []*string{ - &ss.Email, - &ss.User, - &ss.PreferredUsername, - &ss.AccessToken, - &ss.IDToken, - &ss.RefreshToken, - } { - err := into(s, c.Encrypt) - if err != nil { - return "", err - } +// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session +func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) { + packed, err := msgpack.Marshal(s) + if err != nil { + return nil, fmt.Errorf("error marshalling session state to msgpack: %w", err) + } + + if !compress { + return c.Encrypt(packed) + } + + compressed, err := lz4Compress(packed) + if err != nil { + return nil, err + } + return c.Encrypt(compressed) +} + +// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State +func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*SessionState, error) { + decrypted, err := c.Decrypt(data) + if err != nil { + return nil, fmt.Errorf("error decrypting the session state: %w", err) + } + + packed := decrypted + if compressed { + packed, err = lz4Decompress(decrypted) + if err != nil { + return nil, err } } - b, err := json.Marshal(ss) - return string(b), err + var ss SessionState + err = msgpack.Unmarshal(packed, &ss) + if err != nil { + return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) + } + + err = ss.validate() + if err != nil { + return nil, err + } + + return &ss, nil } -// DecodeSessionState decodes the session cookie string into a SessionState -func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { +// LegacyV5DecodeSessionState decodes a legacy JSON session cookie string into a SessionState +func LegacyV5DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) { var ss SessionState err := json.Unmarshal([]byte(v), &ss) if err != nil { return nil, fmt.Errorf("error unmarshalling session: %w", err) } - if c == nil { - // Load only Email and User when cipher is unavailable - ss = SessionState{ - Email: ss.Email, - User: ss.User, - PreferredUsername: ss.PreferredUsername, - } - } else { - // Backward compatibility with using unencrypted Email or User - // Decryption errors will leave original string - err = into(&ss.Email, c.Decrypt) - if err == nil { - if !utf8.ValidString(ss.Email) { - return nil, errors.New("invalid value for decrypted email") - } - } - err = into(&ss.User, c.Decrypt) - if err == nil { - if !utf8.ValidString(ss.User) { - return nil, errors.New("invalid value for decrypted user") - } - } - - for _, s := range []*string{ - &ss.PreferredUsername, - &ss.AccessToken, - &ss.IDToken, - &ss.RefreshToken, - } { - err := into(s, c.Decrypt) - if err != nil { - return nil, err - } + for _, s := range []*string{ + &ss.User, + &ss.Email, + &ss.PreferredUsername, + &ss.AccessToken, + &ss.IDToken, + &ss.RefreshToken, + } { + err := into(s, c.Decrypt) + if err != nil { + return nil, err } } + err = ss.validate() + if err != nil { + return nil, err + } + return &ss, nil } @@ -150,3 +156,86 @@ func into(s *string, f codecFunc) error { *s = string(d) return nil } + +// lz4Compress compresses with LZ4 +// +// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds +// at the expense of greater compression compared to other compression +// algorithms. +func lz4Compress(payload []byte) ([]byte, error) { + buf := new(bytes.Buffer) + zw := lz4.NewWriter(nil) + zw.Header = lz4.Header{ + BlockMaxSize: 65536, + CompressionLevel: 0, + } + zw.Reset(buf) + + reader := bytes.NewReader(payload) + _, err := io.Copy(zw, reader) + if err != nil { + return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err) + } + err = zw.Close() + if err != nil { + return nil, fmt.Errorf("error closing lz4 writer: %w", err) + } + + compressed, err := ioutil.ReadAll(buf) + if err != nil { + return nil, fmt.Errorf("error reading lz4 buffer: %w", err) + } + + return compressed, nil +} + +// lz4Decompress decompresses with LZ4 +func lz4Decompress(compressed []byte) ([]byte, error) { + reader := bytes.NewReader(compressed) + buf := new(bytes.Buffer) + zr := lz4.NewReader(nil) + zr.Reset(reader) + _, err := io.Copy(buf, zr) + if err != nil { + return nil, fmt.Errorf("error copying lz4 stream to buffer: %w", err) + } + + payload, err := ioutil.ReadAll(buf) + if err != nil { + return nil, fmt.Errorf("error reading lz4 buffer: %w", err) + } + + return payload, nil +} + +// validate ensures the decoded session is non-empty and contains valid data +// +// Non-empty check is needed due to ensure the non-authenticated AES-CFB +// decryption doesn't result in garbage data that collides with a valid +// MessagePack header bytes (which MessagePack will unmarshal to an empty +// default SessionState). <1% chance, but observed with random test data. +// +// UTF-8 check ensures the strings are valid and not raw bytes overloaded +// into Latin-1 encoding. The occurs when legacy unencrypted fields are +// decrypted with AES-CFB which results in random bytes. +func (s *SessionState) validate() error { + for _, field := range []string{ + s.User, + s.Email, + s.PreferredUsername, + s.AccessToken, + s.IDToken, + s.RefreshToken, + } { + if !utf8.ValidString(field) { + return errors.New("invalid non-UTF8 field in session") + } + } + + empty := new(SessionState) + if *s == *empty { + return errors.New("invalid empty session unmarshalled") + } + + return nil +} diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 3e9554c5..ac554c60 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -12,132 +12,11 @@ import ( "github.com/stretchr/testify/assert" ) -const secret = "0123456789abcdefghijklmnopqrstuv" -const altSecret = "0000000000abcdefghijklmnopqrstuv" - func timePtr(t time.Time) *time.Time { return &t } -func newTestCipher(secret []byte) (encryption.Cipher, error) { - return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret) -} - -func TestSessionStateSerialization(t *testing.T) { - c, err := newTestCipher([]byte(secret)) - assert.Equal(t, nil, err) - c2, err := newTestCipher([]byte(altSecret)) - assert.Equal(t, nil, err) - s := &SessionState{ - Email: "user@domain.com", - PreferredUsername: "user", - AccessToken: "token1234", - IDToken: "rawtoken1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(c) - assert.Equal(t, nil, err) - - ss, err := DecodeSessionState(encoded, c) - t.Logf("%#v", ss) - assert.Equal(t, nil, err) - assert.Equal(t, "", ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, s.AccessToken, ss.AccessToken) - assert.Equal(t, s.IDToken, ss.IDToken) - assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) - assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) - assert.Equal(t, s.RefreshToken, ss.RefreshToken) - - // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) - t.Logf("%#v", ss) - assert.NotEqual(t, nil, err) -} - -func TestSessionStateSerializationWithUser(t *testing.T) { - c, err := newTestCipher([]byte(secret)) - assert.Equal(t, nil, err) - c2, err := newTestCipher([]byte(altSecret)) - assert.Equal(t, nil, err) - s := &SessionState{ - User: "just-user", - PreferredUsername: "ju", - Email: "user@domain.com", - AccessToken: "token1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(c) - assert.Equal(t, nil, err) - - ss, err := DecodeSessionState(encoded, c) - t.Logf("%#v", ss) - assert.Equal(t, nil, err) - assert.Equal(t, s.User, ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, s.AccessToken, ss.AccessToken) - assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) - assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) - assert.Equal(t, s.RefreshToken, ss.RefreshToken) - - // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) - t.Logf("%#v", ss) - assert.NotEqual(t, nil, err) -} - -func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &SessionState{ - Email: "user@domain.com", - PreferredUsername: "user", - AccessToken: "token1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(nil) - assert.Equal(t, nil, err) - - // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) - assert.Equal(t, nil, err) - assert.Equal(t, "", ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, "", ss.AccessToken) - assert.Equal(t, "", ss.RefreshToken) -} - -func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { - s := &SessionState{ - User: "just-user", - Email: "user@domain.com", - PreferredUsername: "user", - AccessToken: "token1234", - CreatedAt: timePtr(time.Now()), - ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), - RefreshToken: "refresh4321", - } - encoded, err := s.EncodeSessionState(nil) - assert.Equal(t, nil, err) - - // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) - assert.Equal(t, nil, err) - assert.Equal(t, s.User, ss.User) - assert.Equal(t, s.Email, ss.Email) - assert.Equal(t, s.PreferredUsername, ss.PreferredUsername) - assert.Equal(t, "", ss.AccessToken) - assert.Equal(t, "", ss.RefreshToken) -} - -func TestExpired(t *testing.T) { +func TestIsExpired(t *testing.T) { s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} assert.Equal(t, true, s.IsExpired()) @@ -148,161 +27,7 @@ func TestExpired(t *testing.T) { assert.Equal(t, false, s.IsExpired()) } -type testCase struct { - SessionState - Encoded string - Cipher encryption.Cipher - Error bool -} - -// TestEncodeSessionState tests EncodeSessionState with the test vector -// -// Currently only tests without cipher here because we have no way to mock -// the random generator used in EncodeSessionState. -func TestEncodeSessionState(t *testing.T) { - c := time.Now() - e := time.Now().Add(time.Duration(1) * time.Hour) - - testCases := []testCase{ - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: `{"Email":"user@domain.com","User":"just-user"}`, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - AccessToken: "token1234", - IDToken: "rawtoken1234", - CreatedAt: &c, - ExpiresOn: &e, - RefreshToken: "refresh4321", - }, - Encoded: `{"Email":"user@domain.com","User":"just-user"}`, - }, - } - - for i, tc := range testCases { - encoded, err := tc.EncodeSessionState(tc.Cipher) - t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) - if tc.Error { - assert.Error(t, err) - assert.Empty(t, encoded) - continue - } - assert.NoError(t, err) - assert.JSONEq(t, tc.Encoded, encoded) - } -} - -// TestDecodeSessionState testssessions.DecodeSessionState with the test vector -func TestDecodeSessionState(t *testing.T) { - created := time.Now() - createdJSON, _ := created.MarshalJSON() - createdString := string(createdJSON) - e := time.Now().Add(time.Duration(1) * time.Hour) - eJSON, _ := e.MarshalJSON() - eString := string(eJSON) - - c, err := newTestCipher([]byte(secret)) - assert.NoError(t, err) - - testCases := []testCase{ - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: `{"Email":"user@domain.com","User":"just-user"}`, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "", - }, - Encoded: `{"Email":"user@domain.com"}`, - }, - { - SessionState: SessionState{ - User: "just-user", - }, - Encoded: `{"User":"just-user"}`, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - AccessToken: "token1234", - IDToken: "rawtoken1234", - CreatedAt: &created, - ExpiresOn: &e, - RefreshToken: "refresh4321", - }, - Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), - Cipher: c, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "just-user", - }, - Encoded: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`, - Cipher: c, - }, - { - Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, - Cipher: c, - Error: true, - }, - { - Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, - Cipher: c, - Error: true, - }, - { - SessionState: SessionState{ - Email: "user@domain.com", - User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user - }, - Error: true, - Cipher: c, - }, - } - - for i, tc := range testCases { - ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) - t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err) - if tc.Error { - assert.Error(t, err) - assert.Nil(t, ss) - continue - } - assert.NoError(t, err) - if assert.NotNil(t, ss) { - assert.Equal(t, tc.User, ss.User) - assert.Equal(t, tc.Email, ss.Email) - assert.Equal(t, tc.AccessToken, ss.AccessToken) - assert.Equal(t, tc.RefreshToken, ss.RefreshToken) - assert.Equal(t, tc.IDToken, ss.IDToken) - if tc.ExpiresOn != nil { - assert.NotEqual(t, nil, ss.ExpiresOn) - assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) - } - } - } -} - -func TestSessionStateAge(t *testing.T) { +func TestAge(t *testing.T) { ss := &SessionState{} // Created at unset so should be 0 @@ -313,7 +38,149 @@ func TestSessionStateAge(t *testing.T) { assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) } -func TestIntoEncryptAndIntoDecrypt(t *testing.T) { +// TestEncodeAndDecodeSessionState encodes & decodes various session states +// and confirms the operation is 1:1 +func TestEncodeAndDecodeSessionState(t *testing.T) { + created := time.Now() + expires := time.Now().Add(time.Duration(1) * time.Hour) + + // Tokens in the test table are purposefully redundant + // Otherwise compressing small payloads could result in a compressed value + // that is larger (compression dictionary + limited like strings to compress) + // which breaks the len(compressed) < len(uncompressed) assertion. + testCases := map[string]SessionState{ + "Full session": { + Email: "username@example.com", + User: "username", + PreferredUsername: "preferred.username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + ExpiresOn: &expires, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "No ExpiresOn": { + Email: "username@example.com", + User: "username", + PreferredUsername: "preferred.username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "No PreferredUsername": { + Email: "username@example.com", + User: "username", + AccessToken: "AccessToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + ExpiresOn: &expires, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "Minimal session": { + User: "username", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + CreatedAt: &created, + RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + }, + "Bearer authorization header created session": { + Email: "username", + User: "username", + AccessToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7", + ExpiresOn: &expires, + }, + } + + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d byte secret", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.NoError(t, err) + + cfb, err := encryption.NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + gcm, err := encryption.NewGCMCipher([]byte(secret)) + assert.NoError(t, err) + + ciphers := map[string]encryption.Cipher{ + "CFB cipher": cfb, + "GCM cipher": gcm, + } + + for cipherName, c := range ciphers { + t.Run(cipherName, func(t *testing.T) { + for testName, ss := range testCases { + t.Run(testName, func(t *testing.T) { + encoded, err := ss.EncodeSessionState(c, false) + assert.NoError(t, err) + encodedCompressed, err := ss.EncodeSessionState(c, true) + assert.NoError(t, err) + // Make sure compressed version is smaller than if not compressed + assert.Greater(t, len(encoded), len(encodedCompressed)) + + decoded, err := DecodeSessionState(encoded, c, false) + assert.NoError(t, err) + decodedCompressed, err := DecodeSessionState(encodedCompressed, c, true) + assert.NoError(t, err) + + compareSessionStates(t, decoded, decodedCompressed) + compareSessionStates(t, decoded, &ss) + }) + } + }) + } + + t.Run("Mixed cipher types cause errors", func(t *testing.T) { + for testName, ss := range testCases { + t.Run(testName, func(t *testing.T) { + cfbEncoded, err := ss.EncodeSessionState(cfb, false) + assert.NoError(t, err) + _, err = DecodeSessionState(cfbEncoded, gcm, false) + assert.Error(t, err) + + gcmEncoded, err := ss.EncodeSessionState(gcm, false) + assert.NoError(t, err) + _, err = DecodeSessionState(gcmEncoded, cfb, false) + assert.Error(t, err) + }) + } + }) + }) + } +} + +// TestLegacyV5DecodeSessionState confirms V5 JSON sessions decode +// +// TODO: Remove when this is deprecated (likely V7) +func TestLegacyV5DecodeSessionState(t *testing.T) { + testCases, cipher, legacyCipher := CreateLegacyV5TestCases(t) + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + // Legacy sessions fail in DecodeSessionState which results in + // the fallback to LegacyV5DecodeSessionState + _, err := DecodeSessionState([]byte(tc.Input), cipher, false) + assert.Error(t, err) + _, err = DecodeSessionState([]byte(tc.Input), cipher, true) + assert.Error(t, err) + + ss, err := LegacyV5DecodeSessionState(tc.Input, legacyCipher) + if tc.Error { + assert.Error(t, err) + assert.Nil(t, ss) + return + } + assert.NoError(t, err) + compareSessionStates(t, tc.Output, ss) + }) + } +} + +// Test_into tests the into helper function used in LegacyV5DecodeSessionState +// +// TODO: Remove when this is deprecated (likely V7) +func Test_into(t *testing.T) { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" // Test all 3 valid AES sizes @@ -323,8 +190,9 @@ func TestIntoEncryptAndIntoDecrypt(t *testing.T) { _, err := io.ReadFull(rand.Reader, secret) assert.Equal(t, nil, err) - c, err := newTestCipher(secret) + cfb, err := encryption.NewCFBCipher(secret) assert.NoError(t, err) + c := encryption.NewBase64Cipher(cfb) // Check no errors with empty or nil strings empty := "" @@ -353,3 +221,27 @@ func TestIntoEncryptAndIntoDecrypt(t *testing.T) { }) } } + +func compareSessionStates(t *testing.T, expected *SessionState, actual *SessionState) { + if expected.CreatedAt != nil { + assert.NotNil(t, actual.CreatedAt) + assert.Equal(t, true, expected.CreatedAt.Equal(*actual.CreatedAt)) + } else { + assert.Nil(t, actual.CreatedAt) + } + if expected.ExpiresOn != nil { + assert.NotNil(t, actual.ExpiresOn) + assert.Equal(t, true, expected.ExpiresOn.Equal(*actual.ExpiresOn)) + } else { + assert.Nil(t, actual.ExpiresOn) + } + + // Compare sessions without *time.Time fields + exp := *expected + exp.CreatedAt = nil + exp.ExpiresOn = nil + act := *actual + act.CreatedAt = nil + act.ExpiresOn = nil + assert.Equal(t, exp, act) +} diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go index c1158b5c..37e08ba8 100644 --- a/pkg/encryption/cipher.go +++ b/pkg/encryption/cipher.go @@ -21,12 +21,8 @@ type base64Cipher struct { // NewBase64Cipher returns a new AES Cipher for encrypting cookie values // and wrapping them in Base64 -- Supports Legacy encryption scheme -func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) { - c, err := initCipher(secret) - if err != nil { - return nil, err - } - return &base64Cipher{Cipher: c}, nil +func NewBase64Cipher(c Cipher) Cipher { + return &base64Cipher{Cipher: c} } // Encrypt encrypts a value with the embedded Cipher & Base64 encodes it diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go index b552e70c..16e12929 100644 --- a/pkg/encryption/cipher_test.go +++ b/pkg/encryption/cipher_test.go @@ -13,8 +13,9 @@ import ( func TestEncodeAndDecodeAccessToken(t *testing.T) { const secret = "0123456789abcdefghijklmnopqrstuv" const token = "my access token" - c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) - assert.Equal(t, nil, err) + cfb, err := NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + c := NewBase64Cipher(cfb) encoded, err := c.Encrypt([]byte(token)) assert.Equal(t, nil, err) @@ -32,8 +33,9 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { secret, err := base64.URLEncoding.DecodeString(secretBase64) assert.Equal(t, nil, err) - c, err := NewBase64Cipher(NewCFBCipher, []byte(secret)) - assert.Equal(t, nil, err) + cfb, err := NewCFBCipher([]byte(secret)) + assert.NoError(t, err) + c := NewBase64Cipher(cfb) encoded, err := c.Encrypt([]byte(token)) assert.Equal(t, nil, err) @@ -64,8 +66,7 @@ func TestEncryptAndDecrypt(t *testing.T) { cstd, err := initCipher(secret) assert.Equal(t, nil, err) - cb64, err := NewBase64Cipher(initCipher, secret) - assert.Equal(t, nil, err) + cb64 := NewBase64Cipher(cstd) ciphers := map[string]Cipher{ "Standard": cstd, diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 6fa6b5ea..69b55d11 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -60,7 +60,7 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { return nil, errors.New("cookie signature not valid") } - session, err := sessionFromCookie(string(val), s.CookieCipher) + session, err := sessionFromCookie(val, s.CookieCipher) if err != nil { return nil, err } @@ -85,17 +85,26 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { } // cookieForSession serializes a session state for storage in a cookie -func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) { - return s.EncodeSessionState(c) +func cookieForSession(s *sessions.SessionState, c encryption.Cipher) ([]byte, error) { + return s.EncodeSessionState(c, true) } // sessionFromCookie deserializes a session from a cookie value -func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) { - return sessions.DecodeSessionState(v, c) +func sessionFromCookie(v []byte, c encryption.Cipher) (s *sessions.SessionState, err error) { + ss, err := sessions.DecodeSessionState(v, c, true) + // If anything fails (Decrypt, LZ4, MessagePack), try legacy JSON decode + // LZ4 will likely fail for wrong header after AES-CFB spits out garbage + // data from trying to decrypt JSON it things is ciphertext + if err != nil { + // Legacy used Base64 + AES CFB + legacyCipher := encryption.NewBase64Cipher(c) + return sessions.LegacyV5DecodeSessionState(string(v), legacyCipher) + } + return ss, nil } // setSessionCookie adds the user's session cookie to the response -func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { +func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val []byte, created time.Time) { for _, c := range s.makeSessionCookie(req, val, created) { http.SetCookie(rw, c) } @@ -103,12 +112,12 @@ func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Reques // makeSessionCookie creates an http.Cookie containing the authenticated user's // authentication details -func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie { - if value != "" { - value = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, []byte(value), now) +func (s *SessionStore) makeSessionCookie(req *http.Request, value []byte, now time.Time) []*http.Cookie { + strValue := string(value) + if strValue != "" { + strValue = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, value, now) } - c := s.makeCookie(req, s.Cookie.Name, value, s.Cookie.Expire, now) - + c := s.makeCookie(req, s.Cookie.Name, strValue, s.Cookie.Expire, now) if len(c.String()) > maxCookieLength { return splitCookie(c) } @@ -129,7 +138,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string, // NewCookieSessionStore initialises a new instance of the SessionStore from // the configuration given func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { - cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) + cipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) if err != nil { return nil, fmt.Errorf("error initialising cipher: %v", err) } diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index a89349e8..0e0d7cd9 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -40,7 +40,7 @@ type SessionStore struct { // NewRedisSessionStore initialises a new instance of the SessionStore from // the configuration given func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { - cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) + cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) if err != nil { return nil, fmt.Errorf("error initialising cipher: %v", err) } @@ -52,7 +52,7 @@ func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cook rs := &SessionStore{ Client: client, - CookieCipher: cipher, + CookieCipher: cfbCipher, Cookie: cookieOpts, } return rs, nil @@ -146,12 +146,8 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se // Old sessions that we are refreshing would have a request cookie // New sessions don't, so we ignore the error. storeValue will check requestCookie requestCookie, _ := req.Cookie(store.Cookie.Name) - value, err := s.EncodeSessionState(store.CookieCipher) - if err != nil { - return err - } ctx := req.Context() - ticketString, err := store.storeValue(ctx, value, store.Cookie.Expire, requestCookie) + ticketString, err := store.saveSession(ctx, s, store.Cookie.Expire, requestCookie) if err != nil { return err } @@ -180,40 +176,13 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro return nil, fmt.Errorf("cookie signature not valid") } ctx := req.Context() - session, err := store.loadSessionFromString(ctx, string(val)) + session, err := store.loadSessionFromTicket(ctx, string(val)) if err != nil { return nil, fmt.Errorf("error loading session: %s", err) } return session, nil } -// loadSessionFromString loads the session based on the ticket value -func (store *SessionStore) loadSessionFromString(ctx context.Context, value string) (*sessions.SessionState, error) { - ticket, err := decodeTicket(store.Cookie.Name, value) - if err != nil { - return nil, err - } - - resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name)) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(ticket.Secret) - if err != nil { - return nil, err - } - // Use secret as the IV too, because each entry has it's own key - stream := cipher.NewCFBDecrypter(block, ticket.Secret) - stream.XORKeyStream(resultBytes, resultBytes) - - session, err := sessions.DecodeSessionState(string(resultBytes), store.CookieCipher) - if err != nil { - return nil, err - } - return session, nil -} - // Clear clears any saved session information for a given ticket cookie // from redis, and then clears the session func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { @@ -253,6 +222,80 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro return nil } +// saveSession encodes a session with a GCM cipher & saves the data into Redis +func (store *SessionStore) saveSession(ctx context.Context, s *sessions.SessionState, expiration time.Duration, requestCookie *http.Cookie) (string, error) { + ticket, err := store.getTicket(requestCookie) + if err != nil { + return "", fmt.Errorf("error getting ticket: %v", err) + } + + c, err := encryption.NewGCMCipher(ticket.Secret) + if err != nil { + return "", fmt.Errorf("error initiating cipher block %s", err) + } + + // Use AES-GCM since it provides authenticated encryption + // AES-CFB used in cookies has the cookie signing SHA to get around the lack of + // authentication in AES-CFB + ciphertext, err := s.EncodeSessionState(c, false) + if err != nil { + return "", err + } + + handle := ticket.asHandle(store.Cookie.Name) + err = store.Client.Set(ctx, handle, ciphertext, expiration) + if err != nil { + return "", err + } + return ticket.encodeTicket(store.Cookie.Name), nil +} + +// loadSessionFromTicket loads the session based on the ticket value +func (store *SessionStore) loadSessionFromTicket(ctx context.Context, value string) (*sessions.SessionState, error) { + ticket, err := decodeTicket(store.Cookie.Name, value) + if err != nil { + return nil, err + } + + resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name)) + if err != nil { + return nil, err + } + + c, err := encryption.NewGCMCipher(ticket.Secret) + if err != nil { + return nil, err + } + + session, err := sessions.DecodeSessionState(resultBytes, c, false) + if err != nil { + // The GCM cipher will error due to a legacy JSON payload not passing + // the authentication check part of AES GCM encryption. + // In that case, we can attempt to fallback to try a legacy load + legacyCipher := encryption.NewBase64Cipher(store.CookieCipher) + return legacyV5DecodeSession(resultBytes, ticket, legacyCipher) + } + return session, nil +} + +// legacyV5DecodeSession loads the session based on the ticket value +// This fallback uses V5 style encryption of Base64 + AES CFB +func legacyV5DecodeSession(resultBytes []byte, ticket *TicketData, c encryption.Cipher) (*sessions.SessionState, error) { + block, err := aes.NewCipher(ticket.Secret) + if err != nil { + return nil, err + } + // Use secret as the IV too, because each entry has it's own key + stream := cipher.NewCFBDecrypter(block, ticket.Secret) + stream.XORKeyStream(resultBytes, resultBytes) + + session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), c) + if err != nil { + return nil, err + } + return session, nil +} + // makeCookie makes a cookie, signing the value if present func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { if value != "" { @@ -268,30 +311,6 @@ func (store *SessionStore) makeCookie(req *http.Request, value string, expires t ) } -func (store *SessionStore) storeValue(ctx context.Context, value string, expiration time.Duration, requestCookie *http.Cookie) (string, error) { - ticket, err := store.getTicket(requestCookie) - if err != nil { - return "", fmt.Errorf("error getting ticket: %v", err) - } - - ciphertext := make([]byte, len(value)) - block, err := aes.NewCipher(ticket.Secret) - if err != nil { - return "", fmt.Errorf("error initiating cipher block %s", err) - } - - // Use secret as the Initialization Vector too, because each entry has it's own key - stream := cipher.NewCFBEncrypter(block, ticket.Secret) - stream.XORKeyStream(ciphertext, []byte(value)) - - handle := ticket.asHandle(store.Cookie.Name) - err = store.Client.Set(ctx, handle, ciphertext, expiration) - if err != nil { - return "", err - } - return ticket.encodeTicket(store.Cookie.Name), nil -} - // getTicket retrieves an existing ticket from the cookie if present, // or creates a new ticket func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, error) { diff --git a/pkg/sessions/redis/session_store_test.go b/pkg/sessions/redis/redis_store_test.go similarity index 65% rename from pkg/sessions/redis/session_store_test.go rename to pkg/sessions/redis/redis_store_test.go index 78dd111f..12965705 100644 --- a/pkg/sessions/redis/session_store_test.go +++ b/pkg/sessions/redis/redis_store_test.go @@ -1,6 +1,11 @@ package redis import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" "log" "os" "testing" @@ -17,6 +22,65 @@ import ( . "github.com/onsi/gomega" ) +// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession +// when a V5 encoded session is in Redis +// +// TODO: Remove when this is deprecated (likely V7) +func Test_legacyV5DecodeSession(t *testing.T) { + testCases, _, legacyCipher := sessionsapi.CreateLegacyV5TestCases(t) + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + + secret := make([]byte, aes.BlockSize) + _, err := io.ReadFull(rand.Reader, secret) + g.Expect(err).ToNot(HaveOccurred()) + ticket := &TicketData{ + TicketID: "", + Secret: secret, + } + + encrypted, err := legacyStoreValue(tc.Input, ticket.Secret) + g.Expect(err).ToNot(HaveOccurred()) + + ss, err := legacyV5DecodeSession(encrypted, ticket, legacyCipher) + if tc.Error { + g.Expect(err).To(HaveOccurred()) + g.Expect(ss).To(BeNil()) + return + } + g.Expect(err).ToNot(HaveOccurred()) + + // Compare sessions without *time.Time fields + exp := *tc.Output + exp.CreatedAt = nil + exp.ExpiresOn = nil + act := *ss + act.CreatedAt = nil + act.ExpiresOn = nil + g.Expect(exp).To(Equal(act)) + }) + } +} + +// legacyStoreValue implements the legacy V5 Redis store AES-CFB value encryption +// +// TODO: Remove when this is deprecated (likely V7) +func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) { + ciphertext := make([]byte, len(value)) + block, err := aes.NewCipher(ticketSecret) + if err != nil { + return nil, fmt.Errorf("error initiating cipher block: %v", err) + } + + // Use secret as the Initialization Vector too, because each entry has it's own key + stream := cipher.NewCFBEncrypter(block, ticketSecret) + stream.XORKeyStream(ciphertext, []byte(value)) + + return ciphertext, nil +} + func TestSessionStore(t *testing.T) { logger.SetOutput(GinkgoWriter)