short logs and better database access

This commit is contained in:
Gustavo Maronato 2023-08-23 19:30:12 -03:00
parent b143221181
commit 16bb373c60
Signed by: maronato
SSH Key Fingerprint: SHA256:2Gw7kwMz/As+2UkR1qQ/qYYhn+WNh3FGv6ozhoRrLcs
33 changed files with 1167 additions and 269 deletions

View File

@ -1,5 +1,7 @@
# ---> DB
*.db
*.db-shm
*.db-wal
# ---> Go
# If you prefer the allow list template instead of the deny list, see community template:

4
.gitignore vendored
View File

@ -1,5 +1,7 @@
# ---> DB
*.db
*.db-shm
*.db-wal
# ---> Go
# If you prefer the allow list template instead of the deny list, see community template:
@ -157,4 +159,4 @@ dist
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
.pnp.*

View File

@ -9,7 +9,6 @@ import (
"git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/storage"
memorystorage "git.maronato.dev/maronato/goshort/internal/storage/memory"
sqlitestorage "git.maronato.dev/maronato/goshort/internal/storage/sqlite"
"github.com/peterbourgon/ff/v3"
)
@ -60,7 +59,8 @@ func RegisterServerFlags(fs *flag.FlagSet, cfg *config.Config) {
func InitStorage(cfg *config.Config) storage.Storage {
switch cfg.DBType {
case config.DBTypeMemory:
return memorystorage.NewMemoryStorage()
cfg.DBURL = ":memory:"
return sqlitestorage.NewSQLiteStorage(cfg)
case config.DBTypeSQLite:
return sqlitestorage.NewSQLiteStorage(cfg)
default:

View File

@ -10,14 +10,22 @@ const ItemList = <T extends Record<string, unknown>, K extends keyof T>({
idKey: K
Item: FunctionComponent<T>
}) => {
console.log("list")
return (
<ul
role="list"
className="divide-y divide-gray-100 rounded-lg shadow-md overflow-hidden">
{items.map((item) => (
<Item {...item} key={item[idKey] as string} />
))}
</ul>
<>
<ul
role="list"
className="divide-y divide-gray-100 rounded-lg shadow-md overflow-hidden">
{items.map((item) => (
<Item {...item} key={item[idKey] as string} />
))}
</ul>
{items.length === 0 && (
<div className="text-center pt-5 text-xl font-light">
Nothing here yet
</div>
)}
</>
)
}

View File

@ -1,4 +1,4 @@
import { FunctionComponent, useCallback } from "react"
import { FunctionComponent, memo, useCallback } from "react"
import {
ArrowRightIcon,
@ -6,6 +6,7 @@ import {
} from "@heroicons/react/24/outline"
import { useDelete } from "../hooks/useCRUD"
import { useShortLogs, useVisitMetrics } from "../hooks/useStats"
import { Short } from "../types"
import ItemBase from "./ItemBase"
@ -24,17 +25,14 @@ const ShortItem: FunctionComponent<Short> = ({ ...short }) => {
// Handle deletion
const [deleting, del] = useDelete()
const doDelete = useCallback(
() => del({ name: short.name }),
[del, short.name]
)
const doDelete = useCallback(() => del({ id: short.id }), [del, short.id])
return (
<ItemBase
copyString={shortNameURL}
doDelete={doDelete}
deleting={deleting}
detailsPage={`/sht/${short.name}`}>
detailsPage={`/sht/${short.id}`}>
<div className="grid md:grid-cols-12 md:grid-rows-1 w-full grid-flow-row md:grid-flow-col">
<div className="col-span-5 md:col-span-3 my-auto flex flex-col order-1">
<a
@ -69,16 +67,111 @@ const ShortItem: FunctionComponent<Short> = ({ ...short }) => {
</a>
</div>
<div className="col-span-5 md:col-span-2 shrink-0 flex flex-col items-end my-auto order-2 md:order-4">
<p className="text-sm font-bold leading-6 text-gray-900">
<span className="text-green-500">1000</span> views
</p>
<p className="mt-1 text-xs leading-5 text-slate-400">
Last viewed <time dateTime="2023-01-23T13:23Z">3h ago</time>
</p>
<ShortItemStatsMemo key={short.id} id={short.id} />
</div>
</div>
</ItemBase>
)
}
export default ShortItem
export default memo(ShortItem, (prev, next) => prev.id === next.id)
const ShortItemStats: FunctionComponent<Pick<Short, "id">> = ({ id }) => {
const { loading, error, logs } = useShortLogs(id)
const { lastVisit, visits } = useVisitMetrics(logs)
if (error) {
return (
<p className="text-red-500 text-center font-medium">
Failed to load stats
</p>
)
}
const lastViewedISO = lastVisit
? new Date(lastVisit).toISOString()
: undefined
const lastViewedRelative = getRelativeTimeString(lastVisit)
return (
<>
<p className="text-sm font-bold leading-6 text-gray-900">
{loading && "Loading..."}
{loading || (
<>
<span className="text-green-500">{visits}</span> views
</>
)}
</p>
<p className="mt-1 text-xs leading-5 text-slate-400 flex flex-col text-right">
{loading && "Loading..."}
{loading || (
<>
<span className="-mb-1">Last viewed:</span>
<span className="font-medium">
<time dateTime={lastViewedISO}>{lastViewedRelative}</time>
</span>
</>
)}
</p>
</>
)
}
const ShortItemStatsMemo = memo(
ShortItemStats,
(prev, next) => prev.id === next.id
)
/**
* Convert a date to a relative time string, such as
* "a minute ago", "in 2 hours", "yesterday", "3 months ago", etc.
* using Intl.RelativeTimeFormat
* https://gist.github.com/LewisJEllis
*/
function getRelativeTimeString(date: Date | number | null): string {
if (!date) return "never"
// Allow dates or times to be passed
const timeMs = typeof date === "number" ? date : date.getTime()
// Get the amount of seconds between the given date and now
const deltaSeconds = Math.round((timeMs - Date.now()) / 1000)
// Array reprsenting one minute, hour, day, week, month, etc in seconds
const cutoffs = [
60,
3600,
86400,
86400 * 7,
86400 * 30,
86400 * 365,
Infinity,
]
// Array equivalent to the above but in the string representation of the units
const units: Intl.RelativeTimeFormatUnit[] = [
"second",
"minute",
"hour",
"day",
"week",
"month",
"year",
]
// Grab the ideal cutoff unit
const unitIndex = cutoffs.findIndex(
(cutoff) => cutoff > Math.abs(deltaSeconds)
)
// Get the divisor to divide from the seconds. E.g. if our unit is "day" our divisor
// is one day in seconds, so we can divide our seconds by this to get the # of days
const divisor = unitIndex ? cutoffs[unitIndex - 1] : 1
// Intl.RelativeTimeFormat do its magic
const rtf = new Intl.RelativeTimeFormat("en", {
numeric: "auto",
style: "narrow",
})
return rtf.format(Math.floor(deltaSeconds / divisor), units[unitIndex])
}

View File

@ -6,7 +6,8 @@ type Item = Record<string, unknown>
export const useLoadedItems = <T extends Item>() => {
const sourceDefault = useMemo(() => [], [])
const items = (useLoaderData() ?? sourceDefault) as T[]
const data = useLoaderData() ?? sourceDefault
const items = useMemo<T[]>(() => (Array.isArray(data) ? data : []), [data])
return items
}

View File

@ -0,0 +1,168 @@
import { useEffect, useMemo, useState } from "react"
import { ShortLog } from "../types"
import fetchAPI from "../util/fetchAPI"
export const useShortLogs = (shortID: string) => {
const [loading, setLoading] = useState(false)
const [error, setError] = useState("")
const [logs, setLogs] = useState<ShortLog[]>(useMemo(() => [], []))
// Fetch logs on mount
useEffect(() => {
const fetchLogs = async () => {
// Reset logs and error
setLogs([])
setError("")
// Fetch logs
const resp = await fetchAPI<ShortLog[]>(`/shorts/${shortID}/logs`)
if (resp.ok) {
// If response is ok, set logs
if (resp.data.length > 0) {
// Sort logs by date, descending
resp.data.sort((a, b) => {
return b.createdAt.localeCompare(a.createdAt)
})
// Set sorted logs
setLogs(resp.data)
}
} else {
// If response is not ok, set error
setError(resp.error)
}
// Now we're done loading
setLoading(false)
}
if (shortID) {
// If we have an ID, fetch logs
setLoading(true)
fetchLogs()
}
}, [shortID])
return { loading, error, logs }
}
const singleFilterOptions = ["equal", "greater", "lower", "notEqual"] as const
export type ShortLogSingleFilterOptions = (typeof singleFilterOptions)[number]
const listFilterOptions = ["in", "notIn"] as const
export type ShortLogListFilterOptions = (typeof listFilterOptions)[number]
type ShortLogKeys = keyof ShortLog
type SingleFilter<T extends keyof ShortLog> = {
[key in ShortLogSingleFilterOptions]?: ShortLog[T] | null
}
type ListFilter<T extends keyof ShortLog> = {
[key in ShortLogListFilterOptions]?: (ShortLog[T] | null)[]
}
export type ShortLogFilter = {
[key in ShortLogKeys]?: SingleFilter<key> | ListFilter<key>
}
export const useFilteredShortLogs = (
logs: ShortLog[],
filterList: ShortLogFilter[]
) => {
return useMemo(() => {
const start = performance.now()
// If there are no filters, return all logs
if (filterList.length === 0) return logs
// For each log
const result = logs.filter((log) => {
// If it matches any of the filters, return true
return filterList.some((filterSet) => {
// For each key in the filterSet
return Object.entries(filterSet).every(([key, filterMap]) => {
// If the key is not in the log, return false
if (!(key in log)) return false
// For each filter in the filterMap
return Object.entries(filterMap).every(
([filterType, filterValue]) => {
const logValue = log[key as keyof ShortLog]
// If filterType is single
if (
singleFilterOptions.includes(
filterType as unknown as ShortLogSingleFilterOptions
)
) {
switch (filterType) {
case "equal":
return logValue === filterValue
case "greater":
return logValue > filterValue
case "lower":
return logValue < filterValue
case "notEqual":
return logValue !== filterValue
}
} else if (
listFilterOptions.includes(
filterType as unknown as ShortLogListFilterOptions
)
) {
// If filterType is list
switch (filterType) {
case "in":
return filterValue.includes(logValue)
case "notIn":
return !filterValue.includes(logValue)
}
}
// If we get here,
// the filterType is not valid
return false
}
)
})
})
})
console.debug(
`Filtered ${logs.length} logs in ${performance.now() - start}ms`
)
return result
}, [logs, filterList])
}
export const useVisitMetrics = (logs: ShortLog[]) => {
const [visits, setVisits] = useState(0)
const [uniqueVisits, setUniqueVisits] = useState(0)
const [lastVisit, setLastVisit] = useState<Date | null>(null)
useEffect(() => {
// Reset visits
setUniqueVisits(0)
setLastVisit(null)
// Set number of visits
setVisits(logs.length)
// Skip work if there are no logs
if (logs.length === 0) return
// Keep track of unique IPs
const uniqueIPs = new Set()
let last = logs[0].createdAt
logs.forEach((log) => {
// If the log is older, update
if (log.createdAt.localeCompare(last) < 0) {
last = log.createdAt
}
// Add IP to set
uniqueIPs.add(log.ipAddress)
})
// Set unique visits
setUniqueVisits(uniqueIPs.size)
// Set last visit
setLastVisit(new Date(last))
}, [logs])
return { visits, uniqueVisits, lastVisit }
}

View File

@ -54,10 +54,10 @@ const useSubmitActions = ({
useEffect(() => {
if (onDelete && formData && actionData && actionData.ok) {
const data = actionData.data
const name = formData.get("name")
if (typeof data === "string" && name && typeof name === "string") {
const id = formData.get("id")
if (typeof data === "string" && id && typeof id === "string") {
// Deleting
onDelete(name)
onDelete(id)
}
}
}, [formData, actionData, onDelete])
@ -136,8 +136,8 @@ const useRecentShorts = () => {
onCreate: useCallback((short: Short) => {
setRecentShorts((prev) => [short, ...prev])
}, []),
onDelete: useCallback((name: string) => {
setRecentShorts((prev) => prev.filter((short) => short.name !== name))
onDelete: useCallback((id: string) => {
setRecentShorts((prev) => prev.filter((short) => short.id !== id))
}, []),
})
@ -213,7 +213,7 @@ export const Component: FunctionComponent = () => {
</Button>
</form>
<div className="mt-10">
<ItemList items={recentShorts} Item={ShortItem} idKey="name" />
<ItemList items={recentShorts} Item={ShortItem} idKey="id" />
</div>
</>
)

View File

@ -22,7 +22,7 @@ export const loader: LoaderFunction = async (args) => {
const resp = await protectedLoader(args)
if (resp) return resp
const data = await fetchAPI<Short>(`/shorts/${args.params.name}`)
const data = await fetchAPI<Short>(`/shorts/${args.params.id}`)
if (data.ok) {
return data.data
}

View File

@ -1,4 +1,4 @@
import { FunctionComponent, useCallback } from "react"
import { useCallback } from "react"
import { LoaderFunction, redirect } from "react-router-dom"
@ -16,19 +16,10 @@ export function Component() {
useCallback((a, b) => a.name.localeCompare(b.name), [])
)
const Shorts: FunctionComponent = () => {
return <ItemList items={items} Item={ShortItem} idKey="name" />
}
const NoShorts = () => {
return (
<div className="text-center pt-5 text-xl font-light">No shorts yet</div>
)
}
return (
<>
<Header title="Shorts" />
{items.length > 0 ? <Shorts /> : <NoShorts />}
<ItemList items={items} Item={ShortItem} idKey="id" key="list" />
</>
)
}
@ -59,13 +50,13 @@ export const action = crudAction({
})
},
DELETE: async (formData) => {
const name = formData.get("name") as string | null
const id = formData.get("id") as string | null
if (!name) {
if (!id) {
return { ok: false, error: "Invalid request" }
}
return fetchAPI<Short>(`/shorts/${name}`, {
return fetchAPI<Short>(`/shorts/${id}`, {
method: "DELETE",
})
},

View File

@ -52,7 +52,7 @@ export default createBrowserRouter([
},
{
id: "shortDetails",
path: "/sht/:name",
path: "/sht/:id",
lazy: () => import("./pages/ShortDetails"),
},
{

View File

@ -2,16 +2,28 @@ export type GenericItem = Record<string, unknown>
export type User = {
username: string
createdAt: string
}
export type Short = {
id: string
name: string
url: string
createdAt: string
}
export type ShortLog = {
id: string
shortID: string
ipAddress: string
userAgent: string
referer: string
createdAt: string
}
export type Session = {
id: string
ip: string
ipAddress: string
userAgent: string
lastActivity: string
createdAt: string

View File

@ -83,12 +83,12 @@ type Config struct {
func Validate(cfg *Config) error {
// Host and port have to be valid.
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil {
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid host and/or port: %s", err))
return errs.Errorf(fmt.Sprintf("invalid host and/or port: %s", err), errs.ErrInvalidConfig)
}
// UI port has to be valid.
if cfg.UIPort != "" {
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.UIPort)); err != nil {
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid UI port: %s", err))
return errs.Errorf(fmt.Sprintf("invalid UI port: %s", err), errs.ErrInvalidConfig)
}
}
if cfg.DBType != "" {
@ -101,7 +101,7 @@ func Validate(cfg *Config) error {
}
}
if !valid {
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid database type: %s", cfg.DBType))
return errs.Errorf(fmt.Sprintf("invalid database type: %s", cfg.DBType), errs.ErrInvalidConfig)
}
}

View File

@ -40,8 +40,10 @@ var (
ErrInvalidTokenID = errors.New("invalid token ID")
// ErrInvalidTokenName
ErrInvalidTokenName = errors.New("invalid token name")
// ErrDatabaseError
ErrDatabaseError = errors.New("database error")
)
func Error(err error, msg string) error {
return fmt.Errorf("%w: %s", err, msg)
func Errorf(msg string, err error) error {
return fmt.Errorf("%s: %w", msg, err)
}

View File

@ -194,7 +194,7 @@ func (h *APIHandler) CreateShort(w http.ResponseWriter, r *http.Request) {
}
// Set the user
short.User = user
short.UserID = &user.ID
// Shorten URL
newShort, err := h.shorts.Shorten(ctx, short)
@ -283,6 +283,28 @@ func (h *APIHandler) DeleteShort(w http.ResponseWriter, r *http.Request) {
render.NoContent(w, r)
}
func (h *APIHandler) ListShortLogs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Find own short or respond
short, ok := h.findShortOrRespond(w, r)
if !ok {
return
}
// Get logs
logs, err := h.shorts.ListLogs(ctx, short)
if err != nil {
server.RenderServerError(w, r, err)
return
}
// Render the response
render.Status(r, http.StatusOK)
render.JSON(w, r, logs)
}
func (h *APIHandler) ListSessions(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -451,11 +473,11 @@ func (h *APIHandler) findUserOrRespond(w http.ResponseWriter, r *http.Request) (
func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request) (short *models.Short, ok bool) {
ctx := r.Context()
// Get short name from request
name := chi.URLParam(r, "short")
// Get short id from request
id := chi.URLParam(r, "id")
// Find short in storage
short, err := h.shorts.FindShort(ctx, name)
short, err := h.shorts.FindShortByID(ctx, id)
if err != nil {
// If the short doesn't exist or is invalid, return not found
if errors.Is(err, errs.ErrShortDoesNotExist) || errors.Is(err, errs.ErrInvalidShort) {
@ -475,7 +497,7 @@ func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request)
// If the session user does not match the short's user,
// return forbidden.
if user.Username != short.User.Username {
if user.ID != *short.UserID {
server.RenderForbidden(w, r)
return nil, false
@ -514,7 +536,7 @@ func (h *APIHandler) findTokenOrRespond(w http.ResponseWriter, r *http.Request)
// If the session user does not match the token's user,
// return NotFound.
if user.Username != token.User.Username {
if user.ID != *token.UserID {
server.RenderNotFound(w, r)
return nil, false

View File

@ -26,8 +26,9 @@ func NewAPIRouter(h *APIHandler) http.Handler {
// Shorts routes
r.Get("/shorts", h.ListShorts)
r.Post("/shorts", h.CreateShort)
r.Get("/shorts/{short}", h.FindShort)
r.Delete("/shorts/{short}", h.DeleteShort)
r.Get("/shorts/{id}", h.FindShort)
r.Delete("/shorts/{id}", h.DeleteShort)
r.Get("/shorts/{id}/logs", h.ListShortLogs)
// Sessions routes
r.Get("/sessions", h.ListSessions)

View File

@ -15,11 +15,11 @@ import (
)
const (
sessionUserKey = "user"
sessionUserKey = "u"
sessionIPKey = "ip"
sessionUserAgentKey = "user_agent"
sessionLastActivityKey = "last_activity"
sessionCreatedAtKey = "created_at"
sessionUserAgentKey = "ua"
sessionLastActivityKey = "la"
sessionCreatedAtKey = "ca"
tokenHeader = "Authorization"
)
@ -28,8 +28,8 @@ type userContextKey struct{}
type AuthSessionData struct {
// Username is the username of the user to whom the session belongs.
Username string `json:"username"`
// IP is the last IP address used by with the session.
IP string `json:"ip"`
// IPAddress is the last IP address used by with the session.
IPAddress string `json:"ipAddress"`
// UserAgent is the last User-Agent used with the session.
UserAgent string `json:"userAgent"`
// LastActivity is the last time the session was used.
@ -97,7 +97,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) {
func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context) *AuthSessionData {
// Get data from session
username := manager.GetString(sessionCtx, sessionUserKey)
ip := manager.GetString(sessionCtx, sessionIPKey)
ipAddress := manager.GetString(sessionCtx, sessionIPKey)
userAgent := manager.GetString(sessionCtx, sessionUserAgentKey)
lastActivity, err := time.Parse(time.RFC3339, manager.GetString(sessionCtx, sessionLastActivityKey))
if err != nil {
@ -111,7 +111,7 @@ func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context)
// Create new session data
sessionData := &AuthSessionData{
Username: username,
IP: ip,
IPAddress: ipAddress,
UserAgent: userAgent,
LastActivity: lastActivity,
CreatedAt: createdAt,
@ -124,12 +124,12 @@ func UpdateSession(ctx context.Context, r *http.Request) {
manager := SessionManagerFromCtx(ctx)
// Get data from request
ip := r.RemoteAddr
ipAddress := r.RemoteAddr
userAgent := r.UserAgent()
lastActivity := time.Now().Format(time.RFC3339)
// Update session
manager.Put(ctx, sessionIPKey, ip)
manager.Put(ctx, sessionIPKey, ipAddress)
manager.Put(ctx, sessionUserAgentKey, userAgent)
manager.Put(ctx, sessionLastActivityKey, lastActivity)
}
@ -298,14 +298,11 @@ func authenticateViaToken(r *http.Request, tokenService *tokenservice.TokenServi
return nil, errs.ErrTokenMissing
}
// Get token from storage
token, err := tokenService.FindToken(ctx, value)
// Get the token's user from storage
user, err = tokenService.FindTokenUserFromValue(ctx, value)
if err != nil {
return nil, fmt.Errorf("error authenticating via token: %w", err)
return nil, errs.Errorf("could not authenticate user", err)
}
// Get user from token
user = token.User
return user, nil
}

View File

@ -26,12 +26,13 @@ func NewServer(cfg *config.Config) *Server {
mux := chi.NewRouter()
// Register default middlewares
mux.Use(middleware.Recoverer)
mux.Use(middleware.RequestID)
mux.Use(middleware.RealIP)
mux.Use(middleware.Logger)
mux.Use(middleware.Recoverer)
mux.Use(servermiddleware.SessionManager(cfg))
mux.Use(middleware.Timeout(config.RequestTimeout))
mux.Use(middleware.Compress(5, "application/json"))
// Create the server
srv := &http.Server{

View File

@ -2,6 +2,7 @@ package shortserver
import (
"errors"
"fmt"
"net/http"
"git.maronato.dev/maronato/goshort/internal/errs"
@ -32,7 +33,13 @@ func (h *ShortHandler) FindShort(w http.ResponseWriter, r *http.Request) {
switch {
case err == nil:
// If there's no error, redirect to the URL
// If there's no error, log the access and redirect to the URL
err = h.service.LogShortAccess(ctx, short, r)
if err != nil {
// If there was an error logging the access, print a message and
// continue.
fmt.Printf("failed to log short access: %v\n", err)
}
http.Redirect(w, r, short.URL, http.StatusSeeOther)
case errors.Is(err, errs.ErrInvalidShort):
// If the short name is invalid, do nothing and let the static handler

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
@ -20,6 +21,8 @@ const (
MinShortLength = 4
// MaxShortLength is the maximum length of the short URL.
MaxShortLength = 20
// ShortIDLength is the length of the short ID.
ShortIDLength = 16
)
type ShortService struct {
@ -30,6 +33,33 @@ func NewShortService(db storage.Storage) *ShortService {
return &ShortService{db: db}
}
func (s *ShortService) LogShortAccess(ctx context.Context, short *models.Short, r *http.Request) error {
// Log the access
shortLog := &models.ShortLog{
ShortID: short.ID,
IPAddress: r.RemoteAddr,
UserAgent: r.UserAgent(),
Referer: r.Referer(),
}
err := s.db.CreateShortLog(ctx, shortLog)
if err != nil {
return fmt.Errorf("failed to log short access: %w", err)
}
return nil
}
func (s *ShortService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
// Get the logs from storage
logs, err := s.db.ListShortLogs(ctx, short)
if err != nil {
return nil, fmt.Errorf("failed to get short logs from storage: %w", err)
}
return logs, nil
}
func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Short, error) {
// Check if the short is valid
err := ShortNameIsValid(name)
@ -37,11 +67,26 @@ func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Shor
return &models.Short{}, fmt.Errorf("could not validate short: %w", err)
}
// Get the URL from storage
// Get the short from storage
short, err := s.db.FindShort(ctx, name)
if err != nil {
return short, fmt.Errorf("could not get URL from storage: %w", err)
return short, fmt.Errorf("could not get short from storage: %w", err)
}
return short, nil
}
func (s *ShortService) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
// Check if the ID is valid
if len(id) != ShortIDLength {
return &models.Short{}, errs.ErrInvalidShort
}
// Get the short from storage
short, err := s.db.FindShortByID(ctx, id)
if err != nil {
return short, fmt.Errorf("could not get short from storage: %w", err)
}
return short, nil
@ -116,13 +161,13 @@ var shortRegex = regexp.MustCompile(fmt.Sprintf("^%s$", shortPattern))
func ShortNameIsValid(name string) error {
if !shortRegex.MatchString(name) {
return errs.Error(
errs.ErrInvalidShort,
return errs.Errorf(
fmt.Sprintf(
"short must use only letters, numbers, underscores and dashes, and be between %d and %d characters long",
MinShortLength,
MaxShortLength,
),
errs.ErrInvalidShort,
)
}
@ -132,11 +177,11 @@ func ShortNameIsValid(name string) error {
func ShortURLIsValid(shortURL string) error {
parsedURL, err := url.ParseRequestURI(shortURL)
if err != nil {
return errs.Error(errs.ErrInvalidShort, "invalid URL")
return errs.Errorf("invalid URL", errs.ErrInvalidShort)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return errs.Error(errs.ErrInvalidShort, "invalid URL scheme")
return errs.Errorf("invalid URL scheme", errs.ErrInvalidShort)
}
return nil

View File

@ -8,7 +8,6 @@ import (
"git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage"
"git.maronato.dev/maronato/goshort/internal/storage/models"
shortutil "git.maronato.dev/maronato/goshort/internal/util/short"
tokenutil "git.maronato.dev/maronato/goshort/internal/util/token"
)
@ -60,7 +59,7 @@ func (s *TokenService) FindToken(ctx context.Context, value string) (*models.Tok
// FindTokenByID finds a token in the storage using its ID.
func (s *TokenService) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
// Check if the ID is valid
if len(id) != TokenIDLength {
if len(id) != models.IDLength {
return &models.Token{}, errs.ErrInvalidTokenID
}
@ -86,12 +85,10 @@ func (s *TokenService) ListTokens(ctx context.Context, user *models.User) ([]*mo
// CreateToken creates a new token for a user.
func (s *TokenService) CreateToken(ctx context.Context, user *models.User) (*models.Token, error) {
// Generate a new token
id := shortutil.GenerateRandomShort(TokenIDLength)
token := &models.Token{
ID: id,
Name: fmt.Sprintf("%s's token #%s", user.Username, id[:5]),
Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2),
User: user,
Name: fmt.Sprintf("%s's token", user.Username),
Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2),
UserID: &user.ID,
}
// Create the token in storage
@ -128,3 +125,20 @@ func (s *TokenService) ChangeTokenName(ctx context.Context, token *models.Token,
return newToken, nil
}
// FindTokenUserFromValue finds the user that owns a token.
func (s *TokenService) FindTokenUserFromValue(ctx context.Context, value string) (*models.User, error) {
// Get the token from storage
token, err := s.FindToken(ctx, value)
if err != nil {
return &models.User{}, errs.Errorf("could not get token from storage", err)
}
// Get the user from storage
user, err := s.db.FindUserByID(ctx, *token.UserID)
if err != nil {
return &models.User{}, errs.Errorf("could not get token user from storage", err)
}
return user, nil
}

View File

@ -131,7 +131,7 @@ func (s *UserService) SetPassword(ctx context.Context, user *models.User, newPas
func UsernameIsValid(username string) error {
if !UsernameRegex.MatchString(username) {
return errs.Error(errs.ErrInvalidUser, fmt.Sprintf("username must match %s", UsernameRegex))
return errs.Errorf(fmt.Sprintf("username must match %s", UsernameRegex), errs.ErrInvalidUser)
}
return nil

View File

@ -0,0 +1,8 @@
package bunstorage
import "errors"
var (
// ErrDatabaseError is the error returned when a generic database error occurs.
ErrDatabaseError = errors.New("database error")
)

View File

@ -7,42 +7,11 @@ import (
"github.com/uptrace/bun"
)
type ShortModel struct {
bun.BaseModel `bun:"table:shorts,alias:s" json:"-"`
ID int64 `bun:",pk,autoincrement" json:"id"`
// Name is the short's name, which is also the path to access it
Name string `bun:",unique,notnull" json:"name"`
// URL is the URL that the short will redirect to
URL string `bun:",notnull" json:"url"`
// Deleted is whether the short is soft-deleted or not
Deleted bool `bun:",notnull,default:false" json:"-"`
// CreatedAt is when the short was created (initialized by the storage)
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// DeletedAt is when the short was deleted
DeletedAt time.Time `bun:",null" json:"-"`
// UserID is the ID of the user that created the short
// This can be null if the short was deleted
UserID *int64 `json:"-"`
// User is the user that created the short
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
}
func (s *ShortModel) toShort() *models.Short {
return &models.Short{
Name: s.Name,
URL: s.URL,
CreatedAt: s.CreatedAt,
User: s.User.toUser(),
}
}
type UserModel struct {
bun.BaseModel `bun:"table:users,alias:u"`
// ID is the primary key
ID int64 `bun:",pk,autoincrement" json:"id"`
ID string `bun:",pk,unique,notnull" json:"id"`
// Username is the user's username
Username string `bun:",unique,notnull" json:"username"`
// Password is the user's password
@ -58,27 +27,60 @@ type UserModel struct {
func (u *UserModel) toUser() *models.User {
return models.NewAuthenticatableUser(&models.User{
ID: u.ID,
Username: u.Username,
CreatedAt: u.CreatedAt,
}, u.Password)
}
type ShortModel struct {
bun.BaseModel `bun:"table:shorts,alias:s" json:"-"`
ID string `bun:",pk,unique,notnull" json:"id"`
// Name is the short's name, which is also the path to access it
Name string `bun:",unique,notnull" json:"name"`
// URL is the URL that the short will redirect to
URL string `bun:",notnull" json:"url"`
// Deleted is whether the short is soft-deleted or not
Deleted bool `bun:",notnull,default:false" json:"-"`
// CreatedAt is when the short was created (initialized by the storage)
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// DeletedAt is when the short was deleted
DeletedAt time.Time `bun:",null" json:"-"`
// UserID is the ID of the user that created the short
// This can be null if the short was deleted
UserID *string `json:"-"`
// User is the user that created the short
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
}
func (s *ShortModel) toShort() *models.Short {
return &models.Short{
ID: s.ID,
Name: s.Name,
URL: s.URL,
CreatedAt: s.CreatedAt,
UserID: s.UserID,
}
}
type TokenModel struct {
bun.BaseModel `bun:"table:tokens,alias:t"`
// ID is the primary key
ID string `bun:",pk" json:"id"`
ID string `bun:",pk,unique,notnull" json:"id"`
// Name is the user-friendly name of the token
Name string `bun:",notnull" json:"name"`
// Value is the actual token
Value string `bun:",unique,notnull" json:"value"`
// UserID is the ID of the user that created the token
UserID *int64 `bun:",notnull" json:"-"`
// User is the user that created the token
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
// CreatedAt is when the token was created (initialized by the storage)
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// UserID is the ID of the user that created the token
UserID *string `bun:",notnull" json:"-"`
// User is the user that created the token
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
}
func (t *TokenModel) toToken() *models.Token {
@ -87,6 +89,37 @@ func (t *TokenModel) toToken() *models.Token {
Name: t.Name,
Value: t.Value,
CreatedAt: t.CreatedAt,
User: t.User.toUser(),
UserID: t.UserID,
}
}
type ShortLogModel struct {
bun.BaseModel `bun:"table:short_logs,alias:sl"`
// ID is the primary key
ID string `bun:",pk,unique,notnull" json:"id"`
// IPAddress is the IP address of the client that accessed the short
IPAddress string `bun:"," json:"ipAddress"`
// UserAgent is the User-Agent of the client that accessed the short
UserAgent string `bun:"," json:"userAgent"`
// Referer is the referer of the client that accessed the short
Referer string `bun:"," json:"referer"`
// CreatedAt is when the short was accessed
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// ShortID is the ID of the short that was accessed
ShortID string `bun:",notnull" json:"-"`
// Short is the short that was accessed
Short *ShortModel `bun:"rel:belongs-to,join:short_id=id" json:"short,omitempty"`
}
func (sl *ShortLogModel) toShortLog() *models.ShortLog {
return &models.ShortLog{
ID: sl.ID,
IPAddress: sl.IPAddress,
UserAgent: sl.UserAgent,
Referer: sl.Referer,
CreatedAt: sl.CreatedAt,
ShortID: sl.ShortID,
}
}

View File

@ -3,12 +3,10 @@ package bunstorage
import (
"context"
"database/sql"
"fmt"
"time"
"git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage"
"git.maronato.dev/maronato/goshort/internal/storage/models"
"github.com/uptrace/bun"
@ -16,11 +14,14 @@ import (
)
type BunStorage struct {
storage.Storage
db *bun.DB
// StartHooks are ran after the database is connected and the tables are created.
StartHooks []func(ctx context.Context, db *bun.DB) error
// StopHooks are ran before the database is closed.
StopHook []func(ctx context.Context, db *bun.DB) error
}
func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage {
func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
if cfg.Debug {
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
}
@ -30,8 +31,18 @@ func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage {
}
}
// RegisterStartHook registers a hook to be ran after the database is connected and the tables are created.
func (s *BunStorage) RegisterStartHook(hook func(ctx context.Context, db *bun.DB) error) {
s.StartHooks = append(s.StartHooks, hook)
}
// RegisterStopHook registers a hook to be ran before the database is closed.
func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB) error) {
s.StopHook = append(s.StopHook, hook)
}
func (s *BunStorage) Start(ctx context.Context) error {
return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
_, err := tx.NewCreateTable().
IfNotExists().
@ -39,7 +50,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
WithForeignKeys().
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create users table: %w", err)
return errs.Errorf("failed to create users table", err)
}
_, err = tx.NewCreateTable().
@ -48,7 +59,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create shorts table: %w", err)
return errs.Errorf("failed to create shorts table", err)
}
_, err = tx.NewCreateTable().
@ -56,9 +67,17 @@ func (s *BunStorage) Start(ctx context.Context) error {
Model((*TokenModel)(nil)).
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create tokens table: %w", err)
return errs.Errorf("failed to create tokens table", err)
}
_, err = tx.NewCreateTable().
IfNotExists().
Model((*ShortLogModel)(nil)).
ForeignKey(`("short_id") REFERENCES "shorts" ("id") ON DELETE CASCADE`).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short logs table", err)
}
// shorts user_id index
@ -69,7 +88,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
Column("user_id").
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create shorts user_id index: %w", err)
return errs.Errorf("failed to create shorts user_id index", err)
}
// tokens user_id index
@ -80,14 +99,47 @@ func (s *BunStorage) Start(ctx context.Context) error {
Column("user_id").
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to create tokens user_id index: %w", err)
return errs.Errorf("failed to create tokens user_id index", err)
}
// short_logs short_id index
_, err = tx.NewCreateIndex().
IfNotExists().
Model((*ShortLogModel)(nil)).
Index("idx_short_logs_short_id").
Column("short_id").
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short_logs short_id index", err)
}
return nil
})
if err != nil {
return errs.Errorf("failed to start storage", err)
}
// Run start hooks
for _, hook := range s.StartHooks {
err := hook(ctx, s.db)
if err != nil {
return errs.Errorf("failed to start storage", err)
}
}
return nil
}
func (s *BunStorage) Stop(ctx context.Context) error {
// Run stop hooks
for _, hook := range s.StopHook {
err := hook(ctx, s.db)
if err != nil {
return errs.Errorf("failed to stop storage", err)
}
}
return s.db.Close()
}
@ -97,71 +149,95 @@ func (s *BunStorage) FindShort(ctx context.Context, name string) (*models.Short,
err := s.db.NewSelect().
Model(shortModel).
Where("name = ? and deleted = false", name).
Relation("User").
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, errs.ErrShortDoesNotExist
err = errs.ErrShortDoesNotExist
}
return nil, fmt.Errorf("failed to find short: %w", err)
return nil, errs.Errorf("failed to find short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
shortModel := new(ShortModel)
err := s.db.NewSelect().
Model(shortModel).
Where("id = ? and deleted = false", id).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
err = errs.ErrShortDoesNotExist
}
return nil, errs.Errorf("failed to find short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
// Get user from username
user, err := findUser(ctx, s.db, short.User.Username)
// Generate an ID that does not exist
tableName := s.db.NewSelect().
Model((*ShortModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil {
return nil, err
return nil, errs.Errorf("failed to create short", err)
}
shortModel := &ShortModel{
ID: newID,
Name: short.Name,
URL: short.URL,
UserID: &user.ID,
User: user,
UserID: short.UserID,
}
_, err = s.db.NewInsert().
Model(shortModel).
Exec(ctx)
if err != nil {
return nil, errs.ErrShortExists
return nil, errs.Errorf("failed to create short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error {
_, err := withShortDeleteUpdates(
s.db.NewUpdate().
Model((*ShortModel)(nil)).
Where("name = ?", short.Name),
).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete short: %w", err)
}
return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete short logs
_, err := tx.NewDelete().
Model((*ShortLogModel)(nil)).
Where("short_id = ?", short.ID).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete short", err)
}
return nil
// Delete short
_, err = withShortDeleteUpdates(
tx.NewUpdate().
Model((*ShortModel)(nil)).
Where("id = ?", short.ID),
).Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete short", err)
}
return nil
})
}
func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
// Get user ID from username
userID, err := findUserIDFromUsername(ctx, s.db, user.Username)
if err != nil {
return nil, err
}
shortModels := []*ShortModel{}
err = s.db.NewSelect().
err := s.db.NewSelect().
Model(&shortModels).
Where("user_id = ? and deleted = false", userID).
Relation("User").
Where("user_id = ? and deleted = false", user.ID).
Scan(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list shorts: %w", err)
return nil, errs.Errorf("failed to list shorts", err)
}
shorts := []*models.Short{}
@ -171,6 +247,51 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode
return shorts, nil
}
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
// Create new ID
tableName := s.db.NewSelect().
Model((*ShortLogModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil {
return errs.Errorf("failed to create short log", err)
}
shortLogModel := &ShortLogModel{
ID: newID,
ShortID: shortLog.ShortID,
IPAddress: shortLog.IPAddress,
UserAgent: shortLog.UserAgent,
Referer: shortLog.Referer,
}
_, err = s.db.NewInsert().
Model(shortLogModel).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short log", err)
}
return nil
}
func (s *BunStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
shortLogModels := []*ShortLogModel{}
err := s.db.NewSelect().
Model(&shortLogModels).
Where("short_id = ?", short.ID).
Scan(ctx)
if err != nil {
return nil, errs.Errorf("failed to list short logs", err)
}
shortLogs := []*models.ShortLog{}
for _, shortLogModel := range shortLogModels {
shortLogs = append(shortLogs, shortLogModel.toShortLog())
}
return shortLogs, nil
}
func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
user, err := findUser(ctx, s.db, username)
if err != nil {
@ -180,21 +301,39 @@ func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.Use
return user.toUser(), nil
}
func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
user, err := findUserByID(ctx, s.db, id)
if err != nil {
return nil, err
}
return user.toUser(), nil
}
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
// Create a new ID
tableName := s.db.NewSelect().
Model((*UserModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil {
return nil, errs.Errorf("failed to create user", err)
}
userModel := &UserModel{
ID: newID,
Username: user.Username,
Password: user.GetPasswordHash(),
}
_, err := s.db.NewInsert().
_, err = s.db.NewInsert().
Model(userModel).
Exec(ctx)
if err != nil {
return nil, errs.ErrUserExists
return nil, errs.Errorf("failed to create user", err)
}
return userModel.toUser(), nil
}
func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
@ -203,15 +342,15 @@ func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
// Delete user shorts
err := deleteUserShorts(ctx, tx, user)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
return errs.Errorf("failed to delete user", err)
}
// Delete user
_, err = tx.NewDelete().
Model(user).
Where("username = ?", user.Username).
Where("id = ?", user.ID).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
return errs.Errorf("failed to delete user", err)
}
return nil
@ -224,13 +363,13 @@ func (s *BunStorage) FindToken(ctx context.Context, value string) (*models.Token
err := s.db.NewSelect().
Model(tokenModel).
Where("value = ?", value).
Relation("User").
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, errs.ErrTokenDoesNotExist
err = errs.ErrTokenDoesNotExist
}
return nil, fmt.Errorf("failed to find token: %w", err)
return nil, errs.Errorf("failed to find token", err)
}
return tokenModel.toToken(), nil
@ -242,14 +381,14 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke
err := s.db.NewSelect().
Model(tokenModel).
Where("t.id = ?", id).
Relation("User").
Where("id = ?", id).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, errs.ErrTokenDoesNotExist
err = errs.ErrTokenDoesNotExist
}
return nil, fmt.Errorf("failed to find token: %w", err)
return nil, errs.Errorf("failed to find token", err)
}
return tokenModel.toToken(), nil
@ -259,19 +398,12 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke
func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
tokenModels := []*TokenModel{}
// Get user ID from username
userID, err := findUserIDFromUsername(ctx, s.db, user.Username)
if err != nil {
return nil, err
}
err = s.db.NewSelect().
err := s.db.NewSelect().
Model(&tokenModels).
Where("user_id = ?", userID).
Relation("User").
Where("user_id = ?", user.ID).
Scan(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list tokens: %w", err)
return nil, errs.Errorf("failed to list tokens", err)
}
tokens := []*models.Token{}
@ -283,26 +415,27 @@ func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*mode
}
func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
// Get user ID from username
user, err := findUser(ctx, s.db, token.User.Username)
// Create a new ID
tableName := s.db.NewSelect().
Model((*TokenModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil {
return nil, err
return nil, errs.Errorf("failed to create token", err)
}
tokenModel := &TokenModel{
ID: token.ID,
ID: newID,
Name: token.Name,
Value: token.Value,
UserID: &user.ID,
User: user,
UserID: token.UserID,
}
_, err = s.db.NewInsert().
Model(tokenModel).
Exec(ctx)
if err != nil {
return nil, errs.ErrTokenExists
return nil, errs.Errorf("failed to create token", err)
}
return tokenModel.toToken(), nil
@ -314,12 +447,27 @@ func (s *BunStorage) DeleteToken(ctx context.Context, token *models.Token) error
Where("id = ?", token.ID).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete token: %w", err)
return errs.Errorf("failed to delete token", err)
}
return nil
}
func (s *BunStorage) ChangeTokenName(ctx context.Context, token *models.Token, name string) (*models.Token, error) {
newToken := new(TokenModel)
_, err := s.db.NewUpdate().
Model((*TokenModel)(nil)).
Set("name = ?", name).
Where("id = ?", token.ID).
Exec(ctx, newToken)
if err != nil {
return nil, errs.Errorf("failed to change token name", err)
}
return newToken.toToken(), nil
}
func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) {
userModel := new(UserModel)
@ -329,31 +477,31 @@ func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, err
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, errs.ErrUserDoesNotExist
err = errs.ErrUserDoesNotExist
}
return nil, fmt.Errorf("failed to find user: %w", err)
return nil, errs.Errorf("failed to find user", err)
}
return userModel, nil
}
func findUserIDFromUsername(ctx context.Context, db bun.IDB, username string) (*int64, error) {
var userID int64
func findUserByID(ctx context.Context, db bun.IDB, id string) (*UserModel, error) {
userModel := new(UserModel)
err := db.NewSelect().
Table("users").
Column("id").
Where("username = ?", username).
Scan(ctx, &userID)
Model(userModel).
Where("id = ?", id).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, errs.ErrUserDoesNotExist
err = errs.ErrUserDoesNotExist
}
return nil, fmt.Errorf("failed to get user ID: %w", err)
return nil, errs.Errorf("failed to find user", err)
}
return &userID, nil
return userModel, nil
}
func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
@ -363,19 +511,46 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
}
func deleteUserShorts(ctx context.Context, db bun.IDB, user *models.User) error {
userID, err := findUserIDFromUsername(ctx, db, user.Username)
if err != nil {
return fmt.Errorf("failed to delete user shorts: %w", err)
}
_, err = withShortDeleteUpdates(
_, err := withShortDeleteUpdates(
db.NewUpdate().
Model((*ShortModel)(nil)).
Where("user_id = ?", userID),
Where("user_id = ?", user.ID),
).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete user shorts: %w", err)
return errs.Errorf("failed to delete user shorts", err)
}
return nil
}
func createNewID(ctx context.Context, db bun.IDB, table, col string) (string, error) {
var newID string
// Generate an ID that does not exist
maxIters := 10
for {
// Make sure we don't get stuck in an infinite loop
maxIters--
if maxIters <= 0 {
return "", errs.Errorf("failed to create unique ID", ErrDatabaseError)
}
// Create a new ID and check if it exists
newID = models.NewID()
count, err := db.NewSelect().
Table(table).
Column(col).
Where(col+" = ?", newID).
Count(ctx)
if err != nil {
return "", errs.Errorf("failed to create unique ID", err)
}
// If the ID does not exist, break the loop
if count == 0 {
break
}
}
return newID, nil
}

View File

@ -3,9 +3,11 @@ package memorystorage
import (
"context"
"fmt"
"runtime"
"sync"
"time"
"git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage"
"git.maronato.dev/maronato/goshort/internal/storage/models"
@ -13,87 +15,170 @@ import (
// MemoryStorage is a storage that stores everything in memory.
type MemoryStorage struct {
debug bool
storage.Storage
shortMu sync.RWMutex
userMu sync.RWMutex
tokenMu sync.RWMutex
shortMap map[string]*models.Short
userMap map[string]*models.User
tokenMap map[string]*models.Token
tokenIDMap map[string]*models.Token
shortMu sync.RWMutex
shortIDMu sync.RWMutex
shortLogMu sync.RWMutex
userMu sync.RWMutex
userIDMu sync.RWMutex
tokenMu sync.RWMutex
tokenIDMu sync.RWMutex
shortMap map[string]*models.Short
shortIDMap map[string]*models.Short
shortLogMap map[string][]*models.ShortLog
userMap map[string]*models.User
userIDMap map[string]*models.User
tokenMap map[string]*models.Token
tokenIDMap map[string]*models.Token
}
// NewMemoryStorage creates a new MemoryStorage.
func NewMemoryStorage() storage.Storage {
func NewMemoryStorage(cfg *config.Config) *MemoryStorage {
return &MemoryStorage{
shortMap: make(map[string]*models.Short),
userMap: make(map[string]*models.User),
tokenMap: make(map[string]*models.Token),
tokenIDMap: make(map[string]*models.Token),
debug: cfg.Debug,
shortMap: make(map[string]*models.Short),
shortIDMap: make(map[string]*models.Short),
userMap: make(map[string]*models.User),
userIDMap: make(map[string]*models.User),
tokenMap: make(map[string]*models.Token),
tokenIDMap: make(map[string]*models.Token),
shortLogMap: make(map[string][]*models.ShortLog),
}
}
// logPerformance is a helper function to log the performance of a function.
func (s *MemoryStorage) logPerformance() func() {
start := time.Now()
return func() {
elapsed := time.Since(start)
if s.debug {
pc, _, _, ok := runtime.Caller(1)
if !ok {
return
}
method := runtime.FuncForPC(pc).Name()
fmt.Printf("%s took %s\n", method, elapsed)
}
}
}
// Start starts the storage.
func (s *MemoryStorage) Start(ctx context.Context) error {
logPerf := s.logPerformance()
defer logPerf()
return nil
}
// Stop stops the storage.
func (s *MemoryStorage) Stop(ctx context.Context) error {
logPerf := s.logPerformance()
defer logPerf()
return nil
}
// FindShort finds a short in the storage.
func (s *MemoryStorage) FindShort(ctx context.Context, name string) (*models.Short, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortMu.RLock()
defer s.shortMu.RUnlock()
short, ok := s.shortMap[name]
if !ok {
return short, errs.ErrShortDoesNotExist
return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist)
}
return short, nil
}
// FindShortByID finds a short in the storage.
func (s *MemoryStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortIDMu.RLock()
defer s.shortIDMu.RUnlock()
short, ok := s.shortIDMap[id]
if !ok {
return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist)
}
return short, nil
}
func (s *MemoryStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortMu.Lock()
defer s.shortMu.Unlock()
s.shortIDMu.Lock()
defer s.shortIDMu.Unlock()
_, ok := s.shortMap[short.Name]
if ok {
return &models.Short{}, errs.ErrShortExists
return nil, errs.Errorf("failed to create short", errs.ErrShortExists)
}
short.CreatedAt = time.Now()
// Create a new ID
for {
short.ID = models.NewID()
_, ok := s.shortIDMap[short.ID]
if !ok {
break
}
}
s.shortMap[short.Name] = short
s.shortIDMap[short.ID] = short
return short, nil
}
func (s *MemoryStorage) DeleteShort(ctx context.Context, short *models.Short) error {
logPerf := s.logPerformance()
defer logPerf()
s.shortMu.Lock()
defer s.shortMu.Unlock()
s.shortIDMu.Lock()
defer s.shortIDMu.Unlock()
s.shortLogMu.Lock()
defer s.shortLogMu.Unlock()
_, ok := s.shortMap[short.Name]
_, ok := s.shortIDMap[short.ID]
if !ok {
return errs.ErrShortDoesNotExist
return errs.Errorf("failed to delete short", errs.ErrShortDoesNotExist)
}
delete(s.shortMap, short.Name)
delete(s.shortIDMap, short.ID)
delete(s.shortLogMap, short.ID)
return nil
}
func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
s.shortMu.RLock()
defer s.shortMu.RUnlock()
logPerf := s.logPerformance()
defer logPerf()
s.shortIDMu.RLock()
defer s.shortIDMu.RUnlock()
shorts := []*models.Short{}
for _, short := range s.shortMap {
if short.User != nil && short.User.Username == user.Username {
for _, short := range s.shortIDMap {
if short.UserID != nil && *short.UserID == user.ID {
shorts = append(shorts, short)
}
}
@ -101,108 +186,203 @@ func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*m
return shorts, nil
}
func (s *MemoryStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
logPerf := s.logPerformance()
defer logPerf()
s.shortLogMu.Lock()
defer s.shortLogMu.Unlock()
// Check if short exists
short, err := s.FindShortByID(ctx, shortLog.ShortID)
if err != nil {
return fmt.Errorf("could not create short log: %w", err)
}
// Create
_, ok := s.shortLogMap[short.ID]
if !ok {
s.shortLogMap[shortLog.ShortID] = []*models.ShortLog{}
}
shortLog.CreatedAt = time.Now()
// Create a new ID (we can't really know if it's unique, but it's unlikely to happen)
shortLog.ID = models.NewID()
s.shortLogMap[shortLog.ShortID] = append(s.shortLogMap[shortLog.ShortID], shortLog)
return nil
}
func (s *MemoryStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortLogMu.RLock()
defer s.shortLogMu.RUnlock()
shortLogs, ok := s.shortLogMap[short.ID]
if !ok {
return []*models.ShortLog{}, nil
}
return shortLogs, nil
}
func (s *MemoryStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
logPerf := s.logPerformance()
defer logPerf()
s.userMu.RLock()
defer s.userMu.RUnlock()
user, ok := s.userMap[username]
if !ok {
return user, errs.ErrUserDoesNotExist
return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist)
}
return user, nil
}
func (s *MemoryStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
logPerf := s.logPerformance()
defer logPerf()
s.userIDMu.RLock()
defer s.userIDMu.RUnlock()
user, ok := s.userIDMap[id]
if !ok {
return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist)
}
return user, nil
}
func (s *MemoryStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
logPerf := s.logPerformance()
defer logPerf()
s.userMu.Lock()
defer s.userMu.Unlock()
s.userIDMu.Lock()
defer s.userIDMu.Unlock()
_, ok := s.userMap[user.Username]
if ok {
return &models.User{}, errs.ErrUserExists
return nil, errs.Errorf("failed to create user", errs.ErrUserExists)
}
user.CreatedAt = time.Now()
// Create a new ID
for {
user.ID = models.NewID()
_, ok := s.userIDMap[user.ID]
if !ok {
break
}
}
s.userMap[user.Username] = user
s.userIDMap[user.ID] = user
return user, nil
}
func (s *MemoryStorage) DeleteUser(ctx context.Context, user *models.User) error {
logPerf := s.logPerformance()
defer logPerf()
s.userMu.Lock()
defer s.userMu.Unlock()
s.userIDMu.Lock()
defer s.userIDMu.Unlock()
errormsg := "failed to delete user"
_, ok := s.userMap[user.Username]
if !ok {
return errs.ErrUserDoesNotExist
return errs.Errorf(errormsg, errs.ErrUserDoesNotExist)
}
// Find all user shorts
shorts, err := s.ListShorts(ctx, user)
if err != nil {
return fmt.Errorf("could not list user shorts: %w", err)
return errs.Errorf(errormsg, errs.Errorf("could not list user shorts", err))
}
// Delete all user shorts
for _, short := range shorts {
err = s.DeleteShort(ctx, short)
if err != nil {
return fmt.Errorf("could not delete user short: %w", err)
return errs.Errorf(errormsg, errs.Errorf("could not delete user short", err))
}
}
// Find all user tokens
tokens, err := s.ListTokens(ctx, user)
if err != nil {
return fmt.Errorf("could not list user tokens: %w", err)
return errs.Errorf(errormsg, errs.Errorf("could not list user tokens", err))
}
// Delete all user tokens
for _, token := range tokens {
err = s.DeleteToken(ctx, token)
if err != nil {
return fmt.Errorf("could not delete user token: %w", err)
return errs.Errorf(errormsg, errs.Errorf("could not delete user token", err))
}
}
delete(s.userMap, user.Username)
delete(s.userIDMap, user.ID)
return nil
}
func (s *MemoryStorage) FindToken(ctx context.Context, value string) (*models.Token, error) {
logPerf := s.logPerformance()
defer logPerf()
s.tokenMu.RLock()
defer s.tokenMu.RUnlock()
token, ok := s.tokenMap[value]
if !ok {
return token, errs.ErrTokenDoesNotExist
return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist)
}
return token, nil
}
func (s *MemoryStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
s.tokenMu.RLock()
defer s.tokenMu.RUnlock()
logPerf := s.logPerformance()
defer logPerf()
s.tokenIDMu.RLock()
defer s.tokenIDMu.RUnlock()
token, ok := s.tokenIDMap[id]
if !ok {
return token, errs.ErrTokenDoesNotExist
return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist)
}
return token, nil
}
func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
s.tokenMu.RLock()
defer s.tokenMu.RUnlock()
logPerf := s.logPerformance()
defer logPerf()
s.tokenIDMu.RLock()
defer s.tokenIDMu.RUnlock()
tokens := []*models.Token{}
for _, token := range s.tokenMap {
if token.User != nil && token.User.Username == user.Username {
for _, token := range s.tokenIDMap {
if token.UserID != nil && *token.UserID == user.ID {
tokens = append(tokens, token)
}
}
@ -211,20 +391,35 @@ func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*m
}
func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
logPerf := s.logPerformance()
defer logPerf()
s.tokenMu.Lock()
defer s.tokenMu.Unlock()
s.tokenIDMu.Lock()
defer s.tokenIDMu.Unlock()
_, ok := s.tokenMap[token.Value]
if ok {
return &models.Token{}, errs.ErrTokenExists
return nil, errs.Errorf("failed to create token", errs.ErrTokenExists)
}
_, ok = s.tokenIDMap[token.ID]
if ok {
return &models.Token{}, errs.ErrTokenExists
return nil, errs.Errorf("failed to create token", errs.ErrTokenExists)
}
token.CreatedAt = time.Now()
// Create a new ID
for {
token.ID = models.NewID()
_, ok := s.tokenIDMap[token.ID]
if !ok {
break
}
}
s.tokenMap[token.Value] = token
s.tokenIDMap[token.ID] = token
@ -232,16 +427,21 @@ func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*
}
func (s *MemoryStorage) DeleteToken(ctx context.Context, token *models.Token) error {
logPerf := s.logPerformance()
defer logPerf()
s.tokenMu.Lock()
defer s.tokenMu.Unlock()
s.tokenIDMu.Lock()
defer s.tokenIDMu.Unlock()
_, ok := s.tokenMap[token.Value]
if !ok {
return errs.ErrTokenDoesNotExist
return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist)
}
_, ok = s.tokenIDMap[token.ID]
if !ok {
return errs.ErrTokenDoesNotExist
return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist)
}
delete(s.tokenMap, token.Value)

View File

@ -0,0 +1,12 @@
package models
import tokenutil "git.maronato.dev/maronato/goshort/internal/util/token"
const (
// IDLength is the length of IDs.
IDLength = 16
)
func NewID() string {
return tokenutil.GenerateSecureToken(IDLength / 2)
}

View File

@ -2,14 +2,20 @@ package models
import "time"
const (
ShortIDLength = 16
)
type Short struct {
// ID is the unique identifier of the short.
ID string `json:"id"`
// Name is the shortened name of the URL.
Name string `json:"name,omitempty"`
// URL is the URL that is shortened.
URL string `json:"url"`
// CreatedAt is the time the short was created.
CreatedAt time.Time `json:"created_at,omitempty"`
CreatedAt time.Time `json:"createdAt,omitempty"`
// User is the user that created the short.
User *User `json:"-"`
// UserID of the user that created the short.
UserID *string `json:"-"`
}

View File

@ -0,0 +1,18 @@
package models
import "time"
type ShortLog struct {
// ID is the unique identifier of the short log.
ID string `json:"id"`
// ShortID is the ID of the short that was accessed.
ShortID string `json:"shortID"`
// IPAddress is the IP address of the client that accessed the short.
IPAddress string `json:"ipAddress"`
// UserAgent is the User-Agent of the client that accessed the short.
UserAgent string `json:"userAgent"`
// Referer is the referer of the client that accessed the short.
Referer string `json:"referer"`
// CreatedAt is the time the short was accessed.
CreatedAt time.Time `json:"createdAt,omitempty"`
}

View File

@ -12,6 +12,6 @@ type Token struct {
// CreatedAt is when the token was created (initialized by the storage)
CreatedAt time.Time `json:"createdAt"`
// User is the user that created the token.
User *User `json:"-"`
// UserID of the user that created the token.
UserID *string `json:"-"`
}

View File

@ -8,11 +8,15 @@ import (
)
type User struct {
// ID is the user's ID.
ID string `json:"id"`
// Username is the user's username.
Username string `json:"username"`
// password is the user's password.
password string `json:"-"`
// CreatedAt is the time the user was created.
CreatedAt time.Time `json:"created_at,omitempty"`
CreatedAt time.Time `json:"createdAt,omitempty"`
}
// NewAuthenticatableUser is a elper function for storages that takes
@ -20,6 +24,7 @@ type User struct {
func NewAuthenticatableUser(user *User, hashedPass string) *User {
return &User{
ID: user.ID,
Username: user.Username,
CreatedAt: user.CreatedAt,
password: hashedPass,

View File

@ -1,9 +1,12 @@
package sqlitestorage
import (
"context"
"database/sql"
"time"
"git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage"
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
"github.com/uptrace/bun"
@ -13,13 +16,74 @@ import (
// NewSQLiteStorage creates a new SQLite storage.
func NewSQLiteStorage(cfg *config.Config) storage.Storage {
sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL)
// Create a new SQLite database with the following pragmas enabled:
// - journal_mode=WAL: Enables Write-Ahead Logging, which allows for concurrent reads and writes.
// - foreign_keys=ON: Enables foreign key constraints.
// - synchronous=NORMAL: Enables synchronous mode NORMAL
sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL+"?_pragma=foreign_keys(ON)&_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)")
if err != nil {
panic(err)
}
// If running the DB in memory, make sure
// database/sql does not close idle connections.
// Otherwise, the database will be lost.
if cfg.DBType == config.DBTypeMemory {
sqldb.SetMaxIdleConns(1000)
sqldb.SetConnMaxLifetime(0)
}
db := bun.NewDB(sqldb, sqlitedialect.New())
// Use Bun as the storage implementation.
return bunstorage.NewBunStorage(cfg, db)
storage := bunstorage.NewBunStorage(cfg, db)
// Config vacuuming of the database every hour.
ticker := time.NewTicker(time.Hour)
tickerCtx, tickerCancel := context.WithCancel(context.Background())
// Start vacuuming the database when the storage is started.
storage.RegisterStartHook(func(ctx context.Context, activeDB *bun.DB) error {
// Vacuum now
_, err := activeDB.NewRaw("VACUUM").Exec(ctx)
if err != nil {
return errs.Errorf("failed to vacuum database: %w", err)
}
// Vacuum every hour, until the context is canceled.
go func() {
for {
select {
case <-ticker.C:
_, err := activeDB.NewRaw("VACUUM").Exec(ctx)
if err != nil {
panic(errs.Errorf("failed to vacuum database: %w", err))
}
case <-ctx.Done():
return
case <-tickerCtx.Done():
return
}
}
}()
return nil
})
// So some cleaning up and optimization when the storage is stopped.
storage.RegisterStopHook(func(ctx context.Context, activeDB *bun.DB) error {
// Stop the ticker and cancel the ticker context.
ticker.Stop()
tickerCancel()
// Run PRAGMA optimize to optimize the database.
// By this point, the passed context is likely cancelled, so we use the
// background context since otherwise the query will fail.
// This will delay the shutdown of the storage, but it'll be for a good cause.
// The user can force the shutdown by using SIGINT or SIGTERM.
_, err := activeDB.NewRaw("PRAGMA optimize").Exec(context.Background())
if err != nil {
return errs.Errorf("failed to optimize database: %w", err)
}
return nil
})
return storage
}

View File

@ -14,6 +14,8 @@ type Storage interface {
// FindShort finds a short in the storage using its name.
FindShort(ctx context.Context, name string) (*models.Short, error)
// FindShortByID finds a short in the storage using its ID.
FindShortByID(ctx context.Context, id string) (*models.Short, error)
// FindShorts finds all shorts in the storage that belong to a user.
ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error)
// CreateShort creates a short in the storage.
@ -21,10 +23,19 @@ type Storage interface {
// DeleteShort deletes a short from the storage.
DeleteShort(ctx context.Context, short *models.Short) error
// ShortLog Storage
// CreateShortLog creates a short log in the storage.
CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error
// ListShortLogs finds all short logs in the storage that belong to a short.
ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error)
// User Storage
// FindUser finds a user in the storage using its username.
FindUser(ctx context.Context, username string) (*models.User, error)
// FindUserByID finds a user in the storage using its ID.
FindUserByID(ctx context.Context, id string) (*models.User, error)
// CreateUser creates a user in the storage.
CreateUser(ctx context.Context, user *models.User) (*models.User, error)
// DeleteUser deletes a user and all their shorts from the storage.