goshort/internal/storage/bun/storage.go
Gustavo Maronato 8ac3eef33a
Some checks failed
Build / build (push) Has been cancelled
implemented tracing
2023-09-21 16:19:17 -03:00

637 lines
15 KiB
Go

package bunstorage
import (
"context"
"database/sql"
"errors"
"strings"
"time"
"git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage/models"
"github.com/uptrace/bun"
"github.com/uptrace/bun/extra/bundebug"
"github.com/uptrace/bun/extra/bunotel"
)
type BunStorage struct {
db *bun.DB
// StartHooks are ran after the database is connected and the tables are created.
StartHooks []func(ctx context.Context, db *bun.DB) error
// StopHooks are ran before the database is closed.
StopHook []func(ctx context.Context, db *bun.DB) error
// started is true if the storage has been started.
started bool
}
func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
if cfg.Verbose >= config.VerboseLevelDebug {
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
}
// Enable tracing.
db.AddQueryHook(bunotel.NewQueryHook(
bunotel.WithDBName("goshort")),
)
return &BunStorage{
db: db,
started: false,
}
}
// RegisterStartHook registers a hook to be ran after the database is connected and the tables are created.
func (s *BunStorage) RegisterStartHook(hook func(ctx context.Context, db *bun.DB) error) {
s.StartHooks = append(s.StartHooks, hook)
}
// RegisterStopHook registers a hook to be ran before the database is closed.
func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB) error) {
s.StopHook = append(s.StopHook, hook)
}
func (s *BunStorage) Start(ctx context.Context) error { //nolint:cyclop // This function is not that complex.
if s.started {
return errs.Errorf("failed to start storage", errs.ErrStorageStarted)
}
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
_, err := tx.NewCreateTable().
IfNotExists().
Model((*UserModel)(nil)).
WithForeignKeys().
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create users table", err)
}
_, err = tx.NewCreateTable().
IfNotExists().
Model((*ShortModel)(nil)).
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create shorts table", err)
}
_, err = tx.NewCreateTable().
IfNotExists().
Model((*TokenModel)(nil)).
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create tokens table", err)
}
_, err = tx.NewCreateTable().
IfNotExists().
Model((*ShortLogModel)(nil)).
ForeignKey(`("short_id") REFERENCES "shorts" ("id") ON DELETE CASCADE`).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short logs table", err)
}
// shorts user_id index
_, err = tx.NewCreateIndex().
IfNotExists().
Model((*ShortModel)(nil)).
Index("idx_shorts_user_id").
Column("user_id").
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create shorts user_id index", err)
}
// tokens user_id index
_, err = tx.NewCreateIndex().
IfNotExists().
Model((*TokenModel)(nil)).
Index("idx_tokens_user_id").
Column("user_id").
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create tokens user_id index", err)
}
// short_logs short_id index
_, err = tx.NewCreateIndex().
IfNotExists().
Model((*ShortLogModel)(nil)).
Index("idx_short_logs_short_id").
Column("short_id").
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short_logs short_id index", err)
}
return nil
})
if err != nil {
return errs.Errorf("failed to start storage", err)
}
// Run start hooks
for _, hook := range s.StartHooks {
err := hook(ctx, s.db)
if err != nil {
return errs.Errorf("failed to start storage", err)
}
}
s.started = true
return nil
}
func (s *BunStorage) Stop(ctx context.Context) error {
if !s.started {
return errs.Errorf("failed to stop storage", errs.ErrStorageNotStarted)
}
// Run stop hooks
for _, hook := range s.StopHook {
err := hook(ctx, s.db)
if err != nil {
return errs.Errorf("failed to stop storage", err)
}
}
if err := s.db.Close(); err != nil {
return errs.Errorf("failed to stop storage", err)
}
s.started = false
return nil
}
func (s *BunStorage) Ping(ctx context.Context) error {
if err := s.db.PingContext(ctx); err != nil {
return errs.Errorf("failed to ping database", err)
}
return nil
}
func (s *BunStorage) FindShort(ctx context.Context, name string) (*models.Short, error) {
shortModel := new(ShortModel)
err := s.db.NewSelect().
Model(shortModel).
Where("name = ? and deleted = false", name).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrShortDoesNotExist
}
return nil, errs.Errorf("failed to find short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
shortModel := new(ShortModel)
err := s.db.NewSelect().
Model(shortModel).
Where("id = ? and deleted = false", id).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrShortDoesNotExist
}
return nil, errs.Errorf("failed to find short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
if short == nil {
return nil, errs.Errorf("failed to create short", errs.ErrInvalidShort)
}
// Generate an ID that does not exist
newID, err := createNewID(ctx, s.db, (*ShortModel)(nil))
if err != nil {
return nil, errs.Errorf("failed to create short", err)
}
shortModel := &ShortModel{
ID: newID,
Name: short.Name,
URL: short.URL,
UserID: short.UserID,
}
_, err = s.db.NewInsert().
Model(shortModel).
Exec(ctx)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
err = errs.ErrShortExists
}
return nil, errs.Errorf("failed to create short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error {
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete short logs
_, err := tx.NewDelete().
Model((*ShortLogModel)(nil)).
Where("short_id = ?", short.ID).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete short", err)
}
// Delete short
_, err = withShortDeleteUpdates(
tx.NewUpdate().
Model((*ShortModel)(nil)).
Where("id = ? and deleted = false", short.ID),
).Exec(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrShortDoesNotExist
}
return errs.Errorf("failed to delete short", err)
}
return nil
})
if err != nil {
return errs.Errorf("failed to delete short", err)
}
return nil
}
func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
shortModels := []*ShortModel{}
err := s.db.NewSelect().
Model(&shortModels).
Where("user_id = ? and deleted = false", user.ID).
Scan(ctx)
if err != nil {
return nil, errs.Errorf("failed to list shorts", err)
}
shorts := make([]*models.Short, len(shortModels))
for i, shortModel := range shortModels {
shorts[i] = shortModel.toShort()
}
return shorts, nil
}
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
// Create new ID
newID, err := createNewID(ctx, s.db, (*ShortLogModel)(nil))
if err != nil {
return errs.Errorf("failed to create short log", err)
}
shortLogModel := &ShortLogModel{
ID: newID,
ShortID: shortLog.ShortID,
IPAddress: shortLog.IPAddress,
UserAgent: shortLog.UserAgent,
Referer: shortLog.Referer,
CreatedAt: shortLog.CreatedAt,
}
_, err = s.db.NewInsert().
Model(shortLogModel).
Exec(ctx)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
err = errs.ErrShortLogExists
}
return errs.Errorf("failed to create short log", err)
}
return nil
}
func (s *BunStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
shortLogModels := []*ShortLogModel{}
err := s.db.NewSelect().
Model(&shortLogModels).
Where("short_id = ?", short.ID).
Scan(ctx)
if err != nil {
return nil, errs.Errorf("failed to list short logs", err)
}
shortLogs := make([]*models.ShortLog, len(shortLogModels))
for i, shortLogModel := range shortLogModels {
shortLogs[i] = shortLogModel.toShortLog()
}
return shortLogs, nil
}
func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
user, err := findUser(ctx, s.db, username)
if err != nil {
return nil, err
}
return user.toUser(), nil
}
func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
user, err := findUserByID(ctx, s.db, id)
if err != nil {
return nil, err
}
return user.toUser(), nil
}
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
// Create a new ID
newID, err := createNewID(ctx, s.db, (*UserModel)(nil))
if err != nil {
return nil, errs.Errorf("failed to create user", err)
}
if user == nil {
return nil, errs.Errorf("failed to create user", errs.ErrInvalidUser)
}
userModel := &UserModel{
ID: newID,
Username: user.Username,
Password: user.GetPasswordHash(),
}
_, err = s.db.NewInsert().
Model(userModel).
Exec(ctx)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
err = errs.ErrUserExists
}
return nil, errs.Errorf("failed to create user", err)
}
return userModel.toUser(), nil
}
func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
// Delete user in transaction
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete user short logs
_, err := tx.NewDelete().
Model((*ShortLogModel)(nil)).
Where("short_id IN (?)", tx.NewSelect().
Model((*ShortModel)(nil)).
Column("id").
Where("user_id = ?", user.ID)).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete short logs from user's shorts", err)
}
// Delete user shorts
_, err = withShortDeleteUpdates(
tx.NewUpdate().
Model((*ShortModel)(nil)).
Where("user_id = ?", user.ID),
).Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete user shorts", err)
}
// Delete user tokens
_, err = tx.NewDelete().
Model((*TokenModel)(nil)).
Where("user_id = ?", user.ID).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete user", err)
}
// Delete user
_, err = tx.NewDelete().
Model(user).
Where("id = ?", user.ID).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete user", err)
}
return nil
})
if err != nil {
return errs.Errorf("failed to delete user", err)
}
return nil
}
func (s *BunStorage) FindToken(ctx context.Context, value string) (*models.Token, error) {
tokenModel := new(TokenModel)
err := s.db.NewSelect().
Model(tokenModel).
Where("value = ?", value).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrTokenDoesNotExist
}
return nil, errs.Errorf("failed to find token", err)
}
return tokenModel.toToken(), nil
}
func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
tokenModel := new(TokenModel)
err := s.db.NewSelect().
Model(tokenModel).
Where("id = ?", id).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrTokenDoesNotExist
}
return nil, errs.Errorf("failed to find token", err)
}
return tokenModel.toToken(), nil
}
func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
tokenModels := []*TokenModel{}
err := s.db.NewSelect().
Model(&tokenModels).
Where("user_id = ?", user.ID).
Scan(ctx)
if err != nil {
return nil, errs.Errorf("failed to list tokens", err)
}
tokens := make([]*models.Token, len(tokenModels))
for i, tokenModel := range tokenModels {
tokens[i] = tokenModel.toToken()
}
return tokens, nil
}
func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
if token == nil {
return nil, errs.Errorf("failed to create token", errs.ErrInvalidToken)
}
// Create a new ID
newID, err := createNewID(ctx, s.db, (*TokenModel)(nil))
if err != nil {
return nil, errs.Errorf("failed to create token", err)
}
tokenModel := &TokenModel{
ID: newID,
Name: token.Name,
Value: token.Value,
UserID: token.UserID,
}
_, err = s.db.NewInsert().
Model(tokenModel).
Exec(ctx)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
err = errs.ErrTokenExists
}
return nil, errs.Errorf("failed to create token", err)
}
return tokenModel.toToken(), nil
}
func (s *BunStorage) DeleteToken(ctx context.Context, token *models.Token) error {
_, err := s.db.NewDelete().
Model(token).
Where("id = ?", token.ID).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete token", err)
}
return nil
}
func (s *BunStorage) ChangeTokenName(ctx context.Context, token *models.Token, name string) (*models.Token, error) {
newToken := new(TokenModel)
_, err := s.db.NewUpdate().
Model((*TokenModel)(nil)).
Set("name = ?", name).
Where("id = ?", token.ID).
Returning("*").
Exec(ctx, newToken)
if err != nil {
return nil, errs.Errorf("failed to change token name", err)
}
return newToken.toToken(), nil
}
func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) {
userModel := new(UserModel)
err := db.NewSelect().
Model(userModel).
Where("username = ?", username).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrUserDoesNotExist
}
return nil, errs.Errorf("failed to find user", err)
}
return userModel, nil
}
func findUserByID(ctx context.Context, db bun.IDB, id string) (*UserModel, error) {
userModel := new(UserModel)
err := db.NewSelect().
Model(userModel).
Where("id = ?", id).
Scan(ctx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = errs.ErrUserDoesNotExist
}
return nil, errs.Errorf("failed to find user", err)
}
return userModel, nil
}
func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
return q.Set("deleted_at = ?", time.Now()).
Set("deleted = ?", true).
Set("user_id = ?", nil)
}
func createNewID(ctx context.Context, db bun.IDB, model interface{}) (string, error) {
var newID string
maxIters := 10
for {
// Make sure we don't get stuck in an infinite loop
maxIters--
if maxIters <= 0 {
return "", errs.Errorf("failed to create unique ID", errs.ErrDatabaseError)
}
// Create a new ID and check if it exists
newID = models.NewID()
count, err := db.NewSelect().
Model(model).
Column("id").
Where("id = ?", newID).
Count(ctx)
if err != nil {
return "", errs.Errorf("failed to create unique ID", err)
}
// If the ID does not exist, break the loop
if count == 0 {
break
}
}
return newID, nil
}