add more tests and stuff
This commit is contained in:
parent
fe90938f51
commit
970990123e
|
@ -104,7 +104,7 @@ func serveAPI(ctx context.Context, cfg *config.Config) error {
|
|||
shortLogService := shortlogservice.NewShortLogService(storage)
|
||||
|
||||
// Start short log worker
|
||||
stopWorker := shortLogService.StartWorker(ctx)
|
||||
stopWorker, _ := shortLogService.StartWorker(ctx)
|
||||
defer stopWorker()
|
||||
|
||||
// Create handlers
|
||||
|
|
|
@ -21,6 +21,8 @@ func Run() {
|
|||
// Create the application-wide context, and
|
||||
// implement graceful shutdown.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
trapSignalsCrossPlatform(cancel)
|
||||
|
||||
// Create the root command and register subcommands.
|
||||
|
|
|
@ -77,7 +77,7 @@ func exec(ctx context.Context, cfg *config.Config) error {
|
|||
shortLogService := shortlogservice.NewShortLogService(storage)
|
||||
|
||||
// Start short log worker
|
||||
stopWorker := shortLogService.StartWorker(ctx)
|
||||
stopWorker, _ := shortLogService.StartWorker(ctx)
|
||||
defer stopWorker()
|
||||
|
||||
// Create handlers
|
||||
|
|
|
@ -82,6 +82,20 @@ type Config struct {
|
|||
DisableRegistration bool
|
||||
}
|
||||
|
||||
func NewConfig() *Config {
|
||||
return &Config{
|
||||
Prod: DefaultProd,
|
||||
Debug: DefaultDebug,
|
||||
Host: DefaultHost,
|
||||
Port: DefaultPort,
|
||||
UIPort: DefaultUIPort,
|
||||
DBType: DefaultDBType,
|
||||
DBURL: DefaultDBURL,
|
||||
SessionDuration: DefaultSessionDuration,
|
||||
DisableRegistration: DefaultDisableRegistration,
|
||||
}
|
||||
}
|
||||
|
||||
func Validate(cfg *Config) error {
|
||||
// Host and port have to be valid.
|
||||
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil {
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
package apiserver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
apiserver "git.maronato.dev/maronato/goshort/internal/server/api"
|
||||
servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware"
|
||||
authmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/auth"
|
||||
shortservice "git.maronato.dev/maronato/goshort/internal/service/short"
|
||||
shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog"
|
||||
tokenservice "git.maronato.dev/maronato/goshort/internal/service/token"
|
||||
userservice "git.maronato.dev/maronato/goshort/internal/service/user"
|
||||
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
_ "modernc.org/sqlite" // Import the SQLite driver.
|
||||
)
|
||||
|
||||
func newStorage(ctx context.Context, cfg *config.Config) *bunstorage.BunStorage {
|
||||
filename := randomutil.GenerateSecureToken(16)
|
||||
url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1"
|
||||
|
||||
sqldb, err := sql.Open("sqlite", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
|
||||
str := bunstorage.NewBunStorage(cfg, db)
|
||||
|
||||
err = str.Start(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
func setup(ctx context.Context, cfg *config.Config) (
|
||||
*apiserver.APIHandler,
|
||||
*shortservice.ShortService,
|
||||
*userservice.UserService,
|
||||
*tokenservice.TokenService,
|
||||
*shortlogservice.ShortLogService,
|
||||
) {
|
||||
str := newStorage(ctx, cfg)
|
||||
|
||||
shorts := shortservice.NewShortService(str)
|
||||
users := userservice.NewUserService(cfg, str)
|
||||
tokens := tokenservice.NewTokenService(str)
|
||||
shortLogs := shortlogservice.NewShortLogService(str)
|
||||
|
||||
api := apiserver.NewAPIHandler(shorts, users, tokens, shortLogs)
|
||||
|
||||
return api, shorts, users, tokens, shortLogs
|
||||
}
|
||||
|
||||
func makeRequestServer(handler http.HandlerFunc) http.Handler {
|
||||
lf := servermiddleware.NewLogFormatter()
|
||||
lf.DisableStackPrinting()
|
||||
|
||||
rlogger := middleware.RequestLogger(lf)
|
||||
|
||||
// Add middlewares
|
||||
handlerF := middleware.Recoverer(handler)
|
||||
handlerF = rlogger(handlerF)
|
||||
|
||||
return handlerF
|
||||
}
|
||||
|
||||
func makeRequestResponse(ctx context.Context, method, url string, body io.Reader) (*httptest.ResponseRecorder, *http.Request) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
return rr, req
|
||||
}
|
||||
|
||||
func withRequestUser(ctx context.Context, users *userservice.UserService) (context.Context, *models.User) {
|
||||
user := &models.User{
|
||||
Username: "user@example.com",
|
||||
}
|
||||
|
||||
err := users.SetPassword(ctx, user, "mynewpassword")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
newUser, err := users.CreateUser(ctx, user)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return authmiddleware.WithUser(ctx, newUser), newUser
|
||||
}
|
||||
|
||||
func TestEndpointMe(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := config.NewConfig()
|
||||
|
||||
api, _, users, _, _ := setup(ctx, cfg)
|
||||
ctxUser, user := withRequestUser(ctx, users)
|
||||
server := makeRequestServer(api.Me)
|
||||
|
||||
t.Run("WithUser", func(t *testing.T) {
|
||||
res, req := makeRequestResponse(ctxUser, "GET", "/api/me", nil)
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
|
||||
expectedJSON, err := json.Marshal(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.JSONEq(t, string(expectedJSON), res.Body.String())
|
||||
})
|
||||
|
||||
t.Run("WithoutUser", func(t *testing.T) {
|
||||
res, req := makeRequestResponse(ctx, "GET", "/api/me", nil)
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, res.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkEndpointMe(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
cfg := config.NewConfig()
|
||||
|
||||
api, _, users, _, _ := setup(ctx, cfg)
|
||||
ctxUser, _ := withRequestUser(ctx, users)
|
||||
server := makeRequestServer(api.Me)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, req := makeRequestResponse(ctxUser, "GET", "/api/me", nil)
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(b, http.StatusOK, res.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointCreateShort(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := config.NewConfig()
|
||||
|
||||
api, shorts, users, _, _ := setup(ctx, cfg)
|
||||
ctxUser, _ := withRequestUser(ctx, users)
|
||||
server := makeRequestServer(api.CreateShort)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
body io.Reader
|
||||
responseCode int
|
||||
shortName string
|
||||
}{
|
||||
{
|
||||
name: "Valid",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": "https://example.com"}`)),
|
||||
responseCode: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "InvalidURL",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": "invalidurl"}`)),
|
||||
responseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "InvalidJSON",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": "https://example.com"`)),
|
||||
responseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "WithoutUser",
|
||||
ctx: ctx,
|
||||
body: bytes.NewReader([]byte(`{"url": "https://example.com"}`)),
|
||||
responseCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "WithoutURL",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{}`)),
|
||||
responseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "EmptyURL",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": ""}`)),
|
||||
responseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "WithName",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "example"}`)),
|
||||
responseCode: http.StatusCreated,
|
||||
shortName: "example",
|
||||
},
|
||||
{
|
||||
name: "WithInvalidName",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "invalid name"}`)),
|
||||
responseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "WithExistingName",
|
||||
ctx: ctxUser,
|
||||
body: bytes.NewReader([]byte(`{"url": "https://example.com", "name": "example"}`)),
|
||||
responseCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
res, req := makeRequestResponse(tt.ctx, "POST", "/api/shorts", tt.body)
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tt.responseCode, res.Code)
|
||||
|
||||
if res.Code == http.StatusCreated {
|
||||
var short models.Short
|
||||
err := json.Unmarshal(res.Body.Bytes(), &short)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = shorts.FindShortByID(ctx, short.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "https://example.com", short.URL)
|
||||
|
||||
if tt.shortName != "" {
|
||||
assert.Equal(t, tt.shortName, short.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEndpointCreateShort(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
cfg := config.NewConfig()
|
||||
|
||||
api, _, users, _, _ := setup(ctx, cfg)
|
||||
ctxUser, _ := withRequestUser(ctx, users)
|
||||
server := makeRequestServer(api.CreateShort)
|
||||
|
||||
body := []byte(`{"url": "https://example.com"}`)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, req := makeRequestResponse(ctxUser, "POST", "/api/shorts", bytes.NewReader(body))
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(b, http.StatusCreated, res.Code)
|
||||
}
|
||||
}
|
|
@ -33,7 +33,7 @@ func Auth(userService *userservice.UserService, tokenService *tokenservice.Token
|
|||
}
|
||||
|
||||
// Add user to context
|
||||
ctx = context.WithValue(ctx, userCtxKey{}, user)
|
||||
ctx = WithUser(ctx, user)
|
||||
// Register the auth method used
|
||||
ctx = context.WithValue(ctx, authMethodCtxKey{}, authMethod)
|
||||
|
||||
|
@ -71,3 +71,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) {
|
|||
|
||||
return user, ok
|
||||
}
|
||||
|
||||
func WithUser(ctx context.Context, user *models.User) context.Context {
|
||||
return context.WithValue(ctx, userCtxKey{}, user)
|
||||
}
|
||||
|
|
|
@ -9,15 +9,24 @@ import (
|
|||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
type RequestLogFormatter struct{}
|
||||
type RequestLogFormatter struct {
|
||||
printStack bool
|
||||
}
|
||||
|
||||
func NewLogFormatter() *RequestLogFormatter {
|
||||
return &RequestLogFormatter{}
|
||||
return &RequestLogFormatter{
|
||||
printStack: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RequestLogFormatter) DisableStackPrinting() {
|
||||
rl.printStack = false
|
||||
}
|
||||
|
||||
// RequestLogEntry is a struct that implements the LogEntry interface.
|
||||
type RequestLogEntry struct {
|
||||
l *slog.Logger
|
||||
l *slog.Logger
|
||||
printStack bool
|
||||
}
|
||||
|
||||
func (rl *RequestLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry { //nolint:ireturn // This is the signature of the interface.
|
||||
|
@ -47,7 +56,8 @@ func (rl *RequestLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry
|
|||
l = l.With(requestGroup)
|
||||
|
||||
return &RequestLogEntry{
|
||||
l: l,
|
||||
l: l,
|
||||
printStack: rl.printStack,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,6 +81,10 @@ func (le *RequestLogEntry) Write(status, bytes int, header http.Header, elapsed
|
|||
}
|
||||
}
|
||||
|
||||
func (le *RequestLogEntry) Panic(v interface{}, _ []byte) {
|
||||
middleware.PrintPrettyStack(v)
|
||||
func (le *RequestLogEntry) Panic(v interface{}, stack []byte) {
|
||||
if le.printStack {
|
||||
middleware.PrintPrettyStack(v)
|
||||
} else {
|
||||
le.l.Error("Panic", slog.Any("stack", stack))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
package shortserver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware"
|
||||
shortserver "git.maronato.dev/maronato/goshort/internal/server/short"
|
||||
shortservice "git.maronato.dev/maronato/goshort/internal/service/short"
|
||||
shortlogservice "git.maronato.dev/maronato/goshort/internal/service/shortlog"
|
||||
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
handlerutils "git.maronato.dev/maronato/goshort/internal/util/handler"
|
||||
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
_ "modernc.org/sqlite" // Import the SQLite driver.
|
||||
)
|
||||
|
||||
func newStorage(ctx context.Context, cfg *config.Config) *bunstorage.BunStorage {
|
||||
filename := randomutil.GenerateSecureToken(16)
|
||||
url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1"
|
||||
|
||||
sqldb, err := sql.Open("sqlite", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
|
||||
str := bunstorage.NewBunStorage(cfg, db)
|
||||
|
||||
err = str.Start(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
func setup(ctx context.Context, cfg *config.Config) (
|
||||
*shortserver.ShortHandler,
|
||||
*shortservice.ShortService,
|
||||
*shortlogservice.ShortLogService,
|
||||
) {
|
||||
str := newStorage(ctx, cfg)
|
||||
|
||||
shorts := shortservice.NewShortService(str)
|
||||
shortLogs := shortlogservice.NewShortLogService(str)
|
||||
|
||||
handler := shortserver.NewShortHandler(shorts, shortLogs)
|
||||
|
||||
return handler, shorts, shortLogs
|
||||
}
|
||||
|
||||
func makeRequestServer(handler http.HandlerFunc) http.Handler {
|
||||
mux := chi.NewRouter()
|
||||
lf := servermiddleware.NewLogFormatter()
|
||||
lf.DisableStackPrinting()
|
||||
|
||||
// Add middlewares
|
||||
mux.Use(middleware.RequestLogger(lf))
|
||||
mux.Use(middleware.Recoverer)
|
||||
|
||||
// Handle all requests
|
||||
|
||||
shortRouter := chi.NewRouter()
|
||||
shortRouter.Get("/{short}", handler)
|
||||
|
||||
chained := handlerutils.NewChainedHandler(
|
||||
shortRouter,
|
||||
http.NotFoundHandler(),
|
||||
)
|
||||
mux.Mount("/", chained)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
func makeRequestResponse(ctx context.Context, method, url string, body io.Reader) (*httptest.ResponseRecorder, *http.Request) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
return rr, req
|
||||
}
|
||||
|
||||
func TestEndpointFindShort(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := config.NewConfig()
|
||||
|
||||
handler, shorts, shortLogs := setup(ctx, cfg)
|
||||
|
||||
stopWorker, debugCh := shortLogs.StartWorker(ctx)
|
||||
defer stopWorker()
|
||||
|
||||
server := makeRequestServer(handler.FindShort)
|
||||
|
||||
short, err := shorts.Shorten(ctx, &models.Short{
|
||||
URL: "https://example.com",
|
||||
Name: "example",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
t.Run("Redirects", func(t *testing.T) {
|
||||
res, req := makeRequestResponse(ctx, "GET", "/"+short.Name, nil)
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusSeeOther, res.Code)
|
||||
|
||||
assert.Equal(t, short.URL, res.Header().Get("Location"))
|
||||
|
||||
// Check if the access was logged
|
||||
<-debugCh
|
||||
logs, err := shortLogs.ListLogs(ctx, short)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(logs))
|
||||
assert.Equal(t, short.ID, logs[0].ShortID)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
res, req := makeRequestResponse(ctx, "GET", "/other", nil)
|
||||
server.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, res.Code)
|
||||
})
|
||||
}
|
|
@ -17,24 +17,30 @@ const logChLength = 100
|
|||
type ShortLogService struct {
|
||||
db storage.Storage
|
||||
logCh chan *models.ShortLog
|
||||
// debugCh exposes the log channel to the tests
|
||||
debugCh chan *models.ShortLog
|
||||
}
|
||||
|
||||
func NewShortLogService(db storage.Storage) *ShortLogService {
|
||||
logCh := make(chan *models.ShortLog, logChLength)
|
||||
debugCh := make(chan *models.ShortLog)
|
||||
|
||||
return &ShortLogService{
|
||||
db: db,
|
||||
logCh: logCh,
|
||||
db: db,
|
||||
logCh: logCh,
|
||||
debugCh: debugCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ShortLogService) StartWorker(ctx context.Context) context.CancelFunc {
|
||||
workerCtx, cancel := context.WithCancel(ctx)
|
||||
func (s *ShortLogService) StartWorker(ctx context.Context) (stopWorker context.CancelFunc, debugCh <-chan *models.ShortLog) {
|
||||
workerCtx, stopWorker := context.WithCancel(ctx)
|
||||
|
||||
// Start the goroutine
|
||||
go s.shortLogWorker(workerCtx)
|
||||
|
||||
return cancel
|
||||
debugCh = s.debugCh
|
||||
|
||||
return stopWorker, debugCh
|
||||
}
|
||||
|
||||
func (s *ShortLogService) shortLogWorker(ctx context.Context) {
|
||||
|
@ -52,6 +58,13 @@ func (s *ShortLogService) shortLogWorker(ctx context.Context) {
|
|||
l.Debug("writing short log to storage", slog.String("short_id", shortLog.ShortID))
|
||||
|
||||
err := s.db.CreateShortLog(ctx, shortLog)
|
||||
|
||||
select {
|
||||
case s.debugCh <- shortLog:
|
||||
default:
|
||||
l.Warn("short log debug queue is full, dropping log")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
l.Warn("failed to log short access", "error", err)
|
||||
}
|
||||
|
@ -74,7 +87,11 @@ func (s *ShortLogService) LogShortAccess(ctx context.Context, short *models.Shor
|
|||
l.Debug("adding short log to queue", slog.String("short_id", short.ID))
|
||||
|
||||
// Send the log to the channel
|
||||
s.logCh <- shortLog
|
||||
select {
|
||||
case s.logCh <- shortLog:
|
||||
default:
|
||||
l.Warn("short log queue is full, dropping log")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ShortLogService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
|
||||
|
|
|
@ -208,12 +208,9 @@ func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*mod
|
|||
if short == nil {
|
||||
return nil, errs.Errorf("failed to create short", errs.ErrInvalidShort)
|
||||
}
|
||||
// Generate an ID that does not exist
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*ShortModel)(nil)).
|
||||
GetTableName()
|
||||
|
||||
newID, err := createNewID(ctx, s.db, tableName)
|
||||
// Generate an ID that does not exist
|
||||
newID, err := createNewID(ctx, s.db, (*ShortModel)(nil))
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to create short", err)
|
||||
}
|
||||
|
@ -294,11 +291,7 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode
|
|||
|
||||
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
|
||||
// Create new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*ShortLogModel)(nil)).
|
||||
GetTableName()
|
||||
|
||||
newID, err := createNewID(ctx, s.db, tableName)
|
||||
newID, err := createNewID(ctx, s.db, (*ShortLogModel)(nil))
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to create short log", err)
|
||||
}
|
||||
|
@ -365,11 +358,7 @@ func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User,
|
|||
|
||||
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
|
||||
// Create a new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*UserModel)(nil)).
|
||||
GetTableName()
|
||||
|
||||
newID, err := createNewID(ctx, s.db, tableName)
|
||||
newID, err := createNewID(ctx, s.db, (*UserModel)(nil))
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to create user", err)
|
||||
}
|
||||
|
@ -511,11 +500,7 @@ func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*mod
|
|||
}
|
||||
|
||||
// Create a new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*TokenModel)(nil)).
|
||||
GetTableName()
|
||||
|
||||
newID, err := createNewID(ctx, s.db, tableName)
|
||||
newID, err := createNewID(ctx, s.db, (*TokenModel)(nil))
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to create token", err)
|
||||
}
|
||||
|
@ -611,12 +596,10 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
|
|||
Set("user_id = ?", nil)
|
||||
}
|
||||
|
||||
func createNewID(ctx context.Context, db bun.IDB, table string) (string, error) {
|
||||
func createNewID(ctx context.Context, db bun.IDB, model interface{}) (string, error) {
|
||||
var newID string
|
||||
|
||||
// Generate an ID that does not exist
|
||||
maxIters := 10
|
||||
|
||||
for {
|
||||
// Make sure we don't get stuck in an infinite loop
|
||||
maxIters--
|
||||
|
@ -628,7 +611,7 @@ func createNewID(ctx context.Context, db bun.IDB, table string) (string, error)
|
|||
newID = models.NewID()
|
||||
|
||||
count, err := db.NewSelect().
|
||||
Table(table).
|
||||
Model(model).
|
||||
Column("id").
|
||||
Where("id = ?", newID).
|
||||
Count(ctx)
|
||||
|
|
|
@ -54,6 +54,17 @@ func TestStorageInterface(t *testing.T) {
|
|||
storagetesting.ITestComplete(t, getStg)
|
||||
}
|
||||
|
||||
func BenchmarkStorageInterface(b *testing.B) {
|
||||
getStg := func() storage.Storage {
|
||||
db := newNewBunDBHelper()
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
return bs
|
||||
}
|
||||
|
||||
storagetesting.IBenchmarkComplete(b, getStg)
|
||||
}
|
||||
|
||||
func TestBunStorage_Start(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := newNewBunDBHelper()
|
||||
|
|
|
@ -4,6 +4,7 @@ package storagetesting
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -19,6 +20,7 @@ var ErrPlaceholder = errors.New("placeholder error")
|
|||
|
||||
func ITestComplete(t *testing.T, getStg func() storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -129,14 +131,54 @@ func ITestComplete(t *testing.T, getStg func() storage.Storage) {
|
|||
}
|
||||
}
|
||||
|
||||
func IBenchmarkComplete(b *testing.B, getStg func() storage.Storage) {
|
||||
b.Helper()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
test func(b *testing.B, stg storage.Storage)
|
||||
}{
|
||||
{
|
||||
name: "CreateShort",
|
||||
test: IBenchmarkCreateShort,
|
||||
},
|
||||
{
|
||||
name: "CreateShortLog",
|
||||
test: IBenchmarkCreateShortLog,
|
||||
},
|
||||
{
|
||||
name: "CreateUser",
|
||||
test: IBenchmarkCreateUser,
|
||||
},
|
||||
{
|
||||
name: "CreateToken",
|
||||
test: IBenchmarkCreateToken,
|
||||
},
|
||||
{
|
||||
name: "FindShort",
|
||||
test: IBenchmarkFindShort,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
b.Run(test.name, func(b *testing.B) {
|
||||
stg := getStg()
|
||||
_ = stg.Start(context.Background())
|
||||
test.test(b, stg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ITestImplements(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
assert.Implements(t, (*storage.Storage)(nil), stg, "Should implement the storage.Storage interface")
|
||||
}
|
||||
|
||||
func ITestStart(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -151,6 +193,7 @@ func ITestStart(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestStop(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -168,6 +211,7 @@ func ITestStop(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestPing(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -182,6 +226,7 @@ func ITestPing(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestCreateShort(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -303,6 +348,7 @@ func ITestCreateShort(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestFindShort(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -332,6 +378,7 @@ func ITestFindShort(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestFindShortByID(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -361,6 +408,7 @@ func ITestFindShortByID(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestDeleteShort(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -400,6 +448,7 @@ func ITestDeleteShort(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestListShorts(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -451,6 +500,7 @@ func ITestListShorts(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestCreateShortLog(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -497,6 +547,7 @@ func ITestCreateShortLog(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestListShortLogs(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -533,6 +584,7 @@ func ITestListShortLogs(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestCreateUser(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -613,6 +665,7 @@ func ITestCreateUser(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestFindUser(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -645,6 +698,7 @@ func ITestFindUser(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestFindUserByID(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -677,6 +731,7 @@ func ITestFindUserByID(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestDeleteUser(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -738,6 +793,7 @@ func ITestDeleteUser(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestCreateToken(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -820,6 +876,7 @@ func ITestCreateToken(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestFindToken(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -848,6 +905,7 @@ func ITestFindToken(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestFindTokenByID(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -876,6 +934,7 @@ func ITestFindTokenByID(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestDeleteToken(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -903,6 +962,7 @@ func ITestDeleteToken(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestListTokens(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -947,6 +1007,7 @@ func ITestListTokens(t *testing.T, stg storage.Storage) {
|
|||
|
||||
func ITestChangeTokenName(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
@ -966,3 +1027,111 @@ func ITestChangeTokenName(t *testing.T, stg storage.Storage) {
|
|||
assert.Equal(t, token.UserID, newToken.UserID, "Should return the same UserID")
|
||||
assert.Equal(t, token.CreatedAt, newToken.CreatedAt, "Should return the same CreatedAt")
|
||||
}
|
||||
|
||||
func IBenchmarkCreateShort(b *testing.B, stg storage.Storage) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
url := "https://example.com"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := stg.CreateShort(ctx, &models.Short{
|
||||
Name: fmt.Sprintf("myshort%d", i),
|
||||
URL: url,
|
||||
UserID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IBenchmarkCreateShortLog(b *testing.B, stg storage.Storage) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
short, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
})
|
||||
ipAddress := "myip"
|
||||
userAgent := "myuseragent"
|
||||
referer := "myreferer"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: short.ID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
Referer: referer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func IBenchmarkCreateUser(b *testing.B, stg storage.Storage) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := stg.CreateUser(ctx, &models.User{
|
||||
Username: fmt.Sprintf("myusername%d", i),
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IBenchmarkCreateToken(b *testing.B, stg storage.Storage) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := stg.CreateToken(ctx, &models.Token{
|
||||
Value: fmt.Sprintf("myvalue%d", i),
|
||||
UserID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func IBenchmarkFindShort(b *testing.B, stg storage.Storage) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
short, err := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
})
|
||||
|
||||
b.Run("Existing", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = stg.FindShort(ctx, short.Name)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Non-existing", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = stg.FindShort(ctx, "non-existing")
|
||||
if err != nil && !errors.Is(err, errs.ErrShortDoesNotExist) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user