From 25a37ba90d0b38c7cad527dce3ef9fe07a486644 Mon Sep 17 00:00:00 2001 From: Gustavo Maronato Date: Sat, 9 Mar 2024 04:44:22 -0500 Subject: [PATCH] added backend OIDC support --- README.md | 2 +- cmd/dev/dev.go | 9 ++ cmd/serve/serve.go | 10 +++ cmd/shared/shared.go | 5 ++ go.mod | 20 +++-- go.sum | 55 +++++++++--- internal/config/config.go | 62 ++++++++++--- internal/config/config_test.go | 127 ++++++++++++++++++++------- internal/errs/errors.go | 6 ++ internal/server/api/handler.go | 5 ++ internal/server/oidc/handler.go | 119 +++++++++++++++++++++++++ internal/server/oidc/router.go | 16 ++++ internal/service/oidc/oidcservice.go | 100 +++++++++++++++++++++ internal/service/user/userservice.go | 19 ++-- internal/storage/models/user.go | 14 +++ internal/util/oidc/oidc.go | 44 ++++++++++ 16 files changed, 541 insertions(+), 72 deletions(-) create mode 100644 internal/server/oidc/handler.go create mode 100644 internal/server/oidc/router.go create mode 100644 internal/service/oidc/oidcservice.go create mode 100644 internal/util/oidc/oidc.go diff --git a/README.md b/README.md index d4d8545..7c51c0b 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ The docker image has health checks enabled by default. They are done via a subco You can use the same endpoint to perform health checks manually. > Important! If you use Docker or Compose, prefer defining the `GOSHORT_HOST` and `GOSHORT_PORT` configs as environment variables instead of command line flags. -> They are used by the healthcheck command to find the address to call, so it's important that they are sync'd between the main process and the healthcheck. +> They are used by the healthcheck command to find the address to call, so it's important that they are synced between the main process and the healthcheck. ## Building and running from source You'll need: diff --git a/cmd/dev/dev.go b/cmd/dev/dev.go index ea09572..f316643 100644 --- a/cmd/dev/dev.go +++ b/cmd/dev/dev.go @@ -16,8 +16,10 @@ import ( devuiserver "git.maronato.dev/maronato/goshort/internal/server/devui" healthcheckserver "git.maronato.dev/maronato/goshort/internal/server/healthcheck" authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth" + oidcserver "git.maronato.dev/maronato/goshort/internal/server/oidc" shortserver "git.maronato.dev/maronato/goshort/internal/server/short" staticssterver "git.maronato.dev/maronato/goshort/internal/server/static" + oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc" shortservice "git.maronato.dev/maronato/goshort/internal/service/short" shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog" tokenservice "git.maronato.dev/maronato/goshort/internal/service/token" @@ -102,6 +104,7 @@ func serveAPI(ctx context.Context, cfg *config.Config) error { userService := userservice.NewUserService(cfg, storage) tokenService := tokenservice.NewTokenService(storage) shortLogService := shortlogservice.NewShortLogService(storage) + oidcService := oidcservice.NewOIDCService(ctx, cfg) // Start short log worker stopWorker, _ := shortLogService.StartWorker(ctx) @@ -112,12 +115,14 @@ func serveAPI(ctx context.Context, cfg *config.Config) error { shortHandler := shortserver.NewShortHandler(shortService, shortLogService) healthcheckHandler := healthcheckserver.NewHealthcheckHandler(storage) docsHandler := staticssterver.NewStaticHandler(cfg, "/api/docs", docs.Assets()) + oidcHandler := oidcserver.NewOIDCHandler(oidcService, userService) // Create routers apiRouter := apiserver.NewAPIRouter(apiHandler) shortRouter := shortserver.NewShortRouter(shortHandler) healthcheckRouter := healthcheckserver.NewHealthcheckRouter(healthcheckHandler) docsRouter := staticssterver.NewStaticRouter(docsHandler) + oidcRouter := oidcserver.NewOIDCRouter(oidcHandler) // Create the root URL handler by chaining short and NotFound handlers chainedRouter := handlerutils.NewChainedHandler(shortRouter, http.NotFoundHandler()) @@ -137,6 +142,10 @@ func serveAPI(ctx context.Context, cfg *config.Config) error { r.Mount("/docs", docsRouter) r.Mount("/", apiRouter) }) + // Register OIDC routes if the service is enabled + if oidcService != nil { + srv.Mux.Mount("/oidc", oidcRouter) + } srv.Mux.Mount("/healthz", healthcheckRouter) srv.Mux.Mount("/", chainedRouter) diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index f5650dd..3c38f6e 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -14,8 +14,10 @@ import ( apiserver "git.maronato.dev/maronato/goshort/internal/server/api" healthcheckserver "git.maronato.dev/maronato/goshort/internal/server/healthcheck" authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth" + oidcserver "git.maronato.dev/maronato/goshort/internal/server/oidc" shortserver "git.maronato.dev/maronato/goshort/internal/server/short" staticssterver "git.maronato.dev/maronato/goshort/internal/server/static" + oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc" shortservice "git.maronato.dev/maronato/goshort/internal/service/short" shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog" tokenservice "git.maronato.dev/maronato/goshort/internal/service/token" @@ -75,6 +77,7 @@ func exec(ctx context.Context, cfg *config.Config) error { userService := userservice.NewUserService(cfg, storage) tokenService := tokenservice.NewTokenService(storage) shortLogService := shortlogservice.NewShortLogService(storage) + oidcService := oidcservice.NewOIDCService(ctx, cfg) // Start short log worker stopWorker, _ := shortLogService.StartWorker(ctx) @@ -86,6 +89,7 @@ func exec(ctx context.Context, cfg *config.Config) error { staticHandler := staticssterver.NewStaticHandler(cfg, "/", frontend.Assets()) healthcheckHandler := healthcheckserver.NewHealthcheckHandler(storage) docsHandler := staticssterver.NewStaticHandler(cfg, "/api/docs", docs.Assets()) + oidcHandler := oidcserver.NewOIDCHandler(oidcService, userService) // Create routers apiRouter := apiserver.NewAPIRouter(apiHandler) @@ -93,6 +97,7 @@ func exec(ctx context.Context, cfg *config.Config) error { staticRouter := staticssterver.NewStaticRouter(staticHandler) healthcheckRouter := healthcheckserver.NewHealthcheckRouter(healthcheckHandler) docsRouter := staticssterver.NewStaticRouter(docsHandler) + oidcRouter := oidcserver.NewOIDCRouter(oidcHandler) // Create the root URL handler by chaining short and static routers chainedRouter := handlerutils.NewChainedHandler(shortRouter, staticRouter) @@ -103,6 +108,11 @@ func exec(ctx context.Context, cfg *config.Config) error { r.Mount("/", apiRouter) r.Mount("/docs", docsRouter) }) + // Register OIDC routes if the service is enabled + if oidcService != nil { + srv.Mux.Mount("/oidc", oidcRouter) + } + srv.Mux.Mount("/healthz", healthcheckRouter) srv.Mux.Mount("/", chainedRouter) diff --git a/cmd/shared/shared.go b/cmd/shared/shared.go index 248d2d6..0cf05e8 100644 --- a/cmd/shared/shared.go +++ b/cmd/shared/shared.go @@ -66,6 +66,11 @@ func RegisterServerFlags(fs *flag.FlagSet, cfg *config.Config) { fs.StringVar(&cfg.DBURL, "db", config.DefaultDBURL, "database connection string or sqlite db path") fs.DurationVar(&cfg.SessionDuration, "session-duration", config.DefaultSessionDuration, "session duration") + fs.BoolVar(&cfg.DisableCredentialsLogin, "disable-credentials-login", config.DefaultDisableCredentialsLogin, "disable credentials login") + fs.StringVar(&cfg.OIDCIssuerURL, "oidc-issuer-url", "", "OIDC issuer URL") + fs.StringVar(&cfg.OIDCClientID, "oidc-client-id", "", "OIDC client ID") + fs.StringVar(&cfg.OIDCClientSecret, "oidc-client-secret", "", "OIDC client secret") + fs.StringVar(&cfg.OIDCRedirectURL, "oidc-redirect-url", "", "OIDC redirect URL") } // InitStorage initializes the storage depending on the config. diff --git a/go.mod b/go.mod index 3a77f1b..5bf071d 100644 --- a/go.mod +++ b/go.mod @@ -14,18 +14,26 @@ require ( github.com/uptrace/bun/extra/bundebug v1.1.16 github.com/uptrace/bun/extra/bunotel v1.1.16 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 + go.opentelemetry.io/contrib/propagators/autoprop v0.44.0 go.opentelemetry.io/otel v1.18.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0 - go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0 go.opentelemetry.io/otel/sdk v1.18.0 - golang.org/x/crypto v0.12.0 + go.opentelemetry.io/otel/trace v1.18.0 + golang.org/x/crypto v0.14.0 golang.org/x/sync v0.3.0 modernc.org/sqlite v1.25.0 ) +require ( + github.com/go-jose/go-jose/v3 v3.0.1 // indirect + golang.org/x/oauth2 v0.13.0 + google.golang.org/appengine v1.6.8 // indirect +) + require ( github.com/ajg/form v1.5.1 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect + github.com/coreos/go-oidc/v3 v3.9.0 github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.15.0 // indirect @@ -46,20 +54,18 @@ require ( github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.2 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - go.opentelemetry.io/contrib/propagators/autoprop v0.44.0 // indirect go.opentelemetry.io/contrib/propagators/aws v1.19.0 // indirect go.opentelemetry.io/contrib/propagators/b3 v1.19.0 // indirect go.opentelemetry.io/contrib/propagators/jaeger v1.19.0 // indirect go.opentelemetry.io/contrib/propagators/ot v1.19.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 // indirect go.opentelemetry.io/otel/metric v1.18.0 // indirect - go.opentelemetry.io/otel/trace v1.18.0 // indirect go.opentelemetry.io/proto/otlp v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.14.0 // indirect - golang.org/x/sys v0.12.0 // indirect - golang.org/x/text v0.12.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.12.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect diff --git a/go.sum b/go.sum index 8beda49..bcd6095 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/alexedwards/scs/v2 v2.5.1 h1:EhAz3Kb3OSQzD8T+Ub23fKsiuvE0GzbF5Lgn0uTw github.com/alexedwards/scs/v2 v2.5.1/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/coreos/go-oidc/v3 v3.9.0 h1:0J/ogVOd4y8P0f0xUh8l9t07xRP/d8tccvjHl2dcsSo= +github.com/coreos/go-oidc/v3 v3.9.0/go.mod h1:rTKz2PYwftcrtoCzV5g5kvfJoWcm0Mk8AF8y1iAQro4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -19,6 +21,8 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -27,8 +31,10 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE= github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP3NQ= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -82,6 +88,7 @@ github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9 github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 h1:KfYpVmrjI7JuToy5k8XV3nkapjWx48k4E4JOtVstzQI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0/go.mod h1:SeQhzAEccGVZVEy7aH87Nh0km+utSpo1pTv6eMMop48= go.opentelemetry.io/contrib/propagators/autoprop v0.44.0 h1:HgXKc1D1PrpsYKdO8fc5XuyVmYFxcY2mLVAHq8XBZMU= @@ -100,8 +107,6 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 h1:IAtl+7gua134xcV3Nie go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0/go.mod h1:w+pXobnBzh95MNIkeIuAKcHe/Uu/CX2PKIvBP6ipKRA= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0 h1:yE32ay7mJG2leczfREEhoW3VfSZIvHaB+gvVo1o8DQ8= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0/go.mod h1:G17FHPDLt74bCI7tJ4CMitEk4BXTYG4FW6XUpkPBXa4= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0 h1:hSWWvDjXHVLq9DkmB+77fl8v7+t+yYiS+eNkiplDK54= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0/go.mod h1:zG7KQql1WjZCaUJd+L/ReSYx4bjbYJxg5ws9ws+mYes= go.opentelemetry.io/otel/metric v1.18.0 h1:JwVzw94UYmbx3ej++CwLUQZxEODDj/pOuTCvzhtRrSQ= go.opentelemetry.io/otel/metric v1.18.0/go.mod h1:nNSpsVDjWGfb7chbRLUNW+PBNdcSTHD4Uu5pfFMOI0k= go.opentelemetry.io/otel/sdk v1.18.0 h1:e3bAB0wB3MljH38sHzpV/qWrOTCFrdZF2ct9F8rBkcY= @@ -114,23 +119,53 @@ go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= -golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY= +golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g= google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0= google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw= diff --git a/internal/config/config.go b/internal/config/config.go index 8e51d92..8d0c834 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -54,6 +54,8 @@ const ( DefaultVerbose = VerboseLevelInfo // DefaultQuiet is the default quiet mode. DefaultQuiet = false + // DefaultDisableCredentialsLogin is the default value for diable credential login. + DefaultDisableCredentialsLogin = false ) const ( @@ -77,7 +79,7 @@ const ( // Config defines the default configuration for the backend. type Config struct { // Prod is a flag that indicates if the server is running in production mode. - Prod bool + Prod bool `json:"prod"` // Debug is a flag that indicates if the server is running in debug mode. Debug bool // Host is the host to listen on. @@ -93,26 +95,37 @@ type Config struct { // SessionDuration is the duration of the session. SessionDuration time.Duration // DisableRegistration defines whether or not registration are disabled. - DisableRegistration bool + DisableRegistration bool `json:"disableRegistration"` // Verbose defines the verbosity level. - Verbose VerboseLevel + Verbose VerboseLevel `json:"verbose"` // Quiet defines whether or not the server should be quiet. Quiet bool + // DisableCredentialsLogin defines whether or not the server should disable credential login. + DisableCredentialsLogin bool `json:"disableCredentialLogin"` + // OIDCIssuerURL is the URL of the OIDC issuer. + OIDCIssuerURL string + // OIDCClientID is the client ID for OIDC. + OIDCClientID string + // OIDCClientSecret is the client secret for OIDC. + OIDCClientSecret string + // OIDCRedirectURL is the redirect URL for OIDC. + OIDCRedirectURL string } func NewConfig() *Config { return &Config{ - Prod: DefaultProd, - Debug: DefaultDebug, - Host: DefaultHost, - Port: DefaultPort, - UIPort: DefaultUIPort, - DBType: DefaultDBType, - DBURL: DefaultDBURL, - SessionDuration: DefaultSessionDuration, - DisableRegistration: DefaultDisableRegistration, - Verbose: DefaultVerbose, - Quiet: DefaultQuiet, + Prod: DefaultProd, + Debug: DefaultDebug, + Host: DefaultHost, + Port: DefaultPort, + UIPort: DefaultUIPort, + DBType: DefaultDBType, + DBURL: DefaultDBURL, + SessionDuration: DefaultSessionDuration, + DisableRegistration: DefaultDisableRegistration, + Verbose: DefaultVerbose, + Quiet: DefaultQuiet, + DisableCredentialsLogin: DefaultDisableCredentialsLogin, } } @@ -165,6 +178,27 @@ func Validate(cfg *Config) error { cfg.Verbose = VerboseLevelQuiet } + // Either OIDC or credential login has to be enabled. + if cfg.DisableCredentialsLogin && cfg.OIDCIssuerURL == "" { + return errs.Errorf("either OIDC or credential login has to be enabled", errs.ErrInvalidConfig) + } + + if cfg.OIDCIssuerURL != "" { + // If OIDC is enabled, all OIDC fields have to be set. + if cfg.OIDCClientID == "" || cfg.OIDCClientSecret == "" || cfg.OIDCRedirectURL == "" { + return errs.Errorf("all OIDC fields have to be set", errs.ErrInvalidConfig) + } + + // OIDC fields must be valid + if _, err := url.ParseRequestURI(cfg.OIDCIssuerURL); err != nil { + return errs.Errorf(fmt.Sprintf("invalid OIDC issuer URL: %s", err), errs.ErrInvalidConfig) + } + + if _, err := url.ParseRequestURI(cfg.OIDCRedirectURL); err != nil { + return errs.Errorf(fmt.Sprintf("invalid OIDC redirect URL: %s", err), errs.ErrInvalidConfig) + } + } + return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 01dd776..205fc10 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,17 +3,17 @@ package config_test import ( "testing" - . "git.maronato.dev/maronato/goshort/internal/config" + config "git.maronato.dev/maronato/goshort/internal/config" "github.com/stretchr/testify/assert" ) func TestGetDBTypeList(t *testing.T) { dbTypeList := []string{ - DBTypeMemory, - DBTypeSQLite, + config.DBTypeMemory, + config.DBTypeSQLite, } - results := GetDBTypeList() + results := config.GetDBTypeList() assert.ElementsMatch(t, dbTypeList, results, "DBTypeList should be equal to the list of available DB types") } @@ -21,17 +21,17 @@ func TestGetDBTypeList(t *testing.T) { func TestValidate(t *testing.T) { type expecttest struct { Debug *bool - Verbose *VerboseLevel + Verbose *config.VerboseLevel } debugTrue := true debugFalse := false - verboseQuiet := VerboseLevelQuiet - verboseInfo := VerboseLevelInfo + verboseQuiet := config.VerboseLevelQuiet + verboseInfo := config.VerboseLevelInfo type test struct { name string - config *Config + config *config.Config wantErr bool expected expecttest } @@ -39,18 +39,23 @@ func TestValidate(t *testing.T) { tests := []test{ { name: "Valid config", - config: &Config{ - Prod: false, - DBType: DBTypeSQLite, - DBURL: "goshort.db", - Port: "8080", - Host: "localhost", - UIPort: "3000", - Debug: false, - SessionDuration: 7, - DisableRegistration: false, - Verbose: 0, - Quiet: false, + config: &config.Config{ + Prod: false, + DBType: config.DBTypeSQLite, + DBURL: "goshort.db", + Port: "8080", + Host: "localhost", + UIPort: "3000", + Debug: false, + SessionDuration: 7, + DisableRegistration: false, + Verbose: 0, + Quiet: false, + DisableCredentialsLogin: false, + OIDCIssuerURL: "https://example.com", + OIDCClientID: "client-id", + OIDCClientSecret: "client-secret", + OIDCRedirectURL: "https://example.com/redirect", }, expected: expecttest{ Debug: &debugFalse, @@ -60,12 +65,12 @@ func TestValidate(t *testing.T) { }, { name: "Minimum config", - config: &Config{}, + config: &config.Config{}, wantErr: false, }, { name: "Invalid Host", - config: &Config{ + config: &config.Config{ Host: "invalid host", Port: "8080", }, @@ -73,7 +78,7 @@ func TestValidate(t *testing.T) { }, { name: "Invalid Port", - config: &Config{ + config: &config.Config{ Host: "localhost", Port: "invalid port", }, @@ -81,7 +86,7 @@ func TestValidate(t *testing.T) { }, { name: "Invalid UI Port", - config: &Config{ + config: &config.Config{ Host: "localhost", Port: "8080", UIPort: "invalid port", @@ -90,35 +95,35 @@ func TestValidate(t *testing.T) { }, { name: "Valid DB Type", - config: &Config{ - DBType: DBTypeMemory, + config: &config.Config{ + DBType: config.DBTypeMemory, }, wantErr: false, }, { name: "Invalid DB Type", - config: &Config{ + config: &config.Config{ DBType: "invalid db type", }, wantErr: true, }, { name: "Invalid Verbose Level (negative)", - config: &Config{ + config: &config.Config{ Verbose: -1, }, wantErr: true, }, { name: "Invalid Verbose Level (too high)", - config: &Config{ + config: &config.Config{ Verbose: 3, }, wantErr: true, }, { name: "Quiet mode sets Verbose to -1", - config: &Config{ + config: &config.Config{ Quiet: true, }, expected: expecttest{ @@ -127,18 +132,74 @@ func TestValidate(t *testing.T) { }, { name: "Verbose high sets Debug to true", - config: &Config{ - Verbose: VerboseLevelDebug, + config: &config.Config{ + Verbose: config.VerboseLevelDebug, }, expected: expecttest{ Debug: &debugTrue, }, }, + { + name: "Fails of credentials are disabled and oidc is not configured", + config: &config.Config{ + DisableCredentialsLogin: true, + }, + wantErr: true, + }, + { + name: "Succeeds of credentials are disabled and oidc is configured", + config: &config.Config{ + DisableCredentialsLogin: true, + OIDCIssuerURL: "https://example.com", + OIDCClientID: "client-id", + OIDCClientSecret: "client-secret", + OIDCRedirectURL: "https://example.com/redirect", + }, + wantErr: false, + }, + { + name: "Succeeds if both credentials and oidc are enabled", + config: &config.Config{ + OIDCIssuerURL: "https://example.com", + OIDCClientID: "client-id", + OIDCClientSecret: "client-secret", + OIDCRedirectURL: "https://example.com/redirect", + }, + wantErr: false, + }, + { + name: "Fails if issuer is not a valid URL", + config: &config.Config{ + OIDCIssuerURL: "invalid-url", + OIDCClientID: "client-id", + OIDCClientSecret: "client-secret", + OIDCRedirectURL: "https://example.com/redirect", + }, + wantErr: true, + }, + { + name: "Fails if redirect is not a valid URL", + config: &config.Config{ + OIDCIssuerURL: "https://example.com", + OIDCClientID: "client-id", + OIDCClientSecret: "client-secret", + OIDCRedirectURL: "invalid-url", + }, + wantErr: true, + }, + { + name: "Fails if oidc is partially configured", + config: &config.Config{ + OIDCIssuerURL: "https://example.com", + OIDCClientID: "client-id", + }, + wantErr: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - err := Validate(tc.config) + err := config.Validate(tc.config) if tc.wantErr { assert.Error(t, err, "Validate should return an error") diff --git a/internal/errs/errors.go b/internal/errs/errors.go index 448ffd3..8d864a2 100644 --- a/internal/errs/errors.go +++ b/internal/errs/errors.go @@ -48,6 +48,12 @@ var ( ErrStorageNotStarted = errors.New("storage not started") // ErrStorageStarted. ErrStorageStarted = errors.New("storage already started") + // ErrOIDCStateCookieMissing. + ErrOIDCStateCookieMissing = errors.New("OIDC state cookie missing") + // ErrOIDCStateCookieInvalid. + ErrOIDCStateCookieInvalid = errors.New("OIDC state cookie invalid") + // ErrCredentialsLoginDisabled. + ErrCredentialsLoginDisabled = errors.New("credentials login disabled") ) func Errorf(msg string, err error) error { diff --git a/internal/server/api/handler.go b/internal/server/api/handler.go index e8afb28..41e2a14 100644 --- a/internal/server/api/handler.go +++ b/internal/server/api/handler.go @@ -76,6 +76,8 @@ func (h *APIHandler) DeleteMe(w http.ResponseWriter, r *http.Request) { err := authmiddleware.DeleteAllUserSessions(ctx, user) if err != nil { server.RenderServerError(w, r, err) + + return } span.AddEvent("deleted all user sessions") @@ -128,6 +130,9 @@ func (h *APIHandler) Login(w http.ResponseWriter, r *http.Request) { case errors.Is(err, errs.ErrInvalidUser): // If the request was invalid, return bad request server.RenderBadRequest(w, r, err) + case errors.Is(err, errs.ErrCredentialsLoginDisabled): + // If credentials login is disabled, return forbidden + server.RenderForbidden(w, r) default: // Else, server error server.RenderServerError(w, r, err) diff --git a/internal/server/oidc/handler.go b/internal/server/oidc/handler.go new file mode 100644 index 0000000..864f34b --- /dev/null +++ b/internal/server/oidc/handler.go @@ -0,0 +1,119 @@ +package oidcserver + +import ( + "errors" + "net/http" + + "git.maronato.dev/maronato/goshort/internal/errs" + "git.maronato.dev/maronato/goshort/internal/server" + authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth" + oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc" + userservice "git.maronato.dev/maronato/goshort/internal/service/user" + "git.maronato.dev/maronato/goshort/internal/storage/models" + "git.maronato.dev/maronato/goshort/internal/util/logging" + oidcutil "git.maronato.dev/maronato/goshort/internal/util/oidc" + "git.maronato.dev/maronato/goshort/internal/util/tracing" + "go.opentelemetry.io/otel/attribute" +) + +type OIDCHandler struct { + oidc *oidcservice.OIDCService + user *userservice.UserService +} + +func NewOIDCHandler(oidc *oidcservice.OIDCService, user *userservice.UserService) *OIDCHandler { + return &OIDCHandler{ + oidc: oidc, + user: user, + } +} + +func (h *OIDCHandler) Redirect(w http.ResponseWriter, r *http.Request) { + _, span := tracing.StartSpan(r.Context(), "oidc.Redirect") + defer span.End() + // Generate redirect parameters + params := h.oidc.GetRedirectParams() + // Redirect to OIDC provider + oidcutil.DoRedirect(w, r, params) +} + +func (h *OIDCHandler) Callback(w http.ResponseWriter, r *http.Request) { + ctx, span := tracing.StartSpan(r.Context(), "oidc.Callback") + defer span.End() + + l := logging.FromCtx(ctx) + + // Validate the state + err := oidcutil.ValidateRequestState(r) + if err != nil { + server.RenderBadRequest(w, r, err) + + return + } + + span.AddEvent("State validated") + + code := r.URL.Query().Get("code") + + userInfo, err := h.oidc.CallbackExchange(ctx, code) + if err != nil { + server.RenderBadRequest(w, r, err) + + return + } + + span.AddEvent("user info retrieved") + span.SetAttributes(attribute.KeyValue{ + Key: attribute.Key("oidc.user.email"), + Value: attribute.StringValue(userInfo.Email), + }) + + // Check if user exists + user, err := h.user.FindUser(ctx, userInfo.Email) + if err != nil { + // If the user does not exist, create it + if errors.Is(err, errs.ErrUserDoesNotExist) { + // Create the new user and disable their password login. + user = &models.User{ + Username: userInfo.Email, + } + user.SetNoLoginPassword() + + user, err = h.user.CreateUser(ctx, user) + + // Handle errors + if err != nil { + switch { + case errors.Is(err, errs.ErrInvalidUser): + server.RenderBadRequest(w, r, err) + case errors.Is(err, errs.ErrRegistrationDisabled): + l.Debug("failed to create user", "err", err) + + server.RenderForbidden(w, r) + default: + l.Error("failed to create user", "error", err) + server.RenderServerError(w, r, err) + } + } + + span.AddEvent("user created") + } else { + // If the error was something else, render a server error + l.Error("failed to find user", "error", err) + + server.RenderServerError(w, r, err) + + return + } + } + + // Now that user is guaranteed to exist, log them in + span.AddEvent("user found or created") + + authmiddleware.LoginUser(ctx, user, r) + + span.AddEvent("logged in user") + + // Redirect to the home page + http.Redirect(w, r, "/", http.StatusFound) +} diff --git a/internal/server/oidc/router.go b/internal/server/oidc/router.go new file mode 100644 index 0000000..c76ca36 --- /dev/null +++ b/internal/server/oidc/router.go @@ -0,0 +1,16 @@ +package oidcserver + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +func NewOIDCRouter(h *OIDCHandler) http.Handler { + mux := chi.NewRouter() + + mux.Get("/redirect", h.Redirect) + mux.Get("/callback", h.Callback) + + return mux +} diff --git a/internal/service/oidc/oidcservice.go b/internal/service/oidc/oidcservice.go new file mode 100644 index 0000000..bd37b2e --- /dev/null +++ b/internal/service/oidc/oidcservice.go @@ -0,0 +1,100 @@ +package oidcservice + +import ( + "context" + + "git.maronato.dev/maronato/goshort/internal/config" + "git.maronato.dev/maronato/goshort/internal/errs" + randomutil "git.maronato.dev/maronato/goshort/internal/util/random" + "git.maronato.dev/maronato/goshort/internal/util/tracing" + "github.com/coreos/go-oidc/v3/oidc" + "go.opentelemetry.io/otel/attribute" + "golang.org/x/oauth2" +) + +const ( + // StateLength is the length of the state. + StateLength = 16 +) + +// OIDCService is the service that handles OIDC authentication. +type OIDCService struct { + cfg *config.Config + provider *oidc.Provider + oauth2Config *oauth2.Config +} + +// NewOIDCService creates a new OIDC service. +func NewOIDCService(ctx context.Context, cfg *config.Config) *OIDCService { + if cfg.OIDCIssuerURL == "" { + return nil + } + + provider, err := oidc.NewProvider(ctx, cfg.OIDCIssuerURL) + if err != nil { + return nil + } + + oauthConfig := oauth2.Config{ + ClientID: cfg.OIDCClientID, + ClientSecret: cfg.OIDCClientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: cfg.OIDCRedirectURL, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + + return &OIDCService{ + cfg: cfg, + provider: provider, + oauth2Config: &oauthConfig, + } +} + +// RedirectParams contains the parameters to redirect the user to the OIDC provider. +type RedirectParams struct { + State string + URL string +} + +// GetRedirectParams returns the parameters to redirect the user to the OIDC provider. +func (s *OIDCService) GetRedirectParams() *RedirectParams { + state := randomutil.GenerateSecureToken(StateLength) + + url := s.oauth2Config.AuthCodeURL(state) + + return &RedirectParams{ + State: state, + URL: url, + } +} + +// CallbackExchange exchanges the code for a token and returns the user info. +func (s *OIDCService) CallbackExchange(ctx context.Context, code string) (*oidc.UserInfo, error) { + ctx, span := tracing.StartSpan(ctx, "oidcservice.CallbackExchange") + defer span.End() + + span.SetAttributes(attribute.KeyValue{ + Key: attribute.Key("oidc.code"), + Value: attribute.StringValue(code), + }) + + oauth2Token, err := s.oauth2Config.Exchange(ctx, code) + if err != nil { + return nil, errs.Errorf("failed to exchange code for token", err) + } + + span.AddEvent("Code exchanged for token") + + userInfo, err := s.provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) + if err != nil { + return nil, errs.Errorf("failed to get user info", err) + } + + span.AddEvent("User info retrieved") + span.SetAttributes(attribute.KeyValue{ + Key: attribute.Key("oidc.user.email"), + Value: attribute.StringValue(userInfo.Email), + }) + + return userInfo, nil +} diff --git a/internal/service/user/userservice.go b/internal/service/user/userservice.go index 0040483..a3924f2 100644 --- a/internal/service/user/userservice.go +++ b/internal/service/user/userservice.go @@ -22,16 +22,18 @@ const ( ) type UserService struct { - db storage.Storage - disableRegistration bool - hasher passwords.PasswordHasher + db storage.Storage + disableRegistration bool + disableCredentialsLogin bool + hasher passwords.PasswordHasher } func NewUserService(cfg *config.Config, db storage.Storage) *UserService { return &UserService{ - db: db, - disableRegistration: cfg.DisableRegistration, - hasher: bcryptpasswords.NewBcryptHasher(), + db: db, + disableRegistration: cfg.DisableRegistration, + disableCredentialsLogin: cfg.DisableCredentialsLogin, + hasher: bcryptpasswords.NewBcryptHasher(), } } @@ -81,6 +83,9 @@ func (s *UserService) DeleteUser(ctx context.Context, user *models.User) error { } func (s *UserService) AuthenticateUser(ctx context.Context, username, password string) (user *models.User, err error) { + if s.disableCredentialsLogin { + return nil, errs.ErrCredentialsLoginDisabled + } // Get user from storage user, err = s.FindUser(ctx, username) if err != nil { @@ -95,7 +100,7 @@ func (s *UserService) AuthenticateUser(ctx context.Context, username, password s } // Try to authenticate - if password == "" || user.GetPasswordHash() == "" { + if password == "" || user.GetPasswordHash() == "" || !user.CanLogin() { return nil, errs.ErrFailedAuthentication } diff --git a/internal/storage/models/user.go b/internal/storage/models/user.go index a71fc23..afd6766 100644 --- a/internal/storage/models/user.go +++ b/internal/storage/models/user.go @@ -7,6 +7,11 @@ import ( "git.maronato.dev/maronato/goshort/internal/util/passwords" ) +const ( + // noLoginPassword is the password used when the user is not allowed to login using a password. + noLoginPassword = "nologin" +) + type User struct { // ID is the user's ID. ID string `json:"id"` @@ -43,7 +48,16 @@ func (u *User) SetPassword(hasher passwords.PasswordHasher, pass string) error { return nil } +// SetNoLoginPassword sets the user's password to a no-login password. +func (u *User) SetNoLoginPassword() { + u.password = noLoginPassword +} + // GetPasswordHash Returns the password hash to be saved in a storage. func (u *User) GetPasswordHash() string { return u.password } + +func (u *User) CanLogin() bool { + return u.password != noLoginPassword && u.password != "" +} diff --git a/internal/util/oidc/oidc.go b/internal/util/oidc/oidc.go new file mode 100644 index 0000000..7c4c8de --- /dev/null +++ b/internal/util/oidc/oidc.go @@ -0,0 +1,44 @@ +package oidcutil + +import ( + "net/http" + "time" + + "git.maronato.dev/maronato/goshort/internal/errs" + oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc" +) + +const ( + stateCookieName = "gs_oidc_state" + stateCookieExpiration = time.Minute * 5 +) + +// SetStateCookie sets the OIDC code parameter as a cookie. +func SetStateCookie(w http.ResponseWriter, r *http.Request, state string) { + c := &http.Cookie{ + Name: stateCookieName, + Value: state, + MaxAge: int(stateCookieExpiration.Seconds()), + Secure: r.TLS != nil, + HttpOnly: true, + } + http.SetCookie(w, c) +} + +func ValidateRequestState(r *http.Request) error { + state, err := r.Cookie(stateCookieName) + if err != nil { + return errs.ErrOIDCStateCookieMissing + } + + if r.URL.Query().Get("state") != state.Value { + return errs.ErrOIDCStateCookieInvalid + } + + return nil +} + +func DoRedirect(w http.ResponseWriter, r *http.Request, params *oidcservice.RedirectParams) { + SetStateCookie(w, r, params.State) + http.Redirect(w, r, params.URL, http.StatusFound) +}