94 lines
2.6 KiB
Go
94 lines
2.6 KiB
Go
package authmiddleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
tokenservice "git.maronato.dev/maronato/goshort/internal/service/token"
|
|
userservice "git.maronato.dev/maronato/goshort/internal/service/user"
|
|
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
|
"git.maronato.dev/maronato/goshort/internal/util/tracing"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
)
|
|
|
|
type userCtxKey struct{}
|
|
|
|
func Auth(userService *userservice.UserService, tokenService *tokenservice.TokenService,
|
|
) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := tracing.StartSpan(r.Context(), "authmiddleware.Auth")
|
|
defer span.End()
|
|
|
|
// Set default span attributes
|
|
span.SetAttributes(attribute.String("auth_method", "none"))
|
|
|
|
// Authenticate user
|
|
authMethod := TokenAuth
|
|
user, err := authenticateViaToken(r, tokenService)
|
|
|
|
span.AddEvent("queried token auth")
|
|
|
|
if err != nil {
|
|
// Failed to authenticate via token. Try to authenticate via session.
|
|
authMethod = SessionAuth
|
|
user, err = authenticateUserViaSession(r, userService)
|
|
|
|
span.AddEvent("queried session auth")
|
|
|
|
if err != nil {
|
|
// Failed to authenticate via session. Call the next handler.
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
// Add user to context
|
|
ctx = WithUser(ctx, user)
|
|
// Register the auth method used
|
|
ctx = context.WithValue(ctx, authMethodCtxKey{}, authMethod)
|
|
|
|
// Set span attributes
|
|
span.SetAttributes(attribute.String("auth_method", string(authMethod)))
|
|
span.SetAttributes(attribute.String("user_id", user.ID))
|
|
|
|
// Call the next handler
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// AuthRequired is a middleware that requires authentication for the current request.
|
|
func AuthRequired(authMethods ...AuthMethod) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
if !authMethodAllowed(ctx, authMethods) {
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
if _, ok := UserFromCtx(ctx); !ok {
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func UserFromCtx(ctx context.Context) (*models.User, bool) {
|
|
user, ok := ctx.Value(userCtxKey{}).(*models.User)
|
|
|
|
return user, ok
|
|
}
|
|
|
|
func WithUser(ctx context.Context, user *models.User) context.Context {
|
|
return context.WithValue(ctx, userCtxKey{}, user)
|
|
}
|