short logs and better database access

This commit is contained in:
Gustavo Maronato 2023-08-23 19:30:12 -03:00
parent b143221181
commit 16bb373c60
Signed by: maronato
SSH Key Fingerprint: SHA256:2Gw7kwMz/As+2UkR1qQ/qYYhn+WNh3FGv6ozhoRrLcs
33 changed files with 1167 additions and 269 deletions

View File

@ -1,5 +1,7 @@
# ---> DB # ---> DB
*.db *.db
*.db-shm
*.db-wal
# ---> Go # ---> Go
# If you prefer the allow list template instead of the deny list, see community template: # If you prefer the allow list template instead of the deny list, see community template:

4
.gitignore vendored
View File

@ -1,5 +1,7 @@
# ---> DB # ---> DB
*.db *.db
*.db-shm
*.db-wal
# ---> Go # ---> Go
# If you prefer the allow list template instead of the deny list, see community template: # If you prefer the allow list template instead of the deny list, see community template:
@ -157,4 +159,4 @@ dist
.yarn/unplugged .yarn/unplugged
.yarn/build-state.yml .yarn/build-state.yml
.yarn/install-state.gz .yarn/install-state.gz
.pnp.* .pnp.*

View File

@ -9,7 +9,6 @@ import (
"git.maronato.dev/maronato/goshort/internal/config" "git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage"
memorystorage "git.maronato.dev/maronato/goshort/internal/storage/memory"
sqlitestorage "git.maronato.dev/maronato/goshort/internal/storage/sqlite" sqlitestorage "git.maronato.dev/maronato/goshort/internal/storage/sqlite"
"github.com/peterbourgon/ff/v3" "github.com/peterbourgon/ff/v3"
) )
@ -60,7 +59,8 @@ func RegisterServerFlags(fs *flag.FlagSet, cfg *config.Config) {
func InitStorage(cfg *config.Config) storage.Storage { func InitStorage(cfg *config.Config) storage.Storage {
switch cfg.DBType { switch cfg.DBType {
case config.DBTypeMemory: case config.DBTypeMemory:
return memorystorage.NewMemoryStorage() cfg.DBURL = ":memory:"
return sqlitestorage.NewSQLiteStorage(cfg)
case config.DBTypeSQLite: case config.DBTypeSQLite:
return sqlitestorage.NewSQLiteStorage(cfg) return sqlitestorage.NewSQLiteStorage(cfg)
default: default:

View File

@ -10,14 +10,22 @@ const ItemList = <T extends Record<string, unknown>, K extends keyof T>({
idKey: K idKey: K
Item: FunctionComponent<T> Item: FunctionComponent<T>
}) => { }) => {
console.log("list")
return ( return (
<ul <>
role="list" <ul
className="divide-y divide-gray-100 rounded-lg shadow-md overflow-hidden"> role="list"
{items.map((item) => ( className="divide-y divide-gray-100 rounded-lg shadow-md overflow-hidden">
<Item {...item} key={item[idKey] as string} /> {items.map((item) => (
))} <Item {...item} key={item[idKey] as string} />
</ul> ))}
</ul>
{items.length === 0 && (
<div className="text-center pt-5 text-xl font-light">
Nothing here yet
</div>
)}
</>
) )
} }

View File

@ -1,4 +1,4 @@
import { FunctionComponent, useCallback } from "react" import { FunctionComponent, memo, useCallback } from "react"
import { import {
ArrowRightIcon, ArrowRightIcon,
@ -6,6 +6,7 @@ import {
} from "@heroicons/react/24/outline" } from "@heroicons/react/24/outline"
import { useDelete } from "../hooks/useCRUD" import { useDelete } from "../hooks/useCRUD"
import { useShortLogs, useVisitMetrics } from "../hooks/useStats"
import { Short } from "../types" import { Short } from "../types"
import ItemBase from "./ItemBase" import ItemBase from "./ItemBase"
@ -24,17 +25,14 @@ const ShortItem: FunctionComponent<Short> = ({ ...short }) => {
// Handle deletion // Handle deletion
const [deleting, del] = useDelete() const [deleting, del] = useDelete()
const doDelete = useCallback( const doDelete = useCallback(() => del({ id: short.id }), [del, short.id])
() => del({ name: short.name }),
[del, short.name]
)
return ( return (
<ItemBase <ItemBase
copyString={shortNameURL} copyString={shortNameURL}
doDelete={doDelete} doDelete={doDelete}
deleting={deleting} deleting={deleting}
detailsPage={`/sht/${short.name}`}> detailsPage={`/sht/${short.id}`}>
<div className="grid md:grid-cols-12 md:grid-rows-1 w-full grid-flow-row md:grid-flow-col"> <div className="grid md:grid-cols-12 md:grid-rows-1 w-full grid-flow-row md:grid-flow-col">
<div className="col-span-5 md:col-span-3 my-auto flex flex-col order-1"> <div className="col-span-5 md:col-span-3 my-auto flex flex-col order-1">
<a <a
@ -69,16 +67,111 @@ const ShortItem: FunctionComponent<Short> = ({ ...short }) => {
</a> </a>
</div> </div>
<div className="col-span-5 md:col-span-2 shrink-0 flex flex-col items-end my-auto order-2 md:order-4"> <div className="col-span-5 md:col-span-2 shrink-0 flex flex-col items-end my-auto order-2 md:order-4">
<p className="text-sm font-bold leading-6 text-gray-900"> <ShortItemStatsMemo key={short.id} id={short.id} />
<span className="text-green-500">1000</span> views
</p>
<p className="mt-1 text-xs leading-5 text-slate-400">
Last viewed <time dateTime="2023-01-23T13:23Z">3h ago</time>
</p>
</div> </div>
</div> </div>
</ItemBase> </ItemBase>
) )
} }
export default ShortItem export default memo(ShortItem, (prev, next) => prev.id === next.id)
const ShortItemStats: FunctionComponent<Pick<Short, "id">> = ({ id }) => {
const { loading, error, logs } = useShortLogs(id)
const { lastVisit, visits } = useVisitMetrics(logs)
if (error) {
return (
<p className="text-red-500 text-center font-medium">
Failed to load stats
</p>
)
}
const lastViewedISO = lastVisit
? new Date(lastVisit).toISOString()
: undefined
const lastViewedRelative = getRelativeTimeString(lastVisit)
return (
<>
<p className="text-sm font-bold leading-6 text-gray-900">
{loading && "Loading..."}
{loading || (
<>
<span className="text-green-500">{visits}</span> views
</>
)}
</p>
<p className="mt-1 text-xs leading-5 text-slate-400 flex flex-col text-right">
{loading && "Loading..."}
{loading || (
<>
<span className="-mb-1">Last viewed:</span>
<span className="font-medium">
<time dateTime={lastViewedISO}>{lastViewedRelative}</time>
</span>
</>
)}
</p>
</>
)
}
const ShortItemStatsMemo = memo(
ShortItemStats,
(prev, next) => prev.id === next.id
)
/**
* Convert a date to a relative time string, such as
* "a minute ago", "in 2 hours", "yesterday", "3 months ago", etc.
* using Intl.RelativeTimeFormat
* https://gist.github.com/LewisJEllis
*/
function getRelativeTimeString(date: Date | number | null): string {
if (!date) return "never"
// Allow dates or times to be passed
const timeMs = typeof date === "number" ? date : date.getTime()
// Get the amount of seconds between the given date and now
const deltaSeconds = Math.round((timeMs - Date.now()) / 1000)
// Array reprsenting one minute, hour, day, week, month, etc in seconds
const cutoffs = [
60,
3600,
86400,
86400 * 7,
86400 * 30,
86400 * 365,
Infinity,
]
// Array equivalent to the above but in the string representation of the units
const units: Intl.RelativeTimeFormatUnit[] = [
"second",
"minute",
"hour",
"day",
"week",
"month",
"year",
]
// Grab the ideal cutoff unit
const unitIndex = cutoffs.findIndex(
(cutoff) => cutoff > Math.abs(deltaSeconds)
)
// Get the divisor to divide from the seconds. E.g. if our unit is "day" our divisor
// is one day in seconds, so we can divide our seconds by this to get the # of days
const divisor = unitIndex ? cutoffs[unitIndex - 1] : 1
// Intl.RelativeTimeFormat do its magic
const rtf = new Intl.RelativeTimeFormat("en", {
numeric: "auto",
style: "narrow",
})
return rtf.format(Math.floor(deltaSeconds / divisor), units[unitIndex])
}

View File

@ -6,7 +6,8 @@ type Item = Record<string, unknown>
export const useLoadedItems = <T extends Item>() => { export const useLoadedItems = <T extends Item>() => {
const sourceDefault = useMemo(() => [], []) const sourceDefault = useMemo(() => [], [])
const items = (useLoaderData() ?? sourceDefault) as T[] const data = useLoaderData() ?? sourceDefault
const items = useMemo<T[]>(() => (Array.isArray(data) ? data : []), [data])
return items return items
} }

View File

@ -0,0 +1,168 @@
import { useEffect, useMemo, useState } from "react"
import { ShortLog } from "../types"
import fetchAPI from "../util/fetchAPI"
export const useShortLogs = (shortID: string) => {
const [loading, setLoading] = useState(false)
const [error, setError] = useState("")
const [logs, setLogs] = useState<ShortLog[]>(useMemo(() => [], []))
// Fetch logs on mount
useEffect(() => {
const fetchLogs = async () => {
// Reset logs and error
setLogs([])
setError("")
// Fetch logs
const resp = await fetchAPI<ShortLog[]>(`/shorts/${shortID}/logs`)
if (resp.ok) {
// If response is ok, set logs
if (resp.data.length > 0) {
// Sort logs by date, descending
resp.data.sort((a, b) => {
return b.createdAt.localeCompare(a.createdAt)
})
// Set sorted logs
setLogs(resp.data)
}
} else {
// If response is not ok, set error
setError(resp.error)
}
// Now we're done loading
setLoading(false)
}
if (shortID) {
// If we have an ID, fetch logs
setLoading(true)
fetchLogs()
}
}, [shortID])
return { loading, error, logs }
}
const singleFilterOptions = ["equal", "greater", "lower", "notEqual"] as const
export type ShortLogSingleFilterOptions = (typeof singleFilterOptions)[number]
const listFilterOptions = ["in", "notIn"] as const
export type ShortLogListFilterOptions = (typeof listFilterOptions)[number]
type ShortLogKeys = keyof ShortLog
type SingleFilter<T extends keyof ShortLog> = {
[key in ShortLogSingleFilterOptions]?: ShortLog[T] | null
}
type ListFilter<T extends keyof ShortLog> = {
[key in ShortLogListFilterOptions]?: (ShortLog[T] | null)[]
}
export type ShortLogFilter = {
[key in ShortLogKeys]?: SingleFilter<key> | ListFilter<key>
}
export const useFilteredShortLogs = (
logs: ShortLog[],
filterList: ShortLogFilter[]
) => {
return useMemo(() => {
const start = performance.now()
// If there are no filters, return all logs
if (filterList.length === 0) return logs
// For each log
const result = logs.filter((log) => {
// If it matches any of the filters, return true
return filterList.some((filterSet) => {
// For each key in the filterSet
return Object.entries(filterSet).every(([key, filterMap]) => {
// If the key is not in the log, return false
if (!(key in log)) return false
// For each filter in the filterMap
return Object.entries(filterMap).every(
([filterType, filterValue]) => {
const logValue = log[key as keyof ShortLog]
// If filterType is single
if (
singleFilterOptions.includes(
filterType as unknown as ShortLogSingleFilterOptions
)
) {
switch (filterType) {
case "equal":
return logValue === filterValue
case "greater":
return logValue > filterValue
case "lower":
return logValue < filterValue
case "notEqual":
return logValue !== filterValue
}
} else if (
listFilterOptions.includes(
filterType as unknown as ShortLogListFilterOptions
)
) {
// If filterType is list
switch (filterType) {
case "in":
return filterValue.includes(logValue)
case "notIn":
return !filterValue.includes(logValue)
}
}
// If we get here,
// the filterType is not valid
return false
}
)
})
})
})
console.debug(
`Filtered ${logs.length} logs in ${performance.now() - start}ms`
)
return result
}, [logs, filterList])
}
export const useVisitMetrics = (logs: ShortLog[]) => {
const [visits, setVisits] = useState(0)
const [uniqueVisits, setUniqueVisits] = useState(0)
const [lastVisit, setLastVisit] = useState<Date | null>(null)
useEffect(() => {
// Reset visits
setUniqueVisits(0)
setLastVisit(null)
// Set number of visits
setVisits(logs.length)
// Skip work if there are no logs
if (logs.length === 0) return
// Keep track of unique IPs
const uniqueIPs = new Set()
let last = logs[0].createdAt
logs.forEach((log) => {
// If the log is older, update
if (log.createdAt.localeCompare(last) < 0) {
last = log.createdAt
}
// Add IP to set
uniqueIPs.add(log.ipAddress)
})
// Set unique visits
setUniqueVisits(uniqueIPs.size)
// Set last visit
setLastVisit(new Date(last))
}, [logs])
return { visits, uniqueVisits, lastVisit }
}

View File

@ -54,10 +54,10 @@ const useSubmitActions = ({
useEffect(() => { useEffect(() => {
if (onDelete && formData && actionData && actionData.ok) { if (onDelete && formData && actionData && actionData.ok) {
const data = actionData.data const data = actionData.data
const name = formData.get("name") const id = formData.get("id")
if (typeof data === "string" && name && typeof name === "string") { if (typeof data === "string" && id && typeof id === "string") {
// Deleting // Deleting
onDelete(name) onDelete(id)
} }
} }
}, [formData, actionData, onDelete]) }, [formData, actionData, onDelete])
@ -136,8 +136,8 @@ const useRecentShorts = () => {
onCreate: useCallback((short: Short) => { onCreate: useCallback((short: Short) => {
setRecentShorts((prev) => [short, ...prev]) setRecentShorts((prev) => [short, ...prev])
}, []), }, []),
onDelete: useCallback((name: string) => { onDelete: useCallback((id: string) => {
setRecentShorts((prev) => prev.filter((short) => short.name !== name)) setRecentShorts((prev) => prev.filter((short) => short.id !== id))
}, []), }, []),
}) })
@ -213,7 +213,7 @@ export const Component: FunctionComponent = () => {
</Button> </Button>
</form> </form>
<div className="mt-10"> <div className="mt-10">
<ItemList items={recentShorts} Item={ShortItem} idKey="name" /> <ItemList items={recentShorts} Item={ShortItem} idKey="id" />
</div> </div>
</> </>
) )

View File

@ -22,7 +22,7 @@ export const loader: LoaderFunction = async (args) => {
const resp = await protectedLoader(args) const resp = await protectedLoader(args)
if (resp) return resp if (resp) return resp
const data = await fetchAPI<Short>(`/shorts/${args.params.name}`) const data = await fetchAPI<Short>(`/shorts/${args.params.id}`)
if (data.ok) { if (data.ok) {
return data.data return data.data
} }

View File

@ -1,4 +1,4 @@
import { FunctionComponent, useCallback } from "react" import { useCallback } from "react"
import { LoaderFunction, redirect } from "react-router-dom" import { LoaderFunction, redirect } from "react-router-dom"
@ -16,19 +16,10 @@ export function Component() {
useCallback((a, b) => a.name.localeCompare(b.name), []) useCallback((a, b) => a.name.localeCompare(b.name), [])
) )
const Shorts: FunctionComponent = () => {
return <ItemList items={items} Item={ShortItem} idKey="name" />
}
const NoShorts = () => {
return (
<div className="text-center pt-5 text-xl font-light">No shorts yet</div>
)
}
return ( return (
<> <>
<Header title="Shorts" /> <Header title="Shorts" />
{items.length > 0 ? <Shorts /> : <NoShorts />} <ItemList items={items} Item={ShortItem} idKey="id" key="list" />
</> </>
) )
} }
@ -59,13 +50,13 @@ export const action = crudAction({
}) })
}, },
DELETE: async (formData) => { DELETE: async (formData) => {
const name = formData.get("name") as string | null const id = formData.get("id") as string | null
if (!name) { if (!id) {
return { ok: false, error: "Invalid request" } return { ok: false, error: "Invalid request" }
} }
return fetchAPI<Short>(`/shorts/${name}`, { return fetchAPI<Short>(`/shorts/${id}`, {
method: "DELETE", method: "DELETE",
}) })
}, },

View File

@ -52,7 +52,7 @@ export default createBrowserRouter([
}, },
{ {
id: "shortDetails", id: "shortDetails",
path: "/sht/:name", path: "/sht/:id",
lazy: () => import("./pages/ShortDetails"), lazy: () => import("./pages/ShortDetails"),
}, },
{ {

View File

@ -2,16 +2,28 @@ export type GenericItem = Record<string, unknown>
export type User = { export type User = {
username: string username: string
createdAt: string
} }
export type Short = { export type Short = {
id: string
name: string name: string
url: string url: string
createdAt: string
}
export type ShortLog = {
id: string
shortID: string
ipAddress: string
userAgent: string
referer: string
createdAt: string
} }
export type Session = { export type Session = {
id: string id: string
ip: string ipAddress: string
userAgent: string userAgent: string
lastActivity: string lastActivity: string
createdAt: string createdAt: string

View File

@ -83,12 +83,12 @@ type Config struct {
func Validate(cfg *Config) error { func Validate(cfg *Config) error {
// Host and port have to be valid. // Host and port have to be valid.
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil { if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil {
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid host and/or port: %s", err)) return errs.Errorf(fmt.Sprintf("invalid host and/or port: %s", err), errs.ErrInvalidConfig)
} }
// UI port has to be valid. // UI port has to be valid.
if cfg.UIPort != "" { if cfg.UIPort != "" {
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.UIPort)); err != nil { if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.UIPort)); err != nil {
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid UI port: %s", err)) return errs.Errorf(fmt.Sprintf("invalid UI port: %s", err), errs.ErrInvalidConfig)
} }
} }
if cfg.DBType != "" { if cfg.DBType != "" {
@ -101,7 +101,7 @@ func Validate(cfg *Config) error {
} }
} }
if !valid { if !valid {
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid database type: %s", cfg.DBType)) return errs.Errorf(fmt.Sprintf("invalid database type: %s", cfg.DBType), errs.ErrInvalidConfig)
} }
} }

View File

@ -40,8 +40,10 @@ var (
ErrInvalidTokenID = errors.New("invalid token ID") ErrInvalidTokenID = errors.New("invalid token ID")
// ErrInvalidTokenName // ErrInvalidTokenName
ErrInvalidTokenName = errors.New("invalid token name") ErrInvalidTokenName = errors.New("invalid token name")
// ErrDatabaseError
ErrDatabaseError = errors.New("database error")
) )
func Error(err error, msg string) error { func Errorf(msg string, err error) error {
return fmt.Errorf("%w: %s", err, msg) return fmt.Errorf("%s: %w", msg, err)
} }

View File

@ -194,7 +194,7 @@ func (h *APIHandler) CreateShort(w http.ResponseWriter, r *http.Request) {
} }
// Set the user // Set the user
short.User = user short.UserID = &user.ID
// Shorten URL // Shorten URL
newShort, err := h.shorts.Shorten(ctx, short) newShort, err := h.shorts.Shorten(ctx, short)
@ -283,6 +283,28 @@ func (h *APIHandler) DeleteShort(w http.ResponseWriter, r *http.Request) {
render.NoContent(w, r) render.NoContent(w, r)
} }
func (h *APIHandler) ListShortLogs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Find own short or respond
short, ok := h.findShortOrRespond(w, r)
if !ok {
return
}
// Get logs
logs, err := h.shorts.ListLogs(ctx, short)
if err != nil {
server.RenderServerError(w, r, err)
return
}
// Render the response
render.Status(r, http.StatusOK)
render.JSON(w, r, logs)
}
func (h *APIHandler) ListSessions(w http.ResponseWriter, r *http.Request) { func (h *APIHandler) ListSessions(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
@ -451,11 +473,11 @@ func (h *APIHandler) findUserOrRespond(w http.ResponseWriter, r *http.Request) (
func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request) (short *models.Short, ok bool) { func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request) (short *models.Short, ok bool) {
ctx := r.Context() ctx := r.Context()
// Get short name from request // Get short id from request
name := chi.URLParam(r, "short") id := chi.URLParam(r, "id")
// Find short in storage // Find short in storage
short, err := h.shorts.FindShort(ctx, name) short, err := h.shorts.FindShortByID(ctx, id)
if err != nil { if err != nil {
// If the short doesn't exist or is invalid, return not found // If the short doesn't exist or is invalid, return not found
if errors.Is(err, errs.ErrShortDoesNotExist) || errors.Is(err, errs.ErrInvalidShort) { if errors.Is(err, errs.ErrShortDoesNotExist) || errors.Is(err, errs.ErrInvalidShort) {
@ -475,7 +497,7 @@ func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request)
// If the session user does not match the short's user, // If the session user does not match the short's user,
// return forbidden. // return forbidden.
if user.Username != short.User.Username { if user.ID != *short.UserID {
server.RenderForbidden(w, r) server.RenderForbidden(w, r)
return nil, false return nil, false
@ -514,7 +536,7 @@ func (h *APIHandler) findTokenOrRespond(w http.ResponseWriter, r *http.Request)
// If the session user does not match the token's user, // If the session user does not match the token's user,
// return NotFound. // return NotFound.
if user.Username != token.User.Username { if user.ID != *token.UserID {
server.RenderNotFound(w, r) server.RenderNotFound(w, r)
return nil, false return nil, false

View File

@ -26,8 +26,9 @@ func NewAPIRouter(h *APIHandler) http.Handler {
// Shorts routes // Shorts routes
r.Get("/shorts", h.ListShorts) r.Get("/shorts", h.ListShorts)
r.Post("/shorts", h.CreateShort) r.Post("/shorts", h.CreateShort)
r.Get("/shorts/{short}", h.FindShort) r.Get("/shorts/{id}", h.FindShort)
r.Delete("/shorts/{short}", h.DeleteShort) r.Delete("/shorts/{id}", h.DeleteShort)
r.Get("/shorts/{id}/logs", h.ListShortLogs)
// Sessions routes // Sessions routes
r.Get("/sessions", h.ListSessions) r.Get("/sessions", h.ListSessions)

View File

@ -15,11 +15,11 @@ import (
) )
const ( const (
sessionUserKey = "user" sessionUserKey = "u"
sessionIPKey = "ip" sessionIPKey = "ip"
sessionUserAgentKey = "user_agent" sessionUserAgentKey = "ua"
sessionLastActivityKey = "last_activity" sessionLastActivityKey = "la"
sessionCreatedAtKey = "created_at" sessionCreatedAtKey = "ca"
tokenHeader = "Authorization" tokenHeader = "Authorization"
) )
@ -28,8 +28,8 @@ type userContextKey struct{}
type AuthSessionData struct { type AuthSessionData struct {
// Username is the username of the user to whom the session belongs. // Username is the username of the user to whom the session belongs.
Username string `json:"username"` Username string `json:"username"`
// IP is the last IP address used by with the session. // IPAddress is the last IP address used by with the session.
IP string `json:"ip"` IPAddress string `json:"ipAddress"`
// UserAgent is the last User-Agent used with the session. // UserAgent is the last User-Agent used with the session.
UserAgent string `json:"userAgent"` UserAgent string `json:"userAgent"`
// LastActivity is the last time the session was used. // LastActivity is the last time the session was used.
@ -97,7 +97,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) {
func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context) *AuthSessionData { func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context) *AuthSessionData {
// Get data from session // Get data from session
username := manager.GetString(sessionCtx, sessionUserKey) username := manager.GetString(sessionCtx, sessionUserKey)
ip := manager.GetString(sessionCtx, sessionIPKey) ipAddress := manager.GetString(sessionCtx, sessionIPKey)
userAgent := manager.GetString(sessionCtx, sessionUserAgentKey) userAgent := manager.GetString(sessionCtx, sessionUserAgentKey)
lastActivity, err := time.Parse(time.RFC3339, manager.GetString(sessionCtx, sessionLastActivityKey)) lastActivity, err := time.Parse(time.RFC3339, manager.GetString(sessionCtx, sessionLastActivityKey))
if err != nil { if err != nil {
@ -111,7 +111,7 @@ func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context)
// Create new session data // Create new session data
sessionData := &AuthSessionData{ sessionData := &AuthSessionData{
Username: username, Username: username,
IP: ip, IPAddress: ipAddress,
UserAgent: userAgent, UserAgent: userAgent,
LastActivity: lastActivity, LastActivity: lastActivity,
CreatedAt: createdAt, CreatedAt: createdAt,
@ -124,12 +124,12 @@ func UpdateSession(ctx context.Context, r *http.Request) {
manager := SessionManagerFromCtx(ctx) manager := SessionManagerFromCtx(ctx)
// Get data from request // Get data from request
ip := r.RemoteAddr ipAddress := r.RemoteAddr
userAgent := r.UserAgent() userAgent := r.UserAgent()
lastActivity := time.Now().Format(time.RFC3339) lastActivity := time.Now().Format(time.RFC3339)
// Update session // Update session
manager.Put(ctx, sessionIPKey, ip) manager.Put(ctx, sessionIPKey, ipAddress)
manager.Put(ctx, sessionUserAgentKey, userAgent) manager.Put(ctx, sessionUserAgentKey, userAgent)
manager.Put(ctx, sessionLastActivityKey, lastActivity) manager.Put(ctx, sessionLastActivityKey, lastActivity)
} }
@ -298,14 +298,11 @@ func authenticateViaToken(r *http.Request, tokenService *tokenservice.TokenServi
return nil, errs.ErrTokenMissing return nil, errs.ErrTokenMissing
} }
// Get token from storage // Get the token's user from storage
token, err := tokenService.FindToken(ctx, value) user, err = tokenService.FindTokenUserFromValue(ctx, value)
if err != nil { if err != nil {
return nil, fmt.Errorf("error authenticating via token: %w", err) return nil, errs.Errorf("could not authenticate user", err)
} }
// Get user from token
user = token.User
return user, nil return user, nil
} }

View File

@ -26,12 +26,13 @@ func NewServer(cfg *config.Config) *Server {
mux := chi.NewRouter() mux := chi.NewRouter()
// Register default middlewares // Register default middlewares
mux.Use(middleware.Recoverer)
mux.Use(middleware.RequestID) mux.Use(middleware.RequestID)
mux.Use(middleware.RealIP) mux.Use(middleware.RealIP)
mux.Use(middleware.Logger) mux.Use(middleware.Logger)
mux.Use(middleware.Recoverer)
mux.Use(servermiddleware.SessionManager(cfg)) mux.Use(servermiddleware.SessionManager(cfg))
mux.Use(middleware.Timeout(config.RequestTimeout)) mux.Use(middleware.Timeout(config.RequestTimeout))
mux.Use(middleware.Compress(5, "application/json"))
// Create the server // Create the server
srv := &http.Server{ srv := &http.Server{

View File

@ -2,6 +2,7 @@ package shortserver
import ( import (
"errors" "errors"
"fmt"
"net/http" "net/http"
"git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/errs"
@ -32,7 +33,13 @@ func (h *ShortHandler) FindShort(w http.ResponseWriter, r *http.Request) {
switch { switch {
case err == nil: case err == nil:
// If there's no error, redirect to the URL // If there's no error, log the access and redirect to the URL
err = h.service.LogShortAccess(ctx, short, r)
if err != nil {
// If there was an error logging the access, print a message and
// continue.
fmt.Printf("failed to log short access: %v\n", err)
}
http.Redirect(w, r, short.URL, http.StatusSeeOther) http.Redirect(w, r, short.URL, http.StatusSeeOther)
case errors.Is(err, errs.ErrInvalidShort): case errors.Is(err, errs.ErrInvalidShort):
// If the short name is invalid, do nothing and let the static handler // If the short name is invalid, do nothing and let the static handler

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"regexp" "regexp"
@ -20,6 +21,8 @@ const (
MinShortLength = 4 MinShortLength = 4
// MaxShortLength is the maximum length of the short URL. // MaxShortLength is the maximum length of the short URL.
MaxShortLength = 20 MaxShortLength = 20
// ShortIDLength is the length of the short ID.
ShortIDLength = 16
) )
type ShortService struct { type ShortService struct {
@ -30,6 +33,33 @@ func NewShortService(db storage.Storage) *ShortService {
return &ShortService{db: db} return &ShortService{db: db}
} }
func (s *ShortService) LogShortAccess(ctx context.Context, short *models.Short, r *http.Request) error {
// Log the access
shortLog := &models.ShortLog{
ShortID: short.ID,
IPAddress: r.RemoteAddr,
UserAgent: r.UserAgent(),
Referer: r.Referer(),
}
err := s.db.CreateShortLog(ctx, shortLog)
if err != nil {
return fmt.Errorf("failed to log short access: %w", err)
}
return nil
}
func (s *ShortService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
// Get the logs from storage
logs, err := s.db.ListShortLogs(ctx, short)
if err != nil {
return nil, fmt.Errorf("failed to get short logs from storage: %w", err)
}
return logs, nil
}
func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Short, error) { func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Short, error) {
// Check if the short is valid // Check if the short is valid
err := ShortNameIsValid(name) err := ShortNameIsValid(name)
@ -37,11 +67,26 @@ func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Shor
return &models.Short{}, fmt.Errorf("could not validate short: %w", err) return &models.Short{}, fmt.Errorf("could not validate short: %w", err)
} }
// Get the URL from storage // Get the short from storage
short, err := s.db.FindShort(ctx, name) short, err := s.db.FindShort(ctx, name)
if err != nil { if err != nil {
return short, fmt.Errorf("could not get URL from storage: %w", err) return short, fmt.Errorf("could not get short from storage: %w", err)
}
return short, nil
}
func (s *ShortService) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
// Check if the ID is valid
if len(id) != ShortIDLength {
return &models.Short{}, errs.ErrInvalidShort
}
// Get the short from storage
short, err := s.db.FindShortByID(ctx, id)
if err != nil {
return short, fmt.Errorf("could not get short from storage: %w", err)
} }
return short, nil return short, nil
@ -116,13 +161,13 @@ var shortRegex = regexp.MustCompile(fmt.Sprintf("^%s$", shortPattern))
func ShortNameIsValid(name string) error { func ShortNameIsValid(name string) error {
if !shortRegex.MatchString(name) { if !shortRegex.MatchString(name) {
return errs.Error( return errs.Errorf(
errs.ErrInvalidShort,
fmt.Sprintf( fmt.Sprintf(
"short must use only letters, numbers, underscores and dashes, and be between %d and %d characters long", "short must use only letters, numbers, underscores and dashes, and be between %d and %d characters long",
MinShortLength, MinShortLength,
MaxShortLength, MaxShortLength,
), ),
errs.ErrInvalidShort,
) )
} }
@ -132,11 +177,11 @@ func ShortNameIsValid(name string) error {
func ShortURLIsValid(shortURL string) error { func ShortURLIsValid(shortURL string) error {
parsedURL, err := url.ParseRequestURI(shortURL) parsedURL, err := url.ParseRequestURI(shortURL)
if err != nil { if err != nil {
return errs.Error(errs.ErrInvalidShort, "invalid URL") return errs.Errorf("invalid URL", errs.ErrInvalidShort)
} }
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return errs.Error(errs.ErrInvalidShort, "invalid URL scheme") return errs.Errorf("invalid URL scheme", errs.ErrInvalidShort)
} }
return nil return nil

View File

@ -8,7 +8,6 @@ import (
"git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage"
"git.maronato.dev/maronato/goshort/internal/storage/models" "git.maronato.dev/maronato/goshort/internal/storage/models"
shortutil "git.maronato.dev/maronato/goshort/internal/util/short"
tokenutil "git.maronato.dev/maronato/goshort/internal/util/token" tokenutil "git.maronato.dev/maronato/goshort/internal/util/token"
) )
@ -60,7 +59,7 @@ func (s *TokenService) FindToken(ctx context.Context, value string) (*models.Tok
// FindTokenByID finds a token in the storage using its ID. // FindTokenByID finds a token in the storage using its ID.
func (s *TokenService) FindTokenByID(ctx context.Context, id string) (*models.Token, error) { func (s *TokenService) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
// Check if the ID is valid // Check if the ID is valid
if len(id) != TokenIDLength { if len(id) != models.IDLength {
return &models.Token{}, errs.ErrInvalidTokenID return &models.Token{}, errs.ErrInvalidTokenID
} }
@ -86,12 +85,10 @@ func (s *TokenService) ListTokens(ctx context.Context, user *models.User) ([]*mo
// CreateToken creates a new token for a user. // CreateToken creates a new token for a user.
func (s *TokenService) CreateToken(ctx context.Context, user *models.User) (*models.Token, error) { func (s *TokenService) CreateToken(ctx context.Context, user *models.User) (*models.Token, error) {
// Generate a new token // Generate a new token
id := shortutil.GenerateRandomShort(TokenIDLength)
token := &models.Token{ token := &models.Token{
ID: id, Name: fmt.Sprintf("%s's token", user.Username),
Name: fmt.Sprintf("%s's token #%s", user.Username, id[:5]), Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2),
Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2), UserID: &user.ID,
User: user,
} }
// Create the token in storage // Create the token in storage
@ -128,3 +125,20 @@ func (s *TokenService) ChangeTokenName(ctx context.Context, token *models.Token,
return newToken, nil return newToken, nil
} }
// FindTokenUserFromValue finds the user that owns a token.
func (s *TokenService) FindTokenUserFromValue(ctx context.Context, value string) (*models.User, error) {
// Get the token from storage
token, err := s.FindToken(ctx, value)
if err != nil {
return &models.User{}, errs.Errorf("could not get token from storage", err)
}
// Get the user from storage
user, err := s.db.FindUserByID(ctx, *token.UserID)
if err != nil {
return &models.User{}, errs.Errorf("could not get token user from storage", err)
}
return user, nil
}

View File

@ -131,7 +131,7 @@ func (s *UserService) SetPassword(ctx context.Context, user *models.User, newPas
func UsernameIsValid(username string) error { func UsernameIsValid(username string) error {
if !UsernameRegex.MatchString(username) { if !UsernameRegex.MatchString(username) {
return errs.Error(errs.ErrInvalidUser, fmt.Sprintf("username must match %s", UsernameRegex)) return errs.Errorf(fmt.Sprintf("username must match %s", UsernameRegex), errs.ErrInvalidUser)
} }
return nil return nil

View File

@ -0,0 +1,8 @@
package bunstorage
import "errors"
var (
// ErrDatabaseError is the error returned when a generic database error occurs.
ErrDatabaseError = errors.New("database error")
)

View File

@ -7,42 +7,11 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type ShortModel struct {
bun.BaseModel `bun:"table:shorts,alias:s" json:"-"`
ID int64 `bun:",pk,autoincrement" json:"id"`
// Name is the short's name, which is also the path to access it
Name string `bun:",unique,notnull" json:"name"`
// URL is the URL that the short will redirect to
URL string `bun:",notnull" json:"url"`
// Deleted is whether the short is soft-deleted or not
Deleted bool `bun:",notnull,default:false" json:"-"`
// CreatedAt is when the short was created (initialized by the storage)
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// DeletedAt is when the short was deleted
DeletedAt time.Time `bun:",null" json:"-"`
// UserID is the ID of the user that created the short
// This can be null if the short was deleted
UserID *int64 `json:"-"`
// User is the user that created the short
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
}
func (s *ShortModel) toShort() *models.Short {
return &models.Short{
Name: s.Name,
URL: s.URL,
CreatedAt: s.CreatedAt,
User: s.User.toUser(),
}
}
type UserModel struct { type UserModel struct {
bun.BaseModel `bun:"table:users,alias:u"` bun.BaseModel `bun:"table:users,alias:u"`
// ID is the primary key // ID is the primary key
ID int64 `bun:",pk,autoincrement" json:"id"` ID string `bun:",pk,unique,notnull" json:"id"`
// Username is the user's username // Username is the user's username
Username string `bun:",unique,notnull" json:"username"` Username string `bun:",unique,notnull" json:"username"`
// Password is the user's password // Password is the user's password
@ -58,27 +27,60 @@ type UserModel struct {
func (u *UserModel) toUser() *models.User { func (u *UserModel) toUser() *models.User {
return models.NewAuthenticatableUser(&models.User{ return models.NewAuthenticatableUser(&models.User{
ID: u.ID,
Username: u.Username, Username: u.Username,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
}, u.Password) }, u.Password)
} }
type ShortModel struct {
bun.BaseModel `bun:"table:shorts,alias:s" json:"-"`
ID string `bun:",pk,unique,notnull" json:"id"`
// Name is the short's name, which is also the path to access it
Name string `bun:",unique,notnull" json:"name"`
// URL is the URL that the short will redirect to
URL string `bun:",notnull" json:"url"`
// Deleted is whether the short is soft-deleted or not
Deleted bool `bun:",notnull,default:false" json:"-"`
// CreatedAt is when the short was created (initialized by the storage)
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// DeletedAt is when the short was deleted
DeletedAt time.Time `bun:",null" json:"-"`
// UserID is the ID of the user that created the short
// This can be null if the short was deleted
UserID *string `json:"-"`
// User is the user that created the short
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
}
func (s *ShortModel) toShort() *models.Short {
return &models.Short{
ID: s.ID,
Name: s.Name,
URL: s.URL,
CreatedAt: s.CreatedAt,
UserID: s.UserID,
}
}
type TokenModel struct { type TokenModel struct {
bun.BaseModel `bun:"table:tokens,alias:t"` bun.BaseModel `bun:"table:tokens,alias:t"`
// ID is the primary key // ID is the primary key
ID string `bun:",pk" json:"id"` ID string `bun:",pk,unique,notnull" json:"id"`
// Name is the user-friendly name of the token // Name is the user-friendly name of the token
Name string `bun:",notnull" json:"name"` Name string `bun:",notnull" json:"name"`
// Value is the actual token // Value is the actual token
Value string `bun:",unique,notnull" json:"value"` Value string `bun:",unique,notnull" json:"value"`
// UserID is the ID of the user that created the token
UserID *int64 `bun:",notnull" json:"-"`
// User is the user that created the token
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
// CreatedAt is when the token was created (initialized by the storage) // CreatedAt is when the token was created (initialized by the storage)
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"` CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// UserID is the ID of the user that created the token
UserID *string `bun:",notnull" json:"-"`
// User is the user that created the token
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
} }
func (t *TokenModel) toToken() *models.Token { func (t *TokenModel) toToken() *models.Token {
@ -87,6 +89,37 @@ func (t *TokenModel) toToken() *models.Token {
Name: t.Name, Name: t.Name,
Value: t.Value, Value: t.Value,
CreatedAt: t.CreatedAt, CreatedAt: t.CreatedAt,
User: t.User.toUser(), UserID: t.UserID,
}
}
type ShortLogModel struct {
bun.BaseModel `bun:"table:short_logs,alias:sl"`
// ID is the primary key
ID string `bun:",pk,unique,notnull" json:"id"`
// IPAddress is the IP address of the client that accessed the short
IPAddress string `bun:"," json:"ipAddress"`
// UserAgent is the User-Agent of the client that accessed the short
UserAgent string `bun:"," json:"userAgent"`
// Referer is the referer of the client that accessed the short
Referer string `bun:"," json:"referer"`
// CreatedAt is when the short was accessed
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
// ShortID is the ID of the short that was accessed
ShortID string `bun:",notnull" json:"-"`
// Short is the short that was accessed
Short *ShortModel `bun:"rel:belongs-to,join:short_id=id" json:"short,omitempty"`
}
func (sl *ShortLogModel) toShortLog() *models.ShortLog {
return &models.ShortLog{
ID: sl.ID,
IPAddress: sl.IPAddress,
UserAgent: sl.UserAgent,
Referer: sl.Referer,
CreatedAt: sl.CreatedAt,
ShortID: sl.ShortID,
} }
} }

View File

@ -3,12 +3,10 @@ package bunstorage
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"git.maronato.dev/maronato/goshort/internal/config" "git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage"
"git.maronato.dev/maronato/goshort/internal/storage/models" "git.maronato.dev/maronato/goshort/internal/storage/models"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -16,11 +14,14 @@ import (
) )
type BunStorage struct { type BunStorage struct {
storage.Storage
db *bun.DB db *bun.DB
// StartHooks are ran after the database is connected and the tables are created.
StartHooks []func(ctx context.Context, db *bun.DB) error
// StopHooks are ran before the database is closed.
StopHook []func(ctx context.Context, db *bun.DB) error
} }
func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage { func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
if cfg.Debug { if cfg.Debug {
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true))) db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
} }
@ -30,8 +31,18 @@ func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage {
} }
} }
// RegisterStartHook registers a hook to be ran after the database is connected and the tables are created.
func (s *BunStorage) RegisterStartHook(hook func(ctx context.Context, db *bun.DB) error) {
s.StartHooks = append(s.StartHooks, hook)
}
// RegisterStopHook registers a hook to be ran before the database is closed.
func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB) error) {
s.StopHook = append(s.StopHook, hook)
}
func (s *BunStorage) Start(ctx context.Context) error { func (s *BunStorage) Start(ctx context.Context) error {
return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
_, err := tx.NewCreateTable(). _, err := tx.NewCreateTable().
IfNotExists(). IfNotExists().
@ -39,7 +50,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
WithForeignKeys(). WithForeignKeys().
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create users table: %w", err) return errs.Errorf("failed to create users table", err)
} }
_, err = tx.NewCreateTable(). _, err = tx.NewCreateTable().
@ -48,7 +59,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`). ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create shorts table: %w", err) return errs.Errorf("failed to create shorts table", err)
} }
_, err = tx.NewCreateTable(). _, err = tx.NewCreateTable().
@ -56,9 +67,17 @@ func (s *BunStorage) Start(ctx context.Context) error {
Model((*TokenModel)(nil)). Model((*TokenModel)(nil)).
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`). ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create tokens table: %w", err) return errs.Errorf("failed to create tokens table", err)
}
_, err = tx.NewCreateTable().
IfNotExists().
Model((*ShortLogModel)(nil)).
ForeignKey(`("short_id") REFERENCES "shorts" ("id") ON DELETE CASCADE`).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short logs table", err)
} }
// shorts user_id index // shorts user_id index
@ -69,7 +88,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
Column("user_id"). Column("user_id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create shorts user_id index: %w", err) return errs.Errorf("failed to create shorts user_id index", err)
} }
// tokens user_id index // tokens user_id index
@ -80,14 +99,47 @@ func (s *BunStorage) Start(ctx context.Context) error {
Column("user_id"). Column("user_id").
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create tokens user_id index: %w", err) return errs.Errorf("failed to create tokens user_id index", err)
}
// short_logs short_id index
_, err = tx.NewCreateIndex().
IfNotExists().
Model((*ShortLogModel)(nil)).
Index("idx_short_logs_short_id").
Column("short_id").
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short_logs short_id index", err)
} }
return nil return nil
}) })
if err != nil {
return errs.Errorf("failed to start storage", err)
}
// Run start hooks
for _, hook := range s.StartHooks {
err := hook(ctx, s.db)
if err != nil {
return errs.Errorf("failed to start storage", err)
}
}
return nil
} }
func (s *BunStorage) Stop(ctx context.Context) error { func (s *BunStorage) Stop(ctx context.Context) error {
// Run stop hooks
for _, hook := range s.StopHook {
err := hook(ctx, s.db)
if err != nil {
return errs.Errorf("failed to stop storage", err)
}
}
return s.db.Close() return s.db.Close()
} }
@ -97,71 +149,95 @@ func (s *BunStorage) FindShort(ctx context.Context, name string) (*models.Short,
err := s.db.NewSelect(). err := s.db.NewSelect().
Model(shortModel). Model(shortModel).
Where("name = ? and deleted = false", name). Where("name = ? and deleted = false", name).
Relation("User").
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, errs.ErrShortDoesNotExist err = errs.ErrShortDoesNotExist
} }
return nil, fmt.Errorf("failed to find short: %w", err)
return nil, errs.Errorf("failed to find short", err)
}
return shortModel.toShort(), nil
}
func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
shortModel := new(ShortModel)
err := s.db.NewSelect().
Model(shortModel).
Where("id = ? and deleted = false", id).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
err = errs.ErrShortDoesNotExist
}
return nil, errs.Errorf("failed to find short", err)
} }
return shortModel.toShort(), nil return shortModel.toShort(), nil
} }
func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) { func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
// Get user from username // Generate an ID that does not exist
user, err := findUser(ctx, s.db, short.User.Username) tableName := s.db.NewSelect().
Model((*ShortModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil { if err != nil {
return nil, err return nil, errs.Errorf("failed to create short", err)
} }
shortModel := &ShortModel{ shortModel := &ShortModel{
ID: newID,
Name: short.Name, Name: short.Name,
URL: short.URL, URL: short.URL,
UserID: &user.ID, UserID: short.UserID,
User: user,
} }
_, err = s.db.NewInsert(). _, err = s.db.NewInsert().
Model(shortModel). Model(shortModel).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errs.ErrShortExists return nil, errs.Errorf("failed to create short", err)
} }
return shortModel.toShort(), nil return shortModel.toShort(), nil
} }
func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error { func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error {
_, err := withShortDeleteUpdates( return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
s.db.NewUpdate(). // Delete short logs
Model((*ShortModel)(nil)). _, err := tx.NewDelete().
Where("name = ?", short.Name), Model((*ShortLogModel)(nil)).
).Exec(ctx) Where("short_id = ?", short.ID).
if err != nil { Exec(ctx)
return fmt.Errorf("failed to delete short: %w", err) if err != nil {
} return errs.Errorf("failed to delete short", err)
}
return nil // Delete short
_, err = withShortDeleteUpdates(
tx.NewUpdate().
Model((*ShortModel)(nil)).
Where("id = ?", short.ID),
).Exec(ctx)
if err != nil {
return errs.Errorf("failed to delete short", err)
}
return nil
})
} }
func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) { func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
// Get user ID from username
userID, err := findUserIDFromUsername(ctx, s.db, user.Username)
if err != nil {
return nil, err
}
shortModels := []*ShortModel{} shortModels := []*ShortModel{}
err = s.db.NewSelect(). err := s.db.NewSelect().
Model(&shortModels). Model(&shortModels).
Where("user_id = ? and deleted = false", userID). Where("user_id = ? and deleted = false", user.ID).
Relation("User").
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list shorts: %w", err) return nil, errs.Errorf("failed to list shorts", err)
} }
shorts := []*models.Short{} shorts := []*models.Short{}
@ -171,6 +247,51 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode
return shorts, nil return shorts, nil
} }
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
// Create new ID
tableName := s.db.NewSelect().
Model((*ShortLogModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil {
return errs.Errorf("failed to create short log", err)
}
shortLogModel := &ShortLogModel{
ID: newID,
ShortID: shortLog.ShortID,
IPAddress: shortLog.IPAddress,
UserAgent: shortLog.UserAgent,
Referer: shortLog.Referer,
}
_, err = s.db.NewInsert().
Model(shortLogModel).
Exec(ctx)
if err != nil {
return errs.Errorf("failed to create short log", err)
}
return nil
}
func (s *BunStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
shortLogModels := []*ShortLogModel{}
err := s.db.NewSelect().
Model(&shortLogModels).
Where("short_id = ?", short.ID).
Scan(ctx)
if err != nil {
return nil, errs.Errorf("failed to list short logs", err)
}
shortLogs := []*models.ShortLog{}
for _, shortLogModel := range shortLogModels {
shortLogs = append(shortLogs, shortLogModel.toShortLog())
}
return shortLogs, nil
}
func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) { func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
user, err := findUser(ctx, s.db, username) user, err := findUser(ctx, s.db, username)
if err != nil { if err != nil {
@ -180,21 +301,39 @@ func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.Use
return user.toUser(), nil return user.toUser(), nil
} }
func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
user, err := findUserByID(ctx, s.db, id)
if err != nil {
return nil, err
}
return user.toUser(), nil
}
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) { func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
// Create a new ID
tableName := s.db.NewSelect().
Model((*UserModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil {
return nil, errs.Errorf("failed to create user", err)
}
userModel := &UserModel{ userModel := &UserModel{
ID: newID,
Username: user.Username, Username: user.Username,
Password: user.GetPasswordHash(), Password: user.GetPasswordHash(),
} }
_, err := s.db.NewInsert(). _, err = s.db.NewInsert().
Model(userModel). Model(userModel).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errs.ErrUserExists return nil, errs.Errorf("failed to create user", err)
} }
return userModel.toUser(), nil return userModel.toUser(), nil
} }
func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error { func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
@ -203,15 +342,15 @@ func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
// Delete user shorts // Delete user shorts
err := deleteUserShorts(ctx, tx, user) err := deleteUserShorts(ctx, tx, user)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete user: %w", err) return errs.Errorf("failed to delete user", err)
} }
// Delete user // Delete user
_, err = tx.NewDelete(). _, err = tx.NewDelete().
Model(user). Model(user).
Where("username = ?", user.Username). Where("id = ?", user.ID).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete user: %w", err) return errs.Errorf("failed to delete user", err)
} }
return nil return nil
@ -224,13 +363,13 @@ func (s *BunStorage) FindToken(ctx context.Context, value string) (*models.Token
err := s.db.NewSelect(). err := s.db.NewSelect().
Model(tokenModel). Model(tokenModel).
Where("value = ?", value). Where("value = ?", value).
Relation("User").
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, errs.ErrTokenDoesNotExist err = errs.ErrTokenDoesNotExist
} }
return nil, fmt.Errorf("failed to find token: %w", err)
return nil, errs.Errorf("failed to find token", err)
} }
return tokenModel.toToken(), nil return tokenModel.toToken(), nil
@ -242,14 +381,14 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke
err := s.db.NewSelect(). err := s.db.NewSelect().
Model(tokenModel). Model(tokenModel).
Where("t.id = ?", id). Where("id = ?", id).
Relation("User").
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, errs.ErrTokenDoesNotExist err = errs.ErrTokenDoesNotExist
} }
return nil, fmt.Errorf("failed to find token: %w", err)
return nil, errs.Errorf("failed to find token", err)
} }
return tokenModel.toToken(), nil return tokenModel.toToken(), nil
@ -259,19 +398,12 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke
func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) { func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
tokenModels := []*TokenModel{} tokenModels := []*TokenModel{}
// Get user ID from username err := s.db.NewSelect().
userID, err := findUserIDFromUsername(ctx, s.db, user.Username)
if err != nil {
return nil, err
}
err = s.db.NewSelect().
Model(&tokenModels). Model(&tokenModels).
Where("user_id = ?", userID). Where("user_id = ?", user.ID).
Relation("User").
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list tokens: %w", err) return nil, errs.Errorf("failed to list tokens", err)
} }
tokens := []*models.Token{} tokens := []*models.Token{}
@ -283,26 +415,27 @@ func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*mode
} }
func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) { func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
// Create a new ID
// Get user ID from username tableName := s.db.NewSelect().
user, err := findUser(ctx, s.db, token.User.Username) Model((*TokenModel)(nil)).
GetTableName()
newID, err := createNewID(ctx, s.db, tableName, "id")
if err != nil { if err != nil {
return nil, err return nil, errs.Errorf("failed to create token", err)
} }
tokenModel := &TokenModel{ tokenModel := &TokenModel{
ID: token.ID, ID: newID,
Name: token.Name, Name: token.Name,
Value: token.Value, Value: token.Value,
UserID: &user.ID, UserID: token.UserID,
User: user,
} }
_, err = s.db.NewInsert(). _, err = s.db.NewInsert().
Model(tokenModel). Model(tokenModel).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return nil, errs.ErrTokenExists return nil, errs.Errorf("failed to create token", err)
} }
return tokenModel.toToken(), nil return tokenModel.toToken(), nil
@ -314,12 +447,27 @@ func (s *BunStorage) DeleteToken(ctx context.Context, token *models.Token) error
Where("id = ?", token.ID). Where("id = ?", token.ID).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete token: %w", err) return errs.Errorf("failed to delete token", err)
} }
return nil return nil
} }
func (s *BunStorage) ChangeTokenName(ctx context.Context, token *models.Token, name string) (*models.Token, error) {
newToken := new(TokenModel)
_, err := s.db.NewUpdate().
Model((*TokenModel)(nil)).
Set("name = ?", name).
Where("id = ?", token.ID).
Exec(ctx, newToken)
if err != nil {
return nil, errs.Errorf("failed to change token name", err)
}
return newToken.toToken(), nil
}
func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) { func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) {
userModel := new(UserModel) userModel := new(UserModel)
@ -329,31 +477,31 @@ func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, err
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, errs.ErrUserDoesNotExist err = errs.ErrUserDoesNotExist
} }
return nil, fmt.Errorf("failed to find user: %w", err)
return nil, errs.Errorf("failed to find user", err)
} }
return userModel, nil return userModel, nil
} }
func findUserIDFromUsername(ctx context.Context, db bun.IDB, username string) (*int64, error) { func findUserByID(ctx context.Context, db bun.IDB, id string) (*UserModel, error) {
var userID int64 userModel := new(UserModel)
err := db.NewSelect(). err := db.NewSelect().
Table("users"). Model(userModel).
Column("id"). Where("id = ?", id).
Where("username = ?", username). Scan(ctx)
Scan(ctx, &userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, errs.ErrUserDoesNotExist err = errs.ErrUserDoesNotExist
} }
return nil, fmt.Errorf("failed to get user ID: %w", err) return nil, errs.Errorf("failed to find user", err)
} }
return &userID, nil return userModel, nil
} }
func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery { func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
@ -363,19 +511,46 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
} }
func deleteUserShorts(ctx context.Context, db bun.IDB, user *models.User) error { func deleteUserShorts(ctx context.Context, db bun.IDB, user *models.User) error {
userID, err := findUserIDFromUsername(ctx, db, user.Username) _, err := withShortDeleteUpdates(
if err != nil {
return fmt.Errorf("failed to delete user shorts: %w", err)
}
_, err = withShortDeleteUpdates(
db.NewUpdate(). db.NewUpdate().
Model((*ShortModel)(nil)). Model((*ShortModel)(nil)).
Where("user_id = ?", userID), Where("user_id = ?", user.ID),
).Exec(ctx) ).Exec(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete user shorts: %w", err) return errs.Errorf("failed to delete user shorts", err)
} }
return nil return nil
} }
func createNewID(ctx context.Context, db bun.IDB, table, col string) (string, error) {
var newID string
// Generate an ID that does not exist
maxIters := 10
for {
// Make sure we don't get stuck in an infinite loop
maxIters--
if maxIters <= 0 {
return "", errs.Errorf("failed to create unique ID", ErrDatabaseError)
}
// Create a new ID and check if it exists
newID = models.NewID()
count, err := db.NewSelect().
Table(table).
Column(col).
Where(col+" = ?", newID).
Count(ctx)
if err != nil {
return "", errs.Errorf("failed to create unique ID", err)
}
// If the ID does not exist, break the loop
if count == 0 {
break
}
}
return newID, nil
}

View File

@ -3,9 +3,11 @@ package memorystorage
import ( import (
"context" "context"
"fmt" "fmt"
"runtime"
"sync" "sync"
"time" "time"
"git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage"
"git.maronato.dev/maronato/goshort/internal/storage/models" "git.maronato.dev/maronato/goshort/internal/storage/models"
@ -13,87 +15,170 @@ import (
// MemoryStorage is a storage that stores everything in memory. // MemoryStorage is a storage that stores everything in memory.
type MemoryStorage struct { type MemoryStorage struct {
debug bool
storage.Storage storage.Storage
shortMu sync.RWMutex shortMu sync.RWMutex
userMu sync.RWMutex shortIDMu sync.RWMutex
tokenMu sync.RWMutex shortLogMu sync.RWMutex
shortMap map[string]*models.Short userMu sync.RWMutex
userMap map[string]*models.User userIDMu sync.RWMutex
tokenMap map[string]*models.Token tokenMu sync.RWMutex
tokenIDMap map[string]*models.Token tokenIDMu sync.RWMutex
shortMap map[string]*models.Short
shortIDMap map[string]*models.Short
shortLogMap map[string][]*models.ShortLog
userMap map[string]*models.User
userIDMap map[string]*models.User
tokenMap map[string]*models.Token
tokenIDMap map[string]*models.Token
} }
// NewMemoryStorage creates a new MemoryStorage. // NewMemoryStorage creates a new MemoryStorage.
func NewMemoryStorage() storage.Storage { func NewMemoryStorage(cfg *config.Config) *MemoryStorage {
return &MemoryStorage{ return &MemoryStorage{
shortMap: make(map[string]*models.Short), debug: cfg.Debug,
userMap: make(map[string]*models.User), shortMap: make(map[string]*models.Short),
tokenMap: make(map[string]*models.Token), shortIDMap: make(map[string]*models.Short),
tokenIDMap: make(map[string]*models.Token), userMap: make(map[string]*models.User),
userIDMap: make(map[string]*models.User),
tokenMap: make(map[string]*models.Token),
tokenIDMap: make(map[string]*models.Token),
shortLogMap: make(map[string][]*models.ShortLog),
}
}
// logPerformance is a helper function to log the performance of a function.
func (s *MemoryStorage) logPerformance() func() {
start := time.Now()
return func() {
elapsed := time.Since(start)
if s.debug {
pc, _, _, ok := runtime.Caller(1)
if !ok {
return
}
method := runtime.FuncForPC(pc).Name()
fmt.Printf("%s took %s\n", method, elapsed)
}
} }
} }
// Start starts the storage. // Start starts the storage.
func (s *MemoryStorage) Start(ctx context.Context) error { func (s *MemoryStorage) Start(ctx context.Context) error {
logPerf := s.logPerformance()
defer logPerf()
return nil return nil
} }
// Stop stops the storage. // Stop stops the storage.
func (s *MemoryStorage) Stop(ctx context.Context) error { func (s *MemoryStorage) Stop(ctx context.Context) error {
logPerf := s.logPerformance()
defer logPerf()
return nil return nil
} }
// FindShort finds a short in the storage. // FindShort finds a short in the storage.
func (s *MemoryStorage) FindShort(ctx context.Context, name string) (*models.Short, error) { func (s *MemoryStorage) FindShort(ctx context.Context, name string) (*models.Short, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortMu.RLock() s.shortMu.RLock()
defer s.shortMu.RUnlock() defer s.shortMu.RUnlock()
short, ok := s.shortMap[name] short, ok := s.shortMap[name]
if !ok { if !ok {
return short, errs.ErrShortDoesNotExist return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist)
}
return short, nil
}
// FindShortByID finds a short in the storage.
func (s *MemoryStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortIDMu.RLock()
defer s.shortIDMu.RUnlock()
short, ok := s.shortIDMap[id]
if !ok {
return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist)
} }
return short, nil return short, nil
} }
func (s *MemoryStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) { func (s *MemoryStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortMu.Lock() s.shortMu.Lock()
defer s.shortMu.Unlock() defer s.shortMu.Unlock()
s.shortIDMu.Lock()
defer s.shortIDMu.Unlock()
_, ok := s.shortMap[short.Name] _, ok := s.shortMap[short.Name]
if ok { if ok {
return &models.Short{}, errs.ErrShortExists return nil, errs.Errorf("failed to create short", errs.ErrShortExists)
} }
short.CreatedAt = time.Now() short.CreatedAt = time.Now()
// Create a new ID
for {
short.ID = models.NewID()
_, ok := s.shortIDMap[short.ID]
if !ok {
break
}
}
s.shortMap[short.Name] = short s.shortMap[short.Name] = short
s.shortIDMap[short.ID] = short
return short, nil return short, nil
} }
func (s *MemoryStorage) DeleteShort(ctx context.Context, short *models.Short) error { func (s *MemoryStorage) DeleteShort(ctx context.Context, short *models.Short) error {
logPerf := s.logPerformance()
defer logPerf()
s.shortMu.Lock() s.shortMu.Lock()
defer s.shortMu.Unlock() defer s.shortMu.Unlock()
s.shortIDMu.Lock()
defer s.shortIDMu.Unlock()
s.shortLogMu.Lock()
defer s.shortLogMu.Unlock()
_, ok := s.shortMap[short.Name] _, ok := s.shortIDMap[short.ID]
if !ok { if !ok {
return errs.ErrShortDoesNotExist return errs.Errorf("failed to delete short", errs.ErrShortDoesNotExist)
} }
delete(s.shortMap, short.Name) delete(s.shortMap, short.Name)
delete(s.shortIDMap, short.ID)
delete(s.shortLogMap, short.ID)
return nil return nil
} }
func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) { func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
s.shortMu.RLock() logPerf := s.logPerformance()
defer s.shortMu.RUnlock() defer logPerf()
s.shortIDMu.RLock()
defer s.shortIDMu.RUnlock()
shorts := []*models.Short{} shorts := []*models.Short{}
for _, short := range s.shortMap { for _, short := range s.shortIDMap {
if short.User != nil && short.User.Username == user.Username { if short.UserID != nil && *short.UserID == user.ID {
shorts = append(shorts, short) shorts = append(shorts, short)
} }
} }
@ -101,108 +186,203 @@ func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*m
return shorts, nil return shorts, nil
} }
func (s *MemoryStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
logPerf := s.logPerformance()
defer logPerf()
s.shortLogMu.Lock()
defer s.shortLogMu.Unlock()
// Check if short exists
short, err := s.FindShortByID(ctx, shortLog.ShortID)
if err != nil {
return fmt.Errorf("could not create short log: %w", err)
}
// Create
_, ok := s.shortLogMap[short.ID]
if !ok {
s.shortLogMap[shortLog.ShortID] = []*models.ShortLog{}
}
shortLog.CreatedAt = time.Now()
// Create a new ID (we can't really know if it's unique, but it's unlikely to happen)
shortLog.ID = models.NewID()
s.shortLogMap[shortLog.ShortID] = append(s.shortLogMap[shortLog.ShortID], shortLog)
return nil
}
func (s *MemoryStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
logPerf := s.logPerformance()
defer logPerf()
s.shortLogMu.RLock()
defer s.shortLogMu.RUnlock()
shortLogs, ok := s.shortLogMap[short.ID]
if !ok {
return []*models.ShortLog{}, nil
}
return shortLogs, nil
}
func (s *MemoryStorage) FindUser(ctx context.Context, username string) (*models.User, error) { func (s *MemoryStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
logPerf := s.logPerformance()
defer logPerf()
s.userMu.RLock() s.userMu.RLock()
defer s.userMu.RUnlock() defer s.userMu.RUnlock()
user, ok := s.userMap[username] user, ok := s.userMap[username]
if !ok { if !ok {
return user, errs.ErrUserDoesNotExist return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist)
}
return user, nil
}
func (s *MemoryStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
logPerf := s.logPerformance()
defer logPerf()
s.userIDMu.RLock()
defer s.userIDMu.RUnlock()
user, ok := s.userIDMap[id]
if !ok {
return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist)
} }
return user, nil return user, nil
} }
func (s *MemoryStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) { func (s *MemoryStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
logPerf := s.logPerformance()
defer logPerf()
s.userMu.Lock() s.userMu.Lock()
defer s.userMu.Unlock() defer s.userMu.Unlock()
s.userIDMu.Lock()
defer s.userIDMu.Unlock()
_, ok := s.userMap[user.Username] _, ok := s.userMap[user.Username]
if ok { if ok {
return &models.User{}, errs.ErrUserExists return nil, errs.Errorf("failed to create user", errs.ErrUserExists)
} }
user.CreatedAt = time.Now() user.CreatedAt = time.Now()
// Create a new ID
for {
user.ID = models.NewID()
_, ok := s.userIDMap[user.ID]
if !ok {
break
}
}
s.userMap[user.Username] = user s.userMap[user.Username] = user
s.userIDMap[user.ID] = user
return user, nil return user, nil
} }
func (s *MemoryStorage) DeleteUser(ctx context.Context, user *models.User) error { func (s *MemoryStorage) DeleteUser(ctx context.Context, user *models.User) error {
logPerf := s.logPerformance()
defer logPerf()
s.userMu.Lock() s.userMu.Lock()
defer s.userMu.Unlock() defer s.userMu.Unlock()
s.userIDMu.Lock()
defer s.userIDMu.Unlock()
errormsg := "failed to delete user"
_, ok := s.userMap[user.Username] _, ok := s.userMap[user.Username]
if !ok { if !ok {
return errs.ErrUserDoesNotExist return errs.Errorf(errormsg, errs.ErrUserDoesNotExist)
} }
// Find all user shorts // Find all user shorts
shorts, err := s.ListShorts(ctx, user) shorts, err := s.ListShorts(ctx, user)
if err != nil { if err != nil {
return fmt.Errorf("could not list user shorts: %w", err) return errs.Errorf(errormsg, errs.Errorf("could not list user shorts", err))
} }
// Delete all user shorts // Delete all user shorts
for _, short := range shorts { for _, short := range shorts {
err = s.DeleteShort(ctx, short) err = s.DeleteShort(ctx, short)
if err != nil { if err != nil {
return fmt.Errorf("could not delete user short: %w", err) return errs.Errorf(errormsg, errs.Errorf("could not delete user short", err))
} }
} }
// Find all user tokens // Find all user tokens
tokens, err := s.ListTokens(ctx, user) tokens, err := s.ListTokens(ctx, user)
if err != nil { if err != nil {
return fmt.Errorf("could not list user tokens: %w", err) return errs.Errorf(errormsg, errs.Errorf("could not list user tokens", err))
} }
// Delete all user tokens // Delete all user tokens
for _, token := range tokens { for _, token := range tokens {
err = s.DeleteToken(ctx, token) err = s.DeleteToken(ctx, token)
if err != nil { if err != nil {
return fmt.Errorf("could not delete user token: %w", err) return errs.Errorf(errormsg, errs.Errorf("could not delete user token", err))
} }
} }
delete(s.userMap, user.Username) delete(s.userMap, user.Username)
delete(s.userIDMap, user.ID)
return nil return nil
} }
func (s *MemoryStorage) FindToken(ctx context.Context, value string) (*models.Token, error) { func (s *MemoryStorage) FindToken(ctx context.Context, value string) (*models.Token, error) {
logPerf := s.logPerformance()
defer logPerf()
s.tokenMu.RLock() s.tokenMu.RLock()
defer s.tokenMu.RUnlock() defer s.tokenMu.RUnlock()
token, ok := s.tokenMap[value] token, ok := s.tokenMap[value]
if !ok { if !ok {
return token, errs.ErrTokenDoesNotExist return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist)
} }
return token, nil return token, nil
} }
func (s *MemoryStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) { func (s *MemoryStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
s.tokenMu.RLock() logPerf := s.logPerformance()
defer s.tokenMu.RUnlock() defer logPerf()
s.tokenIDMu.RLock()
defer s.tokenIDMu.RUnlock()
token, ok := s.tokenIDMap[id] token, ok := s.tokenIDMap[id]
if !ok { if !ok {
return token, errs.ErrTokenDoesNotExist return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist)
} }
return token, nil return token, nil
} }
func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) { func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
s.tokenMu.RLock() logPerf := s.logPerformance()
defer s.tokenMu.RUnlock() defer logPerf()
s.tokenIDMu.RLock()
defer s.tokenIDMu.RUnlock()
tokens := []*models.Token{} tokens := []*models.Token{}
for _, token := range s.tokenMap { for _, token := range s.tokenIDMap {
if token.User != nil && token.User.Username == user.Username { if token.UserID != nil && *token.UserID == user.ID {
tokens = append(tokens, token) tokens = append(tokens, token)
} }
} }
@ -211,20 +391,35 @@ func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*m
} }
func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) { func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
logPerf := s.logPerformance()
defer logPerf()
s.tokenMu.Lock() s.tokenMu.Lock()
defer s.tokenMu.Unlock() defer s.tokenMu.Unlock()
s.tokenIDMu.Lock()
defer s.tokenIDMu.Unlock()
_, ok := s.tokenMap[token.Value] _, ok := s.tokenMap[token.Value]
if ok { if ok {
return &models.Token{}, errs.ErrTokenExists return nil, errs.Errorf("failed to create token", errs.ErrTokenExists)
} }
_, ok = s.tokenIDMap[token.ID] _, ok = s.tokenIDMap[token.ID]
if ok { if ok {
return &models.Token{}, errs.ErrTokenExists return nil, errs.Errorf("failed to create token", errs.ErrTokenExists)
} }
token.CreatedAt = time.Now() token.CreatedAt = time.Now()
// Create a new ID
for {
token.ID = models.NewID()
_, ok := s.tokenIDMap[token.ID]
if !ok {
break
}
}
s.tokenMap[token.Value] = token s.tokenMap[token.Value] = token
s.tokenIDMap[token.ID] = token s.tokenIDMap[token.ID] = token
@ -232,16 +427,21 @@ func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*
} }
func (s *MemoryStorage) DeleteToken(ctx context.Context, token *models.Token) error { func (s *MemoryStorage) DeleteToken(ctx context.Context, token *models.Token) error {
logPerf := s.logPerformance()
defer logPerf()
s.tokenMu.Lock() s.tokenMu.Lock()
defer s.tokenMu.Unlock() defer s.tokenMu.Unlock()
s.tokenIDMu.Lock()
defer s.tokenIDMu.Unlock()
_, ok := s.tokenMap[token.Value] _, ok := s.tokenMap[token.Value]
if !ok { if !ok {
return errs.ErrTokenDoesNotExist return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist)
} }
_, ok = s.tokenIDMap[token.ID] _, ok = s.tokenIDMap[token.ID]
if !ok { if !ok {
return errs.ErrTokenDoesNotExist return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist)
} }
delete(s.tokenMap, token.Value) delete(s.tokenMap, token.Value)

View File

@ -0,0 +1,12 @@
package models
import tokenutil "git.maronato.dev/maronato/goshort/internal/util/token"
const (
// IDLength is the length of IDs.
IDLength = 16
)
func NewID() string {
return tokenutil.GenerateSecureToken(IDLength / 2)
}

View File

@ -2,14 +2,20 @@ package models
import "time" import "time"
const (
ShortIDLength = 16
)
type Short struct { type Short struct {
// ID is the unique identifier of the short.
ID string `json:"id"`
// Name is the shortened name of the URL. // Name is the shortened name of the URL.
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// URL is the URL that is shortened. // URL is the URL that is shortened.
URL string `json:"url"` URL string `json:"url"`
// CreatedAt is the time the short was created. // CreatedAt is the time the short was created.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"createdAt,omitempty"`
// User is the user that created the short. // UserID of the user that created the short.
User *User `json:"-"` UserID *string `json:"-"`
} }

View File

@ -0,0 +1,18 @@
package models
import "time"
type ShortLog struct {
// ID is the unique identifier of the short log.
ID string `json:"id"`
// ShortID is the ID of the short that was accessed.
ShortID string `json:"shortID"`
// IPAddress is the IP address of the client that accessed the short.
IPAddress string `json:"ipAddress"`
// UserAgent is the User-Agent of the client that accessed the short.
UserAgent string `json:"userAgent"`
// Referer is the referer of the client that accessed the short.
Referer string `json:"referer"`
// CreatedAt is the time the short was accessed.
CreatedAt time.Time `json:"createdAt,omitempty"`
}

View File

@ -12,6 +12,6 @@ type Token struct {
// CreatedAt is when the token was created (initialized by the storage) // CreatedAt is when the token was created (initialized by the storage)
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
// User is the user that created the token. // UserID of the user that created the token.
User *User `json:"-"` UserID *string `json:"-"`
} }

View File

@ -8,11 +8,15 @@ import (
) )
type User struct { type User struct {
// ID is the user's ID.
ID string `json:"id"`
// Username is the user's username.
Username string `json:"username"` Username string `json:"username"`
// password is the user's password.
password string `json:"-"` password string `json:"-"`
// CreatedAt is the time the user was created. // CreatedAt is the time the user was created.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"createdAt,omitempty"`
} }
// NewAuthenticatableUser is a elper function for storages that takes // NewAuthenticatableUser is a elper function for storages that takes
@ -20,6 +24,7 @@ type User struct {
func NewAuthenticatableUser(user *User, hashedPass string) *User { func NewAuthenticatableUser(user *User, hashedPass string) *User {
return &User{ return &User{
ID: user.ID,
Username: user.Username, Username: user.Username,
CreatedAt: user.CreatedAt, CreatedAt: user.CreatedAt,
password: hashedPass, password: hashedPass,

View File

@ -1,9 +1,12 @@
package sqlitestorage package sqlitestorage
import ( import (
"context"
"database/sql" "database/sql"
"time"
"git.maronato.dev/maronato/goshort/internal/config" "git.maronato.dev/maronato/goshort/internal/config"
"git.maronato.dev/maronato/goshort/internal/errs"
"git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage"
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun" bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -13,13 +16,74 @@ import (
// NewSQLiteStorage creates a new SQLite storage. // NewSQLiteStorage creates a new SQLite storage.
func NewSQLiteStorage(cfg *config.Config) storage.Storage { func NewSQLiteStorage(cfg *config.Config) storage.Storage {
sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL) // Create a new SQLite database with the following pragmas enabled:
// - journal_mode=WAL: Enables Write-Ahead Logging, which allows for concurrent reads and writes.
// - foreign_keys=ON: Enables foreign key constraints.
// - synchronous=NORMAL: Enables synchronous mode NORMAL
sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL+"?_pragma=foreign_keys(ON)&_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)")
if err != nil { if err != nil {
panic(err) panic(err)
} }
// If running the DB in memory, make sure
// database/sql does not close idle connections.
// Otherwise, the database will be lost.
if cfg.DBType == config.DBTypeMemory {
sqldb.SetMaxIdleConns(1000)
sqldb.SetConnMaxLifetime(0)
}
db := bun.NewDB(sqldb, sqlitedialect.New()) db := bun.NewDB(sqldb, sqlitedialect.New())
// Use Bun as the storage implementation. // Use Bun as the storage implementation.
return bunstorage.NewBunStorage(cfg, db) storage := bunstorage.NewBunStorage(cfg, db)
// Config vacuuming of the database every hour.
ticker := time.NewTicker(time.Hour)
tickerCtx, tickerCancel := context.WithCancel(context.Background())
// Start vacuuming the database when the storage is started.
storage.RegisterStartHook(func(ctx context.Context, activeDB *bun.DB) error {
// Vacuum now
_, err := activeDB.NewRaw("VACUUM").Exec(ctx)
if err != nil {
return errs.Errorf("failed to vacuum database: %w", err)
}
// Vacuum every hour, until the context is canceled.
go func() {
for {
select {
case <-ticker.C:
_, err := activeDB.NewRaw("VACUUM").Exec(ctx)
if err != nil {
panic(errs.Errorf("failed to vacuum database: %w", err))
}
case <-ctx.Done():
return
case <-tickerCtx.Done():
return
}
}
}()
return nil
})
// So some cleaning up and optimization when the storage is stopped.
storage.RegisterStopHook(func(ctx context.Context, activeDB *bun.DB) error {
// Stop the ticker and cancel the ticker context.
ticker.Stop()
tickerCancel()
// Run PRAGMA optimize to optimize the database.
// By this point, the passed context is likely cancelled, so we use the
// background context since otherwise the query will fail.
// This will delay the shutdown of the storage, but it'll be for a good cause.
// The user can force the shutdown by using SIGINT or SIGTERM.
_, err := activeDB.NewRaw("PRAGMA optimize").Exec(context.Background())
if err != nil {
return errs.Errorf("failed to optimize database: %w", err)
}
return nil
})
return storage
} }

View File

@ -14,6 +14,8 @@ type Storage interface {
// FindShort finds a short in the storage using its name. // FindShort finds a short in the storage using its name.
FindShort(ctx context.Context, name string) (*models.Short, error) FindShort(ctx context.Context, name string) (*models.Short, error)
// FindShortByID finds a short in the storage using its ID.
FindShortByID(ctx context.Context, id string) (*models.Short, error)
// FindShorts finds all shorts in the storage that belong to a user. // FindShorts finds all shorts in the storage that belong to a user.
ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error)
// CreateShort creates a short in the storage. // CreateShort creates a short in the storage.
@ -21,10 +23,19 @@ type Storage interface {
// DeleteShort deletes a short from the storage. // DeleteShort deletes a short from the storage.
DeleteShort(ctx context.Context, short *models.Short) error DeleteShort(ctx context.Context, short *models.Short) error
// ShortLog Storage
// CreateShortLog creates a short log in the storage.
CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error
// ListShortLogs finds all short logs in the storage that belong to a short.
ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error)
// User Storage // User Storage
// FindUser finds a user in the storage using its username. // FindUser finds a user in the storage using its username.
FindUser(ctx context.Context, username string) (*models.User, error) FindUser(ctx context.Context, username string) (*models.User, error)
// FindUserByID finds a user in the storage using its ID.
FindUserByID(ctx context.Context, id string) (*models.User, error)
// CreateUser creates a user in the storage. // CreateUser creates a user in the storage.
CreateUser(ctx context.Context, user *models.User) (*models.User, error) CreateUser(ctx context.Context, user *models.User) (*models.User, error)
// DeleteUser deletes a user and all their shorts from the storage. // DeleteUser deletes a user and all their shorts from the storage.