101 lines
2.6 KiB
Go
101 lines
2.6 KiB
Go
package oidcservice
|
|
|
|
import (
|
|
"context"
|
|
|
|
"git.maronato.dev/maronato/goshort/internal/config"
|
|
"git.maronato.dev/maronato/goshort/internal/errs"
|
|
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
|
|
"git.maronato.dev/maronato/goshort/internal/util/tracing"
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
const (
|
|
// StateLength is the length of the state.
|
|
StateLength = 16
|
|
)
|
|
|
|
// OIDCService is the service that handles OIDC authentication.
|
|
type OIDCService struct {
|
|
cfg *config.Config
|
|
provider *oidc.Provider
|
|
oauth2Config *oauth2.Config
|
|
}
|
|
|
|
// NewOIDCService creates a new OIDC service.
|
|
func NewOIDCService(ctx context.Context, cfg *config.Config) *OIDCService {
|
|
if cfg.OIDCIssuerURL == "" {
|
|
return nil
|
|
}
|
|
|
|
provider, err := oidc.NewProvider(ctx, cfg.OIDCIssuerURL)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
oauthConfig := oauth2.Config{
|
|
ClientID: cfg.OIDCClientID,
|
|
ClientSecret: cfg.OIDCClientSecret,
|
|
Endpoint: provider.Endpoint(),
|
|
RedirectURL: cfg.OIDCRedirectURL,
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
|
|
return &OIDCService{
|
|
cfg: cfg,
|
|
provider: provider,
|
|
oauth2Config: &oauthConfig,
|
|
}
|
|
}
|
|
|
|
// RedirectParams contains the parameters to redirect the user to the OIDC provider.
|
|
type RedirectParams struct {
|
|
State string
|
|
URL string
|
|
}
|
|
|
|
// GetRedirectParams returns the parameters to redirect the user to the OIDC provider.
|
|
func (s *OIDCService) GetRedirectParams() *RedirectParams {
|
|
state := randomutil.GenerateSecureToken(StateLength)
|
|
|
|
url := s.oauth2Config.AuthCodeURL(state)
|
|
|
|
return &RedirectParams{
|
|
State: state,
|
|
URL: url,
|
|
}
|
|
}
|
|
|
|
// CallbackExchange exchanges the code for a token and returns the user info.
|
|
func (s *OIDCService) CallbackExchange(ctx context.Context, code string) (*oidc.UserInfo, error) {
|
|
ctx, span := tracing.StartSpan(ctx, "oidcservice.CallbackExchange")
|
|
defer span.End()
|
|
|
|
span.SetAttributes(attribute.KeyValue{
|
|
Key: attribute.Key("oidc.code"),
|
|
Value: attribute.StringValue(code),
|
|
})
|
|
|
|
oauth2Token, err := s.oauth2Config.Exchange(ctx, code)
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to exchange code for token", err)
|
|
}
|
|
|
|
span.AddEvent("Code exchanged for token")
|
|
|
|
userInfo, err := s.provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to get user info", err)
|
|
}
|
|
|
|
span.AddEvent("User info retrieved")
|
|
span.SetAttributes(attribute.KeyValue{
|
|
Key: attribute.Key("oidc.user.email"),
|
|
Value: attribute.StringValue(userInfo.Email),
|
|
})
|
|
|
|
return userInfo, nil
|
|
}
|