From 970990123e9a065f60532bc2e6bc8be2b8cde253 Mon Sep 17 00:00:00 2001 From: Gustavo Maronato Date: Wed, 6 Sep 2023 12:38:27 -0300 Subject: [PATCH] add more tests and stuff --- cmd/dev/dev.go | 2 +- cmd/main.go | 2 + cmd/serve/serve.go | 2 +- internal/config/config.go | 14 + internal/server/api/handler_test.go | 270 +++++++++++++++++++ internal/server/middleware/auth/auth.go | 6 +- internal/server/middleware/logging.go | 26 +- internal/server/short/handler_test.go | 140 ++++++++++ internal/service/shortlog/shortlogservice.go | 29 +- internal/storage/bun/storage.go | 31 +-- internal/storage/bun/storage_test.go | 11 + internal/storage/testing/storagetesting.go | 169 ++++++++++++ 12 files changed, 663 insertions(+), 39 deletions(-) create mode 100644 internal/server/api/handler_test.go create mode 100644 internal/server/short/handler_test.go diff --git a/cmd/dev/dev.go b/cmd/dev/dev.go index 75c818d..ea09572 100644 --- a/cmd/dev/dev.go +++ b/cmd/dev/dev.go @@ -104,7 +104,7 @@ func serveAPI(ctx context.Context, cfg *config.Config) error { shortLogService := shortlogservice.NewShortLogService(storage) // Start short log worker - stopWorker := shortLogService.StartWorker(ctx) + stopWorker, _ := shortLogService.StartWorker(ctx) defer stopWorker() // Create handlers diff --git a/cmd/main.go b/cmd/main.go index e29dbd3..65eb4b4 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -21,6 +21,8 @@ func Run() { // Create the application-wide context, and // implement graceful shutdown. ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + trapSignalsCrossPlatform(cancel) // Create the root command and register subcommands. diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index a1ea11a..f5650dd 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -77,7 +77,7 @@ func exec(ctx context.Context, cfg *config.Config) error { shortLogService := shortlogservice.NewShortLogService(storage) // Start short log worker - stopWorker := shortLogService.StartWorker(ctx) + stopWorker, _ := shortLogService.StartWorker(ctx) defer stopWorker() // Create handlers diff --git a/internal/config/config.go b/internal/config/config.go index f0c5931..189e4d1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -82,6 +82,20 @@ type Config struct { DisableRegistration bool } +func NewConfig() *Config { + return &Config{ + Prod: DefaultProd, + Debug: DefaultDebug, + Host: DefaultHost, + Port: DefaultPort, + UIPort: DefaultUIPort, + DBType: DefaultDBType, + DBURL: DefaultDBURL, + SessionDuration: DefaultSessionDuration, + DisableRegistration: DefaultDisableRegistration, + } +} + func Validate(cfg *Config) error { // Host and port have to be valid. if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil { diff --git a/internal/server/api/handler_test.go b/internal/server/api/handler_test.go new file mode 100644 index 0000000..d866085 --- /dev/null +++ b/internal/server/api/handler_test.go @@ -0,0 +1,270 @@ +package apiserver_test + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.maronato.dev/maronato/goshort/internal/config" + apiserver "git.maronato.dev/maronato/goshort/internal/server/api" + servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware" + authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth" + shortservice "git.maronato.dev/maronato/goshort/internal/service/short" + shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog" + tokenservice "git.maronato.dev/maronato/goshort/internal/service/token" + userservice "git.maronato.dev/maronato/goshort/internal/service/user" + bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun" + "git.maronato.dev/maronato/goshort/internal/storage/models" + randomutil "git.maronato.dev/maronato/goshort/internal/util/random" + "github.com/go-chi/chi/v5/middleware" + "github.com/stretchr/testify/assert" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + _ "modernc.org/sqlite" // Import the SQLite driver. +) + +func newStorage(ctx context.Context, cfg *config.Config) *bunstorage.BunStorage { + filename := randomutil.GenerateSecureToken(16) + url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1" + + sqldb, err := sql.Open("sqlite", url) + if err != nil { + panic(err) + } + + db := bun.NewDB(sqldb, sqlitedialect.New()) + + str := bunstorage.NewBunStorage(cfg, db) + + err = str.Start(ctx) + if err != nil { + panic(err) + } + + return str +} + +func setup(ctx context.Context, cfg *config.Config) ( + *apiserver.APIHandler, + *shortservice.ShortService, + *userservice.UserService, + *tokenservice.TokenService, + *shortlogservice.ShortLogService, +) { + str := newStorage(ctx, cfg) + + shorts := shortservice.NewShortService(str) + users := userservice.NewUserService(cfg, str) + tokens := tokenservice.NewTokenService(str) + shortLogs := shortlogservice.NewShortLogService(str) + + api := apiserver.NewAPIHandler(shorts, users, tokens, shortLogs) + + return api, shorts, users, tokens, shortLogs +} + +func makeRequestServer(handler http.HandlerFunc) http.Handler { + lf := servermiddleware.NewLogFormatter() + lf.DisableStackPrinting() + + rlogger := middleware.RequestLogger(lf) + + // Add middlewares + handlerF := middleware.Recoverer(handler) + handlerF = rlogger(handlerF) + + return handlerF +} + +func makeRequestResponse(ctx context.Context, method, url string, body io.Reader) (*httptest.ResponseRecorder, *http.Request) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + panic(err) + } + + rr := httptest.NewRecorder() + + return rr, req +} + +func withRequestUser(ctx context.Context, users *userservice.UserService) (context.Context, *models.User) { + user := &models.User{ + Username: "user@example.com", + } + + err := users.SetPassword(ctx, user, "mynewpassword") + if err != nil { + panic(err) + } + + newUser, err := users.CreateUser(ctx, user) + if err != nil { + panic(err) + } + + return authmiddleware.WithUser(ctx, newUser), newUser +} + +func TestEndpointMe(t *testing.T) { + ctx := context.Background() + cfg := config.NewConfig() + + api, _, users, _, _ := setup(ctx, cfg) + ctxUser, user := withRequestUser(ctx, users) + server := makeRequestServer(api.Me) + + t.Run("WithUser", func(t *testing.T) { + res, req := makeRequestResponse(ctxUser, "GET", "/api/me", nil) + server.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + + expectedJSON, err := json.Marshal(user) + assert.Nil(t, err) + + assert.JSONEq(t, string(expectedJSON), res.Body.String()) + }) + + t.Run("WithoutUser", func(t *testing.T) { + res, req := makeRequestResponse(ctx, "GET", "/api/me", nil) + server.ServeHTTP(res, req) + + assert.Equal(t, http.StatusInternalServerError, res.Code) + }) +} + +func BenchmarkEndpointMe(b *testing.B) { + ctx := context.Background() + cfg := config.NewConfig() + + api, _, users, _, _ := setup(ctx, cfg) + ctxUser, _ := withRequestUser(ctx, users) + server := makeRequestServer(api.Me) + + for i := 0; i < b.N; i++ { + res, req := makeRequestResponse(ctxUser, "GET", "/api/me", nil) + server.ServeHTTP(res, req) + + assert.Equal(b, http.StatusOK, res.Code) + } +} + +func TestEndpointCreateShort(t *testing.T) { + ctx := context.Background() + cfg := config.NewConfig() + + api, shorts, users, _, _ := setup(ctx, cfg) + ctxUser, _ := withRequestUser(ctx, users) + server := makeRequestServer(api.CreateShort) + + tests := []struct { + name string + ctx context.Context + body io.Reader + responseCode int + shortName string + }{ + { + name: "Valid", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": "https://example.com"}`)), + responseCode: http.StatusCreated, + }, + { + name: "InvalidURL", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": "invalidurl"}`)), + responseCode: http.StatusBadRequest, + }, + { + name: "InvalidJSON", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": "https://example.com"`)), + responseCode: http.StatusBadRequest, + }, + { + name: "WithoutUser", + ctx: ctx, + body: bytes.NewReader([]byte(`{"url": "https://example.com"}`)), + responseCode: http.StatusInternalServerError, + }, + { + name: "WithoutURL", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{}`)), + responseCode: http.StatusBadRequest, + }, + { + name: "EmptyURL", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": ""}`)), + responseCode: http.StatusBadRequest, + }, + { + name: "WithName", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "example"}`)), + responseCode: http.StatusCreated, + shortName: "example", + }, + { + name: "WithInvalidName", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "invalid name"}`)), + responseCode: http.StatusBadRequest, + }, + { + name: "WithExistingName", + ctx: ctxUser, + body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "example"}`)), + responseCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, req := makeRequestResponse(tt.ctx, "POST", "/api/shorts", tt.body) + server.ServeHTTP(res, req) + + assert.Equal(t, tt.responseCode, res.Code) + + if res.Code == http.StatusCreated { + var short models.Short + err := json.Unmarshal(res.Body.Bytes(), &short) + assert.Nil(t, err) + + _, err = shorts.FindShortByID(ctx, short.ID) + assert.Nil(t, err) + + assert.Equal(t, "https://example.com", short.URL) + + if tt.shortName != "" { + assert.Equal(t, tt.shortName, short.Name) + } + } + }) + } +} + +func BenchmarkEndpointCreateShort(b *testing.B) { + ctx := context.Background() + cfg := config.NewConfig() + + api, _, users, _, _ := setup(ctx, cfg) + ctxUser, _ := withRequestUser(ctx, users) + server := makeRequestServer(api.CreateShort) + + body := []byte(`{"url": "https://example.com"}`) + + for i := 0; i < b.N; i++ { + res, req := makeRequestResponse(ctxUser, "POST", "/api/shorts", bytes.NewReader(body)) + server.ServeHTTP(res, req) + + assert.Equal(b, http.StatusCreated, res.Code) + } +} diff --git a/internal/server/middleware/auth/auth.go b/internal/server/middleware/auth/auth.go index f90e06d..66f7a90 100644 --- a/internal/server/middleware/auth/auth.go +++ b/internal/server/middleware/auth/auth.go @@ -33,7 +33,7 @@ func Auth(userService *userservice.UserService, tokenService *tokenservice.Token } // Add user to context - ctx = context.WithValue(ctx, userCtxKey{}, user) + ctx = WithUser(ctx, user) // Register the auth method used ctx = context.WithValue(ctx, authMethodCtxKey{}, authMethod) @@ -71,3 +71,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) { return user, ok } + +func WithUser(ctx context.Context, user *models.User) context.Context { + return context.WithValue(ctx, userCtxKey{}, user) +} diff --git a/internal/server/middleware/logging.go b/internal/server/middleware/logging.go index 41c6c0d..6e6c413 100644 --- a/internal/server/middleware/logging.go +++ b/internal/server/middleware/logging.go @@ -9,15 +9,24 @@ import ( "github.com/go-chi/chi/v5/middleware" ) -type RequestLogFormatter struct{} +type RequestLogFormatter struct { + printStack bool +} func NewLogFormatter() *RequestLogFormatter { - return &RequestLogFormatter{} + return &RequestLogFormatter{ + printStack: true, + } +} + +func (rl *RequestLogFormatter) DisableStackPrinting() { + rl.printStack = false } // RequestLogEntry is a struct that implements the LogEntry interface. type RequestLogEntry struct { - l *slog.Logger + l *slog.Logger + printStack bool } func (rl *RequestLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry { //nolint:ireturn // This is the signature of the interface. @@ -47,7 +56,8 @@ func (rl *RequestLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry l = l.With(requestGroup) return &RequestLogEntry{ - l: l, + l: l, + printStack: rl.printStack, } } @@ -71,6 +81,10 @@ func (le *RequestLogEntry) Write(status, bytes int, header http.Header, elapsed } } -func (le *RequestLogEntry) Panic(v interface{}, _ []byte) { - middleware.PrintPrettyStack(v) +func (le *RequestLogEntry) Panic(v interface{}, stack []byte) { + if le.printStack { + middleware.PrintPrettyStack(v) + } else { + le.l.Error("Panic", slog.Any("stack", stack)) + } } diff --git a/internal/server/short/handler_test.go b/internal/server/short/handler_test.go new file mode 100644 index 0000000..cc964b4 --- /dev/null +++ b/internal/server/short/handler_test.go @@ -0,0 +1,140 @@ +package shortserver_test + +import ( + "context" + "database/sql" + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.maronato.dev/maronato/goshort/internal/config" + servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware" + shortserver "git.maronato.dev/maronato/goshort/internal/server/short" + shortservice "git.maronato.dev/maronato/goshort/internal/service/short" + shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog" + bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun" + "git.maronato.dev/maronato/goshort/internal/storage/models" + handlerutils "git.maronato.dev/maronato/goshort/internal/util/handler" + randomutil "git.maronato.dev/maronato/goshort/internal/util/random" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/stretchr/testify/assert" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + _ "modernc.org/sqlite" // Import the SQLite driver. +) + +func newStorage(ctx context.Context, cfg *config.Config) *bunstorage.BunStorage { + filename := randomutil.GenerateSecureToken(16) + url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1" + + sqldb, err := sql.Open("sqlite", url) + if err != nil { + panic(err) + } + + db := bun.NewDB(sqldb, sqlitedialect.New()) + + str := bunstorage.NewBunStorage(cfg, db) + + err = str.Start(ctx) + if err != nil { + panic(err) + } + + return str +} + +func setup(ctx context.Context, cfg *config.Config) ( + *shortserver.ShortHandler, + *shortservice.ShortService, + *shortlogservice.ShortLogService, +) { + str := newStorage(ctx, cfg) + + shorts := shortservice.NewShortService(str) + shortLogs := shortlogservice.NewShortLogService(str) + + handler := shortserver.NewShortHandler(shorts, shortLogs) + + return handler, shorts, shortLogs +} + +func makeRequestServer(handler http.HandlerFunc) http.Handler { + mux := chi.NewRouter() + lf := servermiddleware.NewLogFormatter() + lf.DisableStackPrinting() + + // Add middlewares + mux.Use(middleware.RequestLogger(lf)) + mux.Use(middleware.Recoverer) + + // Handle all requests + + shortRouter := chi.NewRouter() + shortRouter.Get("/{short}", handler) + + chained := handlerutils.NewChainedHandler( + shortRouter, + http.NotFoundHandler(), + ) + mux.Mount("/", chained) + + return mux +} + +func makeRequestResponse(ctx context.Context, method, url string, body io.Reader) (*httptest.ResponseRecorder, *http.Request) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + panic(err) + } + + rr := httptest.NewRecorder() + + return rr, req +} + +func TestEndpointFindShort(t *testing.T) { + ctx := context.Background() + cfg := config.NewConfig() + + handler, shorts, shortLogs := setup(ctx, cfg) + + stopWorker, debugCh := shortLogs.StartWorker(ctx) + defer stopWorker() + + server := makeRequestServer(handler.FindShort) + + short, err := shorts.Shorten(ctx, &models.Short{ + URL: "https://example.com", + Name: "example", + }) + if err != nil { + t.Error(err) + } + + t.Run("Redirects", func(t *testing.T) { + res, req := makeRequestResponse(ctx, "GET", "/"+short.Name, nil) + server.ServeHTTP(res, req) + + assert.Equal(t, http.StatusSeeOther, res.Code) + + assert.Equal(t, short.URL, res.Header().Get("Location")) + + // Check if the access was logged + <-debugCh + logs, err := shortLogs.ListLogs(ctx, short) + assert.Nil(t, err) + + assert.Equal(t, 1, len(logs)) + assert.Equal(t, short.ID, logs[0].ShortID) + }) + + t.Run("NotFound", func(t *testing.T) { + res, req := makeRequestResponse(ctx, "GET", "/other", nil) + server.ServeHTTP(res, req) + + assert.Equal(t, http.StatusNotFound, res.Code) + }) +} diff --git a/internal/service/shortlog/shortlogservice.go b/internal/service/shortlog/shortlogservice.go index 537e862..634e52a 100644 --- a/internal/service/shortlog/shortlogservice.go +++ b/internal/service/shortlog/shortlogservice.go @@ -17,24 +17,30 @@ const logChLength = 100 type ShortLogService struct { db storage.Storage logCh chan *models.ShortLog + // debugCh exposes the log channel to the tests + debugCh chan *models.ShortLog } func NewShortLogService(db storage.Storage) *ShortLogService { logCh := make(chan *models.ShortLog, logChLength) + debugCh := make(chan *models.ShortLog) return &ShortLogService{ - db: db, - logCh: logCh, + db: db, + logCh: logCh, + debugCh: debugCh, } } -func (s *ShortLogService) StartWorker(ctx context.Context) context.CancelFunc { - workerCtx, cancel := context.WithCancel(ctx) +func (s *ShortLogService) StartWorker(ctx context.Context) (stopWorker context.CancelFunc, debugCh <-chan *models.ShortLog) { + workerCtx, stopWorker := context.WithCancel(ctx) // Start the goroutine go s.shortLogWorker(workerCtx) - return cancel + debugCh = s.debugCh + + return stopWorker, debugCh } func (s *ShortLogService) shortLogWorker(ctx context.Context) { @@ -52,6 +58,13 @@ func (s *ShortLogService) shortLogWorker(ctx context.Context) { l.Debug("writing short log to storage", slog.String("short_id", shortLog.ShortID)) err := s.db.CreateShortLog(ctx, shortLog) + + select { + case s.debugCh <- shortLog: + default: + l.Warn("short log debug queue is full, dropping log") + } + if err != nil { l.Warn("failed to log short access", "error", err) } @@ -74,7 +87,11 @@ func (s *ShortLogService) LogShortAccess(ctx context.Context, short *models.Shor l.Debug("adding short log to queue", slog.String("short_id", short.ID)) // Send the log to the channel - s.logCh <- shortLog + select { + case s.logCh <- shortLog: + default: + l.Warn("short log queue is full, dropping log") + } } func (s *ShortLogService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) { diff --git a/internal/storage/bun/storage.go b/internal/storage/bun/storage.go index fee3cb3..b776b4a 100644 --- a/internal/storage/bun/storage.go +++ b/internal/storage/bun/storage.go @@ -208,12 +208,9 @@ func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*mod if short == nil { return nil, errs.Errorf("failed to create short", errs.ErrInvalidShort) } - // Generate an ID that does not exist - tableName := s.db.NewSelect(). - Model((*ShortModel)(nil)). - GetTableName() - newID, err := createNewID(ctx, s.db, tableName) + // Generate an ID that does not exist + newID, err := createNewID(ctx, s.db, (*ShortModel)(nil)) if err != nil { return nil, errs.Errorf("failed to create short", err) } @@ -294,11 +291,7 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error { // Create new ID - tableName := s.db.NewSelect(). - Model((*ShortLogModel)(nil)). - GetTableName() - - newID, err := createNewID(ctx, s.db, tableName) + newID, err := createNewID(ctx, s.db, (*ShortLogModel)(nil)) if err != nil { return errs.Errorf("failed to create short log", err) } @@ -365,11 +358,7 @@ func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) { // Create a new ID - tableName := s.db.NewSelect(). - Model((*UserModel)(nil)). - GetTableName() - - newID, err := createNewID(ctx, s.db, tableName) + newID, err := createNewID(ctx, s.db, (*UserModel)(nil)) if err != nil { return nil, errs.Errorf("failed to create user", err) } @@ -511,11 +500,7 @@ func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*mod } // Create a new ID - tableName := s.db.NewSelect(). - Model((*TokenModel)(nil)). - GetTableName() - - newID, err := createNewID(ctx, s.db, tableName) + newID, err := createNewID(ctx, s.db, (*TokenModel)(nil)) if err != nil { return nil, errs.Errorf("failed to create token", err) } @@ -611,12 +596,10 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery { Set("user_id = ?", nil) } -func createNewID(ctx context.Context, db bun.IDB, table string) (string, error) { +func createNewID(ctx context.Context, db bun.IDB, model interface{}) (string, error) { var newID string - // Generate an ID that does not exist maxIters := 10 - for { // Make sure we don't get stuck in an infinite loop maxIters-- @@ -628,7 +611,7 @@ func createNewID(ctx context.Context, db bun.IDB, table string) (string, error) newID = models.NewID() count, err := db.NewSelect(). - Table(table). + Model(model). Column("id"). Where("id = ?", newID). Count(ctx) diff --git a/internal/storage/bun/storage_test.go b/internal/storage/bun/storage_test.go index 86f83ab..c590695 100644 --- a/internal/storage/bun/storage_test.go +++ b/internal/storage/bun/storage_test.go @@ -54,6 +54,17 @@ func TestStorageInterface(t *testing.T) { storagetesting.ITestComplete(t, getStg) } +func BenchmarkStorageInterface(b *testing.B) { + getStg := func() storage.Storage { + db := newNewBunDBHelper() + bs := NewBunStorage(&config.Config{}, db) + + return bs + } + + storagetesting.IBenchmarkComplete(b, getStg) +} + func TestBunStorage_Start(t *testing.T) { ctx := context.Background() db := newNewBunDBHelper() diff --git a/internal/storage/testing/storagetesting.go b/internal/storage/testing/storagetesting.go index 33dc6c2..e7c74c2 100644 --- a/internal/storage/testing/storagetesting.go +++ b/internal/storage/testing/storagetesting.go @@ -4,6 +4,7 @@ package storagetesting import ( "context" "errors" + "fmt" "testing" "time" @@ -19,6 +20,7 @@ var ErrPlaceholder = errors.New("placeholder error") func ITestComplete(t *testing.T, getStg func() storage.Storage) { t.Helper() + t.Parallel() tests := []struct { name string @@ -129,14 +131,54 @@ func ITestComplete(t *testing.T, getStg func() storage.Storage) { } } +func IBenchmarkComplete(b *testing.B, getStg func() storage.Storage) { + b.Helper() + + tests := []struct { + name string + test func(b *testing.B, stg storage.Storage) + }{ + { + name: "CreateShort", + test: IBenchmarkCreateShort, + }, + { + name: "CreateShortLog", + test: IBenchmarkCreateShortLog, + }, + { + name: "CreateUser", + test: IBenchmarkCreateUser, + }, + { + name: "CreateToken", + test: IBenchmarkCreateToken, + }, + { + name: "FindShort", + test: IBenchmarkFindShort, + }, + } + + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + stg := getStg() + _ = stg.Start(context.Background()) + test.test(b, stg) + }) + } +} + func ITestImplements(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() assert.Implements(t, (*storage.Storage)(nil), stg, "Should implement the storage.Storage interface") } func ITestStart(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -151,6 +193,7 @@ func ITestStart(t *testing.T, stg storage.Storage) { func ITestStop(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -168,6 +211,7 @@ func ITestStop(t *testing.T, stg storage.Storage) { func ITestPing(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -182,6 +226,7 @@ func ITestPing(t *testing.T, stg storage.Storage) { func ITestCreateShort(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -303,6 +348,7 @@ func ITestCreateShort(t *testing.T, stg storage.Storage) { func ITestFindShort(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -332,6 +378,7 @@ func ITestFindShort(t *testing.T, stg storage.Storage) { func ITestFindShortByID(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -361,6 +408,7 @@ func ITestFindShortByID(t *testing.T, stg storage.Storage) { func ITestDeleteShort(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -400,6 +448,7 @@ func ITestDeleteShort(t *testing.T, stg storage.Storage) { func ITestListShorts(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -451,6 +500,7 @@ func ITestListShorts(t *testing.T, stg storage.Storage) { func ITestCreateShortLog(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -497,6 +547,7 @@ func ITestCreateShortLog(t *testing.T, stg storage.Storage) { func ITestListShortLogs(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -533,6 +584,7 @@ func ITestListShortLogs(t *testing.T, stg storage.Storage) { func ITestCreateUser(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -613,6 +665,7 @@ func ITestCreateUser(t *testing.T, stg storage.Storage) { func ITestFindUser(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -645,6 +698,7 @@ func ITestFindUser(t *testing.T, stg storage.Storage) { func ITestFindUserByID(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -677,6 +731,7 @@ func ITestFindUserByID(t *testing.T, stg storage.Storage) { func ITestDeleteUser(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -738,6 +793,7 @@ func ITestDeleteUser(t *testing.T, stg storage.Storage) { func ITestCreateToken(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -820,6 +876,7 @@ func ITestCreateToken(t *testing.T, stg storage.Storage) { func ITestFindToken(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -848,6 +905,7 @@ func ITestFindToken(t *testing.T, stg storage.Storage) { func ITestFindTokenByID(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -876,6 +934,7 @@ func ITestFindTokenByID(t *testing.T, stg storage.Storage) { func ITestDeleteToken(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -903,6 +962,7 @@ func ITestDeleteToken(t *testing.T, stg storage.Storage) { func ITestListTokens(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -947,6 +1007,7 @@ func ITestListTokens(t *testing.T, stg storage.Storage) { func ITestChangeTokenName(t *testing.T, stg storage.Storage) { t.Helper() + t.Parallel() ctx := context.Background() @@ -966,3 +1027,111 @@ func ITestChangeTokenName(t *testing.T, stg storage.Storage) { assert.Equal(t, token.UserID, newToken.UserID, "Should return the same UserID") assert.Equal(t, token.CreatedAt, newToken.CreatedAt, "Should return the same CreatedAt") } + +func IBenchmarkCreateShort(b *testing.B, stg storage.Storage) { + b.Helper() + + ctx := context.Background() + + userID := "user" + url := "https://example.com" + + for i := 0; i < b.N; i++ { + _, err := stg.CreateShort(ctx, &models.Short{ + Name: fmt.Sprintf("myshort%d", i), + URL: url, + UserID: &userID, + }) + if err != nil { + b.Fatal(err) + } + } +} + +func IBenchmarkCreateShortLog(b *testing.B, stg storage.Storage) { + b.Helper() + + ctx := context.Background() + + short, _ := stg.CreateShort(ctx, &models.Short{ + Name: "myshort", + URL: "https://example.com", + }) + ipAddress := "myip" + userAgent := "myuseragent" + referer := "myreferer" + + for i := 0; i < b.N; i++ { + _ = stg.CreateShortLog(ctx, &models.ShortLog{ + ShortID: short.ID, + IPAddress: ipAddress, + UserAgent: userAgent, + Referer: referer, + }) + } +} + +func IBenchmarkCreateUser(b *testing.B, stg storage.Storage) { + b.Helper() + + ctx := context.Background() + + for i := 0; i < b.N; i++ { + _, err := stg.CreateUser(ctx, &models.User{ + Username: fmt.Sprintf("myusername%d", i), + }) + if err != nil { + b.Fatal(err) + } + } +} + +func IBenchmarkCreateToken(b *testing.B, stg storage.Storage) { + b.Helper() + + ctx := context.Background() + + userID := "user" + + for i := 0; i < b.N; i++ { + _, err := stg.CreateToken(ctx, &models.Token{ + Value: fmt.Sprintf("myvalue%d", i), + UserID: &userID, + }) + if err != nil { + b.Fatal(err) + } + } +} + +func IBenchmarkFindShort(b *testing.B, stg storage.Storage) { + b.Helper() + + ctx := context.Background() + + userID := "user" + + short, err := stg.CreateShort(ctx, &models.Short{ + Name: "myshort", + URL: "https://example.com", + UserID: &userID, + }) + + b.Run("Existing", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err = stg.FindShort(ctx, short.Name) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("Non-existing", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err = stg.FindShort(ctx, "non-existing") + if err != nil && !errors.Is(err, errs.ErrShortDoesNotExist) { + b.Fatal(err) + } + } + }) +}