add more tests and stuff

This commit is contained in:
Gustavo Maronato 2023-09-06 12:38:27 -03:00
parent fe90938f51
commit 970990123e
Signed by: maronato
SSH Key Fingerprint: SHA256:2Gw7kwMz/As+2UkR1qQ/qYYhn+WNh3FGv6ozhoRrLcs
12 changed files with 663 additions and 39 deletions

View File

@ -104,7 +104,7 @@ func serveAPI(ctx context.Context, cfg *config.Config) error {
shortLogService := shortlogservice.NewShortLogService(storage)
// Start short log worker
stopWorker := shortLogService.StartWorker(ctx)
stopWorker, _ := shortLogService.StartWorker(ctx)
defer stopWorker()
// Create handlers

View File

@ -21,6 +21,8 @@ func Run() {
// Create the application-wide context, and
// implement graceful shutdown.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
trapSignalsCrossPlatform(cancel)
// Create the root command and register subcommands.

View File

@ -77,7 +77,7 @@ func exec(ctx context.Context, cfg *config.Config) error {
shortLogService := shortlogservice.NewShortLogService(storage)
// Start short log worker
stopWorker := shortLogService.StartWorker(ctx)
stopWorker, _ := shortLogService.StartWorker(ctx)
defer stopWorker()
// Create handlers

View File

@ -82,6 +82,20 @@ type Config struct {
DisableRegistration bool
}
func NewConfig() *Config {
return &Config{
Prod: DefaultProd,
Debug: DefaultDebug,
Host: DefaultHost,
Port: DefaultPort,
UIPort: DefaultUIPort,
DBType: DefaultDBType,
DBURL: DefaultDBURL,
SessionDuration: DefaultSessionDuration,
DisableRegistration: DefaultDisableRegistration,
}
}
func Validate(cfg *Config) error {
// Host and port have to be valid.
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil {

View File

@ -0,0 +1,270 @@
package apiserver_test
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"git.maronato.dev/maronato/goshort/internal/config"
apiserver "git.maronato.dev/maronato/goshort/internal/server/api"
servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware"
authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth"
shortservice "git.maronato.dev/maronato/goshort/internal/service/short"
shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog"
tokenservice "git.maronato.dev/maronato/goshort/internal/service/token"
userservice "git.maronato.dev/maronato/goshort/internal/service/user"
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
"git.maronato.dev/maronato/goshort/internal/storage/models"
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
_ "modernc.org/sqlite" // Import the SQLite driver.
)
func newStorage(ctx context.Context, cfg *config.Config) *bunstorage.BunStorage {
filename := randomutil.GenerateSecureToken(16)
url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1"
sqldb, err := sql.Open("sqlite", url)
if err != nil {
panic(err)
}
db := bun.NewDB(sqldb, sqlitedialect.New())
str := bunstorage.NewBunStorage(cfg, db)
err = str.Start(ctx)
if err != nil {
panic(err)
}
return str
}
func setup(ctx context.Context, cfg *config.Config) (
*apiserver.APIHandler,
*shortservice.ShortService,
*userservice.UserService,
*tokenservice.TokenService,
*shortlogservice.ShortLogService,
) {
str := newStorage(ctx, cfg)
shorts := shortservice.NewShortService(str)
users := userservice.NewUserService(cfg, str)
tokens := tokenservice.NewTokenService(str)
shortLogs := shortlogservice.NewShortLogService(str)
api := apiserver.NewAPIHandler(shorts, users, tokens, shortLogs)
return api, shorts, users, tokens, shortLogs
}
func makeRequestServer(handler http.HandlerFunc) http.Handler {
lf := servermiddleware.NewLogFormatter()
lf.DisableStackPrinting()
rlogger := middleware.RequestLogger(lf)
// Add middlewares
handlerF := middleware.Recoverer(handler)
handlerF = rlogger(handlerF)
return handlerF
}
func makeRequestResponse(ctx context.Context, method, url string, body io.Reader) (*httptest.ResponseRecorder, *http.Request) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
panic(err)
}
rr := httptest.NewRecorder()
return rr, req
}
func withRequestUser(ctx context.Context, users *userservice.UserService) (context.Context, *models.User) {
user := &models.User{
Username: "user@example.com",
}
err := users.SetPassword(ctx, user, "mynewpassword")
if err != nil {
panic(err)
}
newUser, err := users.CreateUser(ctx, user)
if err != nil {
panic(err)
}
return authmiddleware.WithUser(ctx, newUser), newUser
}
func TestEndpointMe(t *testing.T) {
ctx := context.Background()
cfg := config.NewConfig()
api, _, users, _, _ := setup(ctx, cfg)
ctxUser, user := withRequestUser(ctx, users)
server := makeRequestServer(api.Me)
t.Run("WithUser", func(t *testing.T) {
res, req := makeRequestResponse(ctxUser, "GET", "/api/me", nil)
server.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
expectedJSON, err := json.Marshal(user)
assert.Nil(t, err)
assert.JSONEq(t, string(expectedJSON), res.Body.String())
})
t.Run("WithoutUser", func(t *testing.T) {
res, req := makeRequestResponse(ctx, "GET", "/api/me", nil)
server.ServeHTTP(res, req)
assert.Equal(t, http.StatusInternalServerError, res.Code)
})
}
func BenchmarkEndpointMe(b *testing.B) {
ctx := context.Background()
cfg := config.NewConfig()
api, _, users, _, _ := setup(ctx, cfg)
ctxUser, _ := withRequestUser(ctx, users)
server := makeRequestServer(api.Me)
for i := 0; i < b.N; i++ {
res, req := makeRequestResponse(ctxUser, "GET", "/api/me", nil)
server.ServeHTTP(res, req)
assert.Equal(b, http.StatusOK, res.Code)
}
}
func TestEndpointCreateShort(t *testing.T) {
ctx := context.Background()
cfg := config.NewConfig()
api, shorts, users, _, _ := setup(ctx, cfg)
ctxUser, _ := withRequestUser(ctx, users)
server := makeRequestServer(api.CreateShort)
tests := []struct {
name string
ctx context.Context
body io.Reader
responseCode int
shortName string
}{
{
name: "Valid",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": "https://example.com"}`)),
responseCode: http.StatusCreated,
},
{
name: "InvalidURL",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": "invalidurl"}`)),
responseCode: http.StatusBadRequest,
},
{
name: "InvalidJSON",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": "https://example.com"`)),
responseCode: http.StatusBadRequest,
},
{
name: "WithoutUser",
ctx: ctx,
body: bytes.NewReader([]byte(`{"url": "https://example.com"}`)),
responseCode: http.StatusInternalServerError,
},
{
name: "WithoutURL",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{}`)),
responseCode: http.StatusBadRequest,
},
{
name: "EmptyURL",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": ""}`)),
responseCode: http.StatusBadRequest,
},
{
name: "WithName",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "example"}`)),
responseCode: http.StatusCreated,
shortName: "example",
},
{
name: "WithInvalidName",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "invalid name"}`)),
responseCode: http.StatusBadRequest,
},
{
name: "WithExistingName",
ctx: ctxUser,
body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "example"}`)),
responseCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, req := makeRequestResponse(tt.ctx, "POST", "/api/shorts", tt.body)
server.ServeHTTP(res, req)
assert.Equal(t, tt.responseCode, res.Code)
if res.Code == http.StatusCreated {
var short models.Short
err := json.Unmarshal(res.Body.Bytes(), &short)
assert.Nil(t, err)
_, err = shorts.FindShortByID(ctx, short.ID)
assert.Nil(t, err)
assert.Equal(t, "https://example.com", short.URL)
if tt.shortName != "" {
assert.Equal(t, tt.shortName, short.Name)
}
}
})
}
}
func BenchmarkEndpointCreateShort(b *testing.B) {
ctx := context.Background()
cfg := config.NewConfig()
api, _, users, _, _ := setup(ctx, cfg)
ctxUser, _ := withRequestUser(ctx, users)
server := makeRequestServer(api.CreateShort)
body := []byte(`{"url": "https://example.com"}`)
for i := 0; i < b.N; i++ {
res, req := makeRequestResponse(ctxUser, "POST", "/api/shorts", bytes.NewReader(body))
server.ServeHTTP(res, req)
assert.Equal(b, http.StatusCreated, res.Code)
}
}

View File

@ -33,7 +33,7 @@ func Auth(userService *userservice.UserService, tokenService *tokenservice.Token
}
// Add user to context
ctx = context.WithValue(ctx, userCtxKey{}, user)
ctx = WithUser(ctx, user)
// Register the auth method used
ctx = context.WithValue(ctx, authMethodCtxKey{}, authMethod)
@ -71,3 +71,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) {
return user, ok
}
func WithUser(ctx context.Context, user *models.User) context.Context {
return context.WithValue(ctx, userCtxKey{}, user)
}

View File

@ -9,15 +9,24 @@ import (
"github.com/go-chi/chi/v5/middleware"
)
type RequestLogFormatter struct{}
type RequestLogFormatter struct {
printStack bool
}
func NewLogFormatter() *RequestLogFormatter {
return &RequestLogFormatter{}
return &RequestLogFormatter{
printStack: true,
}
}
func (rl *RequestLogFormatter) DisableStackPrinting() {
rl.printStack = false
}
// RequestLogEntry is a struct that implements the LogEntry interface.
type RequestLogEntry struct {
l *slog.Logger
l *slog.Logger
printStack bool
}
func (rl *RequestLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry { //nolint:ireturn // This is the signature of the interface.
@ -47,7 +56,8 @@ func (rl *RequestLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry
l = l.With(requestGroup)
return &RequestLogEntry{
l: l,
l: l,
printStack: rl.printStack,
}
}
@ -71,6 +81,10 @@ func (le *RequestLogEntry) Write(status, bytes int, header http.Header, elapsed
}
}
func (le *RequestLogEntry) Panic(v interface{}, _ []byte) {
middleware.PrintPrettyStack(v)
func (le *RequestLogEntry) Panic(v interface{}, stack []byte) {
if le.printStack {
middleware.PrintPrettyStack(v)
} else {
le.l.Error("Panic", slog.Any("stack", stack))
}
}

View File

@ -0,0 +1,140 @@
package shortserver_test
import (
"context"
"database/sql"
"io"
"net/http"
"net/http/httptest"
"testing"
"git.maronato.dev/maronato/goshort/internal/config"
servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware"
shortserver "git.maronato.dev/maronato/goshort/internal/server/short"
shortservice "git.maronato.dev/maronato/goshort/internal/service/short"
shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog"
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
"git.maronato.dev/maronato/goshort/internal/storage/models"
handlerutils "git.maronato.dev/maronato/goshort/internal/util/handler"
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
_ "modernc.org/sqlite" // Import the SQLite driver.
)
func newStorage(ctx context.Context, cfg *config.Config) *bunstorage.BunStorage {
filename := randomutil.GenerateSecureToken(16)
url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1"
sqldb, err := sql.Open("sqlite", url)
if err != nil {
panic(err)
}
db := bun.NewDB(sqldb, sqlitedialect.New())
str := bunstorage.NewBunStorage(cfg, db)
err = str.Start(ctx)
if err != nil {
panic(err)
}
return str
}
func setup(ctx context.Context, cfg *config.Config) (
*shortserver.ShortHandler,
*shortservice.ShortService,
*shortlogservice.ShortLogService,
) {
str := newStorage(ctx, cfg)
shorts := shortservice.NewShortService(str)
shortLogs := shortlogservice.NewShortLogService(str)
handler := shortserver.NewShortHandler(shorts, shortLogs)
return handler, shorts, shortLogs
}
func makeRequestServer(handler http.HandlerFunc) http.Handler {
mux := chi.NewRouter()
lf := servermiddleware.NewLogFormatter()
lf.DisableStackPrinting()
// Add middlewares
mux.Use(middleware.RequestLogger(lf))
mux.Use(middleware.Recoverer)
// Handle all requests
shortRouter := chi.NewRouter()
shortRouter.Get("/{short}", handler)
chained := handlerutils.NewChainedHandler(
shortRouter,
http.NotFoundHandler(),
)
mux.Mount("/", chained)
return mux
}
func makeRequestResponse(ctx context.Context, method, url string, body io.Reader) (*httptest.ResponseRecorder, *http.Request) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
panic(err)
}
rr := httptest.NewRecorder()
return rr, req
}
func TestEndpointFindShort(t *testing.T) {
ctx := context.Background()
cfg := config.NewConfig()
handler, shorts, shortLogs := setup(ctx, cfg)
stopWorker, debugCh := shortLogs.StartWorker(ctx)
defer stopWorker()
server := makeRequestServer(handler.FindShort)
short, err := shorts.Shorten(ctx, &models.Short{
URL: "https://example.com",
Name: "example",
})
if err != nil {
t.Error(err)
}
t.Run("Redirects", func(t *testing.T) {
res, req := makeRequestResponse(ctx, "GET", "/"+short.Name, nil)
server.ServeHTTP(res, req)
assert.Equal(t, http.StatusSeeOther, res.Code)
assert.Equal(t, short.URL, res.Header().Get("Location"))
// Check if the access was logged
<-debugCh
logs, err := shortLogs.ListLogs(ctx, short)
assert.Nil(t, err)
assert.Equal(t, 1, len(logs))
assert.Equal(t, short.ID, logs[0].ShortID)
})
t.Run("NotFound", func(t *testing.T) {
res, req := makeRequestResponse(ctx, "GET", "/other", nil)
server.ServeHTTP(res, req)
assert.Equal(t, http.StatusNotFound, res.Code)
})
}

View File

@ -17,24 +17,30 @@ const logChLength = 100
type ShortLogService struct {
db storage.Storage
logCh chan *models.ShortLog
// debugCh exposes the log channel to the tests
debugCh chan *models.ShortLog
}
func NewShortLogService(db storage.Storage) *ShortLogService {
logCh := make(chan *models.ShortLog, logChLength)
debugCh := make(chan *models.ShortLog)
return &ShortLogService{
db: db,
logCh: logCh,
db: db,
logCh: logCh,
debugCh: debugCh,
}
}
func (s *ShortLogService) StartWorker(ctx context.Context) context.CancelFunc {
workerCtx, cancel := context.WithCancel(ctx)
func (s *ShortLogService) StartWorker(ctx context.Context) (stopWorker context.CancelFunc, debugCh <-chan *models.ShortLog) {
workerCtx, stopWorker := context.WithCancel(ctx)
// Start the goroutine
go s.shortLogWorker(workerCtx)
return cancel
debugCh = s.debugCh
return stopWorker, debugCh
}
func (s *ShortLogService) shortLogWorker(ctx context.Context) {
@ -52,6 +58,13 @@ func (s *ShortLogService) shortLogWorker(ctx context.Context) {
l.Debug("writing short log to storage", slog.String("short_id", shortLog.ShortID))
err := s.db.CreateShortLog(ctx, shortLog)
select {
case s.debugCh <- shortLog:
default:
l.Warn("short log debug queue is full, dropping log")
}
if err != nil {
l.Warn("failed to log short access", "error", err)
}
@ -74,7 +87,11 @@ func (s *ShortLogService) LogShortAccess(ctx context.Context, short *models.Shor
l.Debug("adding short log to queue", slog.String("short_id", short.ID))
// Send the log to the channel
s.logCh <- shortLog
select {
case s.logCh <- shortLog:
default:
l.Warn("short log queue is full, dropping log")
}
}
func (s *ShortLogService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {

View File

@ -208,12 +208,9 @@ func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*mod
if short == nil {
return nil, errs.Errorf("failed to create short", errs.ErrInvalidShort)
}
// Generate an ID that does not exist
tableName := s.db.NewSelect().
Model((*ShortModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName)
// Generate an ID that does not exist
newID, err := createNewID(ctx, s.db, (*ShortModel)(nil))
if err != nil {
return nil, errs.Errorf("failed to create short", err)
}
@ -294,11 +291,7 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
// Create new ID
tableName := s.db.NewSelect().
Model((*ShortLogModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName)
newID, err := createNewID(ctx, s.db, (*ShortLogModel)(nil))
if err != nil {
return errs.Errorf("failed to create short log", err)
}
@ -365,11 +358,7 @@ func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User,
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
// Create a new ID
tableName := s.db.NewSelect().
Model((*UserModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName)
newID, err := createNewID(ctx, s.db, (*UserModel)(nil))
if err != nil {
return nil, errs.Errorf("failed to create user", err)
}
@ -511,11 +500,7 @@ func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*mod
}
// Create a new ID
tableName := s.db.NewSelect().
Model((*TokenModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName)
newID, err := createNewID(ctx, s.db, (*TokenModel)(nil))
if err != nil {
return nil, errs.Errorf("failed to create token", err)
}
@ -611,12 +596,10 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
Set("user_id = ?", nil)
}
func createNewID(ctx context.Context, db bun.IDB, table string) (string, error) {
func createNewID(ctx context.Context, db bun.IDB, model interface{}) (string, error) {
var newID string
// Generate an ID that does not exist
maxIters := 10
for {
// Make sure we don't get stuck in an infinite loop
maxIters--
@ -628,7 +611,7 @@ func createNewID(ctx context.Context, db bun.IDB, table string) (string, error)
newID = models.NewID()
count, err := db.NewSelect().
Table(table).
Model(model).
Column("id").
Where("id = ?", newID).
Count(ctx)

View File

@ -54,6 +54,17 @@ func TestStorageInterface(t *testing.T) {
storagetesting.ITestComplete(t, getStg)
}
func BenchmarkStorageInterface(b *testing.B) {
getStg := func() storage.Storage {
db := newNewBunDBHelper()
bs := NewBunStorage(&config.Config{}, db)
return bs
}
storagetesting.IBenchmarkComplete(b, getStg)
}
func TestBunStorage_Start(t *testing.T) {
ctx := context.Background()
db := newNewBunDBHelper()

View File

@ -4,6 +4,7 @@ package storagetesting
import (
"context"
"errors"
"fmt"
"testing"
"time"
@ -19,6 +20,7 @@ var ErrPlaceholder = errors.New("placeholder error")
func ITestComplete(t *testing.T, getStg func() storage.Storage) {
t.Helper()
t.Parallel()
tests := []struct {
name string
@ -129,14 +131,54 @@ func ITestComplete(t *testing.T, getStg func() storage.Storage) {
}
}
func IBenchmarkComplete(b *testing.B, getStg func() storage.Storage) {
b.Helper()
tests := []struct {
name string
test func(b *testing.B, stg storage.Storage)
}{
{
name: "CreateShort",
test: IBenchmarkCreateShort,
},
{
name: "CreateShortLog",
test: IBenchmarkCreateShortLog,
},
{
name: "CreateUser",
test: IBenchmarkCreateUser,
},
{
name: "CreateToken",
test: IBenchmarkCreateToken,
},
{
name: "FindShort",
test: IBenchmarkFindShort,
},
}
for _, test := range tests {
b.Run(test.name, func(b *testing.B) {
stg := getStg()
_ = stg.Start(context.Background())
test.test(b, stg)
})
}
}
func ITestImplements(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
assert.Implements(t, (*storage.Storage)(nil), stg, "Should implement the storage.Storage interface")
}
func ITestStart(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -151,6 +193,7 @@ func ITestStart(t *testing.T, stg storage.Storage) {
func ITestStop(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -168,6 +211,7 @@ func ITestStop(t *testing.T, stg storage.Storage) {
func ITestPing(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -182,6 +226,7 @@ func ITestPing(t *testing.T, stg storage.Storage) {
func ITestCreateShort(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -303,6 +348,7 @@ func ITestCreateShort(t *testing.T, stg storage.Storage) {
func ITestFindShort(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -332,6 +378,7 @@ func ITestFindShort(t *testing.T, stg storage.Storage) {
func ITestFindShortByID(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -361,6 +408,7 @@ func ITestFindShortByID(t *testing.T, stg storage.Storage) {
func ITestDeleteShort(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -400,6 +448,7 @@ func ITestDeleteShort(t *testing.T, stg storage.Storage) {
func ITestListShorts(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -451,6 +500,7 @@ func ITestListShorts(t *testing.T, stg storage.Storage) {
func ITestCreateShortLog(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -497,6 +547,7 @@ func ITestCreateShortLog(t *testing.T, stg storage.Storage) {
func ITestListShortLogs(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -533,6 +584,7 @@ func ITestListShortLogs(t *testing.T, stg storage.Storage) {
func ITestCreateUser(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -613,6 +665,7 @@ func ITestCreateUser(t *testing.T, stg storage.Storage) {
func ITestFindUser(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -645,6 +698,7 @@ func ITestFindUser(t *testing.T, stg storage.Storage) {
func ITestFindUserByID(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -677,6 +731,7 @@ func ITestFindUserByID(t *testing.T, stg storage.Storage) {
func ITestDeleteUser(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -738,6 +793,7 @@ func ITestDeleteUser(t *testing.T, stg storage.Storage) {
func ITestCreateToken(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -820,6 +876,7 @@ func ITestCreateToken(t *testing.T, stg storage.Storage) {
func ITestFindToken(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -848,6 +905,7 @@ func ITestFindToken(t *testing.T, stg storage.Storage) {
func ITestFindTokenByID(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -876,6 +934,7 @@ func ITestFindTokenByID(t *testing.T, stg storage.Storage) {
func ITestDeleteToken(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -903,6 +962,7 @@ func ITestDeleteToken(t *testing.T, stg storage.Storage) {
func ITestListTokens(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -947,6 +1007,7 @@ func ITestListTokens(t *testing.T, stg storage.Storage) {
func ITestChangeTokenName(t *testing.T, stg storage.Storage) {
t.Helper()
t.Parallel()
ctx := context.Background()
@ -966,3 +1027,111 @@ func ITestChangeTokenName(t *testing.T, stg storage.Storage) {
assert.Equal(t, token.UserID, newToken.UserID, "Should return the same UserID")
assert.Equal(t, token.CreatedAt, newToken.CreatedAt, "Should return the same CreatedAt")
}
func IBenchmarkCreateShort(b *testing.B, stg storage.Storage) {
b.Helper()
ctx := context.Background()
userID := "user"
url := "https://example.com"
for i := 0; i < b.N; i++ {
_, err := stg.CreateShort(ctx, &models.Short{
Name: fmt.Sprintf("myshort%d", i),
URL: url,
UserID: &userID,
})
if err != nil {
b.Fatal(err)
}
}
}
func IBenchmarkCreateShortLog(b *testing.B, stg storage.Storage) {
b.Helper()
ctx := context.Background()
short, _ := stg.CreateShort(ctx, &models.Short{
Name: "myshort",
URL: "https://example.com",
})
ipAddress := "myip"
userAgent := "myuseragent"
referer := "myreferer"
for i := 0; i < b.N; i++ {
_ = stg.CreateShortLog(ctx, &models.ShortLog{
ShortID: short.ID,
IPAddress: ipAddress,
UserAgent: userAgent,
Referer: referer,
})
}
}
func IBenchmarkCreateUser(b *testing.B, stg storage.Storage) {
b.Helper()
ctx := context.Background()
for i := 0; i < b.N; i++ {
_, err := stg.CreateUser(ctx, &models.User{
Username: fmt.Sprintf("myusername%d", i),
})
if err != nil {
b.Fatal(err)
}
}
}
func IBenchmarkCreateToken(b *testing.B, stg storage.Storage) {
b.Helper()
ctx := context.Background()
userID := "user"
for i := 0; i < b.N; i++ {
_, err := stg.CreateToken(ctx, &models.Token{
Value: fmt.Sprintf("myvalue%d", i),
UserID: &userID,
})
if err != nil {
b.Fatal(err)
}
}
}
func IBenchmarkFindShort(b *testing.B, stg storage.Storage) {
b.Helper()
ctx := context.Background()
userID := "user"
short, err := stg.CreateShort(ctx, &models.Short{
Name: "myshort",
URL: "https://example.com",
UserID: &userID,
})
b.Run("Existing", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err = stg.FindShort(ctx, short.Name)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("Non-existing", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err = stg.FindShort(ctx, "non-existing")
if err != nil && !errors.Is(err, errs.ErrShortDoesNotExist) {
b.Fatal(err)
}
}
})
}