goshort/internal/server/middleware/auth/auth.go
Gustavo Maronato be39b22ace
Some checks failed
Check / checks (push) Failing after 3m48s
lint
2024-03-09 05:53:28 -05:00

94 lines
2.6 KiB
Go

package authmiddleware
import (
"context"
"net/http"
tokenservice "git.maronato.dev/maronato/goshort/internal/service/token"
userservice "git.maronato.dev/maronato/goshort/internal/service/user"
"git.maronato.dev/maronato/goshort/internal/storage/models"
"git.maronato.dev/maronato/goshort/internal/util/tracing"
"go.opentelemetry.io/otel/attribute"
)
type userCtxKey struct{}
func Auth(userService *userservice.UserService, tokenService *tokenservice.TokenService,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := tracing.StartSpan(r.Context(), "authmiddleware.Auth")
defer span.End()
// Set default span attributes
span.SetAttributes(attribute.String("auth_method", "none"))
// Authenticate user
authMethod := TokenAuth
user, err := authenticateViaToken(r, tokenService)
span.AddEvent("queried token auth")
if err != nil {
// Failed to authenticate via token. Try to authenticate via session.
authMethod = SessionAuth
user, err = authenticateUserViaSession(r, userService)
span.AddEvent("queried session auth")
if err != nil {
// Failed to authenticate via session. Call the next handler.
next.ServeHTTP(w, r)
return
}
}
// Add user to context
ctx = WithUser(ctx, user)
// Register the auth method used
ctx = context.WithValue(ctx, authMethodCtxKey{}, authMethod)
// Set span attributes
span.SetAttributes(attribute.String("auth_method", string(authMethod)))
span.SetAttributes(attribute.String("user_id", user.ID))
// Call the next handler
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// AuthRequired is a middleware that requires authentication for the current request.
func AuthRequired(authMethods ...AuthMethod) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !authMethodAllowed(ctx, authMethods) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if _, ok := UserFromCtx(ctx); !ok {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
}
func UserFromCtx(ctx context.Context) (*models.User, bool) {
user, ok := ctx.Value(userCtxKey{}).(*models.User)
return user, ok
}
func WithUser(ctx context.Context, user *models.User) context.Context {
return context.WithValue(ctx, userCtxKey{}, user)
}