637 lines
15 KiB
Go
637 lines
15 KiB
Go
package bunstorage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.maronato.dev/maronato/goshort/internal/config"
|
|
"git.maronato.dev/maronato/goshort/internal/errs"
|
|
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
|
"github.com/uptrace/bun"
|
|
"github.com/uptrace/bun/extra/bundebug"
|
|
"github.com/uptrace/bun/extra/bunotel"
|
|
)
|
|
|
|
type BunStorage struct {
|
|
db *bun.DB
|
|
// StartHooks are ran after the database is connected and the tables are created.
|
|
StartHooks []func(ctx context.Context, db *bun.DB) error
|
|
// StopHooks are ran before the database is closed.
|
|
StopHook []func(ctx context.Context, db *bun.DB) error
|
|
// started is true if the storage has been started.
|
|
started bool
|
|
}
|
|
|
|
func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
|
|
if cfg.Verbose >= config.VerboseLevelDebug {
|
|
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
|
|
}
|
|
|
|
// Enable tracing.
|
|
db.AddQueryHook(bunotel.NewQueryHook(
|
|
bunotel.WithDBName("goshort")),
|
|
)
|
|
|
|
return &BunStorage{
|
|
db: db,
|
|
started: false,
|
|
}
|
|
}
|
|
|
|
// RegisterStartHook registers a hook to be ran after the database is connected and the tables are created.
|
|
func (s *BunStorage) RegisterStartHook(hook func(ctx context.Context, db *bun.DB) error) {
|
|
s.StartHooks = append(s.StartHooks, hook)
|
|
}
|
|
|
|
// RegisterStopHook registers a hook to be ran before the database is closed.
|
|
func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB) error) {
|
|
s.StopHook = append(s.StopHook, hook)
|
|
}
|
|
|
|
func (s *BunStorage) Start(ctx context.Context) error { //nolint:cyclop // This function is not that complex.
|
|
if s.started {
|
|
return errs.Errorf("failed to start storage", errs.ErrStorageStarted)
|
|
}
|
|
|
|
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
|
_, err := tx.NewCreateTable().
|
|
IfNotExists().
|
|
Model((*UserModel)(nil)).
|
|
WithForeignKeys().
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create users table", err)
|
|
}
|
|
|
|
_, err = tx.NewCreateTable().
|
|
IfNotExists().
|
|
Model((*ShortModel)(nil)).
|
|
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create shorts table", err)
|
|
}
|
|
|
|
_, err = tx.NewCreateTable().
|
|
IfNotExists().
|
|
Model((*TokenModel)(nil)).
|
|
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create tokens table", err)
|
|
}
|
|
|
|
_, err = tx.NewCreateTable().
|
|
IfNotExists().
|
|
Model((*ShortLogModel)(nil)).
|
|
ForeignKey(`("short_id") REFERENCES "shorts" ("id") ON DELETE CASCADE`).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create short logs table", err)
|
|
}
|
|
|
|
// shorts user_id index
|
|
_, err = tx.NewCreateIndex().
|
|
IfNotExists().
|
|
Model((*ShortModel)(nil)).
|
|
Index("idx_shorts_user_id").
|
|
Column("user_id").
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create shorts user_id index", err)
|
|
}
|
|
|
|
// tokens user_id index
|
|
_, err = tx.NewCreateIndex().
|
|
IfNotExists().
|
|
Model((*TokenModel)(nil)).
|
|
Index("idx_tokens_user_id").
|
|
Column("user_id").
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create tokens user_id index", err)
|
|
}
|
|
|
|
// short_logs short_id index
|
|
_, err = tx.NewCreateIndex().
|
|
IfNotExists().
|
|
Model((*ShortLogModel)(nil)).
|
|
Index("idx_short_logs_short_id").
|
|
Column("short_id").
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to create short_logs short_id index", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return errs.Errorf("failed to start storage", err)
|
|
}
|
|
|
|
// Run start hooks
|
|
for _, hook := range s.StartHooks {
|
|
err := hook(ctx, s.db)
|
|
if err != nil {
|
|
return errs.Errorf("failed to start storage", err)
|
|
}
|
|
}
|
|
|
|
s.started = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) Stop(ctx context.Context) error {
|
|
if !s.started {
|
|
return errs.Errorf("failed to stop storage", errs.ErrStorageNotStarted)
|
|
}
|
|
// Run stop hooks
|
|
for _, hook := range s.StopHook {
|
|
err := hook(ctx, s.db)
|
|
if err != nil {
|
|
return errs.Errorf("failed to stop storage", err)
|
|
}
|
|
}
|
|
|
|
if err := s.db.Close(); err != nil {
|
|
return errs.Errorf("failed to stop storage", err)
|
|
}
|
|
|
|
s.started = false
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) Ping(ctx context.Context) error {
|
|
if err := s.db.PingContext(ctx); err != nil {
|
|
return errs.Errorf("failed to ping database", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) FindShort(ctx context.Context, name string) (*models.Short, error) {
|
|
shortModel := new(ShortModel)
|
|
|
|
err := s.db.NewSelect().
|
|
Model(shortModel).
|
|
Where("name = ? and deleted = false", name).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrShortDoesNotExist
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to find short", err)
|
|
}
|
|
|
|
return shortModel.toShort(), nil
|
|
}
|
|
|
|
func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
|
|
shortModel := new(ShortModel)
|
|
|
|
err := s.db.NewSelect().
|
|
Model(shortModel).
|
|
Where("id = ? and deleted = false", id).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrShortDoesNotExist
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to find short", err)
|
|
}
|
|
|
|
return shortModel.toShort(), nil
|
|
}
|
|
|
|
func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
|
|
if short == nil {
|
|
return nil, errs.Errorf("failed to create short", errs.ErrInvalidShort)
|
|
}
|
|
|
|
// Generate an ID that does not exist
|
|
newID, err := createNewID(ctx, s.db, (*ShortModel)(nil))
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to create short", err)
|
|
}
|
|
|
|
shortModel := &ShortModel{
|
|
ID: newID,
|
|
Name: short.Name,
|
|
URL: short.URL,
|
|
UserID: short.UserID,
|
|
}
|
|
|
|
_, err = s.db.NewInsert().
|
|
Model(shortModel).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
|
err = errs.ErrShortExists
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to create short", err)
|
|
}
|
|
|
|
return shortModel.toShort(), nil
|
|
}
|
|
|
|
func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error {
|
|
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
|
// Delete short logs
|
|
_, err := tx.NewDelete().
|
|
Model((*ShortLogModel)(nil)).
|
|
Where("short_id = ?", short.ID).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete short", err)
|
|
}
|
|
|
|
// Delete short
|
|
_, err = withShortDeleteUpdates(
|
|
tx.NewUpdate().
|
|
Model((*ShortModel)(nil)).
|
|
Where("id = ? and deleted = false", short.ID),
|
|
).Exec(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrShortDoesNotExist
|
|
}
|
|
|
|
return errs.Errorf("failed to delete short", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete short", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
|
|
shortModels := []*ShortModel{}
|
|
|
|
err := s.db.NewSelect().
|
|
Model(&shortModels).
|
|
Where("user_id = ? and deleted = false", user.ID).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to list shorts", err)
|
|
}
|
|
|
|
shorts := make([]*models.Short, len(shortModels))
|
|
for i, shortModel := range shortModels {
|
|
shorts[i] = shortModel.toShort()
|
|
}
|
|
|
|
return shorts, nil
|
|
}
|
|
|
|
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
|
|
// Create new ID
|
|
newID, err := createNewID(ctx, s.db, (*ShortLogModel)(nil))
|
|
if err != nil {
|
|
return errs.Errorf("failed to create short log", err)
|
|
}
|
|
|
|
shortLogModel := &ShortLogModel{
|
|
ID: newID,
|
|
ShortID: shortLog.ShortID,
|
|
IPAddress: shortLog.IPAddress,
|
|
UserAgent: shortLog.UserAgent,
|
|
Referer: shortLog.Referer,
|
|
CreatedAt: shortLog.CreatedAt,
|
|
}
|
|
|
|
_, err = s.db.NewInsert().
|
|
Model(shortLogModel).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
|
err = errs.ErrShortLogExists
|
|
}
|
|
|
|
return errs.Errorf("failed to create short log", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
|
|
shortLogModels := []*ShortLogModel{}
|
|
|
|
err := s.db.NewSelect().
|
|
Model(&shortLogModels).
|
|
Where("short_id = ?", short.ID).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to list short logs", err)
|
|
}
|
|
|
|
shortLogs := make([]*models.ShortLog, len(shortLogModels))
|
|
for i, shortLogModel := range shortLogModels {
|
|
shortLogs[i] = shortLogModel.toShortLog()
|
|
}
|
|
|
|
return shortLogs, nil
|
|
}
|
|
|
|
func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
|
|
user, err := findUser(ctx, s.db, username)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user.toUser(), nil
|
|
}
|
|
|
|
func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
|
|
user, err := findUserByID(ctx, s.db, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user.toUser(), nil
|
|
}
|
|
|
|
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
|
|
// Create a new ID
|
|
newID, err := createNewID(ctx, s.db, (*UserModel)(nil))
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to create user", err)
|
|
}
|
|
|
|
if user == nil {
|
|
return nil, errs.Errorf("failed to create user", errs.ErrInvalidUser)
|
|
}
|
|
|
|
userModel := &UserModel{
|
|
ID: newID,
|
|
Username: user.Username,
|
|
Password: user.GetPasswordHash(),
|
|
}
|
|
|
|
_, err = s.db.NewInsert().
|
|
Model(userModel).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
|
err = errs.ErrUserExists
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to create user", err)
|
|
}
|
|
|
|
return userModel.toUser(), nil
|
|
}
|
|
|
|
func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
|
|
// Delete user in transaction
|
|
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
|
// Delete user short logs
|
|
_, err := tx.NewDelete().
|
|
Model((*ShortLogModel)(nil)).
|
|
Where("short_id IN (?)", tx.NewSelect().
|
|
Model((*ShortModel)(nil)).
|
|
Column("id").
|
|
Where("user_id = ?", user.ID)).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete short logs from user's shorts", err)
|
|
}
|
|
|
|
// Delete user shorts
|
|
_, err = withShortDeleteUpdates(
|
|
tx.NewUpdate().
|
|
Model((*ShortModel)(nil)).
|
|
Where("user_id = ?", user.ID),
|
|
).Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete user shorts", err)
|
|
}
|
|
|
|
// Delete user tokens
|
|
_, err = tx.NewDelete().
|
|
Model((*TokenModel)(nil)).
|
|
Where("user_id = ?", user.ID).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete user", err)
|
|
}
|
|
|
|
// Delete user
|
|
_, err = tx.NewDelete().
|
|
Model(user).
|
|
Where("id = ?", user.ID).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete user", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete user", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) FindToken(ctx context.Context, value string) (*models.Token, error) {
|
|
tokenModel := new(TokenModel)
|
|
|
|
err := s.db.NewSelect().
|
|
Model(tokenModel).
|
|
Where("value = ?", value).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrTokenDoesNotExist
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to find token", err)
|
|
}
|
|
|
|
return tokenModel.toToken(), nil
|
|
}
|
|
|
|
func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
|
|
tokenModel := new(TokenModel)
|
|
|
|
err := s.db.NewSelect().
|
|
Model(tokenModel).
|
|
Where("id = ?", id).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrTokenDoesNotExist
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to find token", err)
|
|
}
|
|
|
|
return tokenModel.toToken(), nil
|
|
}
|
|
|
|
func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
|
|
tokenModels := []*TokenModel{}
|
|
|
|
err := s.db.NewSelect().
|
|
Model(&tokenModels).
|
|
Where("user_id = ?", user.ID).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to list tokens", err)
|
|
}
|
|
|
|
tokens := make([]*models.Token, len(tokenModels))
|
|
for i, tokenModel := range tokenModels {
|
|
tokens[i] = tokenModel.toToken()
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
|
|
if token == nil {
|
|
return nil, errs.Errorf("failed to create token", errs.ErrInvalidToken)
|
|
}
|
|
|
|
// Create a new ID
|
|
newID, err := createNewID(ctx, s.db, (*TokenModel)(nil))
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to create token", err)
|
|
}
|
|
|
|
tokenModel := &TokenModel{
|
|
ID: newID,
|
|
Name: token.Name,
|
|
Value: token.Value,
|
|
UserID: token.UserID,
|
|
}
|
|
|
|
_, err = s.db.NewInsert().
|
|
Model(tokenModel).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
|
err = errs.ErrTokenExists
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to create token", err)
|
|
}
|
|
|
|
return tokenModel.toToken(), nil
|
|
}
|
|
|
|
func (s *BunStorage) DeleteToken(ctx context.Context, token *models.Token) error {
|
|
_, err := s.db.NewDelete().
|
|
Model(token).
|
|
Where("id = ?", token.ID).
|
|
Exec(ctx)
|
|
if err != nil {
|
|
return errs.Errorf("failed to delete token", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BunStorage) ChangeTokenName(ctx context.Context, token *models.Token, name string) (*models.Token, error) {
|
|
newToken := new(TokenModel)
|
|
|
|
_, err := s.db.NewUpdate().
|
|
Model((*TokenModel)(nil)).
|
|
Set("name = ?", name).
|
|
Where("id = ?", token.ID).
|
|
Returning("*").
|
|
Exec(ctx, newToken)
|
|
if err != nil {
|
|
return nil, errs.Errorf("failed to change token name", err)
|
|
}
|
|
|
|
return newToken.toToken(), nil
|
|
}
|
|
|
|
func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) {
|
|
userModel := new(UserModel)
|
|
|
|
err := db.NewSelect().
|
|
Model(userModel).
|
|
Where("username = ?", username).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrUserDoesNotExist
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to find user", err)
|
|
}
|
|
|
|
return userModel, nil
|
|
}
|
|
|
|
func findUserByID(ctx context.Context, db bun.IDB, id string) (*UserModel, error) {
|
|
userModel := new(UserModel)
|
|
|
|
err := db.NewSelect().
|
|
Model(userModel).
|
|
Where("id = ?", id).
|
|
Scan(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
err = errs.ErrUserDoesNotExist
|
|
}
|
|
|
|
return nil, errs.Errorf("failed to find user", err)
|
|
}
|
|
|
|
return userModel, nil
|
|
}
|
|
|
|
func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
|
|
return q.Set("deleted_at = ?", time.Now()).
|
|
Set("deleted = ?", true).
|
|
Set("user_id = ?", nil)
|
|
}
|
|
|
|
func createNewID(ctx context.Context, db bun.IDB, model interface{}) (string, error) {
|
|
var newID string
|
|
|
|
maxIters := 10
|
|
|
|
for {
|
|
// Make sure we don't get stuck in an infinite loop
|
|
maxIters--
|
|
if maxIters <= 0 {
|
|
return "", errs.Errorf("failed to create unique ID", errs.ErrDatabaseError)
|
|
}
|
|
|
|
// Create a new ID and check if it exists
|
|
newID = models.NewID()
|
|
|
|
count, err := db.NewSelect().
|
|
Model(model).
|
|
Column("id").
|
|
Where("id = ?", newID).
|
|
Count(ctx)
|
|
if err != nil {
|
|
return "", errs.Errorf("failed to create unique ID", err)
|
|
}
|
|
|
|
// If the ID does not exist, break the loop
|
|
if count == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return newID, nil
|
|
}
|