added backend OIDC support
This commit is contained in:
parent
c00d6ecca7
commit
25a37ba90d
|
@ -77,7 +77,7 @@ The docker image has health checks enabled by default. They are done via a subco
|
|||
You can use the same endpoint to perform health checks manually.
|
||||
|
||||
> Important! If you use Docker or Compose, prefer defining the `GOSHORT_HOST` and `GOSHORT_PORT` configs as environment variables instead of command line flags.
|
||||
> They are used by the healthcheck command to find the address to call, so it's important that they are sync'd between the main process and the healthcheck.
|
||||
> They are used by the healthcheck command to find the address to call, so it's important that they are synced between the main process and the healthcheck.
|
||||
|
||||
## Building and running from source
|
||||
You'll need:
|
||||
|
|
|
@ -16,8 +16,10 @@ import (
|
|||
devuiserver "git.maronato.dev/maronato/goshort/internal/server/devui"
|
||||
healthcheckserver "git.maronato.dev/maronato/goshort/internal/server/healthcheck"
|
||||
authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth"
|
||||
oidcserver "git.maronato.dev/maronato/goshort/internal/server/oidc"
|
||||
shortserver "git.maronato.dev/maronato/goshort/internal/server/short"
|
||||
staticssterver "git.maronato.dev/maronato/goshort/internal/server/static"
|
||||
oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc"
|
||||
shortservice "git.maronato.dev/maronato/goshort/internal/service/short"
|
||||
shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog"
|
||||
tokenservice "git.maronato.dev/maronato/goshort/internal/service/token"
|
||||
|
@ -102,6 +104,7 @@ func serveAPI(ctx context.Context, cfg *config.Config) error {
|
|||
userService := userservice.NewUserService(cfg, storage)
|
||||
tokenService := tokenservice.NewTokenService(storage)
|
||||
shortLogService := shortlogservice.NewShortLogService(storage)
|
||||
oidcService := oidcservice.NewOIDCService(ctx, cfg)
|
||||
|
||||
// Start short log worker
|
||||
stopWorker, _ := shortLogService.StartWorker(ctx)
|
||||
|
@ -112,12 +115,14 @@ func serveAPI(ctx context.Context, cfg *config.Config) error {
|
|||
shortHandler := shortserver.NewShortHandler(shortService, shortLogService)
|
||||
healthcheckHandler := healthcheckserver.NewHealthcheckHandler(storage)
|
||||
docsHandler := staticssterver.NewStaticHandler(cfg, "/api/docs", docs.Assets())
|
||||
oidcHandler := oidcserver.NewOIDCHandler(oidcService, userService)
|
||||
|
||||
// Create routers
|
||||
apiRouter := apiserver.NewAPIRouter(apiHandler)
|
||||
shortRouter := shortserver.NewShortRouter(shortHandler)
|
||||
healthcheckRouter := healthcheckserver.NewHealthcheckRouter(healthcheckHandler)
|
||||
docsRouter := staticssterver.NewStaticRouter(docsHandler)
|
||||
oidcRouter := oidcserver.NewOIDCRouter(oidcHandler)
|
||||
|
||||
// Create the root URL handler by chaining short and NotFound handlers
|
||||
chainedRouter := handlerutils.NewChainedHandler(shortRouter, http.NotFoundHandler())
|
||||
|
@ -137,6 +142,10 @@ func serveAPI(ctx context.Context, cfg *config.Config) error {
|
|||
r.Mount("/docs", docsRouter)
|
||||
r.Mount("/", apiRouter)
|
||||
})
|
||||
// Register OIDC routes if the service is enabled
|
||||
if oidcService != nil {
|
||||
srv.Mux.Mount("/oidc", oidcRouter)
|
||||
}
|
||||
srv.Mux.Mount("/healthz", healthcheckRouter)
|
||||
srv.Mux.Mount("/", chainedRouter)
|
||||
|
||||
|
|
|
@ -14,8 +14,10 @@ import (
|
|||
apiserver "git.maronato.dev/maronato/goshort/internal/server/api"
|
||||
healthcheckserver "git.maronato.dev/maronato/goshort/internal/server/healthcheck"
|
||||
authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth"
|
||||
oidcserver "git.maronato.dev/maronato/goshort/internal/server/oidc"
|
||||
shortserver "git.maronato.dev/maronato/goshort/internal/server/short"
|
||||
staticssterver "git.maronato.dev/maronato/goshort/internal/server/static"
|
||||
oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc"
|
||||
shortservice "git.maronato.dev/maronato/goshort/internal/service/short"
|
||||
shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog"
|
||||
tokenservice "git.maronato.dev/maronato/goshort/internal/service/token"
|
||||
|
@ -75,6 +77,7 @@ func exec(ctx context.Context, cfg *config.Config) error {
|
|||
userService := userservice.NewUserService(cfg, storage)
|
||||
tokenService := tokenservice.NewTokenService(storage)
|
||||
shortLogService := shortlogservice.NewShortLogService(storage)
|
||||
oidcService := oidcservice.NewOIDCService(ctx, cfg)
|
||||
|
||||
// Start short log worker
|
||||
stopWorker, _ := shortLogService.StartWorker(ctx)
|
||||
|
@ -86,6 +89,7 @@ func exec(ctx context.Context, cfg *config.Config) error {
|
|||
staticHandler := staticssterver.NewStaticHandler(cfg, "/", frontend.Assets())
|
||||
healthcheckHandler := healthcheckserver.NewHealthcheckHandler(storage)
|
||||
docsHandler := staticssterver.NewStaticHandler(cfg, "/api/docs", docs.Assets())
|
||||
oidcHandler := oidcserver.NewOIDCHandler(oidcService, userService)
|
||||
|
||||
// Create routers
|
||||
apiRouter := apiserver.NewAPIRouter(apiHandler)
|
||||
|
@ -93,6 +97,7 @@ func exec(ctx context.Context, cfg *config.Config) error {
|
|||
staticRouter := staticssterver.NewStaticRouter(staticHandler)
|
||||
healthcheckRouter := healthcheckserver.NewHealthcheckRouter(healthcheckHandler)
|
||||
docsRouter := staticssterver.NewStaticRouter(docsHandler)
|
||||
oidcRouter := oidcserver.NewOIDCRouter(oidcHandler)
|
||||
|
||||
// Create the root URL handler by chaining short and static routers
|
||||
chainedRouter := handlerutils.NewChainedHandler(shortRouter, staticRouter)
|
||||
|
@ -103,6 +108,11 @@ func exec(ctx context.Context, cfg *config.Config) error {
|
|||
r.Mount("/", apiRouter)
|
||||
r.Mount("/docs", docsRouter)
|
||||
})
|
||||
// Register OIDC routes if the service is enabled
|
||||
if oidcService != nil {
|
||||
srv.Mux.Mount("/oidc", oidcRouter)
|
||||
}
|
||||
|
||||
srv.Mux.Mount("/healthz", healthcheckRouter)
|
||||
srv.Mux.Mount("/", chainedRouter)
|
||||
|
||||
|
|
|
@ -66,6 +66,11 @@ func RegisterServerFlags(fs *flag.FlagSet, cfg *config.Config) {
|
|||
|
||||
fs.StringVar(&cfg.DBURL, "db", config.DefaultDBURL, "database connection string or sqlite db path")
|
||||
fs.DurationVar(&cfg.SessionDuration, "session-duration", config.DefaultSessionDuration, "session duration")
|
||||
fs.BoolVar(&cfg.DisableCredentialsLogin, "disable-credentials-login", config.DefaultDisableCredentialsLogin, "disable credentials login")
|
||||
fs.StringVar(&cfg.OIDCIssuerURL, "oidc-issuer-url", "", "OIDC issuer URL")
|
||||
fs.StringVar(&cfg.OIDCClientID, "oidc-client-id", "", "OIDC client ID")
|
||||
fs.StringVar(&cfg.OIDCClientSecret, "oidc-client-secret", "", "OIDC client secret")
|
||||
fs.StringVar(&cfg.OIDCRedirectURL, "oidc-redirect-url", "", "OIDC redirect URL")
|
||||
}
|
||||
|
||||
// InitStorage initializes the storage depending on the config.
|
||||
|
|
20
go.mod
20
go.mod
|
@ -14,18 +14,26 @@ require (
|
|||
github.com/uptrace/bun/extra/bundebug v1.1.16
|
||||
github.com/uptrace/bun/extra/bunotel v1.1.16
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.44.0
|
||||
go.opentelemetry.io/otel v1.18.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0
|
||||
go.opentelemetry.io/otel/sdk v1.18.0
|
||||
golang.org/x/crypto v0.12.0
|
||||
go.opentelemetry.io/otel/trace v1.18.0
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/sync v0.3.0
|
||||
modernc.org/sqlite v1.25.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-jose/go-jose/v3 v3.0.1 // indirect
|
||||
golang.org/x/oauth2 v0.13.0
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
|
||||
github.com/coreos/go-oidc/v3 v3.9.0
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fatih/color v1.15.0 // indirect
|
||||
|
@ -46,20 +54,18 @@ require (
|
|||
github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.2 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.44.0 // indirect
|
||||
go.opentelemetry.io/contrib/propagators/aws v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/propagators/b3 v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/propagators/jaeger v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/propagators/ot v1.19.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.18.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.18.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/net v0.14.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/tools v0.12.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
|
||||
|
|
55
go.sum
55
go.sum
|
@ -4,6 +4,8 @@ github.com/alexedwards/scs/v2 v2.5.1 h1:EhAz3Kb3OSQzD8T+Ub23fKsiuvE0GzbF5Lgn0uTw
|
|||
github.com/alexedwards/scs/v2 v2.5.1/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/coreos/go-oidc/v3 v3.9.0 h1:0J/ogVOd4y8P0f0xUh8l9t07xRP/d8tccvjHl2dcsSo=
|
||||
github.com/coreos/go-oidc/v3 v3.9.0/go.mod h1:rTKz2PYwftcrtoCzV5g5kvfJoWcm0Mk8AF8y1iAQro4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
@ -19,6 +21,8 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
|
|||
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
|
||||
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
|
||||
github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA=
|
||||
github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
|
@ -27,8 +31,10 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre
|
|||
github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE=
|
||||
github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP3NQ=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
|
@ -82,6 +88,7 @@ github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9
|
|||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 h1:KfYpVmrjI7JuToy5k8XV3nkapjWx48k4E4JOtVstzQI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0/go.mod h1:SeQhzAEccGVZVEy7aH87Nh0km+utSpo1pTv6eMMop48=
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.44.0 h1:HgXKc1D1PrpsYKdO8fc5XuyVmYFxcY2mLVAHq8XBZMU=
|
||||
|
@ -100,8 +107,6 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 h1:IAtl+7gua134xcV3Nie
|
|||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0/go.mod h1:w+pXobnBzh95MNIkeIuAKcHe/Uu/CX2PKIvBP6ipKRA=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0 h1:yE32ay7mJG2leczfREEhoW3VfSZIvHaB+gvVo1o8DQ8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0/go.mod h1:G17FHPDLt74bCI7tJ4CMitEk4BXTYG4FW6XUpkPBXa4=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0 h1:hSWWvDjXHVLq9DkmB+77fl8v7+t+yYiS+eNkiplDK54=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0/go.mod h1:zG7KQql1WjZCaUJd+L/ReSYx4bjbYJxg5ws9ws+mYes=
|
||||
go.opentelemetry.io/otel/metric v1.18.0 h1:JwVzw94UYmbx3ej++CwLUQZxEODDj/pOuTCvzhtRrSQ=
|
||||
go.opentelemetry.io/otel/metric v1.18.0/go.mod h1:nNSpsVDjWGfb7chbRLUNW+PBNdcSTHD4Uu5pfFMOI0k=
|
||||
go.opentelemetry.io/otel/sdk v1.18.0 h1:e3bAB0wB3MljH38sHzpV/qWrOTCFrdZF2ct9F8rBkcY=
|
||||
|
@ -114,23 +119,53 @@ go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
|||
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY=
|
||||
golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss=
|
||||
golang.org/x/tools v0.12.0/go.mod h1:Sc0INKfu04TlqNoRA1hgpFZbhYXHPr4V5DzpSBTPqQM=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g=
|
||||
google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw=
|
||||
|
|
|
@ -54,6 +54,8 @@ const (
|
|||
DefaultVerbose = VerboseLevelInfo
|
||||
// DefaultQuiet is the default quiet mode.
|
||||
DefaultQuiet = false
|
||||
// DefaultDisableCredentialsLogin is the default value for diable credential login.
|
||||
DefaultDisableCredentialsLogin = false
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -77,7 +79,7 @@ const (
|
|||
// Config defines the default configuration for the backend.
|
||||
type Config struct {
|
||||
// Prod is a flag that indicates if the server is running in production mode.
|
||||
Prod bool
|
||||
Prod bool `json:"prod"`
|
||||
// Debug is a flag that indicates if the server is running in debug mode.
|
||||
Debug bool
|
||||
// Host is the host to listen on.
|
||||
|
@ -93,11 +95,21 @@ type Config struct {
|
|||
// SessionDuration is the duration of the session.
|
||||
SessionDuration time.Duration
|
||||
// DisableRegistration defines whether or not registration are disabled.
|
||||
DisableRegistration bool
|
||||
DisableRegistration bool `json:"disableRegistration"`
|
||||
// Verbose defines the verbosity level.
|
||||
Verbose VerboseLevel
|
||||
Verbose VerboseLevel `json:"verbose"`
|
||||
// Quiet defines whether or not the server should be quiet.
|
||||
Quiet bool
|
||||
// DisableCredentialsLogin defines whether or not the server should disable credential login.
|
||||
DisableCredentialsLogin bool `json:"disableCredentialLogin"`
|
||||
// OIDCIssuerURL is the URL of the OIDC issuer.
|
||||
OIDCIssuerURL string
|
||||
// OIDCClientID is the client ID for OIDC.
|
||||
OIDCClientID string
|
||||
// OIDCClientSecret is the client secret for OIDC.
|
||||
OIDCClientSecret string
|
||||
// OIDCRedirectURL is the redirect URL for OIDC.
|
||||
OIDCRedirectURL string
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
|
@ -113,6 +125,7 @@ func NewConfig() *Config {
|
|||
DisableRegistration: DefaultDisableRegistration,
|
||||
Verbose: DefaultVerbose,
|
||||
Quiet: DefaultQuiet,
|
||||
DisableCredentialsLogin: DefaultDisableCredentialsLogin,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -165,6 +178,27 @@ func Validate(cfg *Config) error {
|
|||
cfg.Verbose = VerboseLevelQuiet
|
||||
}
|
||||
|
||||
// Either OIDC or credential login has to be enabled.
|
||||
if cfg.DisableCredentialsLogin && cfg.OIDCIssuerURL == "" {
|
||||
return errs.Errorf("either OIDC or credential login has to be enabled", errs.ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if cfg.OIDCIssuerURL != "" {
|
||||
// If OIDC is enabled, all OIDC fields have to be set.
|
||||
if cfg.OIDCClientID == "" || cfg.OIDCClientSecret == "" || cfg.OIDCRedirectURL == "" {
|
||||
return errs.Errorf("all OIDC fields have to be set", errs.ErrInvalidConfig)
|
||||
}
|
||||
|
||||
// OIDC fields must be valid
|
||||
if _, err := url.ParseRequestURI(cfg.OIDCIssuerURL); err != nil {
|
||||
return errs.Errorf(fmt.Sprintf("invalid OIDC issuer URL: %s", err), errs.ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if _, err := url.ParseRequestURI(cfg.OIDCRedirectURL); err != nil {
|
||||
return errs.Errorf(fmt.Sprintf("invalid OIDC redirect URL: %s", err), errs.ErrInvalidConfig)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,17 +3,17 @@ package config_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/config"
|
||||
config "git.maronato.dev/maronato/goshort/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetDBTypeList(t *testing.T) {
|
||||
dbTypeList := []string{
|
||||
DBTypeMemory,
|
||||
DBTypeSQLite,
|
||||
config.DBTypeMemory,
|
||||
config.DBTypeSQLite,
|
||||
}
|
||||
|
||||
results := GetDBTypeList()
|
||||
results := config.GetDBTypeList()
|
||||
|
||||
assert.ElementsMatch(t, dbTypeList, results, "DBTypeList should be equal to the list of available DB types")
|
||||
}
|
||||
|
@ -21,17 +21,17 @@ func TestGetDBTypeList(t *testing.T) {
|
|||
func TestValidate(t *testing.T) {
|
||||
type expecttest struct {
|
||||
Debug *bool
|
||||
Verbose *VerboseLevel
|
||||
Verbose *config.VerboseLevel
|
||||
}
|
||||
|
||||
debugTrue := true
|
||||
debugFalse := false
|
||||
verboseQuiet := VerboseLevelQuiet
|
||||
verboseInfo := VerboseLevelInfo
|
||||
verboseQuiet := config.VerboseLevelQuiet
|
||||
verboseInfo := config.VerboseLevelInfo
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
config *Config
|
||||
config *config.Config
|
||||
wantErr bool
|
||||
expected expecttest
|
||||
}
|
||||
|
@ -39,9 +39,9 @@ func TestValidate(t *testing.T) {
|
|||
tests := []test{
|
||||
{
|
||||
name: "Valid config",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Prod: false,
|
||||
DBType: DBTypeSQLite,
|
||||
DBType: config.DBTypeSQLite,
|
||||
DBURL: "goshort.db",
|
||||
Port: "8080",
|
||||
Host: "localhost",
|
||||
|
@ -51,6 +51,11 @@ func TestValidate(t *testing.T) {
|
|||
DisableRegistration: false,
|
||||
Verbose: 0,
|
||||
Quiet: false,
|
||||
DisableCredentialsLogin: false,
|
||||
OIDCIssuerURL: "https://example.com",
|
||||
OIDCClientID: "client-id",
|
||||
OIDCClientSecret: "client-secret",
|
||||
OIDCRedirectURL: "https://example.com/redirect",
|
||||
},
|
||||
expected: expecttest{
|
||||
Debug: &debugFalse,
|
||||
|
@ -60,12 +65,12 @@ func TestValidate(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Minimum config",
|
||||
config: &Config{},
|
||||
config: &config.Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid Host",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Host: "invalid host",
|
||||
Port: "8080",
|
||||
},
|
||||
|
@ -73,7 +78,7 @@ func TestValidate(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Invalid Port",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Host: "localhost",
|
||||
Port: "invalid port",
|
||||
},
|
||||
|
@ -81,7 +86,7 @@ func TestValidate(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Invalid UI Port",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Host: "localhost",
|
||||
Port: "8080",
|
||||
UIPort: "invalid port",
|
||||
|
@ -90,35 +95,35 @@ func TestValidate(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Valid DB Type",
|
||||
config: &Config{
|
||||
DBType: DBTypeMemory,
|
||||
config: &config.Config{
|
||||
DBType: config.DBTypeMemory,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid DB Type",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
DBType: "invalid db type",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Verbose Level (negative)",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Verbose: -1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Verbose Level (too high)",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Verbose: 3,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Quiet mode sets Verbose to -1",
|
||||
config: &Config{
|
||||
config: &config.Config{
|
||||
Quiet: true,
|
||||
},
|
||||
expected: expecttest{
|
||||
|
@ -127,18 +132,74 @@ func TestValidate(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Verbose high sets Debug to true",
|
||||
config: &Config{
|
||||
Verbose: VerboseLevelDebug,
|
||||
config: &config.Config{
|
||||
Verbose: config.VerboseLevelDebug,
|
||||
},
|
||||
expected: expecttest{
|
||||
Debug: &debugTrue,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Fails of credentials are disabled and oidc is not configured",
|
||||
config: &config.Config{
|
||||
DisableCredentialsLogin: true,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Succeeds of credentials are disabled and oidc is configured",
|
||||
config: &config.Config{
|
||||
DisableCredentialsLogin: true,
|
||||
OIDCIssuerURL: "https://example.com",
|
||||
OIDCClientID: "client-id",
|
||||
OIDCClientSecret: "client-secret",
|
||||
OIDCRedirectURL: "https://example.com/redirect",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Succeeds if both credentials and oidc are enabled",
|
||||
config: &config.Config{
|
||||
OIDCIssuerURL: "https://example.com",
|
||||
OIDCClientID: "client-id",
|
||||
OIDCClientSecret: "client-secret",
|
||||
OIDCRedirectURL: "https://example.com/redirect",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Fails if issuer is not a valid URL",
|
||||
config: &config.Config{
|
||||
OIDCIssuerURL: "invalid-url",
|
||||
OIDCClientID: "client-id",
|
||||
OIDCClientSecret: "client-secret",
|
||||
OIDCRedirectURL: "https://example.com/redirect",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Fails if redirect is not a valid URL",
|
||||
config: &config.Config{
|
||||
OIDCIssuerURL: "https://example.com",
|
||||
OIDCClientID: "client-id",
|
||||
OIDCClientSecret: "client-secret",
|
||||
OIDCRedirectURL: "invalid-url",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Fails if oidc is partially configured",
|
||||
config: &config.Config{
|
||||
OIDCIssuerURL: "https://example.com",
|
||||
OIDCClientID: "client-id",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := Validate(tc.config)
|
||||
err := config.Validate(tc.config)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "Validate should return an error")
|
||||
|
|
|
@ -48,6 +48,12 @@ var (
|
|||
ErrStorageNotStarted = errors.New("storage not started")
|
||||
// ErrStorageStarted.
|
||||
ErrStorageStarted = errors.New("storage already started")
|
||||
// ErrOIDCStateCookieMissing.
|
||||
ErrOIDCStateCookieMissing = errors.New("OIDC state cookie missing")
|
||||
// ErrOIDCStateCookieInvalid.
|
||||
ErrOIDCStateCookieInvalid = errors.New("OIDC state cookie invalid")
|
||||
// ErrCredentialsLoginDisabled.
|
||||
ErrCredentialsLoginDisabled = errors.New("credentials login disabled")
|
||||
)
|
||||
|
||||
func Errorf(msg string, err error) error {
|
||||
|
|
|
@ -76,6 +76,8 @@ func (h *APIHandler) DeleteMe(w http.ResponseWriter, r *http.Request) {
|
|||
err := authmiddleware.DeleteAllUserSessions(ctx, user)
|
||||
if err != nil {
|
||||
server.RenderServerError(w, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
span.AddEvent("deleted all user sessions")
|
||||
|
@ -128,6 +130,9 @@ func (h *APIHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|||
case errors.Is(err, errs.ErrInvalidUser):
|
||||
// If the request was invalid, return bad request
|
||||
server.RenderBadRequest(w, r, err)
|
||||
case errors.Is(err, errs.ErrCredentialsLoginDisabled):
|
||||
// If credentials login is disabled, return forbidden
|
||||
server.RenderForbidden(w, r)
|
||||
default:
|
||||
// Else, server error
|
||||
server.RenderServerError(w, r, err)
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
package oidcserver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"git.maronato.dev/maronato/goshort/internal/server"
|
||||
authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth"
|
||||
oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc"
|
||||
userservice "git.maronato.dev/maronato/goshort/internal/service/user"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
"git.maronato.dev/maronato/goshort/internal/util/logging"
|
||||
oidcutil "git.maronato.dev/maronato/goshort/internal/util/oidc"
|
||||
"git.maronato.dev/maronato/goshort/internal/util/tracing"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
)
|
||||
|
||||
type OIDCHandler struct {
|
||||
oidc *oidcservice.OIDCService
|
||||
user *userservice.UserService
|
||||
}
|
||||
|
||||
func NewOIDCHandler(oidc *oidcservice.OIDCService, user *userservice.UserService) *OIDCHandler {
|
||||
return &OIDCHandler{
|
||||
oidc: oidc,
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OIDCHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
_, span := tracing.StartSpan(r.Context(), "oidc.Redirect")
|
||||
defer span.End()
|
||||
// Generate redirect parameters
|
||||
params := h.oidc.GetRedirectParams()
|
||||
// Redirect to OIDC provider
|
||||
oidcutil.DoRedirect(w, r, params)
|
||||
}
|
||||
|
||||
func (h *OIDCHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := tracing.StartSpan(r.Context(), "oidc.Callback")
|
||||
defer span.End()
|
||||
|
||||
l := logging.FromCtx(ctx)
|
||||
|
||||
// Validate the state
|
||||
err := oidcutil.ValidateRequestState(r)
|
||||
if err != nil {
|
||||
server.RenderBadRequest(w, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
span.AddEvent("State validated")
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
userInfo, err := h.oidc.CallbackExchange(ctx, code)
|
||||
if err != nil {
|
||||
server.RenderBadRequest(w, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
span.AddEvent("user info retrieved")
|
||||
span.SetAttributes(attribute.KeyValue{
|
||||
Key: attribute.Key("oidc.user.email"),
|
||||
Value: attribute.StringValue(userInfo.Email),
|
||||
})
|
||||
|
||||
// Check if user exists
|
||||
user, err := h.user.FindUser(ctx, userInfo.Email)
|
||||
if err != nil {
|
||||
// If the user does not exist, create it
|
||||
if errors.Is(err, errs.ErrUserDoesNotExist) {
|
||||
// Create the new user and disable their password login.
|
||||
user = &models.User{
|
||||
Username: userInfo.Email,
|
||||
}
|
||||
user.SetNoLoginPassword()
|
||||
|
||||
user, err = h.user.CreateUser(ctx, user)
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errs.ErrInvalidUser):
|
||||
server.RenderBadRequest(w, r, err)
|
||||
case errors.Is(err, errs.ErrRegistrationDisabled):
|
||||
l.Debug("failed to create user", "err", err)
|
||||
|
||||
server.RenderForbidden(w, r)
|
||||
default:
|
||||
l.Error("failed to create user", "error", err)
|
||||
server.RenderServerError(w, r, err)
|
||||
}
|
||||
}
|
||||
|
||||
span.AddEvent("user created")
|
||||
} else {
|
||||
// If the error was something else, render a server error
|
||||
l.Error("failed to find user", "error", err)
|
||||
|
||||
server.RenderServerError(w, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Now that user is guaranteed to exist, log them in
|
||||
span.AddEvent("user found or created")
|
||||
|
||||
authmiddleware.LoginUser(ctx, user, r)
|
||||
|
||||
span.AddEvent("logged in user")
|
||||
|
||||
// Redirect to the home page
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package oidcserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func NewOIDCRouter(h *OIDCHandler) http.Handler {
|
||||
mux := chi.NewRouter()
|
||||
|
||||
mux.Get("/redirect", h.Redirect)
|
||||
mux.Get("/callback", h.Callback)
|
||||
|
||||
return mux
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package oidcservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
|
||||
"git.maronato.dev/maronato/goshort/internal/util/tracing"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// StateLength is the length of the state.
|
||||
StateLength = 16
|
||||
)
|
||||
|
||||
// OIDCService is the service that handles OIDC authentication.
|
||||
type OIDCService struct {
|
||||
cfg *config.Config
|
||||
provider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
}
|
||||
|
||||
// NewOIDCService creates a new OIDC service.
|
||||
func NewOIDCService(ctx context.Context, cfg *config.Config) *OIDCService {
|
||||
if cfg.OIDCIssuerURL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, cfg.OIDCIssuerURL)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
oauthConfig := oauth2.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
ClientSecret: cfg.OIDCClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: cfg.OIDCRedirectURL,
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
return &OIDCService{
|
||||
cfg: cfg,
|
||||
provider: provider,
|
||||
oauth2Config: &oauthConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectParams contains the parameters to redirect the user to the OIDC provider.
|
||||
type RedirectParams struct {
|
||||
State string
|
||||
URL string
|
||||
}
|
||||
|
||||
// GetRedirectParams returns the parameters to redirect the user to the OIDC provider.
|
||||
func (s *OIDCService) GetRedirectParams() *RedirectParams {
|
||||
state := randomutil.GenerateSecureToken(StateLength)
|
||||
|
||||
url := s.oauth2Config.AuthCodeURL(state)
|
||||
|
||||
return &RedirectParams{
|
||||
State: state,
|
||||
URL: url,
|
||||
}
|
||||
}
|
||||
|
||||
// CallbackExchange exchanges the code for a token and returns the user info.
|
||||
func (s *OIDCService) CallbackExchange(ctx context.Context, code string) (*oidc.UserInfo, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, "oidcservice.CallbackExchange")
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(attribute.KeyValue{
|
||||
Key: attribute.Key("oidc.code"),
|
||||
Value: attribute.StringValue(code),
|
||||
})
|
||||
|
||||
oauth2Token, err := s.oauth2Config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to exchange code for token", err)
|
||||
}
|
||||
|
||||
span.AddEvent("Code exchanged for token")
|
||||
|
||||
userInfo, err := s.provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to get user info", err)
|
||||
}
|
||||
|
||||
span.AddEvent("User info retrieved")
|
||||
span.SetAttributes(attribute.KeyValue{
|
||||
Key: attribute.Key("oidc.user.email"),
|
||||
Value: attribute.StringValue(userInfo.Email),
|
||||
})
|
||||
|
||||
return userInfo, nil
|
||||
}
|
|
@ -24,6 +24,7 @@ const (
|
|||
type UserService struct {
|
||||
db storage.Storage
|
||||
disableRegistration bool
|
||||
disableCredentialsLogin bool
|
||||
hasher passwords.PasswordHasher
|
||||
}
|
||||
|
||||
|
@ -31,6 +32,7 @@ func NewUserService(cfg *config.Config, db storage.Storage) *UserService {
|
|||
return &UserService{
|
||||
db: db,
|
||||
disableRegistration: cfg.DisableRegistration,
|
||||
disableCredentialsLogin: cfg.DisableCredentialsLogin,
|
||||
hasher: bcryptpasswords.NewBcryptHasher(),
|
||||
}
|
||||
}
|
||||
|
@ -81,6 +83,9 @@ func (s *UserService) DeleteUser(ctx context.Context, user *models.User) error {
|
|||
}
|
||||
|
||||
func (s *UserService) AuthenticateUser(ctx context.Context, username, password string) (user *models.User, err error) {
|
||||
if s.disableCredentialsLogin {
|
||||
return nil, errs.ErrCredentialsLoginDisabled
|
||||
}
|
||||
// Get user from storage
|
||||
user, err = s.FindUser(ctx, username)
|
||||
if err != nil {
|
||||
|
@ -95,7 +100,7 @@ func (s *UserService) AuthenticateUser(ctx context.Context, username, password s
|
|||
}
|
||||
|
||||
// Try to authenticate
|
||||
if password == "" || user.GetPasswordHash() == "" {
|
||||
if password == "" || user.GetPasswordHash() == "" || !user.CanLogin() {
|
||||
return nil, errs.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,11 @@ import (
|
|||
"git.maronato.dev/maronato/goshort/internal/util/passwords"
|
||||
)
|
||||
|
||||
const (
|
||||
// noLoginPassword is the password used when the user is not allowed to login using a password.
|
||||
noLoginPassword = "nologin"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
// ID is the user's ID.
|
||||
ID string `json:"id"`
|
||||
|
@ -43,7 +48,16 @@ func (u *User) SetPassword(hasher passwords.PasswordHasher, pass string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetNoLoginPassword sets the user's password to a no-login password.
|
||||
func (u *User) SetNoLoginPassword() {
|
||||
u.password = noLoginPassword
|
||||
}
|
||||
|
||||
// GetPasswordHash Returns the password hash to be saved in a storage.
|
||||
func (u *User) GetPasswordHash() string {
|
||||
return u.password
|
||||
}
|
||||
|
||||
func (u *User) CanLogin() bool {
|
||||
return u.password != noLoginPassword && u.password != ""
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package oidcutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
oidcservice "git.maronato.dev/maronato/goshort/internal/service/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
stateCookieName = "gs_oidc_state"
|
||||
stateCookieExpiration = time.Minute * 5
|
||||
)
|
||||
|
||||
// SetStateCookie sets the OIDC code parameter as a cookie.
|
||||
func SetStateCookie(w http.ResponseWriter, r *http.Request, state string) {
|
||||
c := &http.Cookie{
|
||||
Name: stateCookieName,
|
||||
Value: state,
|
||||
MaxAge: int(stateCookieExpiration.Seconds()),
|
||||
Secure: r.TLS != nil,
|
||||
HttpOnly: true,
|
||||
}
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
func ValidateRequestState(r *http.Request) error {
|
||||
state, err := r.Cookie(stateCookieName)
|
||||
if err != nil {
|
||||
return errs.ErrOIDCStateCookieMissing
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("state") != state.Value {
|
||||
return errs.ErrOIDCStateCookieInvalid
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoRedirect(w http.ResponseWriter, r *http.Request, params *oidcservice.RedirectParams) {
|
||||
SetStateCookie(w, r, params.State)
|
||||
http.Redirect(w, r, params.URL, http.StatusFound)
|
||||
}
|
Loading…
Reference in New Issue
Block a user