149 lines
3.9 KiB
Go
149 lines
3.9 KiB
Go
package userservice
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/mail"
|
|
"strings"
|
|
|
|
"git.maronato.dev/maronato/goshort/internal/config"
|
|
"git.maronato.dev/maronato/goshort/internal/errs"
|
|
"git.maronato.dev/maronato/goshort/internal/storage"
|
|
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
|
"git.maronato.dev/maronato/goshort/internal/util/passwords"
|
|
bcryptpasswords "git.maronato.dev/maronato/goshort/internal/util/passwords/bcrypt"
|
|
)
|
|
|
|
const (
|
|
// MinPasswordLength is the minimum length of a password.
|
|
MinPasswordLength = 8
|
|
// MaxPasswordLength is the maximum length of a password.
|
|
MaxPasswordLength = 128
|
|
)
|
|
|
|
type UserService struct {
|
|
db storage.Storage
|
|
disableRegistration bool
|
|
hasher passwords.PasswordHasher
|
|
}
|
|
|
|
func NewUserService(cfg *config.Config, db storage.Storage) *UserService {
|
|
return &UserService{
|
|
db: db,
|
|
disableRegistration: cfg.DisableRegistration,
|
|
hasher: bcryptpasswords.NewBcryptHasher(),
|
|
}
|
|
}
|
|
|
|
func (s *UserService) FindUser(ctx context.Context, username string) (*models.User, error) {
|
|
// Check if the username is valid
|
|
err := UsernameIsValid(username)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not validate username: %w", err)
|
|
}
|
|
|
|
// Get the user from storage
|
|
user, err := s.db.FindUser(ctx, username)
|
|
if err != nil {
|
|
return user, fmt.Errorf("could not get user from storage: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
|
|
// Check for disabled registration
|
|
if s.disableRegistration {
|
|
return nil, errs.ErrRegistrationDisabled
|
|
}
|
|
|
|
// Check if the user is valid
|
|
err := UserIsValid(user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not validate user: %w", err)
|
|
}
|
|
|
|
newUser, err := s.db.CreateUser(ctx, user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not create user in storage: %w", err)
|
|
}
|
|
|
|
return newUser, nil
|
|
}
|
|
|
|
func (s *UserService) DeleteUser(ctx context.Context, user *models.User) error {
|
|
err := s.db.DeleteUser(ctx, user)
|
|
if err != nil {
|
|
return fmt.Errorf("could not delete user from storage: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) AuthenticateUser(ctx context.Context, username, password string) (user *models.User, err error) {
|
|
// Get user from storage
|
|
user, err = s.FindUser(ctx, username)
|
|
if err != nil {
|
|
// Waste time if the user is not found
|
|
// to mitigate timing attacks
|
|
wasteErr := s.hasher.WasteTime()
|
|
if wasteErr != nil {
|
|
return nil, fmt.Errorf("failed to authenticate: %w", wasteErr)
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
// Try to authenticate
|
|
if password == "" || user.GetPasswordHash() == "" {
|
|
return nil, errs.ErrFailedAuthentication
|
|
}
|
|
|
|
match, err := s.hasher.Verify(password, user.GetPasswordHash())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to authenticate user: %w", err)
|
|
} else if !match {
|
|
return nil, errs.ErrFailedAuthentication
|
|
}
|
|
|
|
// Success
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) SetPassword(_ context.Context, user *models.User, newPassword string) error {
|
|
// Check if the password is valid
|
|
if len(newPassword) < MinPasswordLength || len(newPassword) > MaxPasswordLength {
|
|
return fmt.Errorf("password must be between %d and %d characters long: %w", MinPasswordLength, MaxPasswordLength, errs.ErrInvalidUser)
|
|
}
|
|
// Set the new password
|
|
err := user.SetPassword(s.hasher, newPassword)
|
|
if err != nil {
|
|
return fmt.Errorf("could not set new password: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func UsernameIsValid(username string) error {
|
|
if !strings.Contains(username, "<") {
|
|
if _, err := mail.ParseAddress(username); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errs.Errorf("username must be a valid email address", errs.ErrInvalidUser)
|
|
}
|
|
|
|
func UserIsValid(user *models.User) error {
|
|
err := UsernameIsValid(user.Username)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if user.GetPasswordHash() == "" {
|
|
return fmt.Errorf("missing password hash: %w", errs.ErrInvalidUser)
|
|
}
|
|
|
|
return nil
|
|
}
|