added lots of tests
This commit is contained in:
parent
1baa480318
commit
4d6600b383
|
@ -56,7 +56,6 @@ linters:
|
|||
- noctx
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- paralleltest
|
||||
- prealloc
|
||||
- predeclared
|
||||
- promlinter
|
||||
|
|
5
go.mod
5
go.mod
|
@ -18,6 +18,7 @@ require (
|
|||
|
||||
require (
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fatih/color v1.15.0 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
|
@ -26,7 +27,10 @@ require (
|
|||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
github.com/stretchr/testify v1.8.4 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
|
@ -34,6 +38,7 @@ require (
|
|||
golang.org/x/sys v0.11.0 // indirect
|
||||
golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
lukechampine.com/uint128 v1.2.0 // indirect
|
||||
modernc.org/cc/v3 v3.40.0 // indirect
|
||||
modernc.org/ccgo/v3 v3.16.13 // indirect
|
||||
|
|
7
go.sum
7
go.sum
|
@ -40,9 +40,16 @@ github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qq
|
|||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.1.14 h1:S5vvNnjEynJ0CvnrBOD7MIRW7q/WbtvFXrdfy0lddAM=
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
errs "git.maronato.dev/maronato/goshort/internal/errs"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
package config_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetDBTypeList(t *testing.T) {
|
||||
dbTypeList := []string{
|
||||
DBTypeMemory,
|
||||
DBTypeSQLite,
|
||||
}
|
||||
|
||||
results := GetDBTypeList()
|
||||
|
||||
assert.ElementsMatch(t, dbTypeList, results, "DBTypeList should be equal to the list of available DB types")
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Valid config",
|
||||
config: &Config{
|
||||
Prod: false,
|
||||
DBType: DBTypeSQLite,
|
||||
DBURL: "goshort.db",
|
||||
Port: "8080",
|
||||
Host: "localhost",
|
||||
UIPort: "3000",
|
||||
Debug: false,
|
||||
SessionDuration: 7,
|
||||
DisableRegistration: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Minimum config",
|
||||
config: &Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid Host",
|
||||
config: &Config{
|
||||
Host: "invalid host",
|
||||
Port: "8080",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Port",
|
||||
config: &Config{
|
||||
Host: "localhost",
|
||||
Port: "invalid port",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid UI Port",
|
||||
config: &Config{
|
||||
Host: "localhost",
|
||||
Port: "8080",
|
||||
UIPort: "invalid port",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid DB Type",
|
||||
config: &Config{
|
||||
DBType: DBTypeMemory,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid DB Type",
|
||||
config: &Config{
|
||||
DBType: "invalid db type",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := Validate(tc.config)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "Validate should return an error")
|
||||
} else {
|
||||
assert.NoError(t, err, "Validate should not return an error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -44,6 +44,10 @@ var (
|
|||
ErrDatabaseError = errors.New("database error")
|
||||
// ErrShortLogExists.
|
||||
ErrShortLogExists = errors.New("short log already exists")
|
||||
// ErrStorageNotStarted.
|
||||
ErrStorageNotStarted = errors.New("storage not started")
|
||||
// ErrStorageStarted.
|
||||
ErrStorageStarted = errors.New("storage already started")
|
||||
)
|
||||
|
||||
func Errorf(msg string, err error) error {
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
package errs_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
ErrMyError := errors.New("my error") //nolint:goerr113 // This is a test error
|
||||
|
||||
err := Errorf("my message", ErrMyError)
|
||||
|
||||
assert.ErrorIs(t, err, ErrMyError, "Errorf should return an error with the same error type as the one passed as argument")
|
||||
assert.Equal(t, "my message: my error", err.Error(), "Errorf should return an error with the same message as the one passed as argument")
|
||||
}
|
|
@ -20,6 +20,8 @@ type BunStorage struct {
|
|||
StartHooks []func(ctx context.Context, db *bun.DB) error
|
||||
// StopHooks are ran before the database is closed.
|
||||
StopHook []func(ctx context.Context, db *bun.DB) error
|
||||
// started is true if the storage has been started.
|
||||
started bool
|
||||
}
|
||||
|
||||
func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
|
||||
|
@ -28,7 +30,8 @@ func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
|
|||
}
|
||||
|
||||
return &BunStorage{
|
||||
db: db,
|
||||
db: db,
|
||||
started: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,6 +46,10 @@ func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB)
|
|||
}
|
||||
|
||||
func (s *BunStorage) Start(ctx context.Context) error { //nolint:cyclop // This function is not that complex.
|
||||
if s.started {
|
||||
return errs.Errorf("failed to start storage", errs.ErrStorageStarted)
|
||||
}
|
||||
|
||||
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
_, err := tx.NewCreateTable().
|
||||
IfNotExists().
|
||||
|
@ -127,10 +134,15 @@ func (s *BunStorage) Start(ctx context.Context) error { //nolint:cyclop // This
|
|||
}
|
||||
}
|
||||
|
||||
s.started = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) Stop(ctx context.Context) error {
|
||||
if !s.started {
|
||||
return errs.Errorf("failed to stop storage", errs.ErrStorageNotStarted)
|
||||
}
|
||||
// Run stop hooks
|
||||
for _, hook := range s.StopHook {
|
||||
err := hook(ctx, s.db)
|
||||
|
@ -143,6 +155,8 @@ func (s *BunStorage) Stop(ctx context.Context) error {
|
|||
return errs.Errorf("failed to stop storage", err)
|
||||
}
|
||||
|
||||
s.started = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -191,6 +205,9 @@ func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Shor
|
|||
}
|
||||
|
||||
func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
|
||||
if short == nil {
|
||||
return nil, errs.Errorf("failed to create short", errs.ErrInvalidShort)
|
||||
}
|
||||
// Generate an ID that does not exist
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*ShortModel)(nil)).
|
||||
|
@ -237,9 +254,13 @@ func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error
|
|||
_, err = withShortDeleteUpdates(
|
||||
tx.NewUpdate().
|
||||
Model((*ShortModel)(nil)).
|
||||
Where("id = ?", short.ID),
|
||||
Where("id = ? and deleted = false", short.ID),
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = errs.ErrShortDoesNotExist
|
||||
}
|
||||
|
||||
return errs.Errorf("failed to delete short", err)
|
||||
}
|
||||
|
||||
|
@ -288,6 +309,7 @@ func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortL
|
|||
IPAddress: shortLog.IPAddress,
|
||||
UserAgent: shortLog.UserAgent,
|
||||
Referer: shortLog.Referer,
|
||||
CreatedAt: shortLog.CreatedAt,
|
||||
}
|
||||
|
||||
_, err = s.db.NewInsert().
|
||||
|
@ -352,6 +374,10 @@ func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models
|
|||
return nil, errs.Errorf("failed to create user", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, errs.Errorf("failed to create user", errs.ErrInvalidUser)
|
||||
}
|
||||
|
||||
userModel := &UserModel{
|
||||
ID: newID,
|
||||
Username: user.Username,
|
||||
|
@ -454,6 +480,10 @@ func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*mode
|
|||
}
|
||||
|
||||
func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
|
||||
if token == nil {
|
||||
return nil, errs.Errorf("failed to create token", errs.ErrInvalidToken)
|
||||
}
|
||||
|
||||
// Create a new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*TokenModel)(nil)).
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
package bunstorage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
. "git.maronato.dev/maronato/goshort/internal/storage/bun"
|
||||
storagetesting "git.maronato.dev/maronato/goshort/internal/storage/testing"
|
||||
randomutil "git.maronato.dev/maronato/goshort/internal/util/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
_ "modernc.org/sqlite" // Import the SQLite driver.
|
||||
)
|
||||
|
||||
func newNewBunDBHelper() *bun.DB {
|
||||
filename := randomutil.GenerateSecureToken(16)
|
||||
url := "file:memdb_" + filename + "?cache=shared&mode=memory&_foreign_keys=1"
|
||||
|
||||
sqldb, err := sql.Open("sqlite", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return bun.NewDB(sqldb, sqlitedialect.New())
|
||||
}
|
||||
|
||||
func TestNewBunDB(t *testing.T) {
|
||||
db := newNewBunDBHelper()
|
||||
|
||||
assert.NotNil(t, db, "DB should not be nil")
|
||||
}
|
||||
|
||||
func TestNewBunStorage(t *testing.T) {
|
||||
db := newNewBunDBHelper()
|
||||
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
assert.NotNil(t, bs, "Storage should not be nil")
|
||||
assert.Implements(t, (*storage.Storage)(nil), bs, "Storage should implement the Storage interface")
|
||||
}
|
||||
|
||||
func TestStorageInterface(t *testing.T) {
|
||||
getStg := func() storage.Storage {
|
||||
db := newNewBunDBHelper()
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
return bs
|
||||
}
|
||||
|
||||
storagetesting.ITestComplete(t, getStg)
|
||||
}
|
||||
|
||||
func TestBunStorage_Start(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := newNewBunDBHelper()
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
var tables []string
|
||||
|
||||
_, err := db.NewRaw(`SELECT name FROM sqlite_master WHERE type = "table"`).Exec(ctx, &tables)
|
||||
assert.Nil(t, err, "Should not return an error when querying for tables")
|
||||
|
||||
assert.Empty(t, tables, "Should not have any tables at first")
|
||||
|
||||
err = bs.Start(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when starting the storage")
|
||||
|
||||
_, err = db.NewRaw(`SELECT name FROM sqlite_master WHERE type = "table"`).Exec(ctx, &tables)
|
||||
assert.Nil(t, err, "Should not return an error when querying for tables")
|
||||
|
||||
assert.Len(t, tables, 4, "Should have 4 tables")
|
||||
|
||||
expectedTables := []string{
|
||||
"users",
|
||||
"shorts",
|
||||
"tokens",
|
||||
"short_logs",
|
||||
}
|
||||
for _, expectedTable := range expectedTables {
|
||||
assert.Containsf(t, tables, expectedTable, "Should have the %s table", expectedTable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBunStorage_Stop(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := newNewBunDBHelper()
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
_ = bs.Start(ctx)
|
||||
|
||||
err := db.Ping()
|
||||
assert.Nil(t, err, "Should not return an error when pinging an open database")
|
||||
|
||||
err = bs.Stop(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when stopping the storage")
|
||||
|
||||
err = db.Ping()
|
||||
assert.NotNil(t, err, "Should return an error when pinging a closed database")
|
||||
}
|
||||
|
||||
func TestBunStorage_RegisterStartHook(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := newNewBunDBHelper()
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
var called bool
|
||||
|
||||
bs.RegisterStartHook(func(ctx context.Context, db *bun.DB) error {
|
||||
called = true
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.False(t, called, "Should not have called the hook yet")
|
||||
|
||||
err := bs.Start(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when starting the storage")
|
||||
|
||||
assert.True(t, called, "Should have called the hook")
|
||||
}
|
||||
|
||||
func TestBunStorage_RegisterStopHook(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := newNewBunDBHelper()
|
||||
bs := NewBunStorage(&config.Config{}, db)
|
||||
|
||||
_ = bs.Start(ctx)
|
||||
|
||||
var called bool
|
||||
|
||||
bs.RegisterStopHook(func(ctx context.Context, db *bun.DB) error {
|
||||
called = true
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.False(t, called, "Should not have called the hook yet")
|
||||
|
||||
err := bs.Stop(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when stopping the storage")
|
||||
|
||||
assert.True(t, called, "Should have called the hook")
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package models_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewID(t *testing.T) {
|
||||
id1 := NewID()
|
||||
|
||||
assert.Len(t, id1, IDLength, "NewID must generate an ID with length IDLength")
|
||||
|
||||
id2 := NewID()
|
||||
|
||||
assert.NotEqual(t, id1, id2, "NewID must genereate a different ID every time")
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package models_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
bcryptpasswords "git.maronato.dev/maronato/goshort/internal/util/passwords/bcrypt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAuthenticatableUser(t *testing.T) {
|
||||
u := &User{
|
||||
ID: "myid",
|
||||
Username: "myusername",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
assert.Empty(t, u.GetPasswordHash(), "Default password should be empty")
|
||||
|
||||
au := NewAuthenticatableUser(u, "myhash")
|
||||
|
||||
assert.NotSame(t, u, au, "User and Auth user should not be the same object")
|
||||
assert.Equal(t, u.ID, au.ID, "IDs should be the same")
|
||||
assert.Equal(t, u.Username, au.Username, "IDs should be the same")
|
||||
assert.Equal(t, u.CreatedAt, au.CreatedAt, "IDs should be the same")
|
||||
assert.Equal(t, au.GetPasswordHash(), "myhash", "Password hash must've been set")
|
||||
}
|
||||
|
||||
func TestGetPasswordHash(t *testing.T) {
|
||||
au := NewAuthenticatableUser(&User{}, "myhash")
|
||||
|
||||
assert.Equal(t, au.GetPasswordHash(), "myhash", "Password hash must've been set")
|
||||
}
|
||||
|
||||
func TestSetPassword(t *testing.T) {
|
||||
u := &User{}
|
||||
bh := bcryptpasswords.NewBcryptHasher()
|
||||
pass := "myrandompasswordx"
|
||||
|
||||
err := u.SetPassword(bh, pass)
|
||||
assert.Nil(t, err, "SetPassword must not return an error from a valid password and hasher")
|
||||
|
||||
assert.NotEqual(t, u.GetPasswordHash(), pass, "Password should not be equal to its hash")
|
||||
|
||||
v, _ := bh.Verify(pass, u.GetPasswordHash())
|
||||
|
||||
assert.True(t, v, "Password must be hashed correctly")
|
||||
}
|
|
@ -0,0 +1,935 @@
|
|||
//nolint:goconst // Not gonna use constants here
|
||||
package storagetesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
bcryptpasswords "git.maronato.dev/maronato/goshort/internal/util/passwords/bcrypt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
_ "modernc.org/sqlite" // Import the SQLite driver.
|
||||
)
|
||||
|
||||
var ErrPlaceholder = errors.New("placeholder error")
|
||||
|
||||
func ITestComplete(t *testing.T, getStg func() storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
test func(t *testing.T, stg storage.Storage)
|
||||
dontStart bool
|
||||
}{
|
||||
{
|
||||
name: "ImplementsInterface",
|
||||
test: ITestImplements,
|
||||
},
|
||||
// Lifecycle
|
||||
{
|
||||
name: "Start",
|
||||
test: ITestStart,
|
||||
dontStart: true,
|
||||
},
|
||||
{
|
||||
name: "Stop",
|
||||
test: ITestStop,
|
||||
dontStart: true,
|
||||
},
|
||||
{
|
||||
name: "Ping",
|
||||
test: ITestPing,
|
||||
},
|
||||
// Short storage.Storage
|
||||
{
|
||||
name: "CreateShort",
|
||||
test: ITestCreateShort,
|
||||
},
|
||||
{
|
||||
name: "FindShort",
|
||||
test: ITestFindShort,
|
||||
},
|
||||
{
|
||||
name: "FindShortByID",
|
||||
test: ITestFindShortByID,
|
||||
},
|
||||
{
|
||||
name: "DeleteShort",
|
||||
test: ITestDeleteShort,
|
||||
},
|
||||
{
|
||||
name: "ListShorts",
|
||||
test: ITestListShorts,
|
||||
},
|
||||
// ShortLog storage.Storage
|
||||
{
|
||||
name: "CreateShortLog",
|
||||
test: ITestCreateShortLog,
|
||||
},
|
||||
{
|
||||
name: "ListShortLogs",
|
||||
test: ITestListShortLogs,
|
||||
},
|
||||
// User storage.Storage
|
||||
{
|
||||
name: "CreateUser",
|
||||
test: ITestCreateUser,
|
||||
},
|
||||
{
|
||||
name: "FindUser",
|
||||
test: ITestFindUser,
|
||||
},
|
||||
{
|
||||
name: "FindUserByID",
|
||||
test: ITestFindUserByID,
|
||||
},
|
||||
{
|
||||
name: "DeleteUser",
|
||||
test: ITestDeleteUser,
|
||||
},
|
||||
// Token storage.Storage
|
||||
{
|
||||
name: "CreateToken",
|
||||
test: ITestCreateToken,
|
||||
},
|
||||
{
|
||||
name: "FindToken",
|
||||
test: ITestFindToken,
|
||||
},
|
||||
{
|
||||
name: "FindTokenByID",
|
||||
test: ITestFindTokenByID,
|
||||
},
|
||||
{
|
||||
name: "DeleteToken",
|
||||
test: ITestDeleteToken,
|
||||
},
|
||||
{
|
||||
name: "ListTokens",
|
||||
test: ITestListTokens,
|
||||
},
|
||||
{
|
||||
name: "ChangeTokenName",
|
||||
test: ITestChangeTokenName,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stg := getStg()
|
||||
if !test.dontStart {
|
||||
_ = stg.Start(context.Background())
|
||||
}
|
||||
test.test(t, stg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ITestImplements(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
assert.Implements(t, (*storage.Storage)(nil), stg, "Should implement the storage.Storage interface")
|
||||
}
|
||||
|
||||
func ITestStart(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := stg.Start(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when starting the storage")
|
||||
|
||||
err = stg.Start(ctx)
|
||||
assert.ErrorIs(t, err, errs.ErrStorageStarted, "Should return an error when starting an already started storage")
|
||||
|
||||
_ = stg.Stop(ctx)
|
||||
}
|
||||
|
||||
func ITestStop(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := stg.Stop(ctx)
|
||||
assert.ErrorIs(t, err, errs.ErrStorageNotStarted, "Should return an error when stopping a storage that wasn't started")
|
||||
|
||||
_ = stg.Start(ctx)
|
||||
|
||||
err = stg.Stop(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when stopping the storage")
|
||||
|
||||
err = stg.Stop(ctx)
|
||||
assert.ErrorIs(t, err, errs.ErrStorageNotStarted, "Should return an error when stopping an already stopped storage")
|
||||
}
|
||||
|
||||
func ITestPing(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := stg.Ping(ctx)
|
||||
assert.Nil(t, err, "Should not return an error when pinging the storage")
|
||||
|
||||
_ = stg.Stop(ctx)
|
||||
|
||||
err = stg.Ping(ctx)
|
||||
assert.NotNil(t, err, "Should return an error when pinging a closed storage")
|
||||
}
|
||||
|
||||
func ITestCreateShort(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userStr := "user"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
short *models.Short
|
||||
newShort *models.Short
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Should create a short",
|
||||
short: &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
newShort: &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should create a short with a user",
|
||||
short: &models.Short{
|
||||
Name: "myshort2",
|
||||
URL: "https://example.com",
|
||||
UserID: &userStr,
|
||||
},
|
||||
newShort: &models.Short{
|
||||
Name: "myshort2",
|
||||
URL: "https://example.com",
|
||||
UserID: &userStr,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should not use the given ID",
|
||||
short: &models.Short{
|
||||
ID: "myid",
|
||||
Name: "myshort3",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
newShort: &models.Short{
|
||||
Name: "myshort3",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should not use the given CreatedAt",
|
||||
short: &models.Short{
|
||||
CreatedAt: time.Now().Add(-time.Hour),
|
||||
Name: "myshort4",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
newShort: &models.Short{
|
||||
CreatedAt: time.Now().Add(-time.Hour),
|
||||
Name: "myshort4",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should return an error when the short is nil",
|
||||
short: nil,
|
||||
newShort: nil,
|
||||
err: errs.ErrInvalidShort,
|
||||
},
|
||||
{
|
||||
name: "Should return an error when using the same name",
|
||||
short: &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: nil,
|
||||
},
|
||||
newShort: nil,
|
||||
err: errs.ErrShortExists,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
newShort, err := stg.CreateShort(ctx, test.short)
|
||||
|
||||
if errors.Is(test.err, ErrPlaceholder) {
|
||||
assert.Error(t, err, "Should return an error")
|
||||
} else {
|
||||
assert.ErrorIs(t, err, test.err, "Should return the same error")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if test.short.ID != "" {
|
||||
assert.NotEqual(t, test.short.ID, newShort.ID, "Should not return the same ID")
|
||||
}
|
||||
|
||||
if !test.short.CreatedAt.IsZero() {
|
||||
assert.NotEqual(t, test.short.CreatedAt, newShort.CreatedAt, "Should not return the same CreatedAt")
|
||||
}
|
||||
|
||||
assert.NotZero(t, newShort.CreatedAt, "Should return a non-zero CreatedAt")
|
||||
assert.NotZero(t, newShort.ID, "Should return a non-zero ID")
|
||||
assert.Equal(t, test.newShort.Name, newShort.Name, "Should return the same Name")
|
||||
assert.Equal(t, test.newShort.URL, newShort.URL, "Should return the same URL")
|
||||
assert.Equal(t, test.newShort.UserID, newShort.UserID, "Should return the same UserID")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ITestFindShort(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
baseShort := &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
}
|
||||
|
||||
found, err := stg.FindShort(ctx, baseShort.Name)
|
||||
assert.ErrorIs(t, err, errs.ErrShortDoesNotExist, "Should return an error when the short doesn't exist")
|
||||
assert.Nil(t, found, "Should not find the short")
|
||||
|
||||
short, _ := stg.CreateShort(ctx, baseShort)
|
||||
|
||||
found, err = stg.FindShort(ctx, baseShort.Name)
|
||||
assert.Nil(t, err, "Should not return an error when finding the short")
|
||||
assert.Equal(t, short, found, "Should return the same short")
|
||||
|
||||
_ = stg.DeleteShort(ctx, short)
|
||||
|
||||
found, err = stg.FindShort(ctx, baseShort.Name)
|
||||
assert.ErrorIs(t, err, errs.ErrShortDoesNotExist, "Should return an error when the short is deleted")
|
||||
assert.Nil(t, found, "Should not find the short")
|
||||
}
|
||||
|
||||
func ITestFindShortByID(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
baseShort := &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
}
|
||||
|
||||
found, err := stg.FindShortByID(ctx, "myid")
|
||||
assert.ErrorIs(t, err, errs.ErrShortDoesNotExist, "Should return an error when the short doesn't exist")
|
||||
assert.Nil(t, found, "Should not find the short")
|
||||
|
||||
short, _ := stg.CreateShort(ctx, baseShort)
|
||||
|
||||
found, err = stg.FindShortByID(ctx, short.ID)
|
||||
assert.Nil(t, err, "Should not return an error when finding the short")
|
||||
assert.Equal(t, short, found, "Should return the same short")
|
||||
|
||||
_ = stg.DeleteShort(ctx, short)
|
||||
|
||||
found, err = stg.FindShort(ctx, baseShort.Name)
|
||||
assert.ErrorIs(t, err, errs.ErrShortDoesNotExist, "Should return an error when the short is deleted")
|
||||
assert.Nil(t, found, "Should not find the short")
|
||||
}
|
||||
|
||||
func ITestDeleteShort(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
newShort, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
})
|
||||
|
||||
_ = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: newShort.ID,
|
||||
})
|
||||
_ = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: newShort.ID,
|
||||
})
|
||||
|
||||
found, _ := stg.FindShort(ctx, "myshort")
|
||||
assert.NotNil(t, found, "Should find the short")
|
||||
|
||||
shortLogs, _ := stg.ListShortLogs(ctx, newShort)
|
||||
assert.Len(t, shortLogs, 2, "Should have 2 short logs") //nolint:gomnd // It's a test
|
||||
|
||||
err := stg.DeleteShort(ctx, newShort)
|
||||
assert.Nil(t, err, "Should not return an error when deleting the short")
|
||||
|
||||
found, err = stg.FindShort(ctx, "myshort")
|
||||
assert.ErrorIs(t, err, errs.ErrShortDoesNotExist, "Should return an error when finding the short")
|
||||
assert.Nil(t, found, "Should not find the short")
|
||||
|
||||
shortLogs, _ = stg.ListShortLogs(ctx, newShort)
|
||||
assert.Len(t, shortLogs, 0, "Should not have any short logs")
|
||||
|
||||
err = stg.DeleteShort(ctx, newShort)
|
||||
assert.Nil(t, err, "Should not return an error when deleting a deleted short")
|
||||
}
|
||||
|
||||
func ITestListShorts(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
userID2 := "user2"
|
||||
|
||||
short1, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
})
|
||||
|
||||
short2, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort2",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
})
|
||||
|
||||
deletedShort, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "deletedshort",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID,
|
||||
})
|
||||
_ = stg.DeleteShort(ctx, deletedShort)
|
||||
|
||||
short3, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort3",
|
||||
URL: "https://example.com",
|
||||
UserID: &userID2,
|
||||
})
|
||||
|
||||
shorts, err := stg.ListShorts(ctx, &models.User{
|
||||
ID: userID,
|
||||
})
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when listing the shorts")
|
||||
assert.Len(t, shorts, 2, "Should return 2 shorts") //nolint:gomnd // It's a test
|
||||
assert.Contains(t, shorts, short1, "Should return the first short")
|
||||
assert.Contains(t, shorts, short2, "Should return the second short")
|
||||
|
||||
shorts2, err := stg.ListShorts(ctx, &models.User{
|
||||
ID: userID2,
|
||||
})
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when listing the shorts")
|
||||
assert.Len(t, shorts2, 1, "Should return 1 short")
|
||||
assert.Contains(t, shorts2, short3, "Should return the third short")
|
||||
}
|
||||
|
||||
func ITestCreateShortLog(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
short, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
})
|
||||
|
||||
baseShortLog := &models.ShortLog{
|
||||
ShortID: short.ID,
|
||||
IPAddress: "myip",
|
||||
UserAgent: "myuseragent",
|
||||
Referer: "myreferer",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := stg.CreateShortLog(ctx, baseShortLog)
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when creating the short log")
|
||||
|
||||
err = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: short.ID,
|
||||
})
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when creating a short log twice")
|
||||
|
||||
shortLogs, err := stg.ListShortLogs(ctx, short)
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when listing the short logs")
|
||||
assert.Len(t, shortLogs, 2, "Should return 2 short logs") //nolint:gomnd // It's a test
|
||||
|
||||
shortLog := shortLogs[0]
|
||||
|
||||
assert.Equal(t, baseShortLog.ShortID, shortLog.ShortID, "Should return the same ShortID")
|
||||
assert.Equal(t, baseShortLog.IPAddress, shortLog.IPAddress, "Should return the same IPAddress")
|
||||
assert.Equal(t, baseShortLog.UserAgent, shortLog.UserAgent, "Should return the same UserAgent")
|
||||
assert.Equal(t, baseShortLog.Referer, shortLog.Referer, "Should return the same Referer")
|
||||
assert.Equal(t, baseShortLog.CreatedAt.UTC(), shortLog.CreatedAt.UTC(), "Should return the same CreatedAt")
|
||||
assert.NotZero(t, shortLog.ID, "Should return a non-zero ID")
|
||||
|
||||
assert.NotZero(t, shortLogs[1].ID, "Should return a non-zero ID")
|
||||
assert.NotZero(t, shortLogs[1].CreatedAt, "Should return a non-zero CreatedAt")
|
||||
}
|
||||
|
||||
func ITestListShortLogs(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
short, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort",
|
||||
URL: "https://example.com",
|
||||
})
|
||||
|
||||
short2, _ := stg.CreateShort(ctx, &models.Short{
|
||||
Name: "myshort2",
|
||||
URL: "https://example.com",
|
||||
})
|
||||
|
||||
_ = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: short.ID,
|
||||
})
|
||||
_ = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: short.ID,
|
||||
})
|
||||
_ = stg.CreateShortLog(ctx, &models.ShortLog{
|
||||
ShortID: short2.ID,
|
||||
})
|
||||
|
||||
shortLogs, err := stg.ListShortLogs(ctx, short)
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when listing the short logs")
|
||||
assert.Len(t, shortLogs, 2, "Should return 2 short log") //nolint:gomnd // It's a test
|
||||
|
||||
shortLogs2, err := stg.ListShortLogs(ctx, short2)
|
||||
|
||||
assert.Nil(t, err, "Should not return an error when listing the short logs")
|
||||
assert.Len(t, shortLogs2, 1, "Should return 1 short log")
|
||||
}
|
||||
|
||||
func ITestCreateUser(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
bh := bcryptpasswords.NewBcryptHasher()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *models.User
|
||||
password string
|
||||
newUser *models.User
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Should create a user",
|
||||
user: &models.User{
|
||||
Username: "myusername",
|
||||
},
|
||||
password: "mypassword",
|
||||
newUser: &models.User{
|
||||
Username: "myusername",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should not use the given ID",
|
||||
user: &models.User{
|
||||
ID: "myid",
|
||||
Username: "myusername2",
|
||||
},
|
||||
newUser: &models.User{
|
||||
Username: "myusername2",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should error when the username was already taken",
|
||||
user: &models.User{
|
||||
Username: "myusername",
|
||||
},
|
||||
password: "mypassword",
|
||||
newUser: nil,
|
||||
err: errs.ErrUserExists,
|
||||
},
|
||||
{
|
||||
name: "Should error if the user is nil",
|
||||
user: nil,
|
||||
err: errs.ErrInvalidUser,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.password != "" {
|
||||
_ = test.user.SetPassword(bh, test.password)
|
||||
}
|
||||
newUser, err := stg.CreateUser(ctx, test.user)
|
||||
|
||||
if errors.Is(test.err, ErrPlaceholder) {
|
||||
assert.Error(t, err, "Should return an error")
|
||||
} else {
|
||||
assert.ErrorIs(t, err, test.err, "Should return the same error")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotZero(t, newUser.CreatedAt, "Should return a non-zero CreatedAt")
|
||||
assert.NotZero(t, newUser.ID, "Should return a non-zero ID")
|
||||
assert.Equal(t, test.newUser.Username, newUser.Username, "Should return the same Username")
|
||||
if test.password != "" {
|
||||
v, _ := bh.Verify(test.password, newUser.GetPasswordHash())
|
||||
assert.True(t, v, "Should have the same password")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ITestFindUser(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
bh := bcryptpasswords.NewBcryptHasher()
|
||||
|
||||
baseUser := &models.User{
|
||||
Username: "myusername",
|
||||
}
|
||||
_ = baseUser.SetPassword(bh, "mypassword")
|
||||
|
||||
found, err := stg.FindUser(ctx, baseUser.Username)
|
||||
assert.ErrorIs(t, err, errs.ErrUserDoesNotExist, "Should return an error when the user doesn't exist")
|
||||
assert.Nil(t, found, "Should not find the user")
|
||||
|
||||
user, _ := stg.CreateUser(ctx, baseUser)
|
||||
|
||||
found, err = stg.FindUser(ctx, user.Username)
|
||||
assert.Nil(t, err, "Should not return an error when finding the user")
|
||||
assert.Equal(t, user, found, "Should return the same user")
|
||||
|
||||
v, _ := bh.Verify("mypassword", found.GetPasswordHash())
|
||||
assert.True(t, v, "Should have the same password")
|
||||
|
||||
_ = stg.DeleteUser(ctx, user)
|
||||
|
||||
found, err = stg.FindUser(ctx, user.Username)
|
||||
assert.ErrorIs(t, err, errs.ErrUserDoesNotExist, "Should return an error when the user is deleted")
|
||||
assert.Nil(t, found, "Should not find the user")
|
||||
}
|
||||
|
||||
func ITestFindUserByID(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
bh := bcryptpasswords.NewBcryptHasher()
|
||||
|
||||
baseUser := &models.User{
|
||||
Username: "myusername",
|
||||
}
|
||||
_ = baseUser.SetPassword(bh, "mypassword")
|
||||
|
||||
found, err := stg.FindUserByID(ctx, "my id")
|
||||
assert.ErrorIs(t, err, errs.ErrUserDoesNotExist, "Should return an error when the user doesn't exist")
|
||||
assert.Nil(t, found, "Should not find the user")
|
||||
|
||||
user, _ := stg.CreateUser(ctx, baseUser)
|
||||
|
||||
found, err = stg.FindUserByID(ctx, user.ID)
|
||||
assert.Nil(t, err, "Should not return an error when finding the user")
|
||||
assert.Equal(t, user, found, "Should return the same user")
|
||||
|
||||
v, _ := bh.Verify("mypassword", found.GetPasswordHash())
|
||||
assert.True(t, v, "Should have the same password")
|
||||
|
||||
_ = stg.DeleteUser(ctx, user)
|
||||
|
||||
found, err = stg.FindUserByID(ctx, user.ID)
|
||||
assert.ErrorIs(t, err, errs.ErrUserDoesNotExist, "Should return an error when the user is deleted")
|
||||
assert.Nil(t, found, "Should not find the user")
|
||||
}
|
||||
|
||||
func ITestDeleteUser(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
bh := bcryptpasswords.NewBcryptHasher()
|
||||
|
||||
baseUser := &models.User{
|
||||
Username: "myusername",
|
||||
}
|
||||
_ = baseUser.SetPassword(bh, "mypassword")
|
||||
|
||||
user, _ := stg.CreateUser(ctx, baseUser)
|
||||
|
||||
found, _ := stg.FindUser(ctx, user.Username)
|
||||
assert.NotNil(t, found, "Should find the user")
|
||||
|
||||
err := stg.DeleteUser(ctx, user)
|
||||
assert.Nil(t, err, "Should not return an error when deleting the user")
|
||||
|
||||
found, err = stg.FindUser(ctx, user.Username)
|
||||
assert.ErrorIs(t, err, errs.ErrUserDoesNotExist, "Should return an error when finding the user")
|
||||
assert.Nil(t, found, "Should not find the user")
|
||||
|
||||
err = stg.DeleteUser(ctx, user)
|
||||
assert.Nil(t, err, "Should not return an error when deleting a deleted user")
|
||||
}
|
||||
|
||||
func ITestCreateToken(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token *models.Token
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Should create a token",
|
||||
token: &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
Name: "mytoken",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should not use the given ID",
|
||||
token: &models.Token{
|
||||
ID: "myid",
|
||||
Value: "myvalue2",
|
||||
UserID: &userID,
|
||||
Name: "mytoken2",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should not use the given CreatedAt",
|
||||
token: &models.Token{
|
||||
CreatedAt: time.Now().Add(-time.Hour),
|
||||
Value: "myvalue3",
|
||||
UserID: &userID,
|
||||
Name: "mytoken3",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Should return an error when the token is nil",
|
||||
token: nil,
|
||||
err: errs.ErrInvalidToken,
|
||||
},
|
||||
{
|
||||
name: "Should return an error if the value was already taken",
|
||||
token: &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
Name: "mytoken4",
|
||||
},
|
||||
err: errs.ErrTokenExists,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
newTok, err := stg.CreateToken(ctx, test.token)
|
||||
|
||||
if errors.Is(test.err, ErrPlaceholder) {
|
||||
assert.Error(t, err, "Should return an error")
|
||||
} else {
|
||||
assert.ErrorIs(t, err, test.err, "Should return the same error")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotZero(t, newTok.CreatedAt, "Should return a non-zero CreatedAt")
|
||||
assert.NotZero(t, newTok.ID, "Should return a non-zero ID")
|
||||
assert.NotEqual(t, test.token.CreatedAt, newTok.CreatedAt, "Should not return the same CreatedAt")
|
||||
assert.NotEqual(t, test.token.ID, newTok.ID, "Should not return the same ID")
|
||||
assert.Equal(t, test.token.Value, newTok.Value, "Should return the same Value")
|
||||
assert.Equal(t, test.token.Name, newTok.Name, "Should return the same Name")
|
||||
assert.Equal(t, test.token.UserID, newTok.UserID, "Should return the same UserID")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ITestFindToken(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
found, err := stg.FindToken(ctx, "myvalue")
|
||||
assert.ErrorIs(t, err, errs.ErrTokenDoesNotExist, "Should return an error when the token doesn't exist")
|
||||
assert.Nil(t, found, "Should not find the token")
|
||||
|
||||
token, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
Name: "mytoken",
|
||||
})
|
||||
|
||||
found, err = stg.FindToken(ctx, token.Value)
|
||||
assert.Nil(t, err, "Should not return an error when finding the token")
|
||||
assert.Equal(t, token, found, "Should return the same token")
|
||||
|
||||
_ = stg.DeleteToken(ctx, token)
|
||||
|
||||
found, err = stg.FindToken(ctx, token.Value)
|
||||
assert.ErrorIs(t, err, errs.ErrTokenDoesNotExist, "Should return an error when the token is deleted")
|
||||
assert.Nil(t, found, "Should not find the token")
|
||||
}
|
||||
|
||||
func ITestFindTokenByID(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
found, err := stg.FindTokenByID(ctx, "myid")
|
||||
assert.ErrorIs(t, err, errs.ErrTokenDoesNotExist, "Should return an error when the token doesn't exist")
|
||||
assert.Nil(t, found, "Should not find the token")
|
||||
|
||||
token, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
Name: "mytoken",
|
||||
})
|
||||
|
||||
found, err = stg.FindTokenByID(ctx, token.ID)
|
||||
assert.Nil(t, err, "Should not return an error when finding the token")
|
||||
assert.Equal(t, token, found, "Should return the same token")
|
||||
|
||||
_ = stg.DeleteToken(ctx, token)
|
||||
|
||||
found, err = stg.FindToken(ctx, token.Value)
|
||||
assert.ErrorIs(t, err, errs.ErrTokenDoesNotExist, "Should return an error when the token is deleted")
|
||||
assert.Nil(t, found, "Should not find the token")
|
||||
}
|
||||
|
||||
func ITestDeleteToken(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
newToken, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
Name: "mytoken",
|
||||
})
|
||||
|
||||
found, _ := stg.FindToken(ctx, newToken.Value)
|
||||
assert.NotNil(t, found, "Should find the token")
|
||||
|
||||
err := stg.DeleteToken(ctx, newToken)
|
||||
assert.Nil(t, err, "Should not return an error when deleting the token")
|
||||
|
||||
found, err = stg.FindToken(ctx, newToken.Value)
|
||||
assert.ErrorIs(t, err, errs.ErrTokenDoesNotExist, "Should return an error when finding the token")
|
||||
assert.Nil(t, found, "Should not find the token")
|
||||
|
||||
err = stg.DeleteToken(ctx, newToken)
|
||||
assert.Nil(t, err, "Should not return an error when deleting a deleted token")
|
||||
}
|
||||
|
||||
func ITestListTokens(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
userID2 := "user2"
|
||||
|
||||
token1, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
})
|
||||
token2, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue2",
|
||||
UserID: &userID,
|
||||
})
|
||||
deletedToken, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue3",
|
||||
UserID: &userID,
|
||||
})
|
||||
|
||||
_ = stg.DeleteToken(ctx, deletedToken)
|
||||
|
||||
token3, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue4",
|
||||
UserID: &userID2,
|
||||
})
|
||||
|
||||
tokens, err := stg.ListTokens(ctx, &models.User{
|
||||
ID: userID,
|
||||
})
|
||||
assert.Nil(t, err, "Should not return an error when listing the tokens")
|
||||
assert.Len(t, tokens, 2, "Should return 2 tokens") //nolint:gomnd // It's a test
|
||||
assert.Contains(t, tokens, token1, "Should return the first token")
|
||||
assert.Contains(t, tokens, token2, "Should return the second token")
|
||||
|
||||
tokens2, err := stg.ListTokens(ctx, &models.User{
|
||||
ID: userID2,
|
||||
})
|
||||
assert.Nil(t, err, "Should not return an error when listing the tokens")
|
||||
assert.Len(t, tokens2, 1, "Should return 1 token")
|
||||
assert.Contains(t, tokens2, token3, "Should return the third token")
|
||||
}
|
||||
|
||||
func ITestChangeTokenName(t *testing.T, stg storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
userID := "user"
|
||||
|
||||
token, _ := stg.CreateToken(ctx, &models.Token{
|
||||
Value: "myvalue",
|
||||
UserID: &userID,
|
||||
Name: "mytoken",
|
||||
})
|
||||
|
||||
newToken, err := stg.ChangeTokenName(ctx, token, "mytoken2")
|
||||
assert.Nil(t, err, "Should not return an error when changing the token name")
|
||||
assert.Equal(t, "mytoken2", newToken.Name, "Should return the new name")
|
||||
assert.Equal(t, token.ID, newToken.ID, "Should return the same ID")
|
||||
assert.Equal(t, token.Value, newToken.Value, "Should return the same value")
|
||||
assert.Equal(t, token.UserID, newToken.UserID, "Should return the same UserID")
|
||||
assert.Equal(t, token.CreatedAt, newToken.CreatedAt, "Should return the same CreatedAt")
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package handlerutils_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/util/handler"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChainHandler(t *testing.T) {
|
||||
var stopAt int
|
||||
|
||||
// Makes a chain handler factory that writes the status code whrn
|
||||
// it's n value matches the value of `stopAt`
|
||||
makeHandler := func(n int) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if stopAt == n {
|
||||
w.WriteHeader(200 + n)
|
||||
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Create 3 handlers and chain them together
|
||||
h1 := makeHandler(1)
|
||||
h2 := makeHandler(2)
|
||||
h3 := makeHandler(3)
|
||||
|
||||
ch := NewChainedHandler(h1, h2, h3)
|
||||
|
||||
stopAt = 2
|
||||
|
||||
assert.HTTPStatusCode(t, ch.ServeHTTP, "GET", "/", nil, 202)
|
||||
|
||||
stopAt = 1
|
||||
|
||||
assert.HTTPStatusCode(t, ch.ServeHTTP, "GET", "/", nil, 201)
|
||||
|
||||
stopAt = 3
|
||||
|
||||
assert.HTTPStatusCode(t, ch.ServeHTTP, "GET", "/", nil, 203)
|
||||
|
||||
stopAt = 4
|
||||
// Since it isn't handled by anything, it returns 200 by default
|
||||
assert.HTTPStatusCode(t, ch.ServeHTTP, "GET", "/", nil, 200)
|
||||
}
|
|
@ -67,7 +67,7 @@ func (h *ArgonHasher) Hash(password string) (string, error) {
|
|||
return encodedHash, nil
|
||||
}
|
||||
|
||||
func (h *ArgonHasher) VerifyPassword(password, encodedHash string) (match bool, err error) {
|
||||
func (h *ArgonHasher) Verify(password, encodedHash string) (match bool, err error) {
|
||||
// Extract the parameters, salt and derived key from the encoded password
|
||||
// hash.
|
||||
param, salt, hash, err := h.decodeHash(encodedHash)
|
||||
|
@ -85,7 +85,7 @@ func (h *ArgonHasher) VerifyPassword(password, encodedHash string) (match bool,
|
|||
return false, nil
|
||||
}
|
||||
|
||||
func (h *ArgonHasher) decodeHash(encodedHash string) (p argonParams, salt, hash []byte, err error) {
|
||||
func (h *ArgonHasher) decodeHash(encodedHash string) (p argonParams, salt, hash []byte, err error) { //nolint:cyclop // Not refactoring this
|
||||
algorithm, strSalt, strHash, params, err := passwords.DecodePasswordHash(encodedHash)
|
||||
if err != nil {
|
||||
return p, nil, nil, fmt.Errorf("failed to decode hash: %w", err)
|
||||
|
@ -97,15 +97,33 @@ func (h *ArgonHasher) decodeHash(encodedHash string) (p argonParams, salt, hash
|
|||
|
||||
var version uint32
|
||||
|
||||
paramsFmtString := "v=%s,m=%s,t=%s,p=%s"
|
||||
_, err = fmt.Sscanf(
|
||||
fmt.Sprintf(paramsFmtString, params["v"], params["m"], params["t"], params["p"]),
|
||||
paramsFmtString,
|
||||
&version,
|
||||
&p.memory,
|
||||
&p.iterations,
|
||||
&p.parallelism,
|
||||
)
|
||||
if versionStr, ok := params["v"]; ok {
|
||||
_, err := fmt.Sscanf(versionStr, "%d", &version)
|
||||
if err != nil {
|
||||
return p, nil, nil, fmt.Errorf("failed to decode version in hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if iterationsStr, ok := params["t"]; ok {
|
||||
_, err := fmt.Sscanf(iterationsStr, "%d", &p.iterations)
|
||||
if err != nil {
|
||||
return p, nil, nil, fmt.Errorf("failed to decode iterations in hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if memoryStr, ok := params["m"]; ok {
|
||||
_, err := fmt.Sscanf(memoryStr, "%d", &p.memory)
|
||||
if err != nil {
|
||||
return p, nil, nil, fmt.Errorf("failed to decode memory in hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if parallelismStr, ok := params["p"]; ok {
|
||||
_, err := fmt.Sscanf(parallelismStr, "%d", &p.parallelism)
|
||||
if err != nil {
|
||||
return p, nil, nil, fmt.Errorf("failed to decode parallelism in hash: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return p, nil, nil, fmt.Errorf("failed to decode hash: %w", err)
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package argonpasswords_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/util/passwords"
|
||||
. "git.maronato.dev/maronato/goshort/internal/util/passwords/argon"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewArgonHasher(t *testing.T) {
|
||||
ah := NewArgonHasher()
|
||||
assert.Implements(t, (*passwords.PasswordHasher)(nil), ah, "NewArgonHasher should return a PasswordHasher")
|
||||
}
|
||||
|
||||
func TestArgonHash(t *testing.T) {
|
||||
ah := NewArgonHasher()
|
||||
|
||||
hash, err := ah.Hash("mypassword")
|
||||
|
||||
assert.Nil(t, err, "ArgonHash should not return an error for a valid password")
|
||||
assert.NotEmpty(t, hash, "ArgonHash should return a non-empty hash for a valid password")
|
||||
|
||||
hash2, _ := ah.Hash("mypassword")
|
||||
|
||||
assert.NotEqual(t, hash, hash2, "ArgonHash should return different hashes for the same password via salting")
|
||||
}
|
||||
|
||||
func TestArgonVerify(t *testing.T) {
|
||||
ah := NewArgonHasher()
|
||||
|
||||
hash, _ := ah.Hash("mypassword")
|
||||
|
||||
pass, err := ah.Verify("mypassword", hash)
|
||||
|
||||
assert.Nil(t, err, "ArgonVerify should not return an error for a valid password")
|
||||
assert.True(t, pass, "ArgonVerify should return true for a valid password")
|
||||
|
||||
pass2, err := ah.Verify("mypassword2", hash)
|
||||
|
||||
assert.Nil(t, err, "ArgonVerify should not return an error for an invalid password")
|
||||
assert.False(t, pass2, "ArgonVerify should return false for an invalid password")
|
||||
|
||||
_, err = ah.Verify("mypassword", "invalidhash")
|
||||
|
||||
assert.ErrorIs(t, err, passwords.ErrInvalidHash, "ArgonVerify should return an error for an invalid hash")
|
||||
}
|
|
@ -38,7 +38,7 @@ func (h *BcryptHasher) Verify(password, encodedHash string) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("could not compare password with hash: %w", err)
|
||||
return false, passwords.ErrInvalidHash
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package bcryptpasswords_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/util/passwords"
|
||||
. "git.maronato.dev/maronato/goshort/internal/util/passwords/bcrypt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewBcryptHasher(t *testing.T) {
|
||||
bh := NewBcryptHasher()
|
||||
assert.Implements(t, (*passwords.PasswordHasher)(nil), bh, "NewBcryptHasher should return a PasswordHasher")
|
||||
}
|
||||
|
||||
func TestBcryptHash(t *testing.T) {
|
||||
bh := NewBcryptHasher()
|
||||
|
||||
hash, err := bh.Hash("mypassword")
|
||||
|
||||
assert.Nil(t, err, "BcryptHash should not return an error for a valid password")
|
||||
assert.NotEmpty(t, hash, "BcryptHash should return a non-empty hash for a valid password")
|
||||
|
||||
hash2, _ := bh.Hash("mypassword")
|
||||
|
||||
assert.NotEqual(t, hash, hash2, "BcryptHash should return different hashes for the same password via salting")
|
||||
}
|
||||
|
||||
func TestBcryptVerify(t *testing.T) {
|
||||
bh := NewBcryptHasher()
|
||||
|
||||
hash, _ := bh.Hash("mypassword")
|
||||
|
||||
pass, err := bh.Verify("mypassword", hash)
|
||||
|
||||
assert.Nil(t, err, "BcryptVerify should not return an error for a valid password")
|
||||
assert.True(t, pass, "BcryptVerify should return true for a valid password")
|
||||
|
||||
pass2, err := bh.Verify("mypassword2", hash)
|
||||
|
||||
assert.Nil(t, err, "BcryptVerify should not return an error for an invalid password")
|
||||
assert.False(t, pass2, "BcryptVerify should return false for an invalid password")
|
||||
|
||||
_, err = bh.Verify("mypassword", "invalidhash")
|
||||
|
||||
assert.ErrorIs(t, err, passwords.ErrInvalidHash, "BcryptVerify should return an error for an invalid hash")
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package passwords_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/util/passwords"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenerateSalt(t *testing.T) {
|
||||
r1 := GenerateSalt(8)
|
||||
|
||||
assert.Len(t, r1, 8, "GenerateSalt should return a slice with the length passed as argument")
|
||||
|
||||
r2 := GenerateSalt(8)
|
||||
|
||||
assert.NotEqual(t, r1, r2, "GenerateSalt should return different results for different calls")
|
||||
}
|
||||
|
||||
func TestAreEqual(t *testing.T) {
|
||||
assert.True(t, AreEqual([]byte("abc"), []byte("abc")), "AreEqual should return true for equal byte slices")
|
||||
assert.False(t, AreEqual([]byte("abc"), []byte("def")), "AreEqual should return false for different byte slices")
|
||||
}
|
||||
|
||||
func TestEncodePasswordHash(t *testing.T) {
|
||||
r1 := EncodePasswordHash("myalg", "mysalt", "myhash", map[string]string{"mykey": "myvalue"})
|
||||
|
||||
assert.Equal(t, "myalg$mykey=myvalue$mysalt$myhash", r1, "EncodePasswordHash should return a string with the correct format")
|
||||
}
|
||||
|
||||
func TestDecodePasswordHash(t *testing.T) {
|
||||
alg, salt, hash, params, err := DecodePasswordHash("myalg$mykey=myvalue$mysalt$myhash")
|
||||
|
||||
assert.Nil(t, err, "DecodePasswordHash should not return an error for a valid hash")
|
||||
assert.Equal(t, "myalg", alg, "DecodePasswordHash should return the correct algorithm")
|
||||
assert.Equal(t, "mysalt", salt, "DecodePasswordHash should return the correct salt")
|
||||
assert.Equal(t, "myhash", hash, "DecodePasswordHash should return the correct hash")
|
||||
assert.Equal(t, map[string]string{"mykey": "myvalue"}, params, "DecodePasswordHash should return the correct params")
|
||||
|
||||
_, _, _, _, err = DecodePasswordHash("myalg$mykey=myvalue$mysalt") //nolint:dogsled // This is a test
|
||||
|
||||
assert.NotNil(t, err, "DecodePasswordHash should return an error for an invalid hash")
|
||||
}
|
|
@ -33,14 +33,15 @@ func GenerateFromCharset(charset string, n int) string {
|
|||
return string(result)
|
||||
}
|
||||
|
||||
func GenerateSecureToken(length int) string {
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
const (
|
||||
SecureTokenCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
ShortCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
)
|
||||
|
||||
return GenerateFromCharset(charset, length)
|
||||
func GenerateSecureToken(length int) string {
|
||||
return GenerateFromCharset(SecureTokenCharset, length)
|
||||
}
|
||||
|
||||
func GenerateRandomShort(n int) string {
|
||||
charset := "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
return GenerateFromCharset(charset, n)
|
||||
return GenerateFromCharset(ShortCharset, n)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
package randomutil_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "git.maronato.dev/maronato/goshort/internal/util/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenerateRandomBytes(t *testing.T) {
|
||||
r1, err := GenerateRandomBytes(10)
|
||||
if err != nil {
|
||||
t.Errorf("Error testing GenerateRandomBytes: %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, r1, 10, "GenerateRandomBytes should return a slice with the length passed as argument")
|
||||
|
||||
r2, _ := GenerateRandomBytes(10)
|
||||
|
||||
assert.NotEqual(t, r1, r2, "GenerateRandomBytes should return different results for different calls")
|
||||
}
|
||||
|
||||
func TestGenerateFromCharset(t *testing.T) {
|
||||
r1 := GenerateFromCharset("abc", 10)
|
||||
|
||||
assert.Len(t, r1, 10, "GenerateFromCharset should return a string with the length passed as argument")
|
||||
|
||||
for _, char := range r1 {
|
||||
assert.Contains(t, "abc", string(char), "GenerateFromCharset should return a string with the characters passed as argument")
|
||||
}
|
||||
|
||||
r2 := GenerateFromCharset("abc", 10)
|
||||
|
||||
assert.NotEqual(t, r1, r2, "GenerateFromCharset should return different results for different calls")
|
||||
}
|
||||
|
||||
func TestGenerateSecureToken(t *testing.T) {
|
||||
r1 := GenerateSecureToken(10)
|
||||
|
||||
assert.Len(t, r1, 10, "GenerateSecureToken should return a string with the length passed as argument")
|
||||
|
||||
for _, char := range r1 {
|
||||
assert.Contains(t, SecureTokenCharset, string(char), "GenerateSecureToken should return a token with the characters from SecureTokenCharset")
|
||||
}
|
||||
|
||||
r2 := GenerateSecureToken(10)
|
||||
|
||||
assert.NotEqual(t, r1, r2, "GenerateSecureToken should return different results for different calls")
|
||||
}
|
||||
|
||||
func TestGenerateRandomShort(t *testing.T) {
|
||||
r1 := GenerateRandomShort(10)
|
||||
|
||||
assert.Len(t, r1, 10, "GenerateRandomShort should return a string with the length passed as argument")
|
||||
|
||||
for _, char := range r1 {
|
||||
assert.Contains(t, ShortCharset, string(char), "GenerateRandomShort should return a short with the characters from ShortCharset")
|
||||
}
|
||||
|
||||
r2 := GenerateRandomShort(10)
|
||||
|
||||
assert.NotEqual(t, r1, r2, "GenerateRandomShort should return different results for different calls")
|
||||
}
|
Loading…
Reference in New Issue
Block a user