From 02069521822c3c6d952cd231bd3951f3fa743891 Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Sun, 9 Mar 2025 21:16:42 +0100 Subject: [PATCH] chore: replace gin with standard lib net/http --- cmd/wg-portal/main.go | 44 ++- frontend/src/router/index.js | 6 + go.mod | 44 +-- go.sum | 130 ++------ .../api/core/middleware/cors/middleware.go | 214 +++++++++++++ .../core/middleware/cors/middleware_test.go | 101 ++++++ .../app/api/core/middleware/cors/options.go | 133 ++++++++ .../api/core/middleware/cors/options_test.go | 96 ++++++ .../app/api/core/middleware/cors/wildcard.go | 33 ++ .../api/core/middleware/cors/wildcard_test.go | 94 ++++++ .../api/core/middleware/csrf/middleware.go | 137 ++++++++ .../core/middleware/csrf/middleware_test.go | 251 +++++++++++++++ .../app/api/core/middleware/csrf/options.go | 88 ++++++ .../api/core/middleware/csrf/options_test.go | 75 +++++ .../app/api/core/middleware/csrf/token.go | 90 ++++++ .../api/core/middleware/csrf/token_test.go | 81 +++++ .../api/core/middleware/logging/middleware.go | 199 ++++++++++++ .../middleware/logging/middleware_test.go | 148 +++++++++ .../api/core/middleware/logging/options.go | 80 +++++ .../core/middleware/logging/options_test.go | 88 ++++++ .../app/api/core/middleware/logging/writer.go | 45 +++ .../core/middleware/logging/writer_test.go | 85 +++++ .../core/middleware/recovery/middleware.go | 133 ++++++++ .../middleware/recovery/middleware_test.go | 149 +++++++++ .../api/core/middleware/recovery/options.go | 129 ++++++++ .../core/middleware/recovery/options_test.go | 100 ++++++ .../api/core/middleware/tracing/middleware.go | 69 ++++ .../middleware/tracing/middleware_test.go | 118 +++++++ .../api/core/middleware/tracing/options.go | 85 +++++ .../core/middleware/tracing/options_test.go | 75 +++++ internal/app/api/core/request/basic.go | 259 +++++++++++++++ internal/app/api/core/request/basic_test.go | 221 +++++++++++++ internal/app/api/core/respond/basic.go | 100 ++++++ internal/app/api/core/respond/basic_test.go | 273 ++++++++++++++++ internal/app/api/core/respond/template.go | 46 +++ .../app/api/core/respond/template_test.go | 67 ++++ internal/app/api/core/server.go | 140 +++------ internal/app/api/v0/handlers/base.go | 107 ++++--- .../v0/handlers/endpoint_authentication.go | 188 ++++++----- .../app/api/v0/handlers/endpoint_config.go | 77 +++-- .../api/v0/handlers/endpoint_interfaces.go | 248 ++++++++------- .../app/api/v0/handlers/endpoint_peers.go | 297 +++++++++--------- .../app/api/v0/handlers/endpoint_testing.go | 37 ++- .../app/api/v0/handlers/endpoint_users.go | 237 +++++++------- .../v0/handlers/middleware_authentication.go | 111 ------- internal/app/api/v0/handlers/session.go | 92 ------ .../app/api/v0/handlers/web_authentication.go | 126 ++++++++ internal/app/api/v0/handlers/web_session.go | 88 ++++++ internal/app/api/v1/handlers/base.go | 29 +- .../app/api/v1/handlers/endpoint_interface.go | 137 ++++---- .../app/api/v1/handlers/endpoint_metrics.go | 84 ++--- internal/app/api/v1/handlers/endpoint_peer.go | 156 ++++----- .../api/v1/handlers/endpoint_provisioning.go | 108 ++++--- internal/app/api/v1/handlers/endpoint_user.go | 131 ++++---- .../v1/handlers/middleware_authentication.go | 93 ------ .../app/api/v1/handlers/web_authentication.go | 101 ++++++ internal/config/web.go | 2 + internal/domain/context.go | 17 - 58 files changed, 5302 insertions(+), 1390 deletions(-) create mode 100644 internal/app/api/core/middleware/cors/middleware.go create mode 100644 internal/app/api/core/middleware/cors/middleware_test.go create mode 100644 internal/app/api/core/middleware/cors/options.go create mode 100644 internal/app/api/core/middleware/cors/options_test.go create mode 100644 internal/app/api/core/middleware/cors/wildcard.go create mode 100644 internal/app/api/core/middleware/cors/wildcard_test.go create mode 100644 internal/app/api/core/middleware/csrf/middleware.go create mode 100644 internal/app/api/core/middleware/csrf/middleware_test.go create mode 100644 internal/app/api/core/middleware/csrf/options.go create mode 100644 internal/app/api/core/middleware/csrf/options_test.go create mode 100644 internal/app/api/core/middleware/csrf/token.go create mode 100644 internal/app/api/core/middleware/csrf/token_test.go create mode 100644 internal/app/api/core/middleware/logging/middleware.go create mode 100644 internal/app/api/core/middleware/logging/middleware_test.go create mode 100644 internal/app/api/core/middleware/logging/options.go create mode 100644 internal/app/api/core/middleware/logging/options_test.go create mode 100644 internal/app/api/core/middleware/logging/writer.go create mode 100644 internal/app/api/core/middleware/logging/writer_test.go create mode 100644 internal/app/api/core/middleware/recovery/middleware.go create mode 100644 internal/app/api/core/middleware/recovery/middleware_test.go create mode 100644 internal/app/api/core/middleware/recovery/options.go create mode 100644 internal/app/api/core/middleware/recovery/options_test.go create mode 100644 internal/app/api/core/middleware/tracing/middleware.go create mode 100644 internal/app/api/core/middleware/tracing/middleware_test.go create mode 100644 internal/app/api/core/middleware/tracing/options.go create mode 100644 internal/app/api/core/middleware/tracing/options_test.go create mode 100644 internal/app/api/core/request/basic.go create mode 100644 internal/app/api/core/request/basic_test.go create mode 100644 internal/app/api/core/respond/basic.go create mode 100644 internal/app/api/core/respond/basic_test.go create mode 100644 internal/app/api/core/respond/template.go create mode 100644 internal/app/api/core/respond/template_test.go delete mode 100644 internal/app/api/v0/handlers/middleware_authentication.go delete mode 100644 internal/app/api/v0/handlers/session.go create mode 100644 internal/app/api/v0/handlers/web_authentication.go create mode 100644 internal/app/api/v0/handlers/web_session.go delete mode 100644 internal/app/api/v1/handlers/middleware_authentication.go create mode 100644 internal/app/api/v1/handlers/web_authentication.go diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index d8f8d33..edd4c48 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -7,6 +7,7 @@ import ( "syscall" "time" + "github.com/go-playground/validator/v10" evbus "github.com/vardius/message-bus" "github.com/h44z/wg-portal/internal" @@ -101,21 +102,48 @@ func main() { err = backend.Startup(ctx) internal.AssertNoError(err) - apiFrontend := handlersV0.NewRestApi(cfg, backend) + validatorManager := validator.New() + // region API v0 (SPA frontend) + + apiV0Session := handlersV0.NewSessionWrapper(cfg) + apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session) + + apiV0EndpointAuth := handlersV0.NewAuthEndpoint(backend, apiV0Auth, apiV0Session, validatorManager) + apiV0EndpointUsers := handlersV0.NewUserEndpoint(backend, apiV0Auth, validatorManager) + apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(backend, apiV0Auth, validatorManager) + apiV0EndpointPeers := handlersV0.NewPeerEndpoint(backend, apiV0Auth, validatorManager) + apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth) + apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) + + apiFrontend := handlersV0.NewRestApi(apiV0Session, + apiV0EndpointAuth, + apiV0EndpointUsers, + apiV0EndpointInterfaces, + apiV0EndpointPeers, + apiV0EndpointConfig, + apiV0EndpointTest, + ) + + // endregion API v0 (SPA frontend) + + // region API v1 (User REST API) + + apiV1Auth := handlersV1.NewAuthenticationHandler(userManager) apiV1BackendUsers := backendV1.NewUserService(cfg, userManager) apiV1BackendPeers := backendV1.NewPeerService(cfg, wireGuardManager, userManager) apiV1BackendInterfaces := backendV1.NewInterfaceService(cfg, wireGuardManager) apiV1BackendProvisioning := backendV1.NewProvisioningService(cfg, userManager, wireGuardManager, cfgFileManager) apiV1BackendMetrics := backendV1.NewMetricsService(cfg, database, userManager, wireGuardManager) - apiV1EndpointUsers := handlersV1.NewUserEndpoint(apiV1BackendUsers) - apiV1EndpointPeers := handlersV1.NewPeerEndpoint(apiV1BackendPeers) - apiV1EndpointInterfaces := handlersV1.NewInterfaceEndpoint(apiV1BackendInterfaces) - apiV1EndpointProvisioning := handlersV1.NewProvisioningEndpoint(apiV1BackendProvisioning) - apiV1EndpointMetrics := handlersV1.NewMetricsEndpoint(apiV1BackendMetrics) + + apiV1EndpointUsers := handlersV1.NewUserEndpoint(apiV1Auth, validatorManager, apiV1BackendUsers) + apiV1EndpointPeers := handlersV1.NewPeerEndpoint(apiV1Auth, validatorManager, apiV1BackendPeers) + apiV1EndpointInterfaces := handlersV1.NewInterfaceEndpoint(apiV1Auth, validatorManager, apiV1BackendInterfaces) + apiV1EndpointProvisioning := handlersV1.NewProvisioningEndpoint(apiV1Auth, validatorManager, + apiV1BackendProvisioning) + apiV1EndpointMetrics := handlersV1.NewMetricsEndpoint(apiV1Auth, validatorManager, apiV1BackendMetrics) apiV1 := handlersV1.NewRestApi( - userManager, apiV1EndpointUsers, apiV1EndpointPeers, apiV1EndpointInterfaces, @@ -123,6 +151,8 @@ func main() { apiV1EndpointMetrics, ) + // endregion API v1 (User REST API) + webSrv, err := core.NewServer(cfg, apiFrontend, apiV1) internal.AssertNoError(err) diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js index 25adf58..5bf47ad 100644 --- a/frontend/src/router/index.js +++ b/frontend/src/router/index.js @@ -4,6 +4,7 @@ import LoginView from '../views/LoginView.vue' import InterfaceView from '../views/InterfaceView.vue' import {authStore} from '@/stores/auth' +import {securityStore} from '@/stores/security' import {notify} from "@kyvg/vue3-notification"; const router = createRouter({ @@ -63,6 +64,7 @@ const router = createRouter({ router.beforeEach(async (to) => { const auth = authStore() + const sec = securityStore() // check if the request was a successful oauth login if ('wgLoginState' in to.query && !auth.IsAuthenticated) { @@ -112,6 +114,10 @@ router.beforeEach(async (to) => { auth.SetReturnUrl(to.fullPath) // store original destination before starting the auth process return '/login' } + + if (publicPages.includes(to.path)) { + await sec.LoadSecurityProperties() // make sure we have a valid csrf token + } }) export default router diff --git a/go.mod b/go.mod index 4e331db..e0cae55 100644 --- a/go.mod +++ b/go.mod @@ -4,26 +4,25 @@ go 1.24.0 require ( github.com/a8m/envsubst v1.4.3 + github.com/alexedwards/scs/v2 v2.8.0 github.com/coreos/go-oidc/v3 v3.12.0 - github.com/gin-contrib/cors v1.7.3 - github.com/gin-contrib/sessions v1.0.2 - github.com/gin-gonic/gin v1.10.0 github.com/glebarez/sqlite v1.11.0 github.com/go-ldap/ldap/v3 v3.4.10 + github.com/go-pkgz/routegroup v1.3.1 + github.com/go-playground/validator/v10 v10.25.0 github.com/google/uuid v1.6.0 github.com/prometheus-community/pro-bing v0.6.1 - github.com/prometheus/client_golang v1.21.0 + github.com/prometheus/client_golang v1.21.1 github.com/stretchr/testify v1.10.0 github.com/swaggo/swag v1.16.4 - github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca github.com/vardius/message-bus v1.1.5 github.com/vishvananda/netlink v1.3.0 github.com/xhit/go-simple-mail/v2 v2.16.0 github.com/yeqown/go-qrcode/v2 v2.2.5 github.com/yeqown/go-qrcode/writer/compressed v1.0.1 - golang.org/x/crypto v0.35.0 - golang.org/x/oauth2 v0.27.0 - golang.org/x/sys v0.30.0 + golang.org/x/crypto v0.36.0 + golang.org/x/oauth2 v0.28.0 + golang.org/x/sys v0.31.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -37,15 +36,10 @@ require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bytedance/sonic v1.12.9 // indirect - github.com/bytedance/sonic/loader v0.2.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cloudwego/base64x v0.1.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dchest/uniuri v1.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect - github.com/gin-contrib/sse v1.0.0 // indirect github.com/glebarez/go-sqlite v1.22.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect github.com/go-jose/go-jose/v4 v4.0.5 // indirect @@ -55,16 +49,11 @@ require ( github.com/go-openapi/swag v0.23.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.25.0 // indirect github.com/go-sql-driver/mysql v1.9.0 // indirect github.com/go-test/deep v1.1.1 // indirect - github.com/goccy/go-json v0.10.5 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/gorilla/context v1.1.2 // indirect - github.com/gorilla/securecookie v1.1.2 // indirect - github.com/gorilla/sessions v1.4.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.2 // indirect @@ -73,9 +62,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/josharian/native v1.1.0 // indirect - github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -83,28 +70,21 @@ require ( github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/microsoft/go-mssqldb v1.8.0 // indirect - github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect - github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817 // indirect - github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.12 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/yeqown/reedsolomon v1.0.0 // indirect - golang.org/x/arch v0.14.0 // indirect - golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect - golang.org/x/net v0.35.0 // indirect - golang.org/x/sync v0.11.0 // indirect - golang.org/x/text v0.22.0 // indirect - golang.org/x/tools v0.30.0 // indirect + golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect + golang.org/x/net v0.37.0 // indirect + golang.org/x/sync v0.12.0 // indirect + golang.org/x/text v0.23.0 // indirect + golang.org/x/tools v0.31.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect google.golang.org/protobuf v1.36.5 // indirect modernc.org/libc v1.61.13 // indirect diff --git a/go.sum b/go.sum index 575af42..409eda5 100644 --- a/go.sum +++ b/go.sum @@ -29,52 +29,27 @@ github.com/a8m/envsubst v1.4.3 h1:kDF7paGK8QACWYaQo6KtyYBozY2jhQrTuNNuUxQkhJY= github.com/a8m/envsubst v1.4.3/go.mod h1:4jjHWQlZoaXPoLQUb7H2qT4iLkZDdmEQiOUogdUmqVU= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/alexedwards/scs/v2 v2.8.0 h1:h31yUYoycPuL0zt14c0gd+oqxfRwIj6SOjHdKRZxhEw= +github.com/alexedwards/scs/v2 v2.8.0/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw= -github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= -github.com/bradleypeabody/gorilla-sessions-memcache v0.0.0-20181103040241-659414f458e1/go.mod h1:dkChI7Tbtx7H1Tj7TqGSZMOeGpMP5gLHtjroHd4agiI= -github.com/bytedance/sonic v1.12.9 h1:Od1BvK55NnewtGaJsTDeAOSnLVO2BTSLOe0+ooKokmQ= -github.com/bytedance/sonic v1.12.9/go.mod h1:uVvFidNmlt9+wa31S1urfwwthTWteBgG0hWuoKAXTx8= -github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= -github.com/bytedance/sonic/loader v0.2.3 h1:yctD0Q3v2NOGfSWPLPvG2ggA2kV6TS6s4wioyEqssH0= -github.com/bytedance/sonic/loader v0.2.3/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= -github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= -github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo= github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4= -github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g= -github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= -github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= -github.com/gin-contrib/cors v1.7.3 h1:hV+a5xp8hwJoTw7OY+a70FsL8JkVVFTXw9EcfrYUdns= -github.com/gin-contrib/cors v1.7.3/go.mod h1:M3bcKZhxzsvI+rlRSkkxHyljJt1ESd93COUvemZ79j4= -github.com/gin-contrib/sessions v0.0.0-20190101140330-dc5246754963/go.mod h1:4lkInX8nHSR62NSmhXM3xtPeMSyfiR58NaEz+om1lHM= -github.com/gin-contrib/sessions v1.0.2 h1:UaIjUvTH1cMeOdj3in6dl+Xb6It8RiKRF9Z1anbUyCA= -github.com/gin-contrib/sessions v1.0.2/go.mod h1:KxKxWqWP5LJVDCInulOl4WbLzK2KSPlLesfZ66wRvMs= -github.com/gin-contrib/sse v0.0.0-20170109093832-22d885f9ecc7/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= -github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E= -github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0= -github.com/gin-gonic/gin v1.3.0/go.mod h1:7cKuhb5qV2ggCFctp2fJQ+ErvciLZrIeoOSOm6mUr7Y= -github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= -github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= -github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/go-asn1-ber/asn1-ber v1.5.7 h1:DTX+lbVTWaTw1hQ+PbZPlnDZPEIs0SS/GCZAl535dDk= github.com/go-asn1-ber/asn1-ber v1.5.7/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= @@ -89,6 +64,8 @@ github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9Z github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-pkgz/routegroup v1.3.1 h1:XAVWskX8Iup6HoQD9zv+gJx4DOJC2DSkKBHCMeeW8/s= +github.com/go-pkgz/routegroup v1.3.1/go.mod h1:kDDPDRLRiRY1vnENrZJw1jQAzQX7fvsbsHGRQFNQfKc= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -102,8 +79,6 @@ github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1 github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= @@ -112,32 +87,18 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= -github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= -github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= -github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= -github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= -github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= -github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -169,21 +130,10 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= -github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kidstuff/mongostore v0.0.0-20181113001930-e650cd85ee4b/go.mod h1:g2nVr8KZVXJSS97Jo8pJ0jgq29P6H7dG0oplUA86MQw= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= -github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -192,7 +142,6 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= @@ -201,26 +150,17 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= -github.com/memcachier/mc v2.0.1+incompatible/go.mod h1:7bkvFE61leUBvXz+yxsOnGBQSZpBSPIMUQSmmSHvuXc= github.com/microsoft/go-mssqldb v1.7.2/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA= github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw= github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= -github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= @@ -228,17 +168,14 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus-community/pro-bing v0.6.1 h1:EQukUOma9YFZRPe4DGSscxUf9LH07rpqwisNWjSZrgU= github.com/prometheus-community/pro-bing v0.6.1/go.mod h1:jNCOI3D7pmTCeaoF41cNS6uaxeFY/Gmc3ffwbuJVzAQ= -github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA= -github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= +github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= +github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quasoft/memstore v0.0.0-20180925164028-84a050167438/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= -github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc= -github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -246,8 +183,6 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -262,13 +197,6 @@ github.com/swaggo/swag v1.16.4/go.mod h1:VBsHJRsDvfYvqoiMKnsdwhNV9LEMHgEDZcyVYX0 github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns= github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817 h1:q0hKh5a5FRkhuTb5JNfgjzpzvYLHjH0QOgPZPYnRWGA= github.com/toorop/go-dkim v0.0.0-20250226130143-9025cce95817/go.mod h1:BzWtXXrXzZUvMacR0oF/fbDDgUPO8L36tDMmRAf14ns= -github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= -github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v0.0.0-20181209151446-772ced7fd4c2/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= -github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca h1:lpvAjPK+PcxnbcB8H7axIb4fMNwjX9bE4DzwPjGg8aE= -github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca/go.mod h1:XXKxNbpoLihvvT7orUZbs/iZayg1n4ip7iJakJPAwA8= github.com/vardius/message-bus v1.1.5 h1:YSAC2WB4HRlwc4neFPTmT88kzzoiQ+9WRRbej/E/LZc= github.com/vardius/message-bus v1.1.5/go.mod h1:6xladCV2lMkUAE4bzzS85qKOiB5miV7aBVRafiTJGqw= github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= @@ -285,8 +213,6 @@ github.com/yeqown/go-qrcode/writer/compressed v1.0.1/go.mod h1:BJScsGUIKM+eg0CCL github.com/yeqown/reedsolomon v1.0.0 h1:x1h/Ej/uJnNu8jaX7GLHBWmZKCAWjEJTetkqaabr4B0= github.com/yeqown/reedsolomon v1.0.0/go.mod h1:P76zpcn2TCuL0ul1Fso373qHRc69LKwAw/Iy6g1WiiM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4= -golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= @@ -300,18 +226,17 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= -golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= -golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= -golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= +golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= -golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -329,11 +254,10 @@ golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= -golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= -golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= -golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -341,9 +265,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20181228144115-9a3f9b0469bb/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -364,8 +287,8 @@ golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -393,16 +316,16 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= -golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= +golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= +golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= @@ -411,12 +334,8 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= -gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y= -gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -458,6 +377,5 @@ modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/internal/app/api/core/middleware/cors/middleware.go b/internal/app/api/core/middleware/cors/middleware.go new file mode 100644 index 0000000..3aec5b2 --- /dev/null +++ b/internal/app/api/core/middleware/cors/middleware.go @@ -0,0 +1,214 @@ +package cors + +import ( + "net/http" + "slices" + "strconv" + "strings" +) + +// Middleware is a type that creates a new CORS middleware. The CORS middleware +// adds Cross-Origin Resource Sharing headers to the response. This middleware should +// be used to allow cross-origin requests to your server. +type Middleware struct { + o options + + varyHeaders string // precomputed Vary header + allOrigins bool // all origins are allowed +} + +// New returns a new CORS middleware with the provided options. +func New(opts ...Option) *Middleware { + o := newOptions(opts...) + + m := &Middleware{ + o: o, + } + + // set vary headers + if m.o.allowPrivateNetworks { + m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network" + } else { + m.varyHeaders = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers" + } + + if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" { + m.allOrigins = true + } + + return m +} + +// Handler returns the CORS middleware handler. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle preflight requests and stop the chain as some other + // middleware may not handle OPTIONS requests correctly. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + m.handlePreflight(w, r) + w.WriteHeader(http.StatusNoContent) // always return 204 No Content + return + } + + // handle normal CORS requests + m.handleNormal(w, r) + next.ServeHTTP(w, r) // execute the next handler + }) +} + +// region internal-helpers + +// handlePreflight handles preflight requests. If the request was successful, this function will +// write the CORS headers and return. If the request was not successful, this function will +// not add any CORS headers and return - thus the CORS request is considered invalid. +func (m *Middleware) handlePreflight(w http.ResponseWriter, r *http.Request) { + // Always set Vary headers + // see https://github.com/rs/cors/issues/10, + // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 + w.Header().Add("Vary", m.varyHeaders) + + // check origin + origin := r.Header.Get("Origin") + if origin == "" { + return // not a valid CORS request + } + + if !m.originAllowed(origin) { + return + } + + // check method + reqMethod := r.Header.Get("Access-Control-Request-Method") + if !m.methodAllowed(reqMethod) { + return + } + + // check headers + reqHeaders := r.Header.Get("Access-Control-Request-Headers") + if !m.headersAllowed(reqHeaders) { + return + } + + // set CORS headers for the successful preflight request + if m.allOrigins { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin + } + w.Header().Set("Access-Control-Allow-Methods", reqMethod) + if reqHeaders != "" { + // Spec says: Since the list of headers can be unbounded, simply returning supported headers + // from Access-Control-Request-Headers can be enough + w.Header().Set("Access-Control-Allow-Headers", reqHeaders) + } + if m.o.allowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + if m.o.allowPrivateNetworks && r.Header.Get("Access-Control-Request-Private-Network") == "true" { + w.Header().Set("Access-Control-Allow-Private-Network", "true") + } + if m.o.maxAge > 0 { + w.Header().Set("Access-Control-Max-Age", strconv.Itoa(m.o.maxAge)) + } +} + +// handleNormal handles normal CORS requests. If the request was successful, this function will +// write the CORS headers to the response. If the request was not successful, this function will +// not add any CORS headers to the response. In this case, the CORS request is considered invalid. +func (m *Middleware) handleNormal(w http.ResponseWriter, r *http.Request) { + // Always set Vary headers + // see https://github.com/rs/cors/issues/10, + // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 + w.Header().Add("Vary", "Origin") + + // check origin + origin := r.Header.Get("Origin") + if origin == "" { + return // not a valid CORS request + } + + if !m.originAllowed(origin) { + return + } + + // check method + if !m.methodAllowed(r.Method) { + return + } + + // set CORS headers for the successful CORS request + if m.allOrigins { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + w.Header().Set("Access-Control-Allow-Origin", origin) // return original origin + } + if len(m.o.exposedHeaders) > 0 { + w.Header().Set("Access-Control-Expose-Headers", strings.Join(m.o.exposedHeaders, ", ")) + } + if m.o.allowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } +} + +func (m *Middleware) originAllowed(origin string) bool { + if len(m.o.allowedOrigins) == 1 && m.o.allowedOrigins[0] == "*" { + return true // everything is allowed + } + + // check simple origins + if slices.Contains(m.o.allowedOrigins, origin) { + return true + } + + // check wildcard origins + for _, allowedOrigin := range m.o.allowedOriginPatterns { + if allowedOrigin.match(origin) { + return true + } + } + + return false +} + +func (m *Middleware) methodAllowed(method string) bool { + if method == http.MethodOptions { + return true // preflight request is always allowed + } + + if len(m.o.allowedMethods) == 1 && m.o.allowedMethods[0] == "*" { + return true // everything is allowed + } + + if slices.Contains(m.o.allowedMethods, method) { + return true + } + + return false +} + +func (m *Middleware) headersAllowed(headers string) bool { + if headers == "" { + return true // no headers are requested + } + + if len(m.o.allowedHeaders) == 0 { + return false // no headers are allowed + } + + if _, ok := m.o.allowedHeaders["*"]; ok { + return true // everything is allowed + } + + // split headers by comma (according to definition, the headers are sorted and in lowercase) + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Request-Headers + for header := range strings.SplitSeq(headers, ",") { + if _, ok := m.o.allowedHeaders[strings.TrimSpace(header)]; !ok { + return false + } + } + + return true +} + +// endregion internal-helpers diff --git a/internal/app/api/core/middleware/cors/middleware_test.go b/internal/app/api/core/middleware/cors/middleware_test.go new file mode 100644 index 0000000..e832645 --- /dev/null +++ b/internal/app/api/core/middleware/cors/middleware_test.go @@ -0,0 +1,101 @@ +package cors + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestMiddleware_New(t *testing.T) { + m := New(WithAllowedOrigins("*")) + + if len(m.varyHeaders) == 0 { + t.Errorf("expected vary headers to be populated, got %v", m.varyHeaders) + } + if !m.allOrigins { + t.Errorf("expected allOrigins to be true, got %v", m.allOrigins) + } +} + +func TestMiddleware_Handler_normal(t *testing.T) { + m := New(WithAllowedOrigins("http://example.com")) + + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("Origin", "http://example.com") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Result().StatusCode != http.StatusOK { + t.Errorf("expected status code 200, got %d", w.Result().StatusCode) + } + + if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" { + t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s", + w.Header().Get("Access-Control-Allow-Origin")) + } +} + +func TestMiddleware_Handler_preflight(t *testing.T) { + m := New(WithAllowedOrigins("http://example.com")) + + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodOptions, "http://example.com", nil) + req.Header.Set("Origin", "http://example.com") + req.Header.Set("Access-Control-Request-Method", http.MethodGet) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Result().StatusCode != http.StatusNoContent { + t.Errorf("expected status code 204, got %d", w.Result().StatusCode) + } + + if w.Header().Get("Access-Control-Allow-Origin") != "http://example.com" { + t.Errorf("expected Access-Control-Allow-Origin to be 'http://example.com', got %s", + w.Header().Get("Access-Control-Allow-Origin")) + } +} + +func TestMiddleware_originAllowed(t *testing.T) { + m := New(WithAllowedOrigins("http://example.com")) + + if !m.originAllowed("http://example.com") { + t.Errorf("expected origin 'http://example.com' to be allowed") + } + + if m.originAllowed("http://notallowed.com") { + t.Errorf("expected origin 'http://notallowed.com' to be not allowed") + } +} + +func TestMiddleware_methodAllowed(t *testing.T) { + m := New(WithAllowedMethods(http.MethodGet, http.MethodPost)) + + if !m.methodAllowed(http.MethodGet) { + t.Errorf("expected method 'GET' to be allowed") + } + + if m.methodAllowed(http.MethodDelete) { + t.Errorf("expected method 'DELETE' to be not allowed") + } +} + +func TestMiddleware_headersAllowed(t *testing.T) { + m := New(WithAllowedHeaders("Content-Type", "Authorization")) + + if !m.headersAllowed("content-type, authorization") { + t.Errorf("expected headers 'Content-Type, Authorization' to be allowed") + } + + if m.headersAllowed("x-custom-header") { + t.Errorf("expected header 'X-Custom-Header' to be not allowed") + } +} diff --git a/internal/app/api/core/middleware/cors/options.go b/internal/app/api/core/middleware/cors/options.go new file mode 100644 index 0000000..5675ac7 --- /dev/null +++ b/internal/app/api/core/middleware/cors/options.go @@ -0,0 +1,133 @@ +package cors + +import ( + "net/http" + "strings" +) + +type void struct{} + +// options is a struct that contains options for the CORS middleware. +// It uses the functional options pattern for flexible configuration. +type options struct { + allowedOrigins []string // origins without wildcards + allowedOriginPatterns []wildcard // origins with wildcards + allowedMethods []string + allowedHeaders map[string]void + exposedHeaders []string // these are in addition to the CORS-safelisted response headers + allowCredentials bool + allowPrivateNetworks bool + maxAge int +} + +// Option is a type that is used to set options for the CORS middleware. +// It implements the functional options pattern. +type Option func(*options) + +// WithAllowedOrigins sets the allowed origins for the CORS middleware. +// If the special "*" value is present in the list, all origins will be allowed. +// An origin may contain a wildcard (*) to replace 0 or more characters +// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty. +// Only one wildcard can be used per origin. +// By default, all origins are allowed (*). +func WithAllowedOrigins(origins ...string) Option { + return func(o *options) { + o.allowedOrigins = nil + o.allowedOriginPatterns = nil + + for _, origin := range origins { + if len(origin) > 1 && strings.Contains(origin, "*") { + o.allowedOriginPatterns = append( + o.allowedOriginPatterns, + newWildcard(origin), + ) + } else { + o.allowedOrigins = append(o.allowedOrigins, origin) + } + } + } +} + +// WithAllowedMethods sets the allowed methods for the CORS middleware. +// By default, all methods are allowed (*). +func WithAllowedMethods(methods ...string) Option { + return func(o *options) { + o.allowedMethods = methods + } +} + +// WithAllowedHeaders sets the allowed headers for the CORS middleware. +// By default, all headers are allowed (*). +func WithAllowedHeaders(headers ...string) Option { + return func(o *options) { + o.allowedHeaders = make(map[string]void) + + for _, header := range headers { + // allowed headers are always checked in lowercase + o.allowedHeaders[strings.ToLower(header)] = void{} + } + } +} + +// WithExposedHeaders sets the exposed headers for the CORS middleware. +// By default, no headers are exposed. +func WithExposedHeaders(headers ...string) Option { + return func(o *options) { + o.exposedHeaders = nil + + for _, header := range headers { + o.exposedHeaders = append(o.exposedHeaders, http.CanonicalHeaderKey(header)) + } + } +} + +// WithAllowCredentials sets the allow credentials option for the CORS middleware. +// This setting indicates whether the request can include user credentials like +// cookies, HTTP authentication or client side SSL certificates. +// By default, credentials are not allowed. +func WithAllowCredentials(allow bool) Option { + return func(o *options) { + o.allowCredentials = allow + } +} + +// WithAllowPrivateNetworks sets the allow private networks option for the CORS middleware. +// This setting indicates whether to accept cross-origin requests over a private network. +func WithAllowPrivateNetworks(allow bool) Option { + return func(o *options) { + o.allowPrivateNetworks = allow + } +} + +// WithMaxAge sets the max age (in seconds) for the CORS middleware. +// The maximum age indicates how long (in seconds) the results of a preflight request +// can be cached. A value of 0 means that no Access-Control-Max-Age header is sent back, +// resulting in browsers using their default value (5s by spec). +// If you need to force a 0 max-age, set it to a negative value (ie: -1). +// By default, the max age is 7200 seconds. +func WithMaxAge(age int) Option { + return func(o *options) { + o.maxAge = age + } +} + +// newOptions is a function that returns a new options struct with sane default values. +func newOptions(opts ...Option) options { + o := options{ + allowedOrigins: []string{"*"}, + allowedMethods: []string{ + http.MethodHead, http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, + }, + allowedHeaders: map[string]void{"*": {}}, + exposedHeaders: nil, + allowCredentials: false, + allowPrivateNetworks: false, + maxAge: 0, + } + + for _, opt := range opts { + opt(&o) + } + + return o +} diff --git a/internal/app/api/core/middleware/cors/options_test.go b/internal/app/api/core/middleware/cors/options_test.go new file mode 100644 index 0000000..6957dab --- /dev/null +++ b/internal/app/api/core/middleware/cors/options_test.go @@ -0,0 +1,96 @@ +package cors + +import ( + "maps" + "net/http" + "slices" + "testing" +) + +func TestWithAllowedOrigins(t *testing.T) { + tests := []struct { + name string + origins []string + wantNormal []string + wantWildcard []wildcard + }{ + { + name: "No origins", + origins: []string{}, + wantNormal: nil, + wantWildcard: nil, + }, + { + name: "Single origin", + origins: []string{"http://example.com"}, + wantNormal: []string{"http://example.com"}, + wantWildcard: nil, + }, + { + name: "Wildcard origin", + origins: []string{"http://*.example.com"}, + wantNormal: nil, + wantWildcard: []wildcard{newWildcard("http://*.example.com")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := newOptions(WithAllowedOrigins(tt.origins...)) + if !slices.Equal(o.allowedOrigins, tt.wantNormal) { + t.Errorf("got %v, want %v", o, tt.wantNormal) + } + if !slices.Equal(o.allowedOriginPatterns, tt.wantWildcard) { + t.Errorf("got %v, want %v", o, tt.wantWildcard) + } + }) + } +} + +func TestWithAllowedMethods(t *testing.T) { + methods := []string{http.MethodGet, http.MethodPost} + o := newOptions(WithAllowedMethods(methods...)) + if !slices.Equal(o.allowedMethods, methods) { + t.Errorf("got %v, want %v", o.allowedMethods, methods) + } +} + +func TestWithAllowedHeaders(t *testing.T) { + headers := []string{"Content-Type", "Authorization"} + o := newOptions(WithAllowedHeaders(headers...)) + expectedHeaders := map[string]void{"content-type": {}, "authorization": {}} + if !maps.Equal(o.allowedHeaders, expectedHeaders) { + t.Errorf("got %v, want %v", o.allowedHeaders, expectedHeaders) + } +} + +func TestWithExposedHeaders(t *testing.T) { + headers := []string{"X-Custom-Header"} + o := newOptions(WithExposedHeaders(headers...)) + expectedHeaders := []string{http.CanonicalHeaderKey("X-Custom-Header")} + if !slices.Equal(o.exposedHeaders, expectedHeaders) { + t.Errorf("got %v, want %v", o.exposedHeaders, expectedHeaders) + } +} + +func TestWithAllowCredentials(t *testing.T) { + o := newOptions(WithAllowCredentials(true)) + if !o.allowCredentials { + t.Errorf("got %v, want %v", o.allowCredentials, true) + } +} + +func TestWithAllowPrivateNetworks(t *testing.T) { + o := newOptions(WithAllowPrivateNetworks(true)) + if !o.allowPrivateNetworks { + t.Errorf("got %v, want %v", o.allowPrivateNetworks, true) + } +} + +func TestWithMaxAge(t *testing.T) { + maxAge := 3600 + o := newOptions(WithMaxAge(maxAge)) + if o.maxAge != maxAge { + t.Errorf("got %v, want %v", o.maxAge, maxAge) + } +} diff --git a/internal/app/api/core/middleware/cors/wildcard.go b/internal/app/api/core/middleware/cors/wildcard.go new file mode 100644 index 0000000..01d352d --- /dev/null +++ b/internal/app/api/core/middleware/cors/wildcard.go @@ -0,0 +1,33 @@ +package cors + +import "strings" + +// wildcard is a type that represents a wildcard string. +// This type allows faster matching of strings with a wildcard +// in comparison to using regex. +type wildcard struct { + prefix string + suffix string +} + +// match returns true if the string s has the prefix and suffix of the wildcard. +func (w wildcard) match(s string) bool { + return len(s) >= len(w.prefix)+len(w.suffix) && + strings.HasPrefix(s, w.prefix) && + strings.HasSuffix(s, w.suffix) +} + +func newWildcard(s string) wildcard { + if i := strings.IndexByte(s, '*'); i >= 0 { + return wildcard{ + prefix: s[:i], + suffix: s[i+1:], + } + } + + // fallback, usually this case should not happen + return wildcard{ + prefix: s, + suffix: "", + } +} diff --git a/internal/app/api/core/middleware/cors/wildcard_test.go b/internal/app/api/core/middleware/cors/wildcard_test.go new file mode 100644 index 0000000..93f18a7 --- /dev/null +++ b/internal/app/api/core/middleware/cors/wildcard_test.go @@ -0,0 +1,94 @@ +package cors + +import "testing" + +func TestWildcardMatch(t *testing.T) { + tests := []struct { + name string + wildcard wildcard + input string + expected bool + }{ + { + name: "Match with prefix and suffix", + wildcard: newWildcard("http://*.example.com"), + input: "http://sub.example.com", + expected: true, + }, + { + name: "No match with different prefix", + wildcard: newWildcard("http://*.example.com"), + input: "https://sub.example.com", + expected: false, + }, + { + name: "No match with different suffix", + wildcard: newWildcard("http://*.example.com"), + input: "http://sub.example.org", + expected: false, + }, + { + name: "Match with empty suffix", + wildcard: newWildcard("http://*"), + input: "http://example.com", + expected: true, + }, + { + name: "Match with empty prefix", + wildcard: newWildcard("*.example.com"), + input: "sub.example.com", + expected: true, + }, + { + name: "No match with empty prefix and different suffix", + wildcard: newWildcard("*.example.com"), + input: "sub.example.org", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.wildcard.match(tt.input); got != tt.expected { + t.Errorf("wildcard.match(%s) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +} + +func TestNewWildcard(t *testing.T) { + tests := []struct { + name string + input string + expected wildcard + }{ + { + name: "Wildcard with prefix and suffix", + input: "http://*.example.com", + expected: wildcard{prefix: "http://", suffix: ".example.com"}, + }, + { + name: "Wildcard with empty suffix", + input: "http://*", + expected: wildcard{prefix: "http://", suffix: ""}, + }, + { + name: "Wildcard with empty prefix", + input: "*.example.com", + expected: wildcard{prefix: "", suffix: ".example.com"}, + }, + { + name: "No wildcard character", + input: "http://example.com", + expected: wildcard{prefix: "http://example.com", suffix: ""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newWildcard(tt.input); got != tt.expected { + t.Errorf("newWildcard(%s) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +} diff --git a/internal/app/api/core/middleware/csrf/middleware.go b/internal/app/api/core/middleware/csrf/middleware.go new file mode 100644 index 0000000..ffa7bc2 --- /dev/null +++ b/internal/app/api/core/middleware/csrf/middleware.go @@ -0,0 +1,137 @@ +package csrf + +import ( + "context" + "net/http" + "slices" +) + +// ContextValueIdentifier is the context value identifier for the CSRF token. +// The token is only stored in the context if the RefreshToken function was called before. +const ContextValueIdentifier = "_csrf_token" + +// Middleware is a type that creates a new CSRF middleware. The CSRF middleware +// can be used to mitigate Cross-Site Request Forgery attacks. +type Middleware struct { + o options +} + +// New returns a new CSRF middleware with the provided options. +func New(sessionReader SessionReader, sessionWriter SessionWriter, opts ...Option) *Middleware { + opts = append(opts, withSessionReader(sessionReader), withSessionWriter(sessionWriter)) + o := newOptions(opts...) + + m := &Middleware{ + o: o, + } + + checkForPRNG() + + return m +} + +// Handler returns the CSRF middleware handler. This middleware validates the CSRF token and calls the specified +// error handler if an invalid CSRF token was found. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if slices.Contains(m.o.ignoreMethods, r.Method) { + next.ServeHTTP(w, r) // skip CSRF check for ignored methods + return + } + + // get the token from the request + token := m.o.tokenGetter(r) + storedToken := m.o.sessionGetter(r) + + if !tokenEqual(token, storedToken) { + m.o.errCallback(w, r) + return + } + + next.ServeHTTP(w, r) // execute the next handler + }) +} + +// RefreshToken generates a new CSRF Token and stores it in the session. The token is also passed to subsequent handlers +// via the context value ContextValueIdentifier. +func (m *Middleware) RefreshToken(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if GetToken(r.Context()) != "" { + // token already generated higher up in the chain + next.ServeHTTP(w, r) + return + } + + // generate a new token + token := generateToken(m.o.tokenLength) + key := generateToken(m.o.tokenLength) + + // mask the token + maskedToken := maskToken(token, key) + + // store the encoded token in the session + encodedToken := encodeToken(maskedToken) + m.o.sessionWriter(r, encodedToken) + + // pass the token down the chain via the context + r = r.WithContext(setToken(r.Context(), encodedToken)) + + next.ServeHTTP(w, r) + }) +} + +// region token-access + +// GetToken retrieves the CSRF token from the given context. Ensure that the RefreshToken function was called before, +// otherwise, no token is populated in the context. +func GetToken(ctx context.Context) string { + token, ok := ctx.Value(ContextValueIdentifier).(string) + if !ok { + return "" + } + + return token +} + +// endregion token-access + +// region internal-helpers + +func setToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, ContextValueIdentifier, token) +} + +// defaultTokenGetter is the default token getter function for the CSRF middleware. +// It checks the request form values, URL query parameters, and headers for the CSRF token. +// The order of precedence is: +// 1. Header "X-CSRF-TOKEN" +// 2. Header "X-XSRF-TOKEN" +// 3. URL query parameter "_csrf" +// 4. Form value "_csrf" +func defaultTokenGetter(r *http.Request) string { + if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 { + return t + } + + if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 { + return t + } + + if t := r.URL.Query().Get("_csrf"); len(t) > 0 { + return t + } + + if t := r.FormValue("_csrf"); len(t) > 0 { + return t + } + + return "" +} + +// defaultErrorHandler is the default error handler function for the CSRF middleware. +// It writes a 403 Forbidden response. +func defaultErrorHandler(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "CSRF token mismatch", http.StatusForbidden) +} + +// endregion internal-helpers diff --git a/internal/app/api/core/middleware/csrf/middleware_test.go b/internal/app/api/core/middleware/csrf/middleware_test.go new file mode 100644 index 0000000..78ae770 --- /dev/null +++ b/internal/app/api/core/middleware/csrf/middleware_test.go @@ -0,0 +1,251 @@ +package csrf + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/h44z/wg-portal/internal/app/api/core/request" +) + +func TestMiddleware_Handler(t *testing.T) { + sessionToken := "stored-token" + sessionReader := func(r *http.Request) string { + return sessionToken + } + sessionWriter := func(r *http.Request, token string) { + sessionToken = token + } + m := New(sessionReader, sessionWriter) + + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + method string + token string + wantStatus int + }{ + {"ValidToken", "POST", "stored-token", http.StatusOK}, + {"ValidToken2", "PUT", "stored-token", http.StatusOK}, + {"ValidToken3", "GET", "stored-token", http.StatusOK}, + {"InvalidToken", "POST", "invalid-token", http.StatusForbidden}, + {"IgnoredMethod", "GET", "", http.StatusOK}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/", nil) + req.Header.Set("X-CSRF-TOKEN", tt.token) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != tt.wantStatus { + t.Errorf("Handler() status = %d, want %d", status, tt.wantStatus) + } + }) + } +} + +func TestMiddleware_RefreshToken(t *testing.T) { + sessionToken := "" + sessionReader := func(r *http.Request) string { + return sessionToken + } + sessionWriter := func(r *http.Request, token string) { + sessionToken = token + } + m := New(sessionReader, sessionWriter) + + handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := GetToken(r.Context()) + if token == "" { + t.Errorf("RefreshToken() did not set token in context") + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK) + } + + if sessionToken == "" { + t.Errorf("RefreshToken() did not set token in session") + } +} + +func TestMiddleware_RefreshToken_chained(t *testing.T) { + sessionToken := "" + tokenWrites := 0 + sessionReader := func(r *http.Request) string { + return sessionToken + } + sessionWriter := func(r *http.Request, token string) { + sessionToken = token + tokenWrites++ + } + m := New(sessionReader, sessionWriter) + + handler := m.RefreshToken(m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := GetToken(r.Context()) + if token == "" { + t.Errorf("RefreshToken() did not set token in context") + } + w.WriteHeader(http.StatusOK) + }))) + + req := httptest.NewRequest("POST", "/", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("RefreshToken() status = %d, want %d", status, http.StatusOK) + } + + if sessionToken == "" { + t.Errorf("RefreshToken() did not set token in session") + } + + if tokenWrites != 1 { + t.Errorf("RefreshToken() wrote token to session more than once: %d", tokenWrites) + } +} + +func TestMiddleware_RefreshToken_Handler(t *testing.T) { + sessionToken := "" + sessionReader := func(r *http.Request) string { + return sessionToken + } + sessionWriter := func(r *http.Request, token string) { + sessionToken = token + } + m := New(sessionReader, sessionWriter) + + // simulate two requests: first one GET request with the RefreshToken handler, the next one is a PUT request with + // the token from the first request added as X-CSRF-TOKEN header + + // first request + retrievedToken := "" + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + handler := m.RefreshToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + retrievedToken = GetToken(r.Context()) + if retrievedToken == "" { + t.Errorf("RefreshToken() did not set token in context") + } + w.WriteHeader(http.StatusAccepted) + })) + handler.ServeHTTP(rr, req) + if status := rr.Code; status != http.StatusAccepted { + t.Errorf("Handler() status = %d, want %d", status, http.StatusAccepted) + } + if retrievedToken == "" { + t.Errorf("no token retrieved") + } + if retrievedToken != sessionToken { + t.Errorf("token in context does not match token in session") + } + + // second request + req = httptest.NewRequest("PUT", "/", nil) + req.Header.Set("X-CSRF-TOKEN", retrievedToken) + rr = httptest.NewRecorder() + handler = m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + handler.ServeHTTP(rr, req) + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler() status = %d, want %d", status, http.StatusOK) + } +} + +func TestMiddleware_Handler_FormBody(t *testing.T) { + sessionToken := "stored-token" + sessionReader := func(r *http.Request) string { + return sessionToken + } + sessionWriter := func(r *http.Request, token string) { + sessionToken = token + } + m := New(sessionReader, sessionWriter) + + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyData, err := request.BodyString(r) + if err != nil { + t.Errorf("Handler() error = %v, want nil", err) + } + // ensure that the body is empty - ParseForm() should have been called before by the CSRF middleware + if bodyData != "" { + t.Errorf("Handler() bodyData = %s, want empty", bodyData) + } + + if r.FormValue("_csrf") != "stored-token" { + t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token") + } + + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Form = make(map[string][]string) + req.Form.Add("_csrf", "stored-token") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler() status = %d, want %d", status, http.StatusOK) + } +} + +func TestMiddleware_Handler_FormBodyAvailable(t *testing.T) { + sessionToken := "stored-token" + sessionReader := func(r *http.Request) string { + return sessionToken + } + sessionWriter := func(r *http.Request, token string) { + sessionToken = token + } + m := New(sessionReader, sessionWriter) + + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyData, err := request.BodyString(r) + if err != nil { + t.Errorf("Handler() error = %v, want nil", err) + } + // ensure that the body is not empty, as the CSRF middleware should not have read the body + if bodyData != "the original body" { + t.Errorf("Handler() bodyData = %s, want %s", bodyData, "the original body") + } + + // check if the token is available in the form values (from query parameters) + if r.FormValue("_csrf") != "stored-token" { + t.Errorf("Handler() _csrf = %s, want %s", r.FormValue("_csrf"), "stored-token") + } + + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/?_csrf=stored-token", nil) + req.Header.Set("Content-Type", "text/plain") + req.Body = io.NopCloser(strings.NewReader("the original body")) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("Handler() status = %d, want %d", status, http.StatusOK) + } +} diff --git a/internal/app/api/core/middleware/csrf/options.go b/internal/app/api/core/middleware/csrf/options.go new file mode 100644 index 0000000..0b64078 --- /dev/null +++ b/internal/app/api/core/middleware/csrf/options.go @@ -0,0 +1,88 @@ +package csrf + +import "net/http" + +type SessionReader func(r *http.Request) string +type SessionWriter func(r *http.Request, token string) + +// options is a struct that contains options for the CSRF middleware. +// It uses the functional options pattern for flexible configuration. +type options struct { + tokenLength int + ignoreMethods []string + + errCallbackOverride bool + errCallback func(w http.ResponseWriter, r *http.Request) + + tokenGetterOverride bool + tokenGetter func(r *http.Request) string + + sessionGetter SessionReader + sessionWriter SessionWriter +} + +// Option is a type that is used to set options for the CSRF middleware. +// It implements the functional options pattern. +type Option func(*options) + +// WithTokenLength is a method that sets the token length for the CSRF middleware. +// The default value is 32. +func WithTokenLength(length int) Option { + return func(o *options) { + o.tokenLength = length + } +} + +// WithErrorCallback is a method that sets the error callback function for the CSRF middleware. +// The error callback function is called when the CSRF token is invalid. +// The default behavior is to write a 403 Forbidden response. +func WithErrorCallback(fn func(w http.ResponseWriter, r *http.Request)) Option { + return func(o *options) { + o.errCallback = fn + o.errCallbackOverride = true + } +} + +// WithTokenGetter is a method that sets the token getter function for the CSRF middleware. +// The token getter function is called to get the CSRF token from the request. +// The default behavior is to get the token from the "X-CSRF-Token" header. +func WithTokenGetter(fn func(r *http.Request) string) Option { + return func(o *options) { + o.tokenGetter = fn + o.tokenGetterOverride = true + } +} + +// withSessionReader is a method that sets the session reader function for the CSRF middleware. +// The session reader function is called to get the CSRF token from the session. +func withSessionReader(fn SessionReader) Option { + return func(o *options) { + o.sessionGetter = fn + } +} + +// withSessionWriter is a method that sets the session writer function for the CSRF middleware. +// The session writer function is called to write the CSRF token to the session. +func withSessionWriter(fn SessionWriter) Option { + return func(o *options) { + o.sessionWriter = fn + } +} + +// newOptions is a function that returns a new options struct with sane default values. +func newOptions(opts ...Option) options { + o := options{ + tokenLength: 32, + ignoreMethods: []string{"GET", "HEAD", "OPTIONS"}, + errCallbackOverride: false, + errCallback: defaultErrorHandler, + tokenGetterOverride: false, + tokenGetter: defaultTokenGetter, + } + + for _, opt := range opts { + opt(&o) + } + + return o +} diff --git a/internal/app/api/core/middleware/csrf/options_test.go b/internal/app/api/core/middleware/csrf/options_test.go new file mode 100644 index 0000000..b6d2e18 --- /dev/null +++ b/internal/app/api/core/middleware/csrf/options_test.go @@ -0,0 +1,75 @@ +package csrf + +import ( + "net/http" + "testing" +) + +func TestWithTokenLength(t *testing.T) { + o := newOptions(WithTokenLength(64)) + if o.tokenLength != 64 { + t.Errorf("WithTokenLength() = %d, want %d", o.tokenLength, 64) + } +} + +func TestWithErrorCallback(t *testing.T) { + callback := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + } + o := newOptions(WithErrorCallback(callback)) + if !o.errCallbackOverride { + t.Errorf("WithErrorCallback() did not set errCallbackOverride to true") + } + if o.errCallback == nil { + t.Errorf("WithErrorCallback() did not set errCallback") + } +} + +func TestWithTokenGetter(t *testing.T) { + getter := func(r *http.Request) string { + return "test-token" + } + o := newOptions(WithTokenGetter(getter)) + if !o.tokenGetterOverride { + t.Errorf("WithTokenGetter() did not set tokenGetterOverride to true") + } + if o.tokenGetter == nil { + t.Errorf("WithTokenGetter() did not set tokenGetter") + } +} + +func TestWithSessionReader(t *testing.T) { + reader := func(r *http.Request) string { + return "session-token" + } + o := newOptions(withSessionReader(reader)) + if o.sessionGetter == nil { + t.Errorf("withSessionReader() did not set sessionGetter") + } +} + +func TestWithSessionWriter(t *testing.T) { + writer := func(r *http.Request, token string) { + // do nothing + } + o := newOptions(withSessionWriter(writer)) + if o.sessionWriter == nil { + t.Errorf("withSessionWriter() did not set sessionWriter") + } +} + +func TestNewOptionsDefaults(t *testing.T) { + o := newOptions() + if o.tokenLength != 32 { + t.Errorf("newOptions() default tokenLength = %d, want %d", o.tokenLength, 32) + } + if len(o.ignoreMethods) != 3 { + t.Errorf("newOptions() default ignoreMethods length = %d, want %d", len(o.ignoreMethods), 3) + } + if o.errCallback == nil { + t.Errorf("newOptions() default errCallback is nil") + } + if o.tokenGetter == nil { + t.Errorf("newOptions() default tokenGetter is nil") + } +} diff --git a/internal/app/api/core/middleware/csrf/token.go b/internal/app/api/core/middleware/csrf/token.go new file mode 100644 index 0000000..fa47dee --- /dev/null +++ b/internal/app/api/core/middleware/csrf/token.go @@ -0,0 +1,90 @@ +package csrf + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "slices" +) + +// checkForPRNG is a function that checks if a cryptographically secure PRNG is available. +// If it is not available, the function panics. +func checkForPRNG() { + buf := make([]byte, 1) + _, err := io.ReadFull(rand.Reader, buf) + + if err != nil { + panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err)) + } +} + +// generateToken is a function that generates a secure random CSRF token. +func generateToken(length int) []byte { + bytes := make([]byte, length) + + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + panic(err) + } + + return bytes +} + +// encodeToken is a function that encodes a token to a base64 string. +func encodeToken(token []byte) string { + return base64.URLEncoding.EncodeToString(token) +} + +// decodeToken is a function that decodes a base64 string to a token. +func decodeToken(token string) ([]byte, error) { + return base64.URLEncoding.DecodeString(token) +} + +// maskToken is a function that masks a token with a given key. +// The returned byte slice contains the key + the masked token. +// The key needs to have the same length as the token, otherwise the function panics. +// So the resulting slice has a length of len(token) * 2. +func maskToken(token, key []byte) []byte { + if len(token) != len(key) { + panic("token and key must have the same length") + } + + // masked contains the key in the first half and the XOR masked token in the second half + tokenLength := len(token) + masked := make([]byte, tokenLength*2) + for i := 0; i < len(token); i++ { + masked[i] = key[i] + masked[i+tokenLength] = token[i] ^ key[i] // XOR mask + } + + return masked +} + +// unmaskToken is a function that unmask a token which contains the key in the first half. +// The returned byte slice contains the unmasked token, it has exactly half the length of the input slice. +func unmaskToken(masked []byte) []byte { + tokenLength := len(masked) / 2 + token := make([]byte, tokenLength) + for i := 0; i < tokenLength; i++ { + token[i] = masked[i] ^ masked[i+tokenLength] // XOR unmask + } + + return token +} + +// tokenEqual is a function that compares two tokens for equality. +func tokenEqual(a, b string) bool { + decodedA, err := decodeToken(a) + if err != nil { + return false + } + decodedB, err := decodeToken(b) + if err != nil { + return false + } + + unmaskedA := unmaskToken(decodedA) + unmaskedB := unmaskToken(decodedB) + + return slices.Equal(unmaskedA, unmaskedB) +} diff --git a/internal/app/api/core/middleware/csrf/token_test.go b/internal/app/api/core/middleware/csrf/token_test.go new file mode 100644 index 0000000..67055e9 --- /dev/null +++ b/internal/app/api/core/middleware/csrf/token_test.go @@ -0,0 +1,81 @@ +package csrf + +import ( + "encoding/base64" + "testing" +) + +func TestCheckForPRNG(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("checkForPRNG() panicked: %v", r) + } + }() + checkForPRNG() +} + +func TestGenerateToken(t *testing.T) { + length := 32 + token := generateToken(length) + if len(token) != length { + t.Errorf("generateToken() returned token of length %d, expected %d", len(token), length) + } +} + +func TestEncodeToken(t *testing.T) { + token := []byte("testtoken") + encoded := encodeToken(token) + expected := base64.URLEncoding.EncodeToString(token) + if encoded != expected { + t.Errorf("encodeToken() = %v, want %v", encoded, expected) + } +} + +func TestDecodeToken(t *testing.T) { + token := "dGVzdHRva2Vu" + expected := []byte("testtoken") + decoded, err := decodeToken(token) + if err != nil { + t.Errorf("decodeToken() error = %v", err) + } + if string(decoded) != string(expected) { + t.Errorf("decodeToken() = %v, want %v", decoded, expected) + } +} + +func TestMaskToken(t *testing.T) { + token := []byte("testtoken") + key := []byte("keykeykey") + masked := maskToken(token, key) + if len(masked) != len(token)*2 { + t.Errorf("maskToken() returned masked token of length %d, expected %d", len(masked), len(token)*2) + } +} + +func TestUnmaskToken(t *testing.T) { + token := []byte("testtoken") + key := []byte("keykeykey") + masked := maskToken(token, key) + unmasked := unmaskToken(masked) + if string(unmasked) != string(token) { + t.Errorf("unmaskToken() = %v, want %v", unmasked, token) + } +} + +func TestTokenEqual(t *testing.T) { + tokenA := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03})) + tokenB := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x04, 0x05, 0x06})) + if !tokenEqual(tokenA, tokenB) { + t.Errorf("tokenEqual() = false, want true") + } + + tokenC := encodeToken(maskToken([]byte{0x01, 0x02, 0x03}, []byte{0x07, 0x08, 0x09})) + if !tokenEqual(tokenA, tokenC) { + t.Errorf("tokenEqual() = false, want true") + } + + tokenD := encodeToken(maskToken([]byte{0x09, 0x02, 0x03}, []byte{0x04, 0x05, 0x06})) + if tokenEqual(tokenA, tokenD) { + t.Errorf("tokenEqual() = true, want false") + } +} diff --git a/internal/app/api/core/middleware/logging/middleware.go b/internal/app/api/core/middleware/logging/middleware.go new file mode 100644 index 0000000..dd400ab --- /dev/null +++ b/internal/app/api/core/middleware/logging/middleware.go @@ -0,0 +1,199 @@ +package logging + +import ( + "fmt" + "log/slog" + "net/http" + "strings" + "time" +) + +// LogLevel is an enumeration of the different log levels. +type LogLevel int + +const ( + LogLevelDebug LogLevel = iota + LogLevelInfo + LogLevelWarn + LogLevelError +) + +// Logger is an interface that defines the methods that a logger must implement. +// This allows the logging middleware to be used with different logging libraries. +type Logger interface { + // Debugf logs a message at debug level. + Debugf(format string, args ...any) + // Infof logs a message at info level. + Infof(format string, args ...any) + // Warnf logs a message at warn level. + Warnf(format string, args ...any) + // Errorf logs a message at error level. + Errorf(format string, args ...any) +} + +// Middleware is a type that creates a new logging middleware. The logging middleware +// logs information about each request. +type Middleware struct { + o options +} + +// New returns a new logging middleware with the provided options. +func New(opts ...Option) *Middleware { + o := newOptions(opts...) + + m := &Middleware{ + o: o, + } + + return m +} + +// Handler returns the logging middleware handler. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := newWriterWrapper(w) + start := time.Now() + defer func() { + info := m.extractInfoMap(r, start, ww) + + if m.o.logger == nil { + msg, args := m.buildSlogMessageAndArguments(info) + m.logMsg(msg, args...) + } else { + msg := m.buildNormalLogMessage(info) + m.logMsg(msg) + } + }() + + next.ServeHTTP(ww, r) + }) +} + +func (m *Middleware) extractInfoMap(r *http.Request, start time.Time, ww *writerWrapper) map[string]any { + info := make(map[string]any) + + info["method"] = r.Method + info["path"] = r.URL.Path + info["protocol"] = r.Proto + info["clientIP"] = r.Header.Get("X-Forwarded-For") + if info["clientIP"] == "" { + // If the X-Forwarded-For header is not set, use the remote address without the port number. + lastColonIndex := strings.LastIndex(r.RemoteAddr, ":") + switch lastColonIndex { + case -1: + info["clientIP"] = r.RemoteAddr + default: + info["clientIP"] = r.RemoteAddr[:lastColonIndex] + } + } + info["userAgent"] = r.UserAgent() + info["referer"] = r.Header.Get("Referer") + info["duration"] = time.Since(start).String() + info["status"] = ww.StatusCode + info["dataLength"] = ww.WrittenBytes + + if m.o.headerRequestIdKey != "" { + info["headerRequestId"] = r.Header.Get(m.o.headerRequestIdKey) + } + if m.o.contextRequestIdKey != "" { + info["contextRequestId"], _ = r.Context().Value(m.o.contextRequestIdKey).(string) + } + + return info +} + +func (m *Middleware) buildNormalLogMessage(info map[string]any) string { + switch { + case info["headerRequestId"] != nil && info["contextRequestId"] != nil: + return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s ctx=%s", + info["method"], info["path"], info["protocol"], + info["status"], info["dataLength"], + info["duration"], + info["clientIP"], info["userAgent"], info["referer"], + info["headerRequestId"], info["contextRequestId"]) + case info["headerRequestId"] != nil: + return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - rid=%s", + info["method"], info["path"], info["protocol"], + info["status"], info["dataLength"], + info["duration"], + info["clientIP"], info["userAgent"], info["referer"], + info["headerRequestId"]) + case info["contextRequestId"] != nil: + return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s - ctx=%s", + info["method"], info["path"], info["protocol"], + info["status"], info["dataLength"], + info["duration"], + info["clientIP"], info["userAgent"], info["referer"], + info["contextRequestId"]) + default: + return fmt.Sprintf("%s %s %s - %d %d - %s - %s %s %s", + info["method"], info["path"], info["protocol"], + info["status"], info["dataLength"], + info["duration"], + info["clientIP"], info["userAgent"], info["referer"]) + } +} + +func (m *Middleware) buildSlogMessageAndArguments(info map[string]any) (message string, args []any) { + message = fmt.Sprintf("%s %s", info["method"], info["path"]) + + // Use a fixed order for the keys, so that the message is always the same. + // Skip method and path as they are already in the message. + keys := []string{ + "protocol", + "status", + "dataLength", + "duration", + "clientIP", + "userAgent", + "referer", + "headerRequestId", + "contextRequestId", + } + for _, k := range keys { + if v, ok := info[k]; ok { + args = append(args, k, v) // only add key, value if it exists + } + } + + return +} + +func (m *Middleware) addPrefix(message string) string { + if m.o.prefix != "" { + return m.o.prefix + " " + message + } + return message +} + +func (m *Middleware) logMsg(message string, args ...any) { + message = m.addPrefix(message) + + if m.o.logger != nil { + switch m.o.logLevel { + case LogLevelDebug: + m.o.logger.Debugf(message, args...) + case LogLevelInfo: + m.o.logger.Infof(message, args...) + case LogLevelWarn: + m.o.logger.Warnf(message, args...) + case LogLevelError: + m.o.logger.Errorf(message, args...) + default: + m.o.logger.Infof(message, args...) + } + } else { + switch m.o.logLevel { + case LogLevelDebug: + slog.Debug(message, args...) + case LogLevelInfo: + slog.Info(message, args...) + case LogLevelWarn: + slog.Warn(message, args...) + case LogLevelError: + slog.Error(message, args...) + default: + slog.Info(message, args...) + } + } +} diff --git a/internal/app/api/core/middleware/logging/middleware_test.go b/internal/app/api/core/middleware/logging/middleware_test.go new file mode 100644 index 0000000..ca0caaf --- /dev/null +++ b/internal/app/api/core/middleware/logging/middleware_test.go @@ -0,0 +1,148 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type mockLogger struct { + messages []string +} + +func (m *mockLogger) Debugf(format string, _ ...any) { + m.messages = append(m.messages, "DEBUG: "+format) +} +func (m *mockLogger) Infof(format string, _ ...any) { + m.messages = append(m.messages, "INFO: "+format) +} +func (m *mockLogger) Warnf(format string, _ ...any) { + m.messages = append(m.messages, "WARN: "+format) +} +func (m *mockLogger) Errorf(format string, _ ...any) { + m.messages = append(m.messages, "ERROR: "+format) +} + +func TestMiddleware_Normal(t *testing.T) { + logger := &mockLogger{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("Hello, World!")) + }) + + middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusTeapot { + t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status) + } + + expected := "Hello, World!" + if rr.Body.String() != expected { + t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String()) + } + + if len(logger.messages) == 0 { + t.Errorf("expected log messages, got none") + } + + if len(logger.messages) != 0 && !strings.Contains(logger.messages[0], "ERROR: GET /foo") { + t.Errorf("expected log message to contain request info, got %v", logger.messages[0]) + } +} + +func TestMiddleware_Extended(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("Hello, World!")) + }) + + middleware := New(WithContextRequestIdKey("requestId"), WithHeaderRequestIdKey("X-Request-Id")). + Handler(handler) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusTeapot { + t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, status) + } + + expected := "Hello, World!" + if rr.Body.String() != expected { + t.Errorf("expected response body to be %v, got %v", expected, rr.Body.String()) + } +} + +func TestMiddleware_Logger_remoteAddr(t *testing.T) { + logger := &mockLogger{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("Hello, World!")) + }) + + middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.RemoteAddr = "xhamster.com:1234" + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + +} + +func TestMiddleware_Logger_remoteAddrNoPort(t *testing.T) { + logger := &mockLogger{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("Hello, World!")) + }) + + middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.RemoteAddr = "xhamster.com" + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + +} + +func TestMiddleware_Logger_remoteAddrV6(t *testing.T) { + logger := &mockLogger{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("Hello, World!")) + }) + + middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.RemoteAddr = "[::1]:4711" + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + +} + +func TestMiddleware_Logger_remoteAddrV6NoPort(t *testing.T) { + logger := &mockLogger{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("Hello, World!")) + }) + + middleware := New(WithLogger(logger), WithLevel(LogLevelError)).Handler(handler) + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + req.RemoteAddr = "[::1]" + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + +} diff --git a/internal/app/api/core/middleware/logging/options.go b/internal/app/api/core/middleware/logging/options.go new file mode 100644 index 0000000..d97389e --- /dev/null +++ b/internal/app/api/core/middleware/logging/options.go @@ -0,0 +1,80 @@ +package logging + +// options is a struct that contains options for the logging middleware. +// It uses the functional options pattern for flexible configuration. +type options struct { + logLevel LogLevel + logger Logger + prefix string + + contextRequestIdKey string + headerRequestIdKey string +} + +// Option is a type that is used to set options for the logging middleware. +// It implements the functional options pattern. +type Option func(*options) + +// WithLevel is a method that sets the log level for the logging middleware. +// Possible values are LogLevelDebug, LogLevelInfo, LogLevelWarn, and LogLevelError. +// The default value is LogLevelInfo. +func WithLevel(level LogLevel) Option { + return func(o *options) { + o.logLevel = level + } +} + +// WithPrefix is a method that sets the prefix for the logging middleware. +// If a prefix is set, it will be prepended to each log message. A space will +// be added between the prefix and the log message. +// The default value is an empty string. +func WithPrefix(prefix string) Option { + return func(o *options) { + o.prefix = prefix + } +} + +// WithContextRequestIdKey is a method that sets the key for the request ID in the +// request context. If a key is set, the logging middleware will use this key to +// retrieve the request ID from the request context. +// The default value is an empty string, meaning the request ID will not be logged. +func WithContextRequestIdKey(key string) Option { + return func(o *options) { + o.contextRequestIdKey = key + } +} + +// WithHeaderRequestIdKey is a method that sets the key for the request ID in the +// request headers. If a key is set, the logging middleware will use this key to +// retrieve the request ID from the request headers. +// The default value is an empty string, meaning the request ID will not be logged. +func WithHeaderRequestIdKey(key string) Option { + return func(o *options) { + o.headerRequestIdKey = key + } +} + +// WithLogger is a method that sets the logger for the logging middleware. +// If a logger is set, the logging middleware will use this logger to log messages. +// The default logger is the structured slog logger. +func WithLogger(logger Logger) Option { + return func(o *options) { + o.logger = logger + } +} + +// newOptions is a function that returns a new options struct with sane default values. +func newOptions(opts ...Option) options { + o := options{ + logLevel: LogLevelInfo, + logger: nil, + prefix: "", + contextRequestIdKey: "", + } + + for _, opt := range opts { + opt(&o) + } + + return o +} diff --git a/internal/app/api/core/middleware/logging/options_test.go b/internal/app/api/core/middleware/logging/options_test.go new file mode 100644 index 0000000..80e3722 --- /dev/null +++ b/internal/app/api/core/middleware/logging/options_test.go @@ -0,0 +1,88 @@ +package logging + +import ( + "testing" +) + +func TestWithLevel(t *testing.T) { + // table test to check all possible log levels + levels := []LogLevel{ + LogLevelDebug, + LogLevelInfo, + LogLevelWarn, + LogLevelError, + } + + for _, level := range levels { + opt := WithLevel(level) + o := newOptions(opt) + + if o.logLevel != level { + t.Errorf("expected log level to be %v, got %v", level, o.logLevel) + } + } +} + +func TestWithPrefix(t *testing.T) { + prefix := "TEST" + opt := WithPrefix(prefix) + o := newOptions(opt) + + if o.prefix != prefix { + t.Errorf("expected prefix to be %v, got %v", prefix, o.prefix) + } +} + +func TestWithContextRequestIdKey(t *testing.T) { + key := "contextKey" + opt := WithContextRequestIdKey(key) + o := newOptions(opt) + + if o.contextRequestIdKey != key { + t.Errorf("expected contextRequestIdKey to be %v, got %v", key, o.contextRequestIdKey) + } +} + +func TestWithHeaderRequestIdKey(t *testing.T) { + key := "headerKey" + opt := WithHeaderRequestIdKey(key) + o := newOptions(opt) + + if o.headerRequestIdKey != key { + t.Errorf("expected headerRequestIdKey to be %v, got %v", key, o.headerRequestIdKey) + } +} + +func TestWithLogger(t *testing.T) { + logger := &mockLogger{} + opt := WithLogger(logger) + o := newOptions(opt) + + if o.logger != logger { + t.Errorf("expected logger to be %v, got %v", logger, o.logger) + } +} + +func TestDefaults(t *testing.T) { + o := newOptions() + + if o.logLevel != LogLevelInfo { + t.Errorf("expected log level to be %v, got %v", LogLevelInfo, o.logLevel) + } + + if o.logger != nil { + t.Errorf("expected logger to be nil, got %v", o.logger) + } + + if o.prefix != "" { + t.Errorf("expected prefix to be empty, got %v", o.prefix) + } + + if o.contextRequestIdKey != "" { + t.Errorf("expected contextRequestIdKey to be empty, got %v", o.contextRequestIdKey) + } + + if o.headerRequestIdKey != "" { + t.Errorf("expected headerRequestIdKey to be empty, got %v", o.headerRequestIdKey) + } +} diff --git a/internal/app/api/core/middleware/logging/writer.go b/internal/app/api/core/middleware/logging/writer.go new file mode 100644 index 0000000..4e3c42f --- /dev/null +++ b/internal/app/api/core/middleware/logging/writer.go @@ -0,0 +1,45 @@ +package logging + +import ( + "net/http" +) + +// writerWrapper wraps a http.ResponseWriter and tracks the number of bytes written to it. +// It also tracks the http response code passed to the WriteHeader func of +// the ResponseWriter. +type writerWrapper struct { + http.ResponseWriter + + // StatusCode is the last http response code passed to the WriteHeader func of + // the ResponseWriter. If no such call is made, a default code of http.StatusOK + // is assumed instead. + StatusCode int + + // WrittenBytes is the number of bytes successfully written by the Write or + // ReadFrom function of the ResponseWriter. ResponseWriters may also write + // data to their underlaying connection directly (e.g. headers), but those + // are not tracked. Therefor the number of Written bytes will usually match + // the size of the response body. + WrittenBytes int64 +} + +// WriteHeader wraps the WriteHeader method of the ResponseWriter and tracks the +// http response code passed to it. +func (w *writerWrapper) WriteHeader(code int) { + w.StatusCode = code + w.ResponseWriter.WriteHeader(code) +} + +// Write wraps the Write method of the ResponseWriter and tracks the number of bytes +// written to it. +func (w *writerWrapper) Write(data []byte) (int, error) { + n, err := w.ResponseWriter.Write(data) + w.WrittenBytes += int64(n) + return n, err +} + +// newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter. +// It initializes the StatusCode to http.StatusOK. +func newWriterWrapper(w http.ResponseWriter) *writerWrapper { + return &writerWrapper{ResponseWriter: w, StatusCode: http.StatusOK} +} diff --git a/internal/app/api/core/middleware/logging/writer_test.go b/internal/app/api/core/middleware/logging/writer_test.go new file mode 100644 index 0000000..a5d07d5 --- /dev/null +++ b/internal/app/api/core/middleware/logging/writer_test.go @@ -0,0 +1,85 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriterWrapper_WriteHeader(t *testing.T) { + rr := httptest.NewRecorder() + ww := newWriterWrapper(rr) + + ww.WriteHeader(http.StatusNotFound) + + if ww.StatusCode != http.StatusNotFound { + t.Errorf("expected status code to be %v, got %v", http.StatusNotFound, ww.StatusCode) + } + if rr.Code != http.StatusNotFound { + t.Errorf("expected recorder status code to be %v, got %v", http.StatusNotFound, rr.Code) + } +} + +func TestWriterWrapper_Write(t *testing.T) { + rr := httptest.NewRecorder() + ww := newWriterWrapper(rr) + + data := []byte("Hello, World!") + n, err := ww.Write(data) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if n != len(data) { + t.Errorf("expected written bytes to be %v, got %v", len(data), n) + } + if ww.WrittenBytes != int64(len(data)) { + t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes) + } + if rr.Body.String() != string(data) { + t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String()) + } +} + +func TestWriterWrapper_WriteWithHeaders(t *testing.T) { + rr := httptest.NewRecorder() + ww := newWriterWrapper(rr) + + data := []byte("Hello, World!") + n, err := ww.Write(data) + + ww.Header().Set("Content-Type", "text/plain") + ww.Header().Set("X-Some-Header", "some-value") + ww.WriteHeader(http.StatusTeapot) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if n != len(data) { + t.Errorf("expected written bytes to be %v, got %v", len(data), n) + } + if ww.WrittenBytes != int64(len(data)) { + t.Errorf("expected WrittenBytes to be %v, got %v", len(data), ww.WrittenBytes) + } + if rr.Body.String() != string(data) { + t.Errorf("expected response body to be %v, got %v", string(data), rr.Body.String()) + } + if ww.StatusCode != http.StatusTeapot { + t.Errorf("expected status code to be %v, got %v", http.StatusTeapot, ww.StatusCode) + } +} + +func TestNewWriterWrapper(t *testing.T) { + rr := httptest.NewRecorder() + ww := newWriterWrapper(rr) + + if ww.StatusCode != http.StatusOK { + t.Errorf("expected initial status code to be %v, got %v", http.StatusOK, ww.StatusCode) + } + if ww.WrittenBytes != 0 { + t.Errorf("expected initial WrittenBytes to be %v, got %v", 0, ww.WrittenBytes) + } + if ww.ResponseWriter != rr { + t.Errorf("expected ResponseWriter to be %v, got %v", rr, ww.ResponseWriter) + } +} diff --git a/internal/app/api/core/middleware/recovery/middleware.go b/internal/app/api/core/middleware/recovery/middleware.go new file mode 100644 index 0000000..f4b6c46 --- /dev/null +++ b/internal/app/api/core/middleware/recovery/middleware.go @@ -0,0 +1,133 @@ +package recovery + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + "runtime/debug" + "strings" +) + +// Logger is an interface that defines the methods that a logger must implement. +// This allows the logging middleware to be used with different logging libraries. +type Logger interface { + // Errorf logs a message at error level. + Errorf(format string, args ...any) +} + +// Middleware is a type that creates a new recovery middleware. The recovery middleware +// recovers from panics and returns an Internal Server Error response. This middleware should +// be the first middleware in the middleware chain, so that it can recover from panics in other +// middlewares. +type Middleware struct { + o options +} + +// New returns a new recovery middleware with the provided options. +func New(opts ...Option) *Middleware { + o := newOptions(opts...) + + m := &Middleware{ + o: o, + } + + return m +} + +// Handler returns the recovery middleware handler. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + stack := debug.Stack() + + realErr, ok := err.(error) + if !ok { + realErr = fmt.Errorf("%v", err) + } + + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + brokenPipe := isBrokenPipeError(realErr) + + // Log the error and stack trace + if m.o.logCallback != nil { + m.o.logCallback(realErr, stack, brokenPipe) + } + + switch { + case brokenPipe && m.o.brokenPipeCallback != nil: + m.o.brokenPipeCallback(realErr, stack, w, r) + case !brokenPipe && m.o.errCallback != nil: + m.o.errCallback(realErr, stack, w, r) + default: + // no callback set, simply recover and do nothing... + } + } + }() + + next.ServeHTTP(w, r) + }) +} + +func addPrefix(o options, message string) string { + if o.defaultLogPrefix != "" { + return o.defaultLogPrefix + " " + message + } + return message +} + +// defaultErrCallback is the default error callback function for the recovery middleware. +// It writes a JSON response with an Internal Server Error status code. If the exposeStackTrace option is +// enabled, the stack trace is included in the response. +func getDefaultErrCallback(o options) func(err error, stack []byte, w http.ResponseWriter, r *http.Request) { + return func(err error, stack []byte, w http.ResponseWriter, r *http.Request) { + responseBody := map[string]interface{}{ + "error": "Internal Server Error", + } + if o.exposeStackTrace && len(stack) > 0 { + responseBody["stack"] = string(stack) + } + + jsonBody, _ := json.Marshal(responseBody) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write(jsonBody) + } +} + +// getDefaultLogCallback is the default log callback function for the recovery middleware. +// It logs the error and stack trace using the structured slog logger or the provided logger in Error level. +func getDefaultLogCallback(o options) func(error, []byte, bool) { + return func(err error, stack []byte, brokenPipe bool) { + if brokenPipe { + return // by default, ignore broken pipe errors + } + + switch { + case o.useSlog: + slog.Error(addPrefix(o, err.Error()), "stack", string(stack)) + case o.logger != nil: + o.logger.Errorf(fmt.Sprintf("%s; stacktrace=%s", addPrefix(o, err.Error()), string(stack))) + default: + // no logger set, do nothing... + } + } +} + +func isBrokenPipeError(err error) bool { + var syscallErr *os.SyscallError + if errors.As(err, &syscallErr) { + errMsg := strings.ToLower(syscallErr.Err.Error()) + if strings.Contains(errMsg, "broken pipe") || + strings.Contains(errMsg, "connection reset by peer") { + return true + } + } + + return false +} diff --git a/internal/app/api/core/middleware/recovery/middleware_test.go b/internal/app/api/core/middleware/recovery/middleware_test.go new file mode 100644 index 0000000..d4658c3 --- /dev/null +++ b/internal/app/api/core/middleware/recovery/middleware_test.go @@ -0,0 +1,149 @@ +package recovery + +import ( + "errors" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +type mockLogger struct{} + +func (m *mockLogger) Errorf(_ string, _ ...any) {} + +func TestMiddleware(t *testing.T) { + tests := []struct { + name string + options []Option + panicSimulator func() + expectedStatus int + expectedBody string + expectStack bool + }{ + { + name: "default behavior", + options: []Option{}, + panicSimulator: func() { + panic(errors.New("test panic")) + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: `{"error":"Internal Server Error"}`, + }, + { + name: "custom error callback", + options: []Option{ + WithErrCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + w.Write([]byte("custom error")) + }), + }, + panicSimulator: func() { + panic(errors.New("test panic")) + }, + expectedStatus: http.StatusTeapot, + expectedBody: "custom error", + }, + { + name: "broken pipe error", + options: []Option{ + WithBrokenPipeCallback(func(err error, stack []byte, w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("broken pipe")) + }), + }, + panicSimulator: func() { + panic(&os.SyscallError{Err: errors.New("broken pipe")}) + }, + expectedStatus: http.StatusServiceUnavailable, + expectedBody: "broken pipe", + }, + { + name: "default callback broken pipe error", + options: nil, + panicSimulator: func() { + panic(&os.SyscallError{Err: errors.New("broken pipe")}) + }, + expectedStatus: http.StatusOK, + expectedBody: "", + }, + { + name: "default callback normal error", + options: nil, + panicSimulator: func() { + panic("something went wrong") + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: "{\"error\":\"Internal Server Error\"}", + }, + { + name: "default callback with stack trace", + options: []Option{ + WithExposeStackTrace(true), + }, + panicSimulator: func() { + panic("something went wrong") + }, + expectedStatus: http.StatusInternalServerError, + expectedBody: "\"stack\":", + expectStack: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := New(tt.options...).Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tt.panicSimulator() + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %v, got %v", tt.expectedStatus, rr.Code) + } + if !tt.expectStack && rr.Body.String() != tt.expectedBody { + t.Errorf("expected body %v, got %v", tt.expectedBody, rr.Body.String()) + } + if tt.expectStack && !strings.Contains(rr.Body.String(), tt.expectedBody) { + t.Errorf("expected body to contain %v, got %v", tt.expectedBody, rr.Body.String()) + } + }) + } +} + +func TestIsBrokenPipeError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "broken pipe error", + err: &os.SyscallError{Err: errors.New("broken pipe")}, + expected: true, + }, + { + name: "connection reset by peer error", + err: &os.SyscallError{Err: errors.New("connection reset by peer")}, + expected: true, + }, + { + name: "other error", + err: errors.New("other error"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isBrokenPipeError(tt.err) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/internal/app/api/core/middleware/recovery/options.go b/internal/app/api/core/middleware/recovery/options.go new file mode 100644 index 0000000..470e4b9 --- /dev/null +++ b/internal/app/api/core/middleware/recovery/options.go @@ -0,0 +1,129 @@ +package recovery + +import "net/http" + +// options is a struct that contains options for the recovery middleware. +// It uses the functional options pattern for flexible configuration. +type options struct { + logger Logger + useSlog bool + + errCallbackOverride bool + errCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request) + brokenPipeCallbackOverride bool + brokenPipeCallback func(err error, stack []byte, w http.ResponseWriter, r *http.Request) + + exposeStackTrace bool + defaultLogPrefix string + logCallbackOverride bool + logCallback func(err error, stack []byte, brokenPipe bool) +} + +// Option is a type that is used to set options for the recovery middleware. +// It implements the functional options pattern. +type Option func(*options) + +// WithErrCallback sets the error callback function for the recovery middleware. +// The error callback function is called when a panic is recovered by the middleware. +// This function completely overrides the default behavior of the middleware. It is the +// responsibility of the user to handle the error and write a response to the client. +// +// Ensure that this function does not panic, as it will be called in a deferred function! +func WithErrCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option { + return func(o *options) { + o.errCallback = fn + o.errCallbackOverride = true + } +} + +// WithBrokenPipeCallback sets the broken pipe callback function for the recovery middleware. +// The broken pipe callback function is called when a broken pipe error is recovered by the middleware. +// This function completely overrides the default behavior of the middleware. It is the responsibility +// of the user to handle the error and write a response to the client. +// +// Ensure that this function does not panic, as it will be called in a deferred function! +func WithBrokenPipeCallback(fn func(err error, stack []byte, w http.ResponseWriter, r *http.Request)) Option { + return func(o *options) { + o.brokenPipeCallback = fn + o.brokenPipeCallbackOverride = true + } +} + +// WithLogCallback sets the log callback function for the recovery middleware. +// The log callback function is called when a panic is recovered by the middleware. +// This function allows the user to log the error and stack trace. The default behavior is to log +// the error and stack trace in Error level. +// This function completely overrides the default behavior of the middleware. +// +// Ensure that this function does not panic, as it will be called in a deferred function! +func WithLogCallback(fn func(err error, stack []byte, brokenPipe bool)) Option { + return func(o *options) { + o.logCallback = fn + o.logCallbackOverride = true + } +} + +// WithLogger is a method that sets the logger for the logging middleware. +// If a logger is set, the logging middleware will use this logger to log messages. +// The default logger is the structured slog logger, see WithSlog. +func WithLogger(logger Logger) Option { + return func(o *options) { + o.logger = logger + } +} + +// WithSlog is a method that sets whether the recovery middleware should use the structured slog logger. +// If set to true, the middleware will use the structured slog logger. If set to false, the middleware +// will not use any logger unless one is explicitly set with the WithLogger option. +// The default value is true. +func WithSlog(useSlog bool) Option { + return func(o *options) { + o.useSlog = useSlog + } +} + +// WithDefaultLogPrefix is a method that sets the default log prefix for the recovery middleware. +// If a default log prefix is set and the default log callback is used, the prefix will be prepended +// to each log message. A space will be added between the prefix and the log message. +// The default value is an empty string. +func WithDefaultLogPrefix(defaultLogPrefix string) Option { + return func(o *options) { + o.defaultLogPrefix = defaultLogPrefix + } +} + +// WithExposeStackTrace is a method that sets whether the stack trace should be exposed in the response. +// If set to true, the stack trace will be included in the response body. If set to false, the stack trace +// will not be included in the response body. This only applies to the default error callback. +// The default value is false. +func WithExposeStackTrace(exposeStackTrace bool) Option { + return func(o *options) { + o.exposeStackTrace = exposeStackTrace + } +} + +// newOptions is a function that returns a new options struct with sane default values. +func newOptions(opts ...Option) options { + o := options{ + logger: nil, + useSlog: true, + errCallback: nil, + brokenPipeCallback: nil, // by default, ignore broken pipe errors + exposeStackTrace: false, + defaultLogPrefix: "", + logCallback: nil, + } + + for _, opt := range opts { + opt(&o) + } + + if o.errCallback == nil && !o.errCallbackOverride { + o.errCallback = getDefaultErrCallback(o) + } + if o.logCallback == nil && !o.logCallbackOverride { + o.logCallback = getDefaultLogCallback(o) + } + + return o +} diff --git a/internal/app/api/core/middleware/recovery/options_test.go b/internal/app/api/core/middleware/recovery/options_test.go new file mode 100644 index 0000000..82fe3e0 --- /dev/null +++ b/internal/app/api/core/middleware/recovery/options_test.go @@ -0,0 +1,100 @@ +package recovery + +import ( + "net/http" + "testing" +) + +func TestWithErrCallback(t *testing.T) { + callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {} + opt := WithErrCallback(callback) + o := newOptions(opt) + + if o.errCallback == nil { + t.Errorf("expected errCallback to be set, got nil") + } +} + +func TestWithBrokenPipeCallback(t *testing.T) { + callback := func(err error, stack []byte, w http.ResponseWriter, r *http.Request) {} + opt := WithBrokenPipeCallback(callback) + o := newOptions(opt) + + if o.brokenPipeCallback == nil { + t.Errorf("expected brokenPipeCallback to be set, got nil") + } +} + +func TestWithLogCallback(t *testing.T) { + callback := func(err error, stack []byte, brokenPipe bool) {} + opt := WithLogCallback(callback) + o := newOptions(opt) + + if o.logCallback == nil { + t.Errorf("expected logCallback to be set, got nil") + } +} + +func TestWithLogger(t *testing.T) { + logger := &mockLogger{} + opt := WithLogger(logger) + o := newOptions(opt) + + if o.logger != logger { + t.Errorf("expected logger to be %v, got %v", logger, o.logger) + } +} + +func TestWithSlog(t *testing.T) { + opt := WithSlog(false) + o := newOptions(opt) + + if o.useSlog != false { + t.Errorf("expected useSlog to be false, got %v", o.useSlog) + } +} + +func TestWithDefaultLogPrefix(t *testing.T) { + prefix := "PREFIX" + opt := WithDefaultLogPrefix(prefix) + o := newOptions(opt) + + if o.defaultLogPrefix != prefix { + t.Errorf("expected defaultLogPrefix to be %v, got %v", prefix, o.defaultLogPrefix) + } +} + +func TestWithExposeStackTrace(t *testing.T) { + opt := WithExposeStackTrace(true) + o := newOptions(opt) + + if o.exposeStackTrace != true { + t.Errorf("expected exposeStackTrace to be true, got %v", o.exposeStackTrace) + } +} + +func TestNewOptionsDefaults(t *testing.T) { + o := newOptions() + + if o.logger != nil { + t.Errorf("expected logger to be nil, got %v", o.logger) + } + if o.useSlog != true { + t.Errorf("expected useSlog to be true, got %v", o.useSlog) + } + if o.errCallback == nil { + t.Errorf("expected errCallback to be set, got nil") + } + if o.brokenPipeCallback != nil { + t.Errorf("expected brokenPipeCallback to be nil, got %T", o.brokenPipeCallback) + } + if o.exposeStackTrace != false { + t.Errorf("expected exposeStackTrace to be false, got %T", o.exposeStackTrace) + } + if o.defaultLogPrefix != "" { + t.Errorf("expected defaultLogPrefix to be empty, got %T", o.defaultLogPrefix) + } + if o.logCallback == nil { + t.Errorf("expected logCallback to be set, got nil") + } +} diff --git a/internal/app/api/core/middleware/tracing/middleware.go b/internal/app/api/core/middleware/tracing/middleware.go new file mode 100644 index 0000000..1c8be6b --- /dev/null +++ b/internal/app/api/core/middleware/tracing/middleware.go @@ -0,0 +1,69 @@ +package tracing + +import ( + "context" + "math/rand" + "net/http" +) + +// Middleware is a type that creates a new tracing middleware. The tracing middleware +// can be used to trace requests based on a request ID header or parameter. +type Middleware struct { + o options + + seededRand *rand.Rand +} + +// New returns a new CORS middleware with the provided options. +func New(opts ...Option) *Middleware { + o := newOptions(opts...) + + m := &Middleware{ + o: o, + seededRand: rand.New(rand.NewSource(o.generateSeed)), + } + + return m +} + +// Handler returns the tracing middleware handler. +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var reqId string + + // read upstream header und re-use it + if m.o.upstreamReqIdHeader != "" { + reqId = r.Header.Get(m.o.upstreamReqIdHeader) + } + + // generate new id + if reqId == "" && m.o.generateLength > 0 { + reqId = m.generateRandomId() + } + + // set response header + if m.o.headerIdentifier != "" { + w.Header().Set(m.o.headerIdentifier, reqId) + } + + // set context value + if m.o.contextIdentifier != "" { + ctx := context.WithValue(r.Context(), m.o.contextIdentifier, reqId) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) // execute the next handler + }) +} + +// region internal-helpers + +func (m *Middleware) generateRandomId() string { + b := make([]byte, m.o.generateLength) + for i := range b { + b[i] = m.o.generateCharset[m.seededRand.Intn(len(m.o.generateCharset))] + } + return string(b) +} + +// endregion internal-helpers diff --git a/internal/app/api/core/middleware/tracing/middleware_test.go b/internal/app/api/core/middleware/tracing/middleware_test.go new file mode 100644 index 0000000..614d406 --- /dev/null +++ b/internal/app/api/core/middleware/tracing/middleware_test.go @@ -0,0 +1,118 @@ +package tracing + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +const defaultLength = 8 +const upstreamHeaderValue = "upstream-id" + +func TestMiddleware_Handler_WithUpstreamHeader(t *testing.T) { + m := New(WithUpstreamHeader("X-Upstream-Id")) + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := r.Header.Get("X-Upstream-Id") + if reqId != upstreamHeaderValue { + t.Errorf("expected upstream request id to be 'upstream-id', got %s", reqId) + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Upstream-Id", upstreamHeaderValue) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Header().Get("X-Request-Id") != upstreamHeaderValue { + t.Errorf("expected X-Request-Id header to be set in the response") + } +} + +func TestMiddleware_Handler_GenerateNewId(t *testing.T) { + idLen := 18 + m := New(WithIdLength(idLen)) + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := w.Header().Get("X-Request-Id") + if len(reqId) != 18 { + t.Errorf("expected generated request id length to be %d, got %d", idLen, len(reqId)) + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Header().Get("X-Request-Id") == "" || len(rr.Header().Get("X-Request-Id")) != idLen { + t.Errorf("expected X-Request-Id header to be set in the response") + } +} + +func TestMiddleware_Handler_SetContextValue(t *testing.T) { + m := New() + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := r.Context().Value("RequestId").(string) + if reqId == "" || len(reqId) != defaultLength { + t.Errorf("expected context request id to be set, got empty string") + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) +} + +func TestMiddleware_Handler_SetCustomContextValue(t *testing.T) { + m := New(WithContextIdentifier("Custom-Id")) + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := r.Context().Value("Custom-Id").(string) + if reqId == "" || len(reqId) != defaultLength { + t.Errorf("expected context request id to be set, got empty string") + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) +} + +func TestMiddleware_Handler_NoIdGenerated(t *testing.T) { + m := New(WithIdLength(0)) + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := w.Header().Get("X-Request-Id") + if reqId != "" { + t.Errorf("expected no request id to be generated, got %s", reqId) + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) +} + +func TestMiddleware_Handler_NoIdHeaderSet(t *testing.T) { + m := New(WithHeaderIdentifier("")) + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := w.Header().Get("X-Request-Id") + if reqId != "" { + t.Errorf("expected no request id to be generated, got %s", reqId) + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) +} + +func TestMiddleware_Handler_NoIdContextSet(t *testing.T) { + m := New(WithHeaderIdentifier("")) + handler := m.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqId := r.Context().Value("Request-Id") + if reqId != nil { + t.Errorf("expected no context request id to be set, got %v", reqId) + } + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) +} diff --git a/internal/app/api/core/middleware/tracing/options.go b/internal/app/api/core/middleware/tracing/options.go new file mode 100644 index 0000000..0b56db1 --- /dev/null +++ b/internal/app/api/core/middleware/tracing/options.go @@ -0,0 +1,85 @@ +package tracing + +import "time" + +// options is a struct that contains options for the tracing middleware. +// It uses the functional options pattern for flexible configuration. +type options struct { + upstreamReqIdHeader string + headerIdentifier string + contextIdentifier string + generateLength int + generateCharset string + generateSeed int64 +} + +// Option is a type that is used to set options for the tracing middleware. +// It implements the functional options pattern. +type Option func(*options) + +// WithIdSeed sets the seed for the random request id. +// If no seed is provided, the current timestamp is used. +func WithIdSeed(seed int64) Option { + return func(o *options) { + o.generateSeed = seed + } +} + +// WithIdCharset sets the charset that is used to generate a random request id. +// By default, upper-case letters and numbers are used. +func WithIdCharset(charset string) Option { + return func(o *options) { + o.generateCharset = charset + } +} + +// WithIdLength specifies the length of generated random ids. +// By default, a length of 8 is used. If the length is 0, no request id will be generated. +func WithIdLength(len int) Option { + return func(o *options) { + o.generateLength = len + } +} + +// WithHeaderIdentifier specifies the header name for the request id that is added to the response headers. +// If the identifier is empty, the request id will not be added to the response headers. +func WithHeaderIdentifier(identifier string) Option { + return func(o *options) { + o.headerIdentifier = identifier + } +} + +// WithUpstreamHeader sets the upstream header name, that should be used to fetch the request id. +// If no upstream header is found, a random id will be generated if the id-length parameter is set to a value > 0. +func WithUpstreamHeader(header string) Option { + return func(o *options) { + o.upstreamReqIdHeader = header + } +} + +// WithContextIdentifier specifies the value-key for the request id that is added to the request context. +// If the identifier is empty, the request id will not be added to the context. +// If the request id is added to the context, it can be retrieved with: +// `id := r.Context().Value(THE-IDENTIFIER).(string)` +func WithContextIdentifier(identifier string) Option { + return func(o *options) { + o.contextIdentifier = identifier + } +} + +// newOptions is a function that returns a new options struct with sane default values. +func newOptions(opts ...Option) options { + o := options{ + headerIdentifier: "X-Request-Id", + contextIdentifier: "RequestId", + generateSeed: time.Now().UnixNano(), + generateCharset: "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + generateLength: 8, + } + + for _, opt := range opts { + opt(&o) + } + + return o +} diff --git a/internal/app/api/core/middleware/tracing/options_test.go b/internal/app/api/core/middleware/tracing/options_test.go new file mode 100644 index 0000000..48649f9 --- /dev/null +++ b/internal/app/api/core/middleware/tracing/options_test.go @@ -0,0 +1,75 @@ +package tracing + +import ( + "testing" +) + +func TestWithIdSeed(t *testing.T) { + o := newOptions(WithIdSeed(12345)) + if o.generateSeed != 12345 { + t.Errorf("expected generateSeed to be 12345, got %d", o.generateSeed) + } +} + +func TestWithIdCharset(t *testing.T) { + o := newOptions(WithIdCharset("abc123")) + if o.generateCharset != "abc123" { + t.Errorf("expected generateCharset to be 'abc123', got %s", o.generateCharset) + } +} + +func TestWithIdLength(t *testing.T) { + o := newOptions(WithIdLength(16)) + if o.generateLength != 16 { + t.Errorf("expected generateLength to be 16, got %d", o.generateLength) + } +} + +func TestWithHeaderIdentifier(t *testing.T) { + o := newOptions(WithHeaderIdentifier("X-Custom-Id")) + if o.headerIdentifier != "X-Custom-Id" { + t.Errorf("expected headerIdentifier to be 'X-Custom-Id', got %s", o.headerIdentifier) + } +} + +func TestWithUpstreamHeader(t *testing.T) { + o := newOptions(WithUpstreamHeader("X-Upstream-Id")) + if o.upstreamReqIdHeader != "X-Upstream-Id" { + t.Errorf("expected upstreamReqIdHeader to be 'X-Upstream-Id', got %s", o.upstreamReqIdHeader) + } +} + +func TestWithContextIdentifier(t *testing.T) { + o := newOptions(WithContextIdentifier("Request-Id")) + if o.contextIdentifier != "Request-Id" { + t.Errorf("expected contextIdentifier to be 'Request-Id', got %s", o.contextIdentifier) + } +} + +func TestDefaults(t *testing.T) { + o := newOptions() + + if o.generateLength != 8 { + t.Errorf("expected generateLength to be 8, got %d", o.generateLength) + } + + if o.generateCharset != "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" { + t.Errorf("expected generateCharset to be 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', got %s", o.generateCharset) + } + + if o.generateSeed == 0 { + t.Errorf("expected generateSeed to be non-zero") + } + + if o.headerIdentifier != "X-Request-Id" { + t.Errorf("expected headerIdentifier to be 'X-Request-Id', got %s", o.headerIdentifier) + } + + if o.upstreamReqIdHeader != "" { + t.Errorf("expected upstreamReqIdHeader to be empty, got %s", o.upstreamReqIdHeader) + } + + if o.contextIdentifier != "RequestId" { + t.Errorf("expected contextIdentifier to be 'RequestId', got %s", o.contextIdentifier) + } +} diff --git a/internal/app/api/core/request/basic.go b/internal/app/api/core/request/basic.go new file mode 100644 index 0000000..fe162cb --- /dev/null +++ b/internal/app/api/core/request/basic.go @@ -0,0 +1,259 @@ +// Package request provides functions to extract parameters from the request. +package request + +import ( + "encoding/json" + "io" + "net" + "net/http" + "net/textproto" + "slices" + "strings" +) + +const CheckPrivateProxy = "PRIVATE" + +// PathRaw returns the value of the named path parameter. +func PathRaw(r *http.Request, name string) string { + return r.PathValue(name) +} + +// Path returns the value of the named path parameter. +// The return value is trimmed of leading and trailing whitespace. +func Path(r *http.Request, name string) string { + return strings.TrimSpace(PathRaw(r, name)) +} + +// PathDefault returns the value of the named path parameter. +// If the parameter is empty, it returns the default value. +// The return value is trimmed of leading and trailing whitespace. +func PathDefault(r *http.Request, name string, defaultValue string) string { + value := r.PathValue(name) + if value == "" { + return defaultValue + } + + return Path(r, name) +} + +// QueryRaw returns the value of the named query parameter. +func QueryRaw(r *http.Request, name string) string { + return r.URL.Query().Get(name) +} + +// Query returns the value of the named query parameter. +// The return value is trimmed of leading and trailing whitespace. +func Query(r *http.Request, name string) string { + return strings.TrimSpace(QueryRaw(r, name)) +} + +// QueryDefault returns the value of the named query parameter. +// If the parameter is empty, it returns the default value. +// The return value is trimmed of leading and trailing whitespace. +func QueryDefault(r *http.Request, name string, defaultValue string) string { + if !r.URL.Query().Has(name) { + return defaultValue + } + + return Query(r, name) +} + +// QuerySlice returns the value of the named query parameter. +// All slice values are trimmed of leading and trailing whitespace. +func QuerySlice(r *http.Request, name string) []string { + values, ok := r.URL.Query()[name] + if !ok { + return nil + } + + result := make([]string, len(values)) + for i, value := range values { + result[i] = strings.TrimSpace(value) + } + return result +} + +// QuerySliceDefault returns the value of the named query parameter. +// If the parameter is empty, it returns the default value. +// All slice values are trimmed of leading and trailing whitespace. +func QuerySliceDefault(r *http.Request, name string, defaultValue []string) []string { + if !r.URL.Query().Has(name) { + return defaultValue + } + + return QuerySlice(r, name) +} + +// FragmentRaw returns the value of the named fragment parameter. +func FragmentRaw(r *http.Request) string { + return r.URL.Fragment +} + +// Fragment returns the value of the named fragment parameter. +// The return value is trimmed of leading and trailing whitespace. +func Fragment(r *http.Request) string { + return strings.TrimSpace(FragmentRaw(r)) +} + +// FragmentDefault returns the value of the named fragment parameter. +// If the parameter is empty, it returns the default value. +// The return value is trimmed of leading and trailing whitespace. +func FragmentDefault(r *http.Request, defaultValue string) string { + if r.URL.Fragment == "" { + return defaultValue + } + + return Fragment(r) +} + +// FormRaw returns the value of the named form parameter. +func FormRaw(r *http.Request, name string) string { + return r.FormValue(name) +} + +// Form returns the value of the named form parameter. +// The return value is trimmed of leading and trailing whitespace. +func Form(r *http.Request, name string) string { + return strings.TrimSpace(FormRaw(r, name)) +} + +// DefaultForm returns the value of the named form parameter. +// If the parameter is not set, it returns the default value. +// The return value is trimmed of leading and trailing whitespace. +func DefaultForm(r *http.Request, name, defaultValue string) string { + err := r.ParseForm() + if err != nil { + return defaultValue + } + + if !r.Form.Has(name) { + return defaultValue + } + + return Form(r, name) +} + +// HeaderRaw returns the value of the named header. +func HeaderRaw(r *http.Request, name string) string { + return r.Header.Get(name) +} + +// Header returns the value of the named header. +// The return value is trimmed of leading and trailing whitespace. +func Header(r *http.Request, name string) string { + return strings.TrimSpace(HeaderRaw(r, name)) +} + +// HeaderDefault returns the value of the named header. +// If the header is not set, it returns the default value. +// The return value is trimmed of leading and trailing whitespace. +func HeaderDefault(r *http.Request, name, defaultValue string) string { + if _, ok := textproto.MIMEHeader(r.Header)[name]; !ok { + return defaultValue + } + + return Header(r, name) +} + +// Cookie returns the value of the named cookie. +// The return value is trimmed of leading and trailing whitespace. +func Cookie(r *http.Request, name string) string { + cookie, err := r.Cookie(name) + if err != nil { + return "" + } + + return strings.TrimSpace(cookie.Value) +} + +// CookieDefault returns the value of the named cookie. +// If the cookie is not set, it returns the default value. +// The return value is trimmed of leading and trailing whitespace. +func CookieDefault(r *http.Request, name, defaultValue string) string { + cookie, err := r.Cookie(name) + if err != nil { + return defaultValue + } + + return strings.TrimSpace(cookie.Value) +} + +// ClientIp returns the client IP address. +// +// As the request may come from a proxy, the function checks the +// X-Real-Ip and X-Forwarded-For headers to get the real client IP +// if the request IP matches one of the allowed proxy IPs. +// If the special proxy value CheckPrivateProxy ("PRIVATE") is passed, the function will +// also check the header if the request IP is a private IP address. +func ClientIp(r *http.Request, allowedProxyIp ...string) string { + ipStr, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) + switch { + case err != nil && strings.Contains(err.Error(), "missing port in address"): + ipStr = strings.TrimSpace(r.RemoteAddr) + case err != nil: + ipStr = "" + } + IP := net.ParseIP(ipStr) + if IP == nil { + return "" + } + + isProxiedRequest := false + if len(allowedProxyIp) > 0 { + if slices.Contains(allowedProxyIp, IP.String()) { + isProxiedRequest = true + } + if IP.IsPrivate() && slices.Contains(allowedProxyIp, CheckPrivateProxy) { + isProxiedRequest = true + } + } + + if isProxiedRequest { + realClientIP := r.Header.Get("X-Real-Ip") + if realClientIP == "" { + realClientIP = r.Header.Get("X-Forwarded-For") + } + if realClientIP != "" { + realIpStr, _, err := net.SplitHostPort(strings.TrimSpace(realClientIP)) + switch { + case err != nil && strings.Contains(err.Error(), "missing port in address"): + realIpStr = realClientIP + case err != nil: + realIpStr = ipStr + } + realIP := net.ParseIP(realIpStr) + if realIP == nil { + return IP.String() + } + return realIP.String() + } + } + + return IP.String() +} + +// BodyJson decodes the JSON value from the request body into the target. +// The target must be a pointer to a struct or slice. +// The function returns an error if the JSON value could not be decoded. +// The body reader is closed after reading. +func BodyJson(r *http.Request, target any) error { + defer func() { + _ = r.Body.Close() + }() + return json.NewDecoder(r.Body).Decode(target) +} + +// BodyString returns the request body as a string. +// The content is read and returned as is, without any processing. +// The body is assumed to be UTF-8 encoded. +func BodyString(r *http.Request) (string, error) { + defer func() { + _ = r.Body.Close() + }() + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return "", err + } + return string(bodyBytes), nil +} diff --git a/internal/app/api/core/request/basic_test.go b/internal/app/api/core/request/basic_test.go new file mode 100644 index 0000000..7772b4f --- /dev/null +++ b/internal/app/api/core/request/basic_test.go @@ -0,0 +1,221 @@ +package request + +import ( + "io" + "net/http" + "net/url" + "slices" + "strings" + "testing" +) + +func TestPath(t *testing.T) { + r := &http.Request{URL: &url.URL{Path: "/test/sample"}} + r.SetPathValue("first", "test") + if got := Path(r, "first"); got != "test" { + t.Errorf("Path() = %v, want %v", got, "test") + } +} + +func TestDefaultPath(t *testing.T) { + r := &http.Request{URL: &url.URL{Path: "/"}} + if got := PathDefault(r, "test", "default"); got != "default" { + t.Errorf("PathDefault() = %v, want %v", got, "default") + } +} + +func TestDefaultPath_noDefault(t *testing.T) { + r := &http.Request{URL: &url.URL{Path: "/"}} + r.SetPathValue("first", "test") + if got := PathDefault(r, "first", "test"); got != "test" { + t.Errorf("PathDefault() = %v, want %v", got, "test") + } +} + +func TestQuery(t *testing.T) { + r := &http.Request{URL: &url.URL{RawQuery: "name=value"}} + if got := Query(r, "name"); got != "value" { + t.Errorf("Query() = %v, want %v", got, "value") + } +} + +func TestDefaultQuery(t *testing.T) { + r := &http.Request{URL: &url.URL{RawQuery: ""}} + if got := QueryDefault(r, "name", "default"); got != "default" { + t.Errorf("QueryDefault() = %v, want %v", got, "default") + } +} + +func TestQuerySlice(t *testing.T) { + r := &http.Request{URL: &url.URL{RawQuery: "name=value1 &name=value2"}} + expected := []string{"value1", "value2"} + if got := QuerySlice(r, "name"); !slices.Equal(got, expected) { + t.Errorf("QuerySlice() = %v, want %v", got, expected) + } +} + +func TestQuerySlice_empty(t *testing.T) { + r := &http.Request{URL: &url.URL{RawQuery: "name=value1&name=value2"}} + if got := QuerySlice(r, "nix"); !slices.Equal(got, nil) { + t.Errorf("QuerySlice() = %v, want %v", got, nil) + } +} + +func TestDefaultQuerySlice(t *testing.T) { + r := &http.Request{URL: &url.URL{RawQuery: ""}} + defaultValue := []string{"default1", "default2"} + if got := QuerySliceDefault(r, "name", defaultValue); !slices.Equal(got, defaultValue) { + t.Errorf("QuerySliceDefault() = %v, want %v", got, defaultValue) + } +} + +func TestFragment(t *testing.T) { + r := &http.Request{URL: &url.URL{Fragment: "section"}} + if got := Fragment(r); got != "section" { + t.Errorf("Fragment() = %v, want %v", got, "section") + } +} + +func TestDefaultFragment(t *testing.T) { + r := &http.Request{URL: &url.URL{Fragment: ""}} + if got := FragmentDefault(r, "default"); got != "default" { + t.Errorf("FragmentDefault() = %v, want %v", got, "default") + } +} + +func TestForm(t *testing.T) { + r := &http.Request{Form: url.Values{"name": {"value"}}} + if got := Form(r, "name"); got != "value" { + t.Errorf("Form() = %v, want %v", got, "value") + } +} + +func TestDefaultForm(t *testing.T) { + r := &http.Request{Form: url.Values{}} + if got := DefaultForm(r, "name", "default"); got != "default" { + t.Errorf("DefaultForm() = %v, want %v", got, "default") + } +} + +func TestHeader(t *testing.T) { + r := &http.Request{Header: http.Header{"X-Test-Header": {"value"}}} + if got := Header(r, "X-Test-Header"); got != "value" { + t.Errorf("Header() = %v, want %v", got, "value") + } +} + +func TestDefaultHeader(t *testing.T) { + r := &http.Request{Header: http.Header{}} + if got := HeaderDefault(r, "X-Test-Header", "default"); got != "default" { + t.Errorf("HeaderDefault() = %v, want %v", got, "default") + } +} + +func TestCookie(t *testing.T) { + r := &http.Request{Header: http.Header{"Cookie": {"name=value"}}} + if got := Cookie(r, "name"); got != "value" { + t.Errorf("Cookie() = %v, want %v", got, "value") + } +} + +func TestDefaultCookie(t *testing.T) { + r := &http.Request{Header: http.Header{}} + if got := CookieDefault(r, "name", "default"); got != "default" { + t.Errorf("CookieDefault() = %v, want %v", got, "default") + } +} + +func TestClientIp(t *testing.T) { + r := &http.Request{RemoteAddr: "192.168.1.1:12345"} + if got := ClientIp(r); got != "192.168.1.1" { + t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1") + } +} + +func TestClientIp_invalid(t *testing.T) { + r := &http.Request{RemoteAddr: "was_isn_des"} + if got := ClientIp(r); got != "" { + t.Errorf("ClientIp() = %v, want %v", got, "") + } +} + +func TestClientIp_ignore_header(t *testing.T) { + r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}} + if got := ClientIp(r); got != "192.168.1.1" { + t.Errorf("ClientIp() = %v, want %v", got, "192.168.1.1") + } +} + +func TestClientIp_header1(t *testing.T) { + r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Forwarded-For": {"123.45.67.1"}}} + if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" { + t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1") + } +} + +func TestClientIp_header2(t *testing.T) { + r := &http.Request{RemoteAddr: "192.168.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}} + if got := ClientIp(r, CheckPrivateProxy); got != "123.45.67.1" { + t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1") + } +} + +func TestClientIp_header3(t *testing.T) { + r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}} + if got := ClientIp(r, "1.1.1.1"); got != "123.45.67.1" { + t.Errorf("ClientIp() = %v, want %v", got, "123.45.67.1") + } +} + +func TestClientIp_header4(t *testing.T) { + r := &http.Request{RemoteAddr: "8.8.8.8:12345", Header: http.Header{"X-Real-Ip": {"123.45.67.1"}}} + if got := ClientIp(r, "1.1.1.1"); got != "8.8.8.8" { + t.Errorf("ClientIp() = %v, want %v", got, "8.8.8.8") + } +} + +func TestClientIp_header_invalid(t *testing.T) { + r := &http.Request{RemoteAddr: "1.1.1.1:12345", Header: http.Header{"X-Real-Ip": {"so-sicher-nit"}}} + if got := ClientIp(r, "1.1.1.1"); got != "1.1.1.1" { + t.Errorf("ClientIp() = %v, want %v", got, "1.1.1.1") + } +} + +func TestBodyJson(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + Value int `json:"value"` + } + + jsonStr := `{"name": "test", "value": 123}` + r := &http.Request{ + Body: io.NopCloser(strings.NewReader(jsonStr)), + } + + var result TestStruct + err := BodyJson(r, &result) + if err != nil { + t.Fatalf("BodyJson() error = %v", err) + } + + expected := TestStruct{Name: "test", Value: 123} + if result != expected { + t.Errorf("BodyJson() = %v, want %v", result, expected) + } +} + +func TestBodyString(t *testing.T) { + bodyStr := "test body content" + r := &http.Request{ + Body: io.NopCloser(strings.NewReader(bodyStr)), + } + + result, err := BodyString(r) + if err != nil { + t.Fatalf("BodyString() error = %v", err) + } + + if result != bodyStr { + t.Errorf("BodyString() = %v, want %v", result, bodyStr) + } +} diff --git a/internal/app/api/core/respond/basic.go b/internal/app/api/core/respond/basic.go new file mode 100644 index 0000000..64c07ff --- /dev/null +++ b/internal/app/api/core/respond/basic.go @@ -0,0 +1,100 @@ +// Package respond provides a set of utility functions to help with the HTTP response handling. +package respond + +import ( + "encoding/json" + "io" + "net/http" + "strconv" +) + +// Status writes a response with the given status code. +// The response will not contain any data. +func Status(w http.ResponseWriter, code int) { + w.WriteHeader(code) +} + +// String writes a plain text response with the given status code and data. +// The Content-Type header is set to text/plain with a charset of utf-8. +func String(w http.ResponseWriter, code int, data string) { + w.Header().Set("Content-Type", "text/plain;charset=utf-8") + w.WriteHeader(code) + + _, _ = w.Write([]byte(data)) +} + +// JSON writes a JSON response with the given status code and data. +// If data is nil, the response will null. The status code is set to the given code. +// The Content-Type header is set to application/json. +// If the given data is not JSON serializable, the response will not contain any data. +// All encoding errors are silently ignored. +func JSON(w http.ResponseWriter, code int, data any) { + w.Header().Set("Content-Type", "application/json") + + // if no data was given, simply return null + if data == nil { + w.WriteHeader(code) + _, _ = w.Write([]byte("null")) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + + _ = json.NewEncoder(w).Encode(data) +} + +// Data writes a response with the given status code, content type, and data. +// If no content type is provided, it is detected from the data. +func Data(w http.ResponseWriter, code int, contentType string, data []byte) { + if contentType == "" { + contentType = http.DetectContentType(data) // ensure content type is set + } + + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.WriteHeader(code) + + _, _ = w.Write(data) +} + +// Reader writes a response with the given status code, content type, and data. +// The content length is optional, it is only set if the given length is greater than 0. +func Reader(w http.ResponseWriter, code int, contentType string, contentLength int, data io.Reader) { + w.Header().Set("Content-Type", contentType) + if contentLength > 0 { + w.Header().Set("Content-Length", strconv.Itoa(contentLength)) + } + w.WriteHeader(code) + + _, _ = io.Copy(w, data) +} + +// Attachment writes a response with the given status code, content type, filename, and data. +// If no content type is provided, it is detected from the data. +func Attachment(w http.ResponseWriter, code int, filename, contentType string, data []byte) { + w.Header().Set("Content-Disposition", "attachment; filename="+filename) + + Data(w, code, contentType, data) +} + +// AttachmentReader writes a response with the given status code, content type, filename, content length, and data. +// The content length is optional, it is only set if the given length is greater than 0. +func AttachmentReader( + w http.ResponseWriter, + code int, + filename, contentType string, + contentLength int, + data io.Reader, +) { + w.Header().Set("Content-Disposition", "attachment; filename="+filename) + + Reader(w, code, contentType, contentLength, data) +} + +// Redirect writes a response with the given status code and redirects to the given URL. +// The redirect url will always be an absolute URL. If the given URL is relative, +// the original request URL is used as the base. +func Redirect(w http.ResponseWriter, r *http.Request, code int, url string) { + http.Redirect(w, r, url, code) +} diff --git a/internal/app/api/core/respond/basic_test.go b/internal/app/api/core/respond/basic_test.go new file mode 100644 index 0000000..06b3817 --- /dev/null +++ b/internal/app/api/core/respond/basic_test.go @@ -0,0 +1,273 @@ +package respond + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" +) + +func TestStatus(t *testing.T) { + rec := httptest.NewRecorder() + Status(rec, http.StatusNoContent) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + t.Errorf("expected status %d, got %d", http.StatusNoContent, res.StatusCode) + } + + body, _ := io.ReadAll(res.Body) + if len(body) != 0 { + t.Errorf("expected no body, got %s", body) + } +} + +func TestString(t *testing.T) { + rec := httptest.NewRecorder() + String(rec, http.StatusOK, "Hello, World!") + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "text/plain;charset=utf-8" { + t.Errorf("expected content type %s, got %s", "text/plain;charset=utf-8", contentType) + } + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello, World!" { + t.Errorf("expected body %s, got %s", "Hello, World!", string(body)) + } +} + +func TestJSON(t *testing.T) { + rec := httptest.NewRecorder() + data := map[string]string{"hello": "world"} + JSON(rec, http.StatusOK, data) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "application/json" { + t.Errorf("expected content type %s, got %s", "application/json", contentType) + } + + var body map[string]string + _ = json.NewDecoder(res.Body).Decode(&body) + if body["hello"] != "world" { + t.Errorf("expected body %v, got %v", data, body) + } +} + +func TestJSON_empty(t *testing.T) { + rec := httptest.NewRecorder() + JSON(rec, http.StatusOK, nil) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "application/json" { + t.Errorf("expected content type %s, got %s", "application/json", contentType) + } + + body, _ := io.ReadAll(res.Body) + if string(body) != "null" { + t.Errorf("expected body %s, got %s", "null", body) + } +} + +func TestData(t *testing.T) { + rec := httptest.NewRecorder() + data := []byte("Hello, World!") + Data(rec, http.StatusOK, "text/plain", data) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" { + t.Errorf("expected content type %s, got %s", "text/plain", contentType) + } + + body, _ := io.ReadAll(res.Body) + if !bytes.Equal(body, data) { + t.Errorf("expected body %s, got %s", data, body) + } +} + +func TestData_noContentType(t *testing.T) { + rec := httptest.NewRecorder() + data := []byte{0x1, 0x2, 0x3, 0x4, 0x5} + Data(rec, http.StatusOK, "", data) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "application/octet-stream" { + t.Errorf("expected content type %s, got %s", "application/octet-stream", contentType) + } + + body, _ := io.ReadAll(res.Body) + if !bytes.Equal(body, data) { + t.Errorf("expected body %s, got %s", data, body) + } +} + +func TestReader(t *testing.T) { + rec := httptest.NewRecorder() + data := []byte("Hello, World!") + reader := bytes.NewBufferString(string(data)) + Reader(rec, http.StatusOK, "text/plain", len(data), reader) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" { + t.Errorf("expected content type %s, got %s", "text/plain", contentType) + } + + if contentLength := res.Header.Get("Content-Length"); contentLength != strconv.Itoa(len(data)) { + t.Errorf("expected content length %d, got %s", len(data), contentLength) + } + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello, World!" { + t.Errorf("expected body %s, got %s", "Hello, World!", string(body)) + } +} + +func TestReader_unknownLength(t *testing.T) { + rec := httptest.NewRecorder() + data := bytes.NewBufferString("Hello, World!") + Reader(rec, http.StatusOK, "text/plain", 0, data) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" { + t.Errorf("expected content type %s, got %s", "text/plain", contentType) + } + + if contentLength := res.Header.Get("Content-Length"); contentLength != "" { + t.Errorf("expected no content length, got %s", contentLength) + } + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello, World!" { + t.Errorf("expected body %s, got %s", "Hello, World!", string(body)) + } +} + +func TestAttachment(t *testing.T) { + rec := httptest.NewRecorder() + data := []byte("Hello, World!") + Attachment(rec, http.StatusOK, "example.txt", "text/plain", data) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" { + t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition) + } + + body, _ := io.ReadAll(res.Body) + if !bytes.Equal(body, data) { + t.Errorf("expected body %s, got %s", data, body) + } +} + +func TestAttachmentReader(t *testing.T) { + rec := httptest.NewRecorder() + data := bytes.NewBufferString("Hello, World!") + AttachmentReader(rec, http.StatusOK, "example.txt", "text/plain", data.Len(), data) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "attachment; filename=example.txt" { + t.Errorf("expected content disposition %s, got %s", "attachment; filename=example.txt", contentDisposition) + } + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello, World!" { + t.Errorf("expected body %s, got %s", "Hello, World!", string(body)) + } +} + +func TestRedirect(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/old", nil) + url := "http://example.com/new" + + Redirect(rec, req, http.StatusMovedPermanently, url) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusMovedPermanently { + t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode) + } + + if location := res.Header.Get("Location"); location != url { + t.Errorf("expected location %s, got %s", url, location) + } +} + +func TestRedirect_relative(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/old/dir", nil) + url := "newlocation/sub" + want := "/old/newlocation/sub" + + Redirect(rec, req, http.StatusMovedPermanently, url) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusMovedPermanently { + t.Errorf("expected status %d, got %d", http.StatusMovedPermanently, res.StatusCode) + } + + if location := res.Header.Get("Location"); location != want { + t.Errorf("expected location %s, got %s", want, location) + } +} diff --git a/internal/app/api/core/respond/template.go b/internal/app/api/core/respond/template.go new file mode 100644 index 0000000..65796f3 --- /dev/null +++ b/internal/app/api/core/respond/template.go @@ -0,0 +1,46 @@ +package respond + +import ( + "fmt" + "io" + "net/http" +) + +// TplData is a map of template data. This is a convenience type for passing data to templates. +type TplData map[string]any + +// TemplateInstance is an interface that wraps the ExecuteTemplate method. +// It is implemented by the html/template and text/template packages. +type TemplateInstance interface { + // ExecuteTemplate executes a template with the given name and data. + ExecuteTemplate(wr io.Writer, name string, data any) error +} + +// TemplateRenderer is a renderer that uses a template instance to render HTML or Text templates. +type TemplateRenderer struct { + t TemplateInstance +} + +// NewTemplateRenderer creates a new HTML or Text template renderer with the given template instance. +func NewTemplateRenderer(t TemplateInstance) *TemplateRenderer { + return &TemplateRenderer{t: t} +} + +// Render renders a template with the given name and data. +// If rendering fails, it will panic with an error. +func (r *TemplateRenderer) Render(w http.ResponseWriter, code int, name, contentType string, data any) { + w.Header().Set("Content-Type", contentType) + w.WriteHeader(code) + + err := r.t.ExecuteTemplate(w, name, data) + if err != nil { + panic(fmt.Errorf("error rendering template %s: %v", name, err)) + } +} + +// HTML renders a template with the given name and data. It is a convenience method for Render. +// The content type is set to "text/html" and the encoding to "utf-8". +// If rendering fails, it will panic with an error. +func (r *TemplateRenderer) HTML(w http.ResponseWriter, code int, name string, data any) { + r.Render(w, code, name, "text/html;charset=utf-8", data) +} diff --git a/internal/app/api/core/respond/template_test.go b/internal/app/api/core/respond/template_test.go new file mode 100644 index 0000000..110d895 --- /dev/null +++ b/internal/app/api/core/respond/template_test.go @@ -0,0 +1,67 @@ +package respond + +import ( + "html/template" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +type mockTemplate struct { + tmpl *template.Template +} + +func (m *mockTemplate) ExecuteTemplate(wr io.Writer, name string, data any) error { + return m.tmpl.ExecuteTemplate(wr, name, data) +} + +func TestTemplateRenderer_Render(t *testing.T) { + tmpl := template.Must(template.New("test").Parse(`{{define "test"}}Hello, {{.}}!{{end}}`)) + renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl}) + + rec := httptest.NewRecorder() + renderer.Render(rec, http.StatusOK, "test", "text/plain", "World") + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "text/plain" { + t.Errorf("expected content type %s, got %s", "text/plain", contentType) + } + + body, _ := io.ReadAll(res.Body) + expectedBody := "Hello, World!" + if string(body) != expectedBody { + t.Errorf("expected body %s, got %s", expectedBody, string(body)) + } +} + +func TestTemplateRenderer_HTML(t *testing.T) { + tmpl := template.Must(template.New("test").Parse(`{{define "test"}}

Hello, {{.}}!

{{end}}`)) + renderer := NewTemplateRenderer(&mockTemplate{tmpl: tmpl}) + + rec := httptest.NewRecorder() + renderer.HTML(rec, http.StatusOK, "test", "World") + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, res.StatusCode) + } + + if contentType := res.Header.Get("Content-Type"); contentType != "text/html;charset=utf-8" { + t.Errorf("expected content type %s, got %s", "text/html;charset=utf-8", contentType) + } + + body, _ := io.ReadAll(res.Body) + expectedBody := "

Hello, World!

" + if string(body) != expectedBody { + t.Errorf("expected body %s, got %s", expectedBody, string(body)) + } +} diff --git a/internal/app/api/core/server.go b/internal/app/api/core/server.go index 441011e..fe0d1e6 100644 --- a/internal/app/api/core/server.go +++ b/internal/app/api/core/server.go @@ -2,27 +2,25 @@ package core import ( "context" - "encoding/base64" "fmt" "html/template" - "io" "io/fs" "log/slog" - "math/rand" "net/http" "os" "time" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/cors" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/logging" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/recovery" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/tracing" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/config" ) -var ( - random = rand.New(rand.NewSource(time.Now().UTC().UnixNano())) -) - const ( RequestIDKey = "X-Request-ID" ) @@ -30,19 +28,21 @@ const ( type ApiVersion string type HandlerName string -type GroupSetupFn func(group *gin.RouterGroup) +type GroupSetupFn func(group *routegroup.Bundle) type ApiEndpointSetupFunc func() (ApiVersion, GroupSetupFn) type Server struct { cfg *config.Config - server *gin.Engine - versions map[ApiVersion]*gin.RouterGroup + server *routegroup.Bundle + tpl *respond.TemplateRenderer + versions map[ApiVersion]*routegroup.Bundle } func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server, error) { s := &Server{ - cfg: cfg, + cfg: cfg, + server: routegroup.New(http.NewServeMux()), } hostname, err := os.Hostname() @@ -51,69 +51,39 @@ func NewServer(cfg *config.Config, endpoints ...ApiEndpointSetupFunc) (*Server, } hostname += ", version " + internal.Version - // Setup http server - gin.SetMode(gin.ReleaseMode) - gin.DefaultWriter = io.Discard - s.server = gin.New() - + s.server.Use(recovery.New().Handler) if cfg.Web.RequestLogging { - if cfg.Advanced.LogLevel == "trace" { - gin.SetMode(gin.DebugMode) - } - s.server.Use(func(c *gin.Context) { - start := time.Now() - path := c.Request.URL.Path - raw := c.Request.URL.RawQuery + s.server.Use(logging.New(logging.WithLevel(logging.LogLevelDebug)).Handler) - c.Next() - - if raw != "" { - path = path + "?" + raw - } - - latency := time.Since(start) - status := c.Writer.Status() - clientIP := c.ClientIP() - method := c.Request.Method - errorMsg := c.Errors.ByType(gin.ErrorTypePrivate).String() - - slog.Debug("HTTP Request", - "status", status, - "latency", latency, - "client", clientIP, - "method", method, - "path", path, - "error", errorMsg, - ) + } + s.server.Use(cors.New().Handler) + s.server.Use(tracing.New( + tracing.WithContextIdentifier(RequestIDKey), + tracing.WithHeaderIdentifier(RequestIDKey), + ).Handler) + if cfg.Web.ExposeHostInfo { + s.server.Use(func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Served-By", hostname) + handler.ServeHTTP(w, r) + }) }) } - s.server.Use(gin.Recovery()).Use(func(c *gin.Context) { - c.Writer.Header().Set("X-Served-By", hostname) - c.Next() - }).Use(func(c *gin.Context) { - xRequestID := uuid(16) - - c.Request.Header.Set(RequestIDKey, xRequestID) - c.Set(RequestIDKey, xRequestID) - c.Next() - }) - // Setup templates - templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(apiTemplates, "assets/tpl/*.gohtml")) - s.server.SetHTMLTemplate(templates) + s.tpl = respond.NewTemplateRenderer( + template.Must(template.New("").ParseFS(apiTemplates, "assets/tpl/*.gohtml")), + ) // Serve static files imgFs := http.FS(fsMust(fs.Sub(apiStatics, "assets/img"))) - s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css")))) - s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js")))) - s.server.StaticFS("/img", imgFs) - s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts")))) - s.server.StaticFS("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc")))) + s.server.HandleFiles("/css", http.FS(fsMust(fs.Sub(apiStatics, "assets/css")))) + s.server.HandleFiles("/js", http.FS(fsMust(fs.Sub(apiStatics, "assets/js")))) + s.server.HandleFiles("/img", imgFs) + s.server.HandleFiles("/fonts", http.FS(fsMust(fs.Sub(apiStatics, "assets/fonts")))) + s.server.HandleFiles("/doc", http.FS(fsMust(fs.Sub(apiStatics, "assets/doc")))) // Setup routes - s.server.UseRawPath = true - s.server.UnescapePathValues = true s.setupRoutes(endpoints...) s.setupFrontendRoutes() @@ -136,9 +106,7 @@ func (s *Server) Run(ctx context.Context, listenAddress string) { err = srv.ListenAndServe() } if err != nil { - slog.Info("web service exited", - "address", listenAddress, - "error", err) + slog.Info("web service exited", "address", listenAddress, "error", err) cancelFn() } }() @@ -157,18 +125,18 @@ func (s *Server) Run(ctx context.Context, listenAddress string) { } func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) { - s.server.GET("/api", s.landingPage) - s.versions = make(map[ApiVersion]*gin.RouterGroup) + s.server.HandleFunc("GET /api", s.landingPage) + s.versions = make(map[ApiVersion]*routegroup.Bundle) for _, setupFunc := range endpoints { version, groupSetupFn := setupFunc() if _, ok := s.versions[version]; !ok { - s.versions[version] = s.server.Group(fmt.Sprintf("/api/%s", version)) + s.versions[version] = s.server.Mount(fmt.Sprintf("/api/%s", version)) // OpenAPI documentation (via RapiDoc) - s.versions[version].GET("/swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link - s.versions[version].GET("/doc.html", s.rapiDocHandler(version)) + s.versions[version].HandleFunc("GET /swagger/index.html", s.rapiDocHandler(version)) // Deprecated: old link + s.versions[version].HandleFunc("GET /doc.html", s.rapiDocHandler(version)) groupSetupFn(s.versions[version]) } @@ -177,25 +145,27 @@ func (s *Server) setupRoutes(endpoints ...ApiEndpointSetupFunc) { func (s *Server) setupFrontendRoutes() { // Serve static files - s.server.GET("/", func(c *gin.Context) { - c.Redirect(http.StatusMovedPermanently, "/app") + s.server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + respond.Redirect(w, r, http.StatusMovedPermanently, "/app") }) - s.server.GET("/favicon.ico", func(c *gin.Context) { - c.Redirect(http.StatusMovedPermanently, "/app/favicon.ico") + + s.server.HandleFunc("/favicon.ico", func(w http.ResponseWriter, r *http.Request) { + respond.Redirect(w, r, http.StatusMovedPermanently, "/app/favicon.ico") }) - s.server.StaticFS("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist")))) + + s.server.HandleFiles("/app", http.FS(fsMust(fs.Sub(frontendStatics, "frontend-dist")))) } -func (s *Server) landingPage(c *gin.Context) { - c.HTML(http.StatusOK, "index.gohtml", gin.H{ +func (s *Server) landingPage(w http.ResponseWriter, _ *http.Request) { + s.tpl.HTML(w, http.StatusOK, "index.gohtml", respond.TplData{ "Version": internal.Version, "Year": time.Now().Year(), }) } -func (s *Server) rapiDocHandler(version ApiVersion) gin.HandlerFunc { - return func(c *gin.Context) { - c.HTML(http.StatusOK, "rapidoc.gohtml", gin.H{ +func (s *Server) rapiDocHandler(version ApiVersion) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + s.tpl.HTML(w, http.StatusOK, "rapidoc.gohtml", respond.TplData{ "RapiDocSource": "/js/rapidoc-min.js", "ApiSpecUrl": fmt.Sprintf("/doc/%s_swagger.yaml", version), "Version": internal.Version, @@ -210,9 +180,3 @@ func fsMust(f fs.FS, err error) fs.FS { } return f } - -func uuid(len int) string { - bytes := make([]byte, len) - random.Read(bytes) - return base64.StdEncoding.EncodeToString(bytes)[:len] -} diff --git a/internal/app/api/v0/handlers/base.go b/internal/app/api/v0/handlers/base.go index ecbe5aa..bb32b5e 100644 --- a/internal/app/api/v0/handlers/base.go +++ b/internal/app/api/v0/handlers/base.go @@ -1,24 +1,46 @@ package handlers import ( + "context" "net/http" - "strings" - "github.com/gin-contrib/cors" - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" - "github.com/gin-gonic/gin" - csrf "github.com/utrack/gin-csrf" + "github.com/go-pkgz/routegroup" - "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app/api/core" - "github.com/h44z/wg-portal/internal/app/api/v0/model" - "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/cors" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/csrf" + "github.com/h44z/wg-portal/internal/app/api/core/respond" ) -type handler interface { +type SessionMiddleware interface { + // SetData sets the session data for the given context. + SetData(ctx context.Context, val SessionData) + // GetData returns the session data for the given context. If no data is found, the default session data is returned. + GetData(ctx context.Context) SessionData + // DestroyData destroys the session data for the given context. + DestroyData(ctx context.Context) + + // GetString returns the string value for the given key. If no value is found, an empty string is returned. + GetString(ctx context.Context, key string) string + // Put sets the value for the given key. + Put(ctx context.Context, key string, value any) + // LoadAndSave is a middleware that loads the session data for the given request and saves it after the request is + // finished. + LoadAndSave(next http.Handler) http.Handler +} + +type Handler interface { + // GetName returns the name of the handler. GetName() string - RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) + // RegisterRoutes registers the routes for the handler. The session manager is passed to the handler. + RegisterRoutes(g *routegroup.Bundle) +} + +type Authenticator interface { + // LoggedIn checks if a user is logged in. If scopes are given, they are validated as well. + LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler + // UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted. + UserIdMatch(idParameter string) func(next http.Handler) http.Handler } // To compile the API documentation use the @@ -35,54 +57,33 @@ type handler interface { // @BasePath /api/v0 // @query.collection.format multi -func NewRestApi(cfg *config.Config, app *app.App) core.ApiEndpointSetupFunc { - authenticator := &authenticationHandler{ - app: app, - Session: GinSessionStore{sessionIdentifier: cfg.Web.SessionIdentifier}, - } - - handlers := make([]handler, 0, 1) - handlers = append(handlers, testEndpoint{}) - handlers = append(handlers, userEndpoint{app: app, authenticator: authenticator}) - handlers = append(handlers, newConfigEndpoint(app, authenticator)) - handlers = append(handlers, authEndpoint{app: app, authenticator: authenticator}) - handlers = append(handlers, interfaceEndpoint{app: app, authenticator: authenticator}) - handlers = append(handlers, peerEndpoint{app: app, authenticator: authenticator}) - +func NewRestApi( + session SessionMiddleware, + handlers ...Handler, +) core.ApiEndpointSetupFunc { return func() (core.ApiVersion, core.GroupSetupFn) { - return "v0", func(group *gin.RouterGroup) { - cookieStore := memstore.NewStore([]byte(cfg.Web.SessionSecret)) - cookieStore.Options(sessions.Options{ - Path: "/", - MaxAge: 86400, // auth session is valid for 1 day - Secure: strings.HasPrefix(cfg.Web.ExternalUrl, "https"), - HttpOnly: true, - SameSite: http.SameSiteLaxMode, + return "v0", func(group *routegroup.Bundle) { + csrfMiddleware := csrf.New(func(r *http.Request) string { + return session.GetString(r.Context(), "csrf_token") + }, func(r *http.Request, token string) { + session.Put(r.Context(), "csrf_token", token) }) - group.Use(sessions.Sessions(cfg.Web.SessionIdentifier, cookieStore)) - group.Use(cors.Default()) - group.Use(csrf.Middleware(csrf.Options{ - Secret: cfg.Web.CsrfSecret, - ErrorFunc: func(c *gin.Context) { - c.JSON(http.StatusBadRequest, model.Error{ - Code: http.StatusBadRequest, - Message: "CSRF token mismatch", - }) - c.Abort() - }, - })) - group.GET("/csrf", handleCsrfGet()) + group.Use(session.LoadAndSave) + group.Use(csrfMiddleware.Handler) + group.Use(cors.New().Handler) + + group.With(csrfMiddleware.RefreshToken).HandleFunc("GET /csrf", handleCsrfGet()) // Handler functions for _, h := range handlers { - h.RegisterRoutes(group, authenticator) + h.RegisterRoutes(group) } } } } -// handleCsrfGet returns a gorm handler function. +// handleCsrfGet returns a gorm Handler function. // // @ID base_handleCsrfGet // @Tags Security @@ -90,8 +91,12 @@ func NewRestApi(cfg *config.Config, app *app.App) core.ApiEndpointSetupFunc { // @Produce json // @Success 200 {object} string // @Router /csrf [get] -func handleCsrfGet() gin.HandlerFunc { - return func(c *gin.Context) { - c.JSON(http.StatusOK, csrf.GetToken(c)) +func handleCsrfGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + respond.JSON(w, http.StatusOK, csrf.GetToken(r.Context())) } } + +// region session wrapper + +// endregion session wrapper diff --git a/internal/app/api/v0/handlers/endpoint_authentication.go b/internal/app/api/v0/handlers/endpoint_authentication.go index ce1b975..17fd3f8 100644 --- a/internal/app/api/v0/handlers/endpoint_authentication.go +++ b/internal/app/api/v0/handlers/endpoint_authentication.go @@ -8,36 +8,62 @@ import ( "strings" "time" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v0/model" "github.com/h44z/wg-portal/internal/domain" ) -type authEndpoint struct { - app *app.App - authenticator *authenticationHandler +type Session interface { + // SetData sets the session data for the given context. + SetData(ctx context.Context, val SessionData) + // GetData returns the session data for the given context. If no data is found, the default session data is returned. + GetData(ctx context.Context) SessionData + // DestroyData destroys the session data for the given context. + DestroyData(ctx context.Context) } -func (e authEndpoint) GetName() string { +type Validator interface { + Struct(s interface{}) error +} + +type AuthEndpoint struct { + app *app.App + authenticator Authenticator + session Session + validate Validator +} + +func NewAuthEndpoint(app *app.App, authenticator Authenticator, session Session, validator Validator) AuthEndpoint { + return AuthEndpoint{ + app: app, + authenticator: authenticator, + session: session, + validate: validator, + } +} + +func (e AuthEndpoint) GetName() string { return "AuthEndpoint" } -func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/auth") +func (e AuthEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/auth") - apiGroup.GET("/providers", e.handleExternalLoginProvidersGet()) - apiGroup.GET("/session", e.handleSessionInfoGet()) + apiGroup.HandleFunc("GET /providers", e.handleExternalLoginProvidersGet()) + apiGroup.HandleFunc("GET /session", e.handleSessionInfoGet()) - apiGroup.GET("/login/:provider/init", e.handleOauthInitiateGet()) - apiGroup.GET("/login/:provider/callback", e.handleOauthCallbackGet()) + apiGroup.HandleFunc("GET /login/{provider}/init", e.handleOauthInitiateGet()) + apiGroup.HandleFunc("GET /login/{provider}/callback", e.handleOauthCallbackGet()) - apiGroup.POST("/login", e.handleLoginPost()) - apiGroup.POST("/logout", authenticator.LoggedIn(), e.handleLogoutPost()) + apiGroup.HandleFunc("POST /login", e.handleLoginPost()) + apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("POST /logout", e.handleLogoutPost()) } -// handleExternalLoginProvidersGet returns a gorm handler function. +// handleExternalLoginProvidersGet returns a gorm Handler function. // // @ID auth_handleExternalLoginProvidersGet // @Tags Authentication @@ -45,16 +71,15 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti // @Produce json // @Success 200 {object} []model.LoginProviderInfo // @Router /auth/providers [get] -func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - providers := e.app.Authenticator.GetExternalLoginProviders(ctx) +func (e AuthEndpoint) handleExternalLoginProvidersGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + providers := e.app.Authenticator.GetExternalLoginProviders(r.Context()) - c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers)) + respond.JSON(w, http.StatusOK, model.NewLoginProviderInfos(providers)) } } -// handleSessionInfoGet returns a gorm handler function. +// handleSessionInfoGet returns a gorm Handler function. // // @ID auth_handleSessionInfoGet // @Tags Authentication @@ -63,9 +88,9 @@ func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc { // @Success 200 {object} []model.SessionInfo // @Failure 500 {object} model.Error // @Router /auth/session [get] -func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc { - return func(c *gin.Context) { - currentSession := e.authenticator.Session.GetData(c) +func (e AuthEndpoint) handleSessionInfoGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + currentSession := e.session.GetData(r.Context()) var loggedInUid *string var firstname *string @@ -83,7 +108,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc { email = &e } - c.JSON(http.StatusOK, model.SessionInfo{ + respond.JSON(w, http.StatusOK, model.SessionInfo{ LoggedIn: currentSession.LoggedIn, IsAdmin: currentSession.IsAdmin, UserIdentifier: loggedInUid, @@ -94,7 +119,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc { } } -// handleOauthInitiateGet returns a gorm handler function. +// handleOauthInitiateGet returns a gorm Handler function. // // @ID auth_handleOauthInitiateGet // @Tags Authentication @@ -102,23 +127,24 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc { // @Produce json // @Success 200 {object} []model.LoginProviderInfo // @Router /auth/{provider}/init [get] -func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc { - return func(c *gin.Context) { - currentSession := e.authenticator.Session.GetData(c) +func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + currentSession := e.session.GetData(r.Context()) - autoRedirect, _ := strconv.ParseBool(c.DefaultQuery("redirect", "false")) - returnTo := c.Query("return") - provider := c.Param("provider") + autoRedirect, _ := strconv.ParseBool(request.QueryDefault(r, "redirect", "false")) + returnTo := request.Query(r, "return") + provider := request.Path(r, "provider") var returnUrl *url.URL var returnParams string redirectToReturn := func() { - c.Redirect(http.StatusFound, returnUrl.String()+"?"+returnParams) + respond.Redirect(w, r, http.StatusFound, returnUrl.String()+"?"+returnParams) } if returnTo != "" { if !e.isValidReturnUrl(returnTo) { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid return URL"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "invalid return URL"}) return } if u, err := url.Parse(returnTo); err == nil { @@ -137,34 +163,34 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc { returnParams = queryParams.Encode() redirectToReturn() } else { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "already logged in"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "already logged in"}) } return } - ctx := domain.SetUserInfoFromGin(c) - authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider) + authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(context.Background(), provider) if err != nil { if autoRedirect && e.isValidReturnUrl(returnTo) { redirectToReturn() } else { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) } return } - authSession := e.authenticator.Session.DefaultSessionData() + authSession := e.session.GetData(r.Context()) authSession.OauthState = state authSession.OauthNonce = nonce authSession.OauthProvider = provider authSession.OauthReturnTo = returnTo - e.authenticator.Session.SetData(c, authSession) + e.session.SetData(r.Context(), authSession) if autoRedirect { - c.Redirect(http.StatusFound, authCodeUrl) + respond.Redirect(w, r, http.StatusFound, authCodeUrl) } else { - c.JSON(http.StatusOK, model.OauthInitiationResponse{ + respond.JSON(w, http.StatusOK, model.OauthInitiationResponse{ RedirectUrl: authCodeUrl, State: state, }) @@ -172,7 +198,7 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc { } } -// handleOauthCallbackGet returns a gorm handler function. +// handleOauthCallbackGet returns a gorm Handler function. // // @ID auth_handleOauthCallbackGet // @Tags Authentication @@ -180,14 +206,14 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc { // @Produce json // @Success 200 {object} []model.LoginProviderInfo // @Router /auth/{provider}/callback [get] -func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc { - return func(c *gin.Context) { - currentSession := e.authenticator.Session.GetData(c) +func (e AuthEndpoint) handleOauthCallbackGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + currentSession := e.session.GetData(r.Context()) var returnUrl *url.URL var returnParams string redirectToReturn := func() { - c.Redirect(http.StatusFound, returnUrl.String()+"?"+returnParams) + respond.Redirect(w, r, http.StatusFound, returnUrl.String()+"?"+returnParams) } if currentSession.OauthReturnTo != "" { @@ -207,20 +233,20 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc { returnParams = queryParams.Encode() redirectToReturn() } else { - c.JSON(http.StatusBadRequest, model.Error{Message: "already logged in"}) + respond.JSON(w, http.StatusBadRequest, model.Error{Message: "already logged in"}) } return } - provider := c.Param("provider") - oauthCode := c.Query("code") - oauthState := c.Query("state") + provider := request.Path(r, "provider") + oauthCode := request.Query(r, "code") + oauthState := request.Query(r, "state") if provider != currentSession.OauthProvider { if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) { redirectToReturn() } else { - c.JSON(http.StatusBadRequest, + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid oauth provider"}) } return @@ -229,7 +255,8 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc { if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) { redirectToReturn() } else { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "invalid oauth state"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "invalid oauth state"}) } return } @@ -241,12 +268,13 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc { if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) { redirectToReturn() } else { - c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: err.Error()}) + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: err.Error()}) } return } - e.setAuthenticatedUser(c, user) + e.setAuthenticatedUser(r, user) if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) { queryParams := returnUrl.Query() @@ -254,13 +282,13 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc { returnParams = queryParams.Encode() redirectToReturn() } else { - c.JSON(http.StatusOK, user) + respond.JSON(w, http.StatusOK, user) } } } -func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) { - currentSession := e.authenticator.Session.GetData(c) +func (e AuthEndpoint) setAuthenticatedUser(r *http.Request, user *domain.User) { + currentSession := e.session.GetData(r.Context()) currentSession.LoggedIn = true currentSession.IsAdmin = user.IsAdmin @@ -274,10 +302,10 @@ func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) { currentSession.OauthProvider = "" currentSession.OauthReturnTo = "" - e.authenticator.Session.SetData(c, currentSession) + e.session.SetData(r.Context(), currentSession) } -// handleLoginPost returns a gorm handler function. +// handleLoginPost returns a gorm Handler function. // // @ID auth_handleLoginPost // @Tags Authentication @@ -285,11 +313,11 @@ func (e authEndpoint) setAuthenticatedUser(c *gin.Context, user *domain.User) { // @Produce json // @Success 200 {object} []model.LoginProviderInfo // @Router /auth/login [post] -func (e authEndpoint) handleLoginPost() gin.HandlerFunc { - return func(c *gin.Context) { - currentSession := e.authenticator.Session.GetData(c) +func (e AuthEndpoint) handleLoginPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + currentSession := e.session.GetData(r.Context()) if currentSession.LoggedIn { - c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "already logged in"}) + respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "already logged in"}) return } @@ -298,25 +326,29 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc { Password string `json:"password" binding:"required,min=4"` } - if err := c.ShouldBindJSON(&loginData); err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &loginData); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validate.Struct(loginData); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - ctx := domain.SetUserInfoFromGin(c) - user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password) + user, err := e.app.Authenticator.PlainLogin(context.Background(), loginData.Username, loginData.Password) if err != nil { - c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"}) + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: "login failed"}) return } - e.setAuthenticatedUser(c, user) + e.setAuthenticatedUser(r, user) - c.JSON(http.StatusOK, user) + respond.JSON(w, http.StatusOK, user) } } -// handleLogoutPost returns a gorm handler function. +// handleLogoutPost returns a gorm Handler function. // // @ID auth_handleLogoutGet // @Tags Authentication @@ -324,22 +356,22 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc { // @Produce json // @Success 200 {object} []model.LoginProviderInfo // @Router /auth/logout [get] -func (e authEndpoint) handleLogoutPost() gin.HandlerFunc { - return func(c *gin.Context) { - currentSession := e.authenticator.Session.GetData(c) +func (e AuthEndpoint) handleLogoutPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + currentSession := e.session.GetData(r.Context()) if !currentSession.LoggedIn { // Not logged in - c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "not logged in"}) + respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "not logged in"}) return } - e.authenticator.Session.DestroyData(c) - c.JSON(http.StatusOK, model.Error{Code: http.StatusOK, Message: "logout ok"}) + e.session.DestroyData(r.Context()) + respond.JSON(w, http.StatusOK, model.Error{Code: http.StatusOK, Message: "logout ok"}) } } // isValidReturnUrl checks if the given return URL matches the configured external URL of the application. -func (e authEndpoint) isValidReturnUrl(returnUrl string) bool { +func (e AuthEndpoint) isValidReturnUrl(returnUrl string) bool { if !strings.HasPrefix(returnUrl, e.app.Config.Web.ExternalUrl) { return false } diff --git a/internal/app/api/v0/handlers/endpoint_config.go b/internal/app/api/v0/handlers/endpoint_config.go index b4dc4f0..c81c2c9 100644 --- a/internal/app/api/v0/handlers/endpoint_config.go +++ b/internal/app/api/v0/handlers/endpoint_config.go @@ -1,7 +1,6 @@ package handlers import ( - "bytes" "embed" "fmt" "html/template" @@ -9,57 +8,61 @@ import ( "net/http" "net/url" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal" - "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v0/model" + "github.com/h44z/wg-portal/internal/config" ) //go:embed frontend_config.js.gotpl var frontendJs embed.FS -type configEndpoint struct { - app *app.App - authenticator *authenticationHandler +type ConfigEndpoint struct { + cfg *config.Config + authenticator Authenticator - tpl *template.Template + tpl *respond.TemplateRenderer } -func newConfigEndpoint(app *app.App, authenticator *authenticationHandler) configEndpoint { - ep := configEndpoint{ - app: app, +func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint { + ep := ConfigEndpoint{ + cfg: cfg, authenticator: authenticator, - tpl: template.Must(template.ParseFS(frontendJs, "frontend_config.js.gotpl")), + tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs, + "frontend_config.js.gotpl"))), } return ep } -func (e configEndpoint) GetName() string { +func (e ConfigEndpoint) GetName() string { return "ConfigEndpoint" } -func (e configEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) { - apiGroup := g.Group("/config") +func (e ConfigEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/config") - apiGroup.GET("/frontend.js", e.handleConfigJsGet()) - apiGroup.GET("/settings", e.authenticator.LoggedIn(), e.handleSettingsGet()) + apiGroup.HandleFunc("GET /frontend.js", e.handleConfigJsGet()) + apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("GET /settings", e.handleSettingsGet()) } -// handleConfigJsGet returns a gorm handler function. +// handleConfigJsGet returns a gorm Handler function. // // @ID config_handleConfigJsGet // @Tags Configuration // @Summary Get the dynamic frontend configuration javascript. // @Produce text/javascript // @Success 200 string javascript "The JavaScript contents" +// @Failure 500 // @Router /config/frontend.js [get] -func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc { - return func(c *gin.Context) { - backendUrl := fmt.Sprintf("%s/api/v0", e.app.Config.Web.ExternalUrl) - if c.GetHeader("x-wg-dev") != "" { - referer := c.Request.Header.Get("Referer") +func (e ConfigEndpoint) handleConfigJsGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + backendUrl := fmt.Sprintf("%s/api/v0", e.cfg.Web.ExternalUrl) + if request.Header(r, "x-wg-dev") != "" { + referer := request.Header(r, "Referer") host := "localhost" port := "5000" parsedReferer, err := url.Parse(referer) @@ -69,23 +72,17 @@ func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc { backendUrl = fmt.Sprintf("http://%s:%s/api/v0", host, port) // override if request comes from frontend started with npm run dev } - buf := &bytes.Buffer{} - err := e.tpl.ExecuteTemplate(buf, "frontend_config.js.gotpl", gin.H{ + + e.tpl.Render(w, http.StatusOK, "frontend_config.js.gotpl", "text/javascript", map[string]any{ "BackendUrl": backendUrl, "Version": internal.Version, - "SiteTitle": e.app.Config.Web.SiteTitle, - "SiteCompanyName": e.app.Config.Web.SiteCompanyName, + "SiteTitle": e.cfg.Web.SiteTitle, + "SiteCompanyName": e.cfg.Web.SiteCompanyName, }) - if err != nil { - c.Status(http.StatusInternalServerError) - return - } - - c.Data(http.StatusOK, "application/javascript", buf.Bytes()) } } -// handleSettingsGet returns a gorm handler function. +// handleSettingsGet returns a gorm Handler function. // // @ID config_handleSettingsGet // @Tags Configuration @@ -94,13 +91,13 @@ func (e configEndpoint) handleConfigJsGet() gin.HandlerFunc { // @Success 200 {object} model.Settings // @Success 200 string javascript "The JavaScript contents" // @Router /config/settings [get] -func (e configEndpoint) handleSettingsGet() gin.HandlerFunc { - return func(c *gin.Context) { - c.JSON(http.StatusOK, model.Settings{ - MailLinkOnly: e.app.Config.Mail.LinkOnly, - PersistentConfigSupported: e.app.Config.Advanced.ConfigStoragePath != "", - SelfProvisioning: e.app.Config.Core.SelfProvisioningAllowed, - ApiAdminOnly: e.app.Config.Advanced.ApiAdminOnly, +func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + respond.JSON(w, http.StatusOK, model.Settings{ + MailLinkOnly: e.cfg.Mail.LinkOnly, + PersistentConfigSupported: e.cfg.Advanced.ConfigStoragePath != "", + SelfProvisioning: e.cfg.Core.SelfProvisioningAllowed, + ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly, }) } } diff --git a/internal/app/api/v0/handlers/endpoint_interfaces.go b/internal/app/api/v0/handlers/endpoint_interfaces.go index 0c66d2f..fe89810 100644 --- a/internal/app/api/v0/handlers/endpoint_interfaces.go +++ b/internal/app/api/v0/handlers/endpoint_interfaces.go @@ -4,39 +4,51 @@ import ( "io" "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v0/model" "github.com/h44z/wg-portal/internal/domain" ) -type interfaceEndpoint struct { +type InterfaceEndpoint struct { app *app.App - authenticator *authenticationHandler + authenticator Authenticator + validator Validator } -func (e interfaceEndpoint) GetName() string { +func NewInterfaceEndpoint(app *app.App, authenticator Authenticator, validator Validator) InterfaceEndpoint { + return InterfaceEndpoint{ + app: app, + authenticator: authenticator, + validator: validator, + } +} + +func (e InterfaceEndpoint) GetName() string { return "InterfaceEndpoint" } -func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) { - apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin)) +func (e InterfaceEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/interface") + apiGroup.Use(e.authenticator.LoggedIn(ScopeAdmin)) - apiGroup.GET("/prepare", e.handlePrepareGet()) - apiGroup.GET("/all", e.handleAllGet()) - apiGroup.GET("/get/:id", e.handleSingleGet()) - apiGroup.PUT("/:id", e.handleUpdatePut()) - apiGroup.DELETE("/:id", e.handleDelete()) - apiGroup.POST("/new", e.handleCreatePost()) - apiGroup.GET("/config/:id", e.handleConfigGet()) - apiGroup.POST("/:id/save-config", e.handleSaveConfigPost()) - apiGroup.POST("/:id/apply-peer-defaults", e.handleApplyPeerDefaultsPost()) + apiGroup.HandleFunc("GET /prepare", e.handlePrepareGet()) + apiGroup.HandleFunc("GET /all", e.handleAllGet()) + apiGroup.HandleFunc("GET /get/{id}", e.handleSingleGet()) + apiGroup.HandleFunc("PUT /{id}", e.handleUpdatePut()) + apiGroup.HandleFunc("DELETE /{id}", e.handleDelete()) + apiGroup.HandleFunc("POST /new", e.handleCreatePost()) + apiGroup.HandleFunc("GET /config/{id}", e.handleConfigGet()) + apiGroup.HandleFunc("POST /{id}/save-config", e.handleSaveConfigPost()) + apiGroup.HandleFunc("POST /{id}/apply-peer-defaults", e.handleApplyPeerDefaultsPost()) - apiGroup.GET("/peers/:id", e.handlePeersGet()) + apiGroup.HandleFunc("GET /peers/{id}", e.handlePeersGet()) } -// handlePrepareGet returns a gorm handler function. +// handlePrepareGet returns a gorm Handler function. // // @ID interfaces_handlePrepareGet // @Tags Interface @@ -45,22 +57,21 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationH // @Success 200 {object} model.Interface // @Failure 500 {object} model.Error // @Router /interface/prepare [get] -func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - in, err := e.app.PrepareInterface(ctx) +func (e InterfaceEndpoint) handlePrepareGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + in, err := e.app.PrepareInterface(r.Context()) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, model.NewInterface(in, nil)) + respond.JSON(w, http.StatusOK, model.NewInterface(in, nil)) } } -// handleAllGet returns a gorm handler function. +// handleAllGet returns a gorm Handler function. // // @ID interfaces_handleAllGet // @Tags Interface @@ -69,22 +80,21 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc { // @Success 200 {object} []model.Interface // @Failure 500 {object} model.Error // @Router /interface/all [get] -func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx) +func (e InterfaceEndpoint) handleAllGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + interfaces, peers, err := e.app.GetAllInterfacesAndPeers(r.Context()) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, model.NewInterfaces(interfaces, peers)) + respond.JSON(w, http.StatusOK, model.NewInterfaces(interfaces, peers)) } } -// handleSingleGet returns a gorm handler function. +// handleSingleGet returns a gorm Handler function. // // @ID interfaces_handleSingleGet // @Tags Interface @@ -94,30 +104,29 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/get/{id} [get] -func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handleSingleGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{ + respond.JSON(w, http.StatusBadRequest, model.Error{ Code: http.StatusInternalServerError, Message: "missing id parameter", }) return } - iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id)) + iface, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, model.NewInterface(iface, peers)) + respond.JSON(w, http.StatusOK, model.NewInterface(iface, peers)) } } -// handleConfigGet returns a gorm handler function. +// handleConfigGet returns a gorm Handler function. // // @ID interfaces_handleConfigGet // @Tags Interface @@ -127,20 +136,19 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/config/{id} [get] -func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handleConfigGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{ + respond.JSON(w, http.StatusBadRequest, model.Error{ Code: http.StatusInternalServerError, Message: "missing id parameter", }) return } - config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id)) + config, err := e.app.GetInterfaceConfig(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return @@ -148,17 +156,17 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc { configString, err := io.ReadAll(config) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, string(configString)) + respond.JSON(w, http.StatusOK, string(configString)) } } -// handleUpdatePut returns a gorm handler function. +// handleUpdatePut returns a gorm Handler function. // // @ID interfaces_handleUpdatePut // @Tags Interface @@ -170,41 +178,44 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/{id} [put] -func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handleUpdatePut() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } var in model.Interface - err := c.BindJSON(&in) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &in); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(in); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } if id != in.Identifier { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) return } - updatedInterface, peers, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in)) + updatedInterface, peers, err := e.app.UpdateInterface(r.Context(), model.NewDomainInterface(&in)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, model.NewInterface(updatedInterface, peers)) + respond.JSON(w, http.StatusOK, model.NewInterface(updatedInterface, peers)) } } -// handleCreatePost returns a gorm handler function. +// handleCreatePost returns a gorm Handler function. // // @ID interfaces_handleCreatePost // @Tags Interface @@ -215,30 +226,31 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/new [post] -func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - +func (e InterfaceEndpoint) handleCreatePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var in model.Interface - err := c.BindJSON(&in) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &in); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(in); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - newInterface, err := e.app.CreateInterface(ctx, model.NewDomainInterface(&in)) + newInterface, err := e.app.CreateInterface(r.Context(), model.NewDomainInterface(&in)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, model.NewInterface(newInterface, nil)) + respond.JSON(w, http.StatusOK, model.NewInterface(newInterface, nil)) } } -// handlePeersGet returns a gorm handler function. +// handlePeersGet returns a gorm Handler function. // // @ID interfaces_handlePeersGet // @Tags Interface @@ -247,31 +259,29 @@ func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc { // @Success 200 {object} []model.Peer // @Failure 500 {object} model.Error // @Router /interface/peers/{id} [get] -func (e interfaceEndpoint) handlePeersGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handlePeersGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{ + respond.JSON(w, http.StatusBadRequest, model.Error{ Code: http.StatusInternalServerError, Message: "missing id parameter", }) return } - _, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id)) + _, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, model.NewPeers(peers)) + respond.JSON(w, http.StatusOK, model.NewPeers(peers)) } } -// handleDelete returns a gorm handler function. +// handleDelete returns a gorm Handler function. // // @ID interfaces_handleDelete // @Tags Interface @@ -282,29 +292,28 @@ func (e interfaceEndpoint) handlePeersGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/{id} [delete] -func (e interfaceEndpoint) handleDelete() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handleDelete() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - err := e.app.DeleteInterface(ctx, domain.InterfaceIdentifier(id)) + err := e.app.DeleteInterface(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } -// handleSaveConfigPost returns a gorm handler function. +// handleSaveConfigPost returns a gorm Handler function. // // @ID interfaces_handleSaveConfigPost // @Tags Interface @@ -315,29 +324,28 @@ func (e interfaceEndpoint) handleDelete() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/{id}/save-config [post] -func (e interfaceEndpoint) handleSaveConfigPost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handleSaveConfigPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - err := e.app.PersistInterfaceConfig(ctx, domain.InterfaceIdentifier(id)) + err := e.app.PersistInterfaceConfig(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } -// handleApplyPeerDefaultsPost returns a gorm handler function. +// handleApplyPeerDefaultsPost returns a gorm Handler function. // // @ID interfaces_handleApplyPeerDefaultsPost // @Tags Interface @@ -349,36 +357,38 @@ func (e interfaceEndpoint) handleSaveConfigPost() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /interface/{id}/apply-peer-defaults [post] -func (e interfaceEndpoint) handleApplyPeerDefaultsPost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e InterfaceEndpoint) handleApplyPeerDefaultsPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } var in model.Interface - err := c.BindJSON(&in) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &in); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(in); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } if id != in.Identifier { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) return } - err = e.app.ApplyPeerDefaults(ctx, model.NewDomainInterface(&in)) - if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + if err := e.app.ApplyPeerDefaults(r.Context(), model.NewDomainInterface(&in)); err != nil { + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } diff --git a/internal/app/api/v0/handlers/endpoint_peers.go b/internal/app/api/v0/handlers/endpoint_peers.go index 3cdc4c3..64114fa 100644 --- a/internal/app/api/v0/handlers/endpoint_peers.go +++ b/internal/app/api/v0/handlers/endpoint_peers.go @@ -4,39 +4,52 @@ import ( "io" "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v0/model" "github.com/h44z/wg-portal/internal/domain" ) -type peerEndpoint struct { +type PeerEndpoint struct { app *app.App - authenticator *authenticationHandler + authenticator Authenticator + validator Validator } -func (e peerEndpoint) GetName() string { +func NewPeerEndpoint(app *app.App, authenticator Authenticator, validator Validator) PeerEndpoint { + return PeerEndpoint{ + app: app, + authenticator: authenticator, + validator: validator, + } +} + +func (e PeerEndpoint) GetName() string { return "PeerEndpoint" } -func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) { - apiGroup := g.Group("/peer", e.authenticator.LoggedIn()) +func (e PeerEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/peer") + apiGroup.Use(e.authenticator.LoggedIn()) - apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet()) - apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet()) - apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(), e.handlePrepareGet()) - apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(), e.handleCreatePost()) - apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost()) - apiGroup.GET("/config-qr/:id", e.handleQrCodeGet()) - apiGroup.POST("/config-mail", e.handleEmailPost()) - apiGroup.GET("/config/:id", e.handleConfigGet()) - apiGroup.GET("/:id", e.handleSingleGet()) - apiGroup.PUT("/:id", e.handleUpdatePut()) - apiGroup.DELETE("/:id", e.handleDelete()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /iface/{iface}/all", e.handleAllGet()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /iface/{iface}/stats", e.handleStatsGet()) + apiGroup.HandleFunc("GET /iface/{iface}/prepare", e.handlePrepareGet()) + apiGroup.HandleFunc("POST /iface/{iface}/new", e.handleCreatePost()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /iface/{iface}/multiplenew", + e.handleCreateMultiplePost()) + apiGroup.HandleFunc("GET /config-qr/{id}", e.handleQrCodeGet()) + apiGroup.HandleFunc("POST /config-mail", e.handleEmailPost()) + apiGroup.HandleFunc("GET /config/{id}", e.handleConfigGet()) + apiGroup.HandleFunc("GET /{id}", e.handleSingleGet()) + apiGroup.HandleFunc("PUT /{id}", e.handleUpdatePut()) + apiGroup.HandleFunc("DELETE /{id}", e.handleDelete()) } -// handleAllGet returns a gorm handler function. +// handleAllGet returns a gorm Handler function. // // @ID peers_handleAllGet // @Tags Peer @@ -47,28 +60,27 @@ func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/iface/{iface}/all [get] -func (e peerEndpoint) handleAllGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - interfaceId := Base64UrlDecode(c.Param("iface")) +func (e PeerEndpoint) handleAllGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + interfaceId := Base64UrlDecode(request.Path(r, "iface")) if interfaceId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) return } - _, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(interfaceId)) + _, peers, err := e.app.GetInterfaceAndPeers(r.Context(), domain.InterfaceIdentifier(interfaceId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeers(peers)) + respond.JSON(w, http.StatusOK, model.NewPeers(peers)) } } -// handleSingleGet returns a gorm handler function. +// handleSingleGet returns a gorm Handler function. // // @ID peers_handleSingleGet // @Tags Peer @@ -79,28 +91,27 @@ func (e peerEndpoint) handleAllGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/{id} [get] -func (e peerEndpoint) handleSingleGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - peerId := Base64UrlDecode(c.Param("id")) +func (e PeerEndpoint) handleSingleGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + peerId := Base64UrlDecode(request.Path(r, "id")) if peerId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"}) return } - peer, err := e.app.GetPeer(ctx, domain.PeerIdentifier(peerId)) + peer, err := e.app.GetPeer(r.Context(), domain.PeerIdentifier(peerId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeer(peer)) + respond.JSON(w, http.StatusOK, model.NewPeer(peer)) } } -// handlePrepareGet returns a gorm handler function. +// handlePrepareGet returns a gorm Handler function. // // @ID peers_handlePrepareGet // @Tags Peer @@ -111,28 +122,27 @@ func (e peerEndpoint) handleSingleGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/iface/{iface}/prepare [get] -func (e peerEndpoint) handlePrepareGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - interfaceId := Base64UrlDecode(c.Param("iface")) +func (e PeerEndpoint) handlePrepareGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + interfaceId := Base64UrlDecode(request.Path(r, "iface")) if interfaceId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) return } - peer, err := e.app.PreparePeer(ctx, domain.InterfaceIdentifier(interfaceId)) + peer, err := e.app.PreparePeer(r.Context(), domain.InterfaceIdentifier(interfaceId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeer(peer)) + respond.JSON(w, http.StatusOK, model.NewPeer(peer)) } } -// handleCreatePost returns a gorm handler function. +// handleCreatePost returns a gorm Handler function. // // @ID peers_handleCreatePost // @Tags Peer @@ -144,40 +154,43 @@ func (e peerEndpoint) handlePrepareGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/iface/{iface}/new [post] -func (e peerEndpoint) handleCreatePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - interfaceId := Base64UrlDecode(c.Param("iface")) +func (e PeerEndpoint) handleCreatePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + interfaceId := Base64UrlDecode(request.Path(r, "iface")) if interfaceId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) return } var p model.Peer - err := c.BindJSON(&p) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &p); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(p); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } if p.InterfaceIdentifier != interfaceId { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) return } - newPeer, err := e.app.CreatePeer(ctx, model.NewDomainPeer(&p)) + newPeer, err := e.app.CreatePeer(r.Context(), model.NewDomainPeer(&p)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeer(newPeer)) + respond.JSON(w, http.StatusOK, model.NewPeer(newPeer)) } } -// handleCreateMultiplePost returns a gorm handler function. +// handleCreateMultiplePost returns a gorm Handler function. // // @ID peers_handleCreateMultiplePost // @Tags Peer @@ -189,36 +202,38 @@ func (e peerEndpoint) handleCreatePost() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/iface/{iface}/multiplenew [post] -func (e peerEndpoint) handleCreateMultiplePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - interfaceId := Base64UrlDecode(c.Param("iface")) +func (e PeerEndpoint) handleCreateMultiplePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + interfaceId := Base64UrlDecode(request.Path(r, "iface")) if interfaceId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) return } var req model.MultiPeerRequest - err := c.BindJSON(&req) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &req); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(req); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - newPeers, err := e.app.CreateMultiplePeers(ctx, domain.InterfaceIdentifier(interfaceId), + newPeers, err := e.app.CreateMultiplePeers(r.Context(), domain.InterfaceIdentifier(interfaceId), model.NewDomainPeerCreationRequest(&req)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeers(newPeers)) + respond.JSON(w, http.StatusOK, model.NewPeers(newPeers)) } } -// handleUpdatePut returns a gorm handler function. +// handleUpdatePut returns a gorm Handler function. // // @ID peers_handleUpdatePut // @Tags Peer @@ -230,40 +245,43 @@ func (e peerEndpoint) handleCreateMultiplePost() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/{id} [put] -func (e peerEndpoint) handleUpdatePut() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - peerId := Base64UrlDecode(c.Param("id")) +func (e PeerEndpoint) handleUpdatePut() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + peerId := Base64UrlDecode(request.Path(r, "id")) if peerId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing id parameter"}) return } var p model.Peer - err := c.BindJSON(&p) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &p); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(p); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } if p.Identifier != peerId { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "peer id mismatch"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "peer id mismatch"}) return } - updatedPeer, err := e.app.UpdatePeer(ctx, model.NewDomainPeer(&p)) + updatedPeer, err := e.app.UpdatePeer(r.Context(), model.NewDomainPeer(&p)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeer(updatedPeer)) + respond.JSON(w, http.StatusOK, model.NewPeer(updatedPeer)) } } -// handleDelete returns a gorm handler function. +// handleDelete returns a gorm Handler function. // // @ID peers_handleDelete // @Tags Peer @@ -274,28 +292,26 @@ func (e peerEndpoint) handleUpdatePut() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/{id} [delete] -func (e peerEndpoint) handleDelete() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e PeerEndpoint) handleDelete() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } - err := e.app.DeletePeer(ctx, domain.PeerIdentifier(id)) + err := e.app.DeletePeer(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } -// handleConfigGet returns a gorm handler function. +// handleConfigGet returns a gorm Handler function. // // @ID peers_handleConfigGet // @Tags Peer @@ -306,21 +322,19 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/config/{id} [get] -func (e peerEndpoint) handleConfigGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e PeerEndpoint) handleConfigGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{ + respond.JSON(w, http.StatusBadRequest, model.Error{ Code: http.StatusInternalServerError, Message: "missing id parameter", }) return } - config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id)) + config, err := e.app.GetPeerConfig(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return @@ -328,17 +342,17 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc { configString, err := io.ReadAll(config) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.JSON(http.StatusOK, string(configString)) + respond.JSON(w, http.StatusOK, string(configString)) } } -// handleQrCodeGet returns a gorm handler function. +// handleQrCodeGet returns a gorm Handler function. // // @ID peers_handleQrCodeGet // @Tags Peer @@ -350,20 +364,19 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/config-qr/{id} [get] -func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - id := Base64UrlDecode(c.Param("id")) +func (e PeerEndpoint) handleQrCodeGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{ + respond.JSON(w, http.StatusBadRequest, model.Error{ Code: http.StatusInternalServerError, Message: "missing id parameter", }) return } - config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id)) + config, err := e.app.GetPeerConfigQrCode(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return @@ -371,17 +384,17 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc { configData, err := io.ReadAll(config) if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) return } - c.Data(http.StatusOK, "image/png", configData) + respond.Data(w, http.StatusOK, "image/png", configData) } } -// handleEmailPost returns a gorm handler function. +// handleEmailPost returns a gorm Handler function. // // @ID peers_handleEmailPost // @Tags Peer @@ -392,38 +405,39 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/config-mail [post] -func (e peerEndpoint) handleEmailPost() gin.HandlerFunc { - return func(c *gin.Context) { +func (e PeerEndpoint) handleEmailPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var req model.PeerMailRequest - err := c.BindJSON(&req) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &req); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(req); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } if len(req.Identifiers) == 0 { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing peer identifiers"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing peer identifiers"}) return } - ctx := domain.SetUserInfoFromGin(c) - peerIds := make([]domain.PeerIdentifier, len(req.Identifiers)) for i := range req.Identifiers { peerIds[i] = domain.PeerIdentifier(req.Identifiers[i]) } - err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...) - if err != nil { - c.JSON(http.StatusInternalServerError, + if err := e.app.SendPeerEmail(r.Context(), req.LinkOnly, peerIds...); err != nil { + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } -// handleStatsGet returns a gorm handler function. +// handleStatsGet returns a gorm Handler function. // // @ID peers_handleStatsGet // @Tags Peer @@ -434,23 +448,22 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /peer/iface/{iface}/stats [get] -func (e peerEndpoint) handleStatsGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - interfaceId := Base64UrlDecode(c.Param("iface")) +func (e PeerEndpoint) handleStatsGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + interfaceId := Base64UrlDecode(request.Path(r, "iface")) if interfaceId == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "missing iface parameter"}) return } - stats, err := e.app.GetPeerStats(ctx, domain.InterfaceIdentifier(interfaceId)) + stats, err := e.app.GetPeerStats(r.Context(), domain.InterfaceIdentifier(interfaceId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats)) + respond.JSON(w, http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats)) } } diff --git a/internal/app/api/v0/handlers/endpoint_testing.go b/internal/app/api/v0/handlers/endpoint_testing.go index bc090c9..9f6909c 100644 --- a/internal/app/api/v0/handlers/endpoint_testing.go +++ b/internal/app/api/v0/handlers/endpoint_testing.go @@ -5,20 +5,29 @@ import ( "os" "time" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v0/model" ) -type testEndpoint struct{} +type TestEndpoint struct { + authenticator Authenticator +} -func (e testEndpoint) GetName() string { +func NewTestEndpoint(authenticator Authenticator) TestEndpoint { + return TestEndpoint{ + authenticator: authenticator, + } +} + +func (e TestEndpoint) GetName() string { return "TestEndpoint" } -func (e testEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) { - g.GET("/now", e.handleCurrentTimeGet()) - g.GET("/hostname", e.handleHostnameGet()) +func (e TestEndpoint) RegisterRoutes(g *routegroup.Bundle) { + g.HandleFunc("GET /now", e.handleCurrentTimeGet()) + g.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /hostname", e.handleHostnameGet()) } // handleCurrentTimeGet represents the GET endpoint that responds the current time @@ -31,15 +40,15 @@ func (e testEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle // @Success 200 {object} string // @Failure 500 {object} model.Error // @Router /now [get] -func (e testEndpoint) handleCurrentTimeGet() gin.HandlerFunc { - return func(c *gin.Context) { +func (e TestEndpoint) handleCurrentTimeGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { if time.Now().Second() == 0 { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: "invalid time", }) } - c.JSON(http.StatusOK, time.Now().String()) + respond.JSON(w, http.StatusOK, time.Now().String()) } } @@ -53,15 +62,15 @@ func (e testEndpoint) handleCurrentTimeGet() gin.HandlerFunc { // @Success 200 {object} string // @Failure 500 {object} model.Error // @Router /hostname [get] -func (e testEndpoint) handleHostnameGet() gin.HandlerFunc { - return func(c *gin.Context) { +func (e TestEndpoint) handleHostnameGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { hostname, err := os.Hostname() if err != nil { - c.JSON(http.StatusInternalServerError, model.Error{ + respond.JSON(w, http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), }) } - c.JSON(http.StatusOK, hostname) + respond.JSON(w, http.StatusOK, hostname) } } diff --git a/internal/app/api/v0/handlers/endpoint_users.go b/internal/app/api/v0/handlers/endpoint_users.go index d3f13f5..7d65a9d 100644 --- a/internal/app/api/v0/handlers/endpoint_users.go +++ b/internal/app/api/v0/handlers/endpoint_users.go @@ -3,38 +3,50 @@ package handlers import ( "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v0/model" "github.com/h44z/wg-portal/internal/domain" ) -type userEndpoint struct { +type UserEndpoint struct { app *app.App - authenticator *authenticationHandler + authenticator Authenticator + validator Validator } -func (e userEndpoint) GetName() string { +func NewUserEndpoint(app *app.App, authenticator Authenticator, validator Validator) UserEndpoint { + return UserEndpoint{ + app: app, + authenticator: authenticator, + validator: validator, + } +} + +func (e UserEndpoint) GetName() string { return "UserEndpoint" } -func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandler) { - apiGroup := g.Group("/user", e.authenticator.LoggedIn()) +func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/user") + apiGroup.Use(e.authenticator.LoggedIn()) - apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet()) - apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet()) - apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut()) - apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete()) - apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost()) - apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet()) - apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet()) - apiGroup.GET("/:id/interfaces", e.authenticator.UserIdMatch("id"), e.handleInterfacesGet()) - apiGroup.POST("/:id/api/enable", e.authenticator.UserIdMatch("id"), e.handleApiEnablePost()) - apiGroup.POST("/:id/api/disable", e.authenticator.UserIdMatch("id"), e.handleApiDisablePost()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /all", e.handleAllGet()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}", e.handleSingleGet()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("PUT /{id}", e.handleUpdatePut()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("DELETE /{id}", e.handleDelete()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/peers", e.handlePeersGet()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/stats", e.handleStatsGet()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/interfaces", e.handleInterfacesGet()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/enable", e.handleApiEnablePost()) + apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/disable", e.handleApiDisablePost()) } -// handleAllGet returns a gorm handler function. +// handleAllGet returns a gorm Handler function. // // @ID users_handleAllGet // @Tags Users @@ -43,22 +55,20 @@ func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, _ *authenticationHandle // @Success 200 {object} []model.User // @Failure 500 {object} model.Error // @Router /user/all [get] -func (e userEndpoint) handleAllGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - users, err := e.app.GetAllUsers(ctx) +func (e UserEndpoint) handleAllGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + users, err := e.app.GetAllUsers(r.Context()) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewUsers(users)) + respond.JSON(w, http.StatusOK, model.NewUsers(users)) } } -// handleSingleGet returns a gorm handler function. +// handleSingleGet returns a gorm Handler function. // // @ID users_handleSingleGet // @Tags Users @@ -68,28 +78,26 @@ func (e userEndpoint) handleAllGet() gin.HandlerFunc { // @Success 200 {object} model.User // @Failure 500 {object} model.Error // @Router /user/{id} [get] -func (e userEndpoint) handleSingleGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleSingleGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } - user, err := e.app.GetUser(ctx, domain.UserIdentifier(id)) + user, err := e.app.GetUser(r.Context(), domain.UserIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewUser(user, true)) + respond.JSON(w, http.StatusOK, model.NewUser(user, true)) } } -// handleUpdatePut returns a gorm handler function. +// handleUpdatePut returns a gorm Handler function. // // @ID users_handleUpdatePut // @Tags Users @@ -101,40 +109,42 @@ func (e userEndpoint) handleSingleGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id} [put] -func (e userEndpoint) handleUpdatePut() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleUpdatePut() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } var user model.User - err := c.BindJSON(&user) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &user); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(user); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } if id != user.Identifier { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "user id mismatch"}) + respond.JSON(w, http.StatusBadRequest, + model.Error{Code: http.StatusBadRequest, Message: "user id mismatch"}) return } - updateUser, err := e.app.UpdateUser(ctx, model.NewDomainUser(&user)) + updateUser, err := e.app.UpdateUser(r.Context(), model.NewDomainUser(&user)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewUser(updateUser, false)) + respond.JSON(w, http.StatusOK, model.NewUser(updateUser, false)) } } -// handleCreatePost returns a gorm handler function. +// handleCreatePost returns a gorm Handler function. // // @ID users_handleCreatePost // @Tags Users @@ -145,29 +155,30 @@ func (e userEndpoint) handleUpdatePut() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/new [post] -func (e userEndpoint) handleCreatePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - +func (e UserEndpoint) handleCreatePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var user model.User - err := c.BindJSON(&user) - if err != nil { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &user); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(user); err != nil { + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - newUser, err := e.app.CreateUser(ctx, model.NewDomainUser(&user)) + newUser, err := e.app.CreateUser(r.Context(), model.NewDomainUser(&user)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewUser(newUser, false)) + respond.JSON(w, http.StatusOK, model.NewUser(newUser, false)) } } -// handlePeersGet returns a gorm handler function. +// handlePeersGet returns a gorm Handler function. // // @ID users_handlePeersGet // @Tags Users @@ -178,29 +189,27 @@ func (e userEndpoint) handleCreatePost() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id}/peers [get] -func (e userEndpoint) handlePeersGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - userId := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handlePeersGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userId := Base64UrlDecode(request.Path(r, "id")) if userId == "" { - c.JSON(http.StatusBadRequest, + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"}) return } - peers, err := e.app.GetUserPeers(ctx, domain.UserIdentifier(userId)) + peers, err := e.app.GetUserPeers(r.Context(), domain.UserIdentifier(userId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeers(peers)) + respond.JSON(w, http.StatusOK, model.NewPeers(peers)) } } -// handleStatsGet returns a gorm handler function. +// handleStatsGet returns a gorm Handler function. // // @ID users_handleStatsGet // @Tags Users @@ -211,29 +220,27 @@ func (e userEndpoint) handlePeersGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id}/stats [get] -func (e userEndpoint) handleStatsGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - userId := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleStatsGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userId := Base64UrlDecode(request.Path(r, "id")) if userId == "" { - c.JSON(http.StatusBadRequest, + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"}) return } - stats, err := e.app.GetUserPeerStats(ctx, domain.UserIdentifier(userId)) + stats, err := e.app.GetUserPeerStats(r.Context(), domain.UserIdentifier(userId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats)) + respond.JSON(w, http.StatusOK, model.NewPeerStats(e.app.Config.Statistics.CollectPeerData, stats)) } } -// handleInterfacesGet returns a gorm handler function. +// handleInterfacesGet returns a gorm Handler function. // // @ID users_handleInterfacesGet // @Tags Users @@ -244,29 +251,27 @@ func (e userEndpoint) handleStatsGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id}/interfaces [get] -func (e userEndpoint) handleInterfacesGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - userId := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleInterfacesGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userId := Base64UrlDecode(request.Path(r, "id")) if userId == "" { - c.JSON(http.StatusBadRequest, + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"}) return } - peers, err := e.app.GetUserInterfaces(ctx, domain.UserIdentifier(userId)) + peers, err := e.app.GetUserInterfaces(r.Context(), domain.UserIdentifier(userId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewInterfaces(peers, nil)) + respond.JSON(w, http.StatusOK, model.NewInterfaces(peers, nil)) } } -// handleDelete returns a gorm handler function. +// handleDelete returns a gorm Handler function. // // @ID users_handleDelete // @Tags Users @@ -277,28 +282,26 @@ func (e userEndpoint) handleInterfacesGet() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id} [delete] -func (e userEndpoint) handleDelete() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleDelete() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := Base64UrlDecode(request.Path(r, "id")) if id == "" { - c.JSON(http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } - err := e.app.DeleteUser(ctx, domain.UserIdentifier(id)) + err := e.app.DeleteUser(r.Context(), domain.UserIdentifier(id)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } -// handleApiEnablePost returns a gorm handler function. +// handleApiEnablePost returns a gorm Handler function. // // @ID users_handleApiEnablePost // @Tags Users @@ -308,29 +311,27 @@ func (e userEndpoint) handleDelete() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id}/api/enable [post] -func (e userEndpoint) handleApiEnablePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - userId := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleApiEnablePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userId := Base64UrlDecode(request.Path(r, "id")) if userId == "" { - c.JSON(http.StatusBadRequest, + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"}) return } - user, err := e.app.ActivateApi(ctx, domain.UserIdentifier(userId)) + user, err := e.app.ActivateApi(r.Context(), domain.UserIdentifier(userId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewUser(user, true)) + respond.JSON(w, http.StatusOK, model.NewUser(user, true)) } } -// handleApiDisablePost returns a gorm handler function. +// handleApiDisablePost returns a gorm Handler function. // // @ID users_handleApiDisablePost // @Tags Users @@ -340,24 +341,22 @@ func (e userEndpoint) handleApiEnablePost() gin.HandlerFunc { // @Failure 400 {object} model.Error // @Failure 500 {object} model.Error // @Router /user/{id}/api/disable [post] -func (e userEndpoint) handleApiDisablePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - userId := Base64UrlDecode(c.Param("id")) +func (e UserEndpoint) handleApiDisablePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userId := Base64UrlDecode(request.Path(r, "id")) if userId == "" { - c.JSON(http.StatusBadRequest, + respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"}) return } - user, err := e.app.DeactivateApi(ctx, domain.UserIdentifier(userId)) + user, err := e.app.DeactivateApi(r.Context(), domain.UserIdentifier(userId)) if err != nil { - c.JSON(http.StatusInternalServerError, + respond.JSON(w, http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()}) return } - c.JSON(http.StatusOK, model.NewUser(user, false)) + respond.JSON(w, http.StatusOK, model.NewUser(user, false)) } } diff --git a/internal/app/api/v0/handlers/middleware_authentication.go b/internal/app/api/v0/handlers/middleware_authentication.go deleted file mode 100644 index 03e9c45..0000000 --- a/internal/app/api/v0/handlers/middleware_authentication.go +++ /dev/null @@ -1,111 +0,0 @@ -package handlers - -import ( - "net/http" - - "github.com/gin-gonic/gin" - - "github.com/h44z/wg-portal/internal/app" - "github.com/h44z/wg-portal/internal/app/api/v0/model" - "github.com/h44z/wg-portal/internal/domain" -) - -type Scope string - -const ( - ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes -) - -type authenticationHandler struct { - app *app.App - Session SessionStore -} - -// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well. -func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc { - return func(c *gin.Context) { - session := h.Session.GetData(c) - - if !session.LoggedIn { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "not logged in"}) - return - } - - if !UserHasScopes(session, scopes...) { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) - return - } - - // Check if logged-in user is still valid - if !h.app.Authenticator.IsUserValid(c.Request.Context(), domain.UserIdentifier(session.UserIdentifier)) { - h.Session.DestroyData(c) - c.Abort() - c.JSON(http.StatusUnauthorized, - model.Error{Code: http.StatusUnauthorized, Message: "session no longer available"}) - return - } - - c.Set(domain.CtxUserInfo, &domain.ContextUserInfo{ - Id: domain.UserIdentifier(session.UserIdentifier), - IsAdmin: session.IsAdmin, - }) - - // Continue down the chain to handler etc - c.Next() - } -} - -// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted. -func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc { - return func(c *gin.Context) { - session := h.Session.GetData(c) - - if session.IsAdmin { - c.Next() // Admins can do everything - return - } - - sessionUserId := domain.UserIdentifier(session.UserIdentifier) - requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter))) - - if sessionUserId != requestUserId { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) - return - } - - // Continue down the chain to handler etc - c.Next() - } -} - -func UserHasScopes(session SessionData, scopes ...Scope) bool { - // No scopes give, so the check should succeed - if len(scopes) == 0 { - return true - } - - // check if user has admin scope - if session.IsAdmin { - return true - } - - // Check if admin scope is required - for _, scope := range scopes { - if scope == ScopeAdmin { - return false - } - } - - // For all other scopes, a logged-in user is sufficient (for now) - if session.LoggedIn { - return true - } - - return false -} diff --git a/internal/app/api/v0/handlers/session.go b/internal/app/api/v0/handlers/session.go deleted file mode 100644 index 5a34d27..0000000 --- a/internal/app/api/v0/handlers/session.go +++ /dev/null @@ -1,92 +0,0 @@ -package handlers - -import ( - "encoding/gob" - "fmt" - - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" -) - -func init() { - gob.Register(SessionData{}) -} - -type SessionData struct { - LoggedIn bool - IsAdmin bool - - UserIdentifier string - - Firstname string - Lastname string - Email string - - OauthState string - OauthNonce string - OauthProvider string - OauthReturnTo string -} - -type SessionStore interface { - DefaultSessionData() SessionData - - GetData(c *gin.Context) SessionData - SetData(c *gin.Context, data SessionData) - - DestroyData(c *gin.Context) -} - -type GinSessionStore struct { - sessionIdentifier string -} - -func (g GinSessionStore) GetData(c *gin.Context) SessionData { - session := sessions.Default(c) - rawSessionData := session.Get(g.sessionIdentifier) - - var sessionData SessionData - if rawSessionData != nil { - sessionData = rawSessionData.(SessionData) - } else { - // init a new default session - sessionData = g.DefaultSessionData() - session.Set(g.sessionIdentifier, sessionData) - if err := session.Save(); err != nil { - panic(fmt.Sprintf("failed to store session: %v", err)) - } - } - - return sessionData -} - -func (g GinSessionStore) DefaultSessionData() SessionData { - return SessionData{ - LoggedIn: false, - IsAdmin: false, - UserIdentifier: "", - Firstname: "", - Lastname: "", - Email: "", - OauthState: "", - OauthNonce: "", - OauthProvider: "", - OauthReturnTo: "", - } -} - -func (g GinSessionStore) SetData(c *gin.Context, data SessionData) { - session := sessions.Default(c) - session.Set(g.sessionIdentifier, data) - if err := session.Save(); err != nil { - panic(fmt.Sprintf("failed to store session: %v", err)) - } -} - -func (g GinSessionStore) DestroyData(c *gin.Context) { - session := sessions.Default(c) - session.Delete(g.sessionIdentifier) - if err := session.Save(); err != nil { - panic(fmt.Sprintf("failed to store session: %v", err)) - } -} diff --git a/internal/app/api/v0/handlers/web_authentication.go b/internal/app/api/v0/handlers/web_authentication.go new file mode 100644 index 0000000..1b6a570 --- /dev/null +++ b/internal/app/api/v0/handlers/web_authentication.go @@ -0,0 +1,126 @@ +package handlers + +import ( + "context" + "net/http" + + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" + "github.com/h44z/wg-portal/internal/app/api/v0/model" + "github.com/h44z/wg-portal/internal/domain" +) + +type Scope string + +const ( + ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes +) + +type UserAuthenticator interface { + IsUserValid(ctx context.Context, id domain.UserIdentifier) bool +} + +type AuthenticationHandler struct { + authenticator UserAuthenticator + session Session +} + +func NewAuthenticationHandler(authenticator UserAuthenticator, session Session) AuthenticationHandler { + return AuthenticationHandler{ + authenticator: authenticator, + session: session, + } +} + +// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well. +func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session := h.session.GetData(r.Context()) + + if !session.LoggedIn { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: "not logged in"}) + return + } + + if !UserHasScopes(session, scopes...) { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusForbidden, + model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) + return + } + + // Check if logged-in user is still valid + if !h.authenticator.IsUserValid(r.Context(), domain.UserIdentifier(session.UserIdentifier)) { + h.session.DestroyData(r.Context()) + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: "session no longer available"}) + return + } + + ctx := context.WithValue(r.Context(), domain.CtxUserInfo, &domain.ContextUserInfo{ + Id: domain.UserIdentifier(session.UserIdentifier), + IsAdmin: session.IsAdmin, + }) + r = r.WithContext(ctx) + + // Continue down the chain to Handler etc + next.ServeHTTP(w, r) + }) + } +} + +// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted. +func (h AuthenticationHandler) UserIdMatch(idParameter string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session := h.session.GetData(r.Context()) + + if session.IsAdmin { + next.ServeHTTP(w, r) // Admins can do everything + return + } + + sessionUserId := domain.UserIdentifier(session.UserIdentifier) + requestUserId := domain.UserIdentifier(Base64UrlDecode(request.Path(r, idParameter))) + + if sessionUserId != requestUserId { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusForbidden, + model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) + return + } + + // Continue down the chain to Handler etc + next.ServeHTTP(w, r) + }) + } +} + +func UserHasScopes(session SessionData, scopes ...Scope) bool { + // No scopes give, so the check should succeed + if len(scopes) == 0 { + return true + } + + // check if user has admin scope + if session.IsAdmin { + return true + } + + // Check if admin scope is required + for _, scope := range scopes { + if scope == ScopeAdmin { + return false + } + } + + // For all other scopes, a logged-in user is sufficient (for now) + if session.LoggedIn { + return true + } + + return false +} diff --git a/internal/app/api/v0/handlers/web_session.go b/internal/app/api/v0/handlers/web_session.go new file mode 100644 index 0000000..7a1d793 --- /dev/null +++ b/internal/app/api/v0/handlers/web_session.go @@ -0,0 +1,88 @@ +package handlers + +import ( + "context" + "encoding/gob" + "net/http" + "strings" + "time" + + "github.com/alexedwards/scs/v2" + + "github.com/h44z/wg-portal/internal/config" +) + +func init() { + gob.Register(SessionData{}) +} + +type SessionData struct { + LoggedIn bool + IsAdmin bool + + UserIdentifier string + + Firstname string + Lastname string + Email string + + OauthState string + OauthNonce string + OauthProvider string + OauthReturnTo string + + CsrfToken string +} + +const sessionApiV0Key = "session_api_v0" + +type SessionWrapper struct { + *scs.SessionManager +} + +func NewSessionWrapper(cfg *config.Config) *SessionWrapper { + sessionManager := scs.New() + sessionManager.Lifetime = 24 * time.Hour + sessionManager.IdleTimeout = 1 * time.Hour + sessionManager.Cookie.Name = cfg.Web.SessionIdentifier + sessionManager.Cookie.Secure = strings.HasPrefix(cfg.Web.ExternalUrl, "https") + sessionManager.Cookie.HttpOnly = true + sessionManager.Cookie.SameSite = http.SameSiteLaxMode + sessionManager.Cookie.Path = "/" + sessionManager.Cookie.Persist = false + + wrappedSessionManager := &SessionWrapper{sessionManager} + + return wrappedSessionManager +} + +func (s *SessionWrapper) SetData(ctx context.Context, value SessionData) { + s.SessionManager.Put(ctx, sessionApiV0Key, value) +} + +func (s *SessionWrapper) GetData(ctx context.Context) SessionData { + sessionData, ok := s.SessionManager.Get(ctx, sessionApiV0Key).(SessionData) + if !ok { + return s.defaultSessionData() + } + return sessionData +} + +func (s *SessionWrapper) DestroyData(ctx context.Context) { + _ = s.SessionManager.Destroy(ctx) +} + +func (s *SessionWrapper) defaultSessionData() SessionData { + return SessionData{ + LoggedIn: false, + IsAdmin: false, + UserIdentifier: "", + Firstname: "", + Lastname: "", + Email: "", + OauthState: "", + OauthNonce: "", + OauthProvider: "", + OauthReturnTo: "", + } +} diff --git a/internal/app/api/v1/handlers/base.go b/internal/app/api/v1/handlers/base.go index 6772688..ccb4fb5 100644 --- a/internal/app/api/v1/handlers/base.go +++ b/internal/app/api/v1/handlers/base.go @@ -4,17 +4,19 @@ import ( "errors" "net/http" - "github.com/gin-contrib/cors" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" "github.com/h44z/wg-portal/internal/app/api/core" + "github.com/h44z/wg-portal/internal/app/api/core/middleware/cors" "github.com/h44z/wg-portal/internal/app/api/v1/models" "github.com/h44z/wg-portal/internal/domain" ) type Handler interface { + // GetName returns the name of the handler. GetName() string - RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) + // RegisterRoutes registers the routes for the handler. The session manager is passed to the handler. + RegisterRoutes(g *routegroup.Bundle) } // To compile the API documentation use the @@ -38,18 +40,14 @@ type Handler interface { // @BasePath /api/v1 // @query.collection.format multi -func NewRestApi(userSource UserSource, handlers ...Handler) core.ApiEndpointSetupFunc { - authenticator := &authenticationHandler{ - userSource: userSource, - } - +func NewRestApi(handlers ...Handler) core.ApiEndpointSetupFunc { return func() (core.ApiVersion, core.GroupSetupFn) { - return "v1", func(group *gin.RouterGroup) { - group.Use(cors.Default()) + return "v1", func(group *routegroup.Bundle) { + group.Use(cors.New().Handler) // Handler functions for _, h := range handlers { - h.RegisterRoutes(group, authenticator) + h.RegisterRoutes(group) } } } @@ -80,3 +78,12 @@ func ParseServiceError(err error) (int, models.Error) { Message: err.Error(), } } + +type Authenticator interface { + // LoggedIn checks if a user is logged in. If scopes are given, they are validated as well. + LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler +} + +type Validator interface { + Struct(s interface{}) error +} diff --git a/internal/app/api/v1/handlers/endpoint_interface.go b/internal/app/api/v1/handlers/endpoint_interface.go index e1b2cf1..a0c9d5f 100644 --- a/internal/app/api/v1/handlers/endpoint_interface.go +++ b/internal/app/api/v1/handlers/endpoint_interface.go @@ -4,8 +4,10 @@ import ( "context" "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v1/models" "github.com/h44z/wg-portal/internal/domain" ) @@ -19,12 +21,20 @@ type InterfaceEndpointInterfaceService interface { } type InterfaceEndpoint struct { - interfaces InterfaceEndpointInterfaceService + interfaces InterfaceEndpointInterfaceService + authenticator Authenticator + validator Validator } -func NewInterfaceEndpoint(interfaceService InterfaceEndpointInterfaceService) *InterfaceEndpoint { +func NewInterfaceEndpoint( + authenticator Authenticator, + validator Validator, + interfaceService InterfaceEndpointInterfaceService, +) *InterfaceEndpoint { return &InterfaceEndpoint{ - interfaces: interfaceService, + authenticator: authenticator, + validator: validator, + interfaces: interfaceService, } } @@ -32,15 +42,16 @@ func (e InterfaceEndpoint) GetName() string { return "InterfaceEndpoint" } -func (e InterfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/interface", authenticator.LoggedIn()) +func (e InterfaceEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/interface") + apiGroup.Use(e.authenticator.LoggedIn(ScopeAdmin)) - apiGroup.GET("/all", authenticator.LoggedIn(ScopeAdmin), e.handleAllGet()) - apiGroup.GET("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleByIdGet()) + apiGroup.HandleFunc("GET /all", e.handleAllGet()) + apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet()) - apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost()) - apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut()) - apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete()) + apiGroup.HandleFunc("POST /new", e.handleCreatePost()) + apiGroup.HandleFunc("PUT /by-id/{id}", e.handleUpdatePut()) + apiGroup.HandleFunc("DELETE /by-id/{id}", e.handleDelete()) } // handleAllGet returns a gorm Handler function. @@ -54,17 +65,16 @@ func (e InterfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut // @Failure 500 {object} models.Error // @Router /interface/all [get] // @Security BasicAuth -func (e InterfaceEndpoint) handleAllGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - allInterfaces, allPeersPerInterface, err := e.interfaces.GetAll(ctx) +func (e InterfaceEndpoint) handleAllGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + allInterfaces, allPeersPerInterface, err := e.interfaces.GetAll(r.Context()) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewInterfaces(allInterfaces, allPeersPerInterface)) + respond.JSON(w, http.StatusOK, models.NewInterfaces(allInterfaces, allPeersPerInterface)) } } @@ -82,23 +92,23 @@ func (e InterfaceEndpoint) handleAllGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /interface/by-id/{id} [get] // @Security BasicAuth -func (e InterfaceEndpoint) handleByIdGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e InterfaceEndpoint) handleByIdGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - iface, interfacePeers, err := e.interfaces.GetById(ctx, domain.InterfaceIdentifier(id)) + iface, interfacePeers, err := e.interfaces.GetById(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewInterface(iface, interfacePeers)) + respond.JSON(w, http.StatusOK, models.NewInterface(iface, interfacePeers)) } } @@ -117,24 +127,26 @@ func (e InterfaceEndpoint) handleByIdGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /interface/new [post] // @Security BasicAuth -func (e InterfaceEndpoint) handleCreatePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - +func (e InterfaceEndpoint) handleCreatePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var iface models.Interface - err := c.BindJSON(&iface) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &iface); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(iface); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - newInterface, err := e.interfaces.Create(ctx, models.NewDomainInterface(&iface)) + newInterface, err := e.interfaces.Create(r.Context(), models.NewDomainInterface(&iface)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewInterface(newInterface, nil)) + respond.JSON(w, http.StatusOK, models.NewInterface(newInterface, nil)) } } @@ -154,34 +166,43 @@ func (e InterfaceEndpoint) handleCreatePost() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /interface/by-id/{id} [put] // @Security BasicAuth -func (e InterfaceEndpoint) handleUpdatePut() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e InterfaceEndpoint) handleUpdatePut() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } var iface models.Interface - err := c.BindJSON(&iface) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &iface); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(iface); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + + if id != iface.Identifier { + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "interface id mismatch"}) return } updatedInterface, updatedInterfacePeers, err := e.interfaces.Update( - ctx, + r.Context(), domain.InterfaceIdentifier(id), models.NewDomainInterface(&iface), ) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewInterface(updatedInterface, updatedInterfacePeers)) + respond.JSON(w, http.StatusOK, models.NewInterface(updatedInterface, updatedInterfacePeers)) } } @@ -200,22 +221,22 @@ func (e InterfaceEndpoint) handleUpdatePut() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /interface/by-id/{id} [delete] // @Security BasicAuth -func (e InterfaceEndpoint) handleDelete() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e InterfaceEndpoint) handleDelete() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - err := e.interfaces.Delete(ctx, domain.InterfaceIdentifier(id)) + err := e.interfaces.Delete(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } diff --git a/internal/app/api/v1/handlers/endpoint_metrics.go b/internal/app/api/v1/handlers/endpoint_metrics.go index c9bb2b0..5629d99 100644 --- a/internal/app/api/v1/handlers/endpoint_metrics.go +++ b/internal/app/api/v1/handlers/endpoint_metrics.go @@ -4,8 +4,10 @@ import ( "context" "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v1/models" "github.com/h44z/wg-portal/internal/domain" ) @@ -17,12 +19,20 @@ type MetricsEndpointStatisticsService interface { } type MetricsEndpoint struct { - metrics MetricsEndpointStatisticsService + metrics MetricsEndpointStatisticsService + authenticator Authenticator + validator Validator } -func NewMetricsEndpoint(metrics MetricsEndpointStatisticsService) *MetricsEndpoint { +func NewMetricsEndpoint( + authenticator Authenticator, + validator Validator, + metrics MetricsEndpointStatisticsService, +) *MetricsEndpoint { return &MetricsEndpoint{ - metrics: metrics, + authenticator: authenticator, + validator: validator, + metrics: metrics, } } @@ -30,12 +40,14 @@ func (e MetricsEndpoint) GetName() string { return "MetricsEndpoint" } -func (e MetricsEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/metrics", authenticator.LoggedIn()) +func (e MetricsEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/metrics") + apiGroup.Use(e.authenticator.LoggedIn()) - apiGroup.GET("/by-interface/:id", authenticator.LoggedIn(ScopeAdmin), e.handleMetricsForInterfaceGet()) - apiGroup.GET("/by-user/:id", authenticator.LoggedIn(), e.handleMetricsForUserGet()) - apiGroup.GET("/by-peer/:id", authenticator.LoggedIn(), e.handleMetricsForPeerGet()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /by-interface/{id}", + e.handleMetricsForInterfaceGet()) + apiGroup.HandleFunc("GET /by-user/{id}", e.handleMetricsForUserGet()) + apiGroup.HandleFunc("GET /by-peer/{id}", e.handleMetricsForPeerGet()) } // handleMetricsForInterfaceGet returns a gorm Handler function. @@ -52,23 +64,23 @@ func (e MetricsEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authe // @Failure 500 {object} models.Error // @Router /metrics/by-interface/{id} [get] // @Security BasicAuth -func (e MetricsEndpoint) handleMetricsForInterfaceGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e MetricsEndpoint) handleMetricsForInterfaceGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - interfaceMetrics, err := e.metrics.GetForInterface(ctx, domain.InterfaceIdentifier(id)) + interfaceMetrics, err := e.metrics.GetForInterface(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewInterfaceMetrics(interfaceMetrics)) + respond.JSON(w, http.StatusOK, models.NewInterfaceMetrics(interfaceMetrics)) } } @@ -86,23 +98,23 @@ func (e MetricsEndpoint) handleMetricsForInterfaceGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /metrics/by-user/{id} [get] // @Security BasicAuth -func (e MetricsEndpoint) handleMetricsForUserGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e MetricsEndpoint) handleMetricsForUserGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - user, userMetrics, err := e.metrics.GetForUser(ctx, domain.UserIdentifier(id)) + user, userMetrics, err := e.metrics.GetForUser(r.Context(), domain.UserIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewUserMetrics(user, userMetrics)) + respond.JSON(w, http.StatusOK, models.NewUserMetrics(user, userMetrics)) } } @@ -120,22 +132,22 @@ func (e MetricsEndpoint) handleMetricsForUserGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /metrics/by-peer/{id} [get] // @Security BasicAuth -func (e MetricsEndpoint) handleMetricsForPeerGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e MetricsEndpoint) handleMetricsForPeerGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } - peerMetrics, err := e.metrics.GetForPeer(ctx, domain.PeerIdentifier(id)) + peerMetrics, err := e.metrics.GetForPeer(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeerMetrics(peerMetrics)) + respond.JSON(w, http.StatusOK, models.NewPeerMetrics(peerMetrics)) } } diff --git a/internal/app/api/v1/handlers/endpoint_peer.go b/internal/app/api/v1/handlers/endpoint_peer.go index 1f17cf2..ad79d13 100644 --- a/internal/app/api/v1/handlers/endpoint_peer.go +++ b/internal/app/api/v1/handlers/endpoint_peer.go @@ -4,8 +4,10 @@ import ( "context" "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v1/models" "github.com/h44z/wg-portal/internal/domain" ) @@ -20,12 +22,19 @@ type PeerService interface { } type PeerEndpoint struct { - peers PeerService + peers PeerService + authenticator Authenticator + validator Validator } -func NewPeerEndpoint(peerService PeerService) *PeerEndpoint { +func NewPeerEndpoint( + authenticator Authenticator, + validator Validator, peerService PeerService, +) *PeerEndpoint { return &PeerEndpoint{ - peers: peerService, + authenticator: authenticator, + validator: validator, + peers: peerService, } } @@ -33,16 +42,18 @@ func (e PeerEndpoint) GetName() string { return "PeerEndpoint" } -func (e PeerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/peer", authenticator.LoggedIn()) +func (e PeerEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/peer") + apiGroup.Use(e.authenticator.LoggedIn()) - apiGroup.GET("/by-interface/:id", authenticator.LoggedIn(ScopeAdmin), e.handleAllForInterfaceGet()) - apiGroup.GET("/by-user/:id", authenticator.LoggedIn(), e.handleAllForUserGet()) - apiGroup.GET("/by-id/:id", authenticator.LoggedIn(), e.handleByIdGet()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /by-interface/{id}", + e.handleAllForInterfaceGet()) + apiGroup.HandleFunc("GET /by-user/{id}", e.handleAllForUserGet()) + apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet()) - apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost()) - apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut()) - apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("PUT /by-id/{id}", e.handleUpdatePut()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("DELETE /by-id/{id}", e.handleDelete()) } // handleAllForInterfaceGet returns a gorm Handler function. @@ -57,23 +68,23 @@ func (e PeerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti // @Failure 500 {object} models.Error // @Router /peer/by-interface/{id} [get] // @Security BasicAuth -func (e PeerEndpoint) handleAllForInterfaceGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e PeerEndpoint) handleAllForInterfaceGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing interface id"}) return } - interfacePeers, err := e.peers.GetForInterface(ctx, domain.InterfaceIdentifier(id)) + interfacePeers, err := e.peers.GetForInterface(r.Context(), domain.InterfaceIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeers(interfacePeers)) + respond.JSON(w, http.StatusOK, models.NewPeers(interfacePeers)) } } @@ -90,23 +101,23 @@ func (e PeerEndpoint) handleAllForInterfaceGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /peer/by-user/{id} [get] // @Security BasicAuth -func (e PeerEndpoint) handleAllForUserGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e PeerEndpoint) handleAllForUserGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } - interfacePeers, err := e.peers.GetForUser(ctx, domain.UserIdentifier(id)) + interfacePeers, err := e.peers.GetForUser(r.Context(), domain.UserIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeers(interfacePeers)) + respond.JSON(w, http.StatusOK, models.NewPeers(interfacePeers)) } } @@ -125,23 +136,23 @@ func (e PeerEndpoint) handleAllForUserGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /peer/by-id/{id} [get] // @Security BasicAuth -func (e PeerEndpoint) handleByIdGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e PeerEndpoint) handleByIdGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } - peer, err := e.peers.GetById(ctx, domain.PeerIdentifier(id)) + peer, err := e.peers.GetById(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeer(peer)) + respond.JSON(w, http.StatusOK, models.NewPeer(peer)) } } @@ -161,24 +172,26 @@ func (e PeerEndpoint) handleByIdGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /peer/new [post] // @Security BasicAuth -func (e PeerEndpoint) handleCreatePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - +func (e PeerEndpoint) handleCreatePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var peer models.Peer - err := c.BindJSON(&peer) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &peer); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(peer); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - newPeer, err := e.peers.Create(ctx, models.NewDomainPeer(&peer)) + newPeer, err := e.peers.Create(r.Context(), models.NewDomainPeer(&peer)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeer(newPeer)) + respond.JSON(w, http.StatusOK, models.NewPeer(newPeer)) } } @@ -199,30 +212,33 @@ func (e PeerEndpoint) handleCreatePost() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /peer/by-id/{id} [put] // @Security BasicAuth -func (e PeerEndpoint) handleUpdatePut() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e PeerEndpoint) handleUpdatePut() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } var peer models.Peer - err := c.BindJSON(&peer) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &peer); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(peer); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - updatedPeer, err := e.peers.Update(ctx, domain.PeerIdentifier(id), models.NewDomainPeer(&peer)) + updatedPeer, err := e.peers.Update(r.Context(), domain.PeerIdentifier(id), models.NewDomainPeer(&peer)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeer(updatedPeer)) + respond.JSON(w, http.StatusOK, models.NewPeer(updatedPeer)) } } @@ -241,22 +257,22 @@ func (e PeerEndpoint) handleUpdatePut() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /peer/by-id/{id} [delete] // @Security BasicAuth -func (e PeerEndpoint) handleDelete() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e PeerEndpoint) handleDelete() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } - err := e.peers.Delete(ctx, domain.PeerIdentifier(id)) + err := e.peers.Delete(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } diff --git a/internal/app/api/v1/handlers/endpoint_provisioning.go b/internal/app/api/v1/handlers/endpoint_provisioning.go index 548521d..c283424 100644 --- a/internal/app/api/v1/handlers/endpoint_provisioning.go +++ b/internal/app/api/v1/handlers/endpoint_provisioning.go @@ -5,8 +5,10 @@ import ( "net/http" "strings" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v1/models" "github.com/h44z/wg-portal/internal/domain" ) @@ -23,12 +25,20 @@ type ProvisioningEndpointProvisioningService interface { } type ProvisioningEndpoint struct { - provisioning ProvisioningEndpointProvisioningService + provisioning ProvisioningEndpointProvisioningService + authenticator Authenticator + validator Validator } -func NewProvisioningEndpoint(provisioning ProvisioningEndpointProvisioningService) *ProvisioningEndpoint { +func NewProvisioningEndpoint( + authenticator Authenticator, + validator Validator, + provisioning ProvisioningEndpointProvisioningService, +) *ProvisioningEndpoint { return &ProvisioningEndpoint{ - provisioning: provisioning, + authenticator: authenticator, + validator: validator, + provisioning: provisioning, } } @@ -36,14 +46,15 @@ func (e ProvisioningEndpoint) GetName() string { return "ProvisioningEndpoint" } -func (e ProvisioningEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/provisioning", authenticator.LoggedIn()) +func (e ProvisioningEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/provisioning") + apiGroup.Use(e.authenticator.LoggedIn()) - apiGroup.GET("/data/user-info", authenticator.LoggedIn(), e.handleUserInfoGet()) - apiGroup.GET("/data/peer-config", authenticator.LoggedIn(), e.handlePeerConfigGet()) - apiGroup.GET("/data/peer-qr", authenticator.LoggedIn(), e.handlePeerQrGet()) + apiGroup.HandleFunc("GET /data/user-info", e.handleUserInfoGet()) + apiGroup.HandleFunc("GET /data/peer-config", e.handlePeerConfigGet()) + apiGroup.HandleFunc("GET /data/peer-qr", e.handlePeerQrGet()) - apiGroup.POST("/new-peer", authenticator.LoggedIn(), e.handleNewPeerPost()) + apiGroup.HandleFunc("POST /new-peer", e.handleNewPeerPost()) } // handleUserInfoGet returns a gorm Handler function. @@ -63,24 +74,23 @@ func (e ProvisioningEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator * // @Failure 500 {object} models.Error // @Router /provisioning/data/user-info [get] // @Security BasicAuth -func (e ProvisioningEndpoint) handleUserInfoGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := strings.TrimSpace(c.Query("UserId")) - email := strings.TrimSpace(c.Query("Email")) +func (e ProvisioningEndpoint) handleUserInfoGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := strings.TrimSpace(request.Query(r, "UserId")) + email := strings.TrimSpace(request.Query(r, "Email")) if id == "" && email == "" { - id = string(domain.GetUserInfo(ctx).Id) + id = string(domain.GetUserInfo(r.Context()).Id) } - user, peers, err := e.provisioning.GetUserAndPeers(ctx, domain.UserIdentifier(id), email) + user, peers, err := e.provisioning.GetUserAndPeers(r.Context(), domain.UserIdentifier(id), email) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewUserInformation(user, peers)) + respond.JSON(w, http.StatusOK, models.NewUserInformation(user, peers)) } } @@ -101,23 +111,23 @@ func (e ProvisioningEndpoint) handleUserInfoGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /provisioning/data/peer-config [get] // @Security BasicAuth -func (e ProvisioningEndpoint) handlePeerConfigGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := strings.TrimSpace(c.Query("PeerId")) +func (e ProvisioningEndpoint) handlePeerConfigGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := strings.TrimSpace(request.Query(r, "PeerId")) if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } - peerConfig, err := e.provisioning.GetPeerConfig(ctx, domain.PeerIdentifier(id)) + peerConfig, err := e.provisioning.GetPeerConfig(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.Data(http.StatusOK, "text/plain", peerConfig) + respond.Data(w, http.StatusOK, "text/plain", peerConfig) } } @@ -138,23 +148,23 @@ func (e ProvisioningEndpoint) handlePeerConfigGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /provisioning/data/peer-qr [get] // @Security BasicAuth -func (e ProvisioningEndpoint) handlePeerQrGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := strings.TrimSpace(c.Query("PeerId")) +func (e ProvisioningEndpoint) handlePeerQrGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := strings.TrimSpace(request.Query(r, "PeerId")) if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing peer id"}) return } - peerConfigQrCode, err := e.provisioning.GetPeerQrPng(ctx, domain.PeerIdentifier(id)) + peerConfigQrCode, err := e.provisioning.GetPeerQrPng(r.Context(), domain.PeerIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.Data(http.StatusOK, "image/png", peerConfigQrCode) + respond.Data(w, http.StatusOK, "image/png", peerConfigQrCode) } } @@ -174,23 +184,25 @@ func (e ProvisioningEndpoint) handlePeerQrGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /provisioning/new-peer [post] // @Security BasicAuth -func (e ProvisioningEndpoint) handleNewPeerPost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - +func (e ProvisioningEndpoint) handleNewPeerPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var req models.ProvisioningRequest - err := c.BindJSON(&req) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &req); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(req); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - peer, err := e.provisioning.NewPeer(ctx, req) + peer, err := e.provisioning.NewPeer(r.Context(), req) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewPeer(peer)) + respond.JSON(w, http.StatusOK, models.NewPeer(peer)) } } diff --git a/internal/app/api/v1/handlers/endpoint_user.go b/internal/app/api/v1/handlers/endpoint_user.go index 9279487..28902eb 100644 --- a/internal/app/api/v1/handlers/endpoint_user.go +++ b/internal/app/api/v1/handlers/endpoint_user.go @@ -4,8 +4,10 @@ import ( "context" "net/http" - "github.com/gin-gonic/gin" + "github.com/go-pkgz/routegroup" + "github.com/h44z/wg-portal/internal/app/api/core/request" + "github.com/h44z/wg-portal/internal/app/api/core/respond" "github.com/h44z/wg-portal/internal/app/api/v1/models" "github.com/h44z/wg-portal/internal/domain" ) @@ -19,12 +21,20 @@ type UserService interface { } type UserEndpoint struct { - users UserService + users UserService + authenticator Authenticator + validator Validator } -func NewUserEndpoint(userService UserService) *UserEndpoint { +func NewUserEndpoint( + authenticator Authenticator, + validator Validator, + userService UserService, +) *UserEndpoint { return &UserEndpoint{ - users: userService, + authenticator: authenticator, + validator: validator, + users: userService, } } @@ -32,14 +42,15 @@ func (e UserEndpoint) GetName() string { return "UserEndpoint" } -func (e UserEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) { - apiGroup := g.Group("/user", authenticator.LoggedIn()) +func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) { + apiGroup := g.Mount("/user") + apiGroup.Use(e.authenticator.LoggedIn()) - apiGroup.GET("/all", authenticator.LoggedIn(ScopeAdmin), e.handleAllGet()) - apiGroup.GET("/by-id/:id", authenticator.LoggedIn(), e.handleByIdGet()) - apiGroup.POST("/new", authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost()) - apiGroup.PUT("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleUpdatePut()) - apiGroup.DELETE("/by-id/:id", authenticator.LoggedIn(ScopeAdmin), e.handleDelete()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("GET /all", e.handleAllGet()) + apiGroup.HandleFunc("GET /by-id/{id}", e.handleByIdGet()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("POST /new", e.handleCreatePost()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("PUT /by-id/{id}", e.handleUpdatePut()) + apiGroup.With(e.authenticator.LoggedIn(ScopeAdmin)).HandleFunc("DELETE /by-id/{id}", e.handleDelete()) } // handleAllGet returns a gorm Handler function. @@ -53,17 +64,16 @@ func (e UserEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti // @Failure 500 {object} models.Error // @Router /user/all [get] // @Security BasicAuth -func (e UserEndpoint) handleAllGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - users, err := e.users.GetAll(ctx) +func (e UserEndpoint) handleAllGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + users, err := e.users.GetAll(r.Context()) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewUsers(users)) + respond.JSON(w, http.StatusOK, models.NewUsers(users)) } } @@ -82,23 +92,23 @@ func (e UserEndpoint) handleAllGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /user/by-id/{id} [get] // @Security BasicAuth -func (e UserEndpoint) handleByIdGet() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e UserEndpoint) handleByIdGet() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } - user, err := e.users.GetById(ctx, domain.UserIdentifier(id)) + user, err := e.users.GetById(r.Context(), domain.UserIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewUser(user, true)) + respond.JSON(w, http.StatusOK, models.NewUser(user, true)) } } @@ -118,24 +128,26 @@ func (e UserEndpoint) handleByIdGet() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /user/new [post] // @Security BasicAuth -func (e UserEndpoint) handleCreatePost() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - +func (e UserEndpoint) handleCreatePost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var user models.User - err := c.BindJSON(&user) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &user); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(user); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - newUser, err := e.users.Create(ctx, models.NewDomainUser(&user)) + newUser, err := e.users.Create(r.Context(), models.NewDomainUser(&user)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewUser(newUser, true)) + respond.JSON(w, http.StatusOK, models.NewUser(newUser, true)) } } @@ -156,30 +168,33 @@ func (e UserEndpoint) handleCreatePost() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /user/by-id/{id} [put] // @Security BasicAuth -func (e UserEndpoint) handleUpdatePut() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e UserEndpoint) handleUpdatePut() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } var user models.User - err := c.BindJSON(&user) - if err != nil { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + if err := request.BodyJson(r, &user); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + if err := e.validator.Struct(user); err != nil { + respond.JSON(w, http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: err.Error()}) return } - updateUser, err := e.users.Update(ctx, domain.UserIdentifier(id), models.NewDomainUser(&user)) + updateUser, err := e.users.Update(r.Context(), domain.UserIdentifier(id), models.NewDomainUser(&user)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.JSON(http.StatusOK, models.NewUser(updateUser, true)) + respond.JSON(w, http.StatusOK, models.NewUser(updateUser, true)) } } @@ -198,22 +213,22 @@ func (e UserEndpoint) handleUpdatePut() gin.HandlerFunc { // @Failure 500 {object} models.Error // @Router /user/by-id/{id} [delete] // @Security BasicAuth -func (e UserEndpoint) handleDelete() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := domain.SetUserInfoFromGin(c) - - id := c.Param("id") +func (e UserEndpoint) handleDelete() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := request.Path(r, "id") if id == "" { - c.JSON(http.StatusBadRequest, models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) + respond.JSON(w, http.StatusBadRequest, + models.Error{Code: http.StatusBadRequest, Message: "missing user id"}) return } - err := e.users.Delete(ctx, domain.UserIdentifier(id)) + err := e.users.Delete(r.Context(), domain.UserIdentifier(id)) if err != nil { - c.JSON(ParseServiceError(err)) + status, model := ParseServiceError(err) + respond.JSON(w, status, model) return } - c.Status(http.StatusNoContent) + respond.Status(w, http.StatusNoContent) } } diff --git a/internal/app/api/v1/handlers/middleware_authentication.go b/internal/app/api/v1/handlers/middleware_authentication.go deleted file mode 100644 index 6a91bc3..0000000 --- a/internal/app/api/v1/handlers/middleware_authentication.go +++ /dev/null @@ -1,93 +0,0 @@ -package handlers - -import ( - "context" - "net/http" - - "github.com/gin-gonic/gin" - - "github.com/h44z/wg-portal/internal/app/api/v0/model" - "github.com/h44z/wg-portal/internal/domain" -) - -type Scope string - -const ( - ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes -) - -type UserSource interface { - GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) -} - -type authenticationHandler struct { - userSource UserSource -} - -// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well. -func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc { - return func(c *gin.Context) { - username, password, ok := c.Request.BasicAuth() - if !ok || username == "" || password == "" { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "missing credentials"}) - return - } - - // check if user exists in DB - - ctx := domain.SetUserInfo(c.Request.Context(), domain.SystemAdminContextUserInfo()) - user, err := h.userSource.GetUser(ctx, domain.UserIdentifier(username)) - if err != nil { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"}) - return - } - - // validate API token - if err := user.CheckApiToken(password); err != nil { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"}) - return - } - - if !UserHasScopes(user, scopes...) { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) - return - } - - c.Set(domain.CtxUserInfo, &domain.ContextUserInfo{ - Id: user.Identifier, - IsAdmin: user.IsAdmin, - }) - - // Continue down the chain to Handler etc - c.Next() - } -} - -func UserHasScopes(user *domain.User, scopes ...Scope) bool { - // No scopes give, so the check should succeed - if len(scopes) == 0 { - return true - } - - // check if user has admin scope - if user.IsAdmin { - return true - } - - // Check if admin scope is required - for _, scope := range scopes { - if scope == ScopeAdmin { - return false - } - } - - return true -} diff --git a/internal/app/api/v1/handlers/web_authentication.go b/internal/app/api/v1/handlers/web_authentication.go new file mode 100644 index 0000000..952f978 --- /dev/null +++ b/internal/app/api/v1/handlers/web_authentication.go @@ -0,0 +1,101 @@ +package handlers + +import ( + "context" + "net/http" + + "github.com/h44z/wg-portal/internal/app/api/core/respond" + "github.com/h44z/wg-portal/internal/app/api/v0/model" + "github.com/h44z/wg-portal/internal/domain" +) + +type Scope string + +const ( + ScopeAdmin Scope = "ADMIN" // Admin scope contains all other scopes +) + +type UserAuthenticator interface { + GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) +} + +type AuthenticationHandler struct { + authenticator UserAuthenticator +} + +func NewAuthenticationHandler(authenticator UserAuthenticator) AuthenticationHandler { + return AuthenticationHandler{ + authenticator: authenticator, + } +} + +// LoggedIn checks if a user is logged in. If scopes are given, they are validated as well. +func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username == "" || password == "" { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: "missing credentials"}) + return + } + + // check if user exists in DB + + ctx := domain.SetUserInfo(r.Context(), domain.SystemAdminContextUserInfo()) + user, err := h.authenticator.GetUser(ctx, domain.UserIdentifier(username)) + if err != nil { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"}) + return + } + + // validate API token + if err := user.CheckApiToken(password); err != nil { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusUnauthorized, + model.Error{Code: http.StatusUnauthorized, Message: "invalid credentials"}) + return + } + + if !UserHasScopes(user, scopes...) { + // Abort the request with the appropriate error code + respond.JSON(w, http.StatusForbidden, + model.Error{Code: http.StatusForbidden, Message: "not enough permissions"}) + return + } + + ctx = context.WithValue(r.Context(), domain.CtxUserInfo, &domain.ContextUserInfo{ + Id: user.Identifier, + IsAdmin: user.IsAdmin, + }) + r = r.WithContext(ctx) + + // Continue down the chain to Handler etc + next.ServeHTTP(w, r) + }) + } +} + +func UserHasScopes(user *domain.User, scopes ...Scope) bool { + // No scopes give, so the check should succeed + if len(scopes) == 0 { + return true + } + + // check if user has admin scope + if user.IsAdmin { + return true + } + + // Check if admin scope is required + for _, scope := range scopes { + if scope == ScopeAdmin { + return false + } + } + + return true +} diff --git a/internal/config/web.go b/internal/config/web.go index 26bfb6f..5327d6c 100644 --- a/internal/config/web.go +++ b/internal/config/web.go @@ -3,6 +3,8 @@ package config type WebConfig struct { // RequestLogging enables logging of all HTTP requests. RequestLogging bool `yaml:"request_logging"` + // ExposeHostInfo sets whether the host information should be exposed in a response header. + ExposeHostInfo bool `yaml:"expose_host_info"` // ExternalUrl is the URL where a client can access WireGuard Portal. // This is used for the callback URL of the OAuth providers. ExternalUrl string `yaml:"external_url"` diff --git a/internal/domain/context.go b/internal/domain/context.go index f36706c..1c734a3 100644 --- a/internal/domain/context.go +++ b/internal/domain/context.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "log/slog" - - "github.com/gin-gonic/gin" ) const CtxUserInfo = "userInfo" @@ -47,21 +45,6 @@ func SystemAdminContextUserInfo() *ContextUserInfo { } } -// SetUserInfoFromGin sets the user info from the gin context to the request context. -func SetUserInfoFromGin(c *gin.Context) context.Context { - ginUserInfo, exists := c.Get(CtxUserInfo) - - info := DefaultContextUserInfo() - if exists { - if ginInfo, ok := ginUserInfo.(*ContextUserInfo); ok { - info = ginInfo - } - } - - ctx := SetUserInfo(c.Request.Context(), info) - return ctx -} - // SetUserInfo sets the user info in the context. func SetUserInfo(ctx context.Context, info *ContextUserInfo) context.Context { ctx = context.WithValue(ctx, CtxUserInfo, info)