From 16bb373c60986c11f7949424a56251e5bb29306f Mon Sep 17 00:00:00 2001 From: Gustavo Maronato Date: Wed, 23 Aug 2023 19:30:12 -0300 Subject: [PATCH] short logs and better database access --- .dockerignore | 2 + .gitignore | 4 +- cmd/shared/shared.go | 4 +- frontend/src/components/ItemList.tsx | 22 +- frontend/src/components/ShortItem.tsx | 119 ++++++++- frontend/src/hooks/useLoaderItems.tsx | 3 +- frontend/src/hooks/useStats.ts | 168 ++++++++++++ frontend/src/pages/Index.tsx | 12 +- frontend/src/pages/ShortDetails.tsx | 2 +- frontend/src/pages/Shorts.tsx | 19 +- frontend/src/router.tsx | 2 +- frontend/src/types.ts | 14 +- internal/config/config.go | 6 +- internal/errs/errors.go | 6 +- internal/server/api/handler.go | 34 ++- internal/server/api/router.go | 5 +- internal/server/middleware/auth.go | 29 +- internal/server/server.go | 3 +- internal/server/short/handler.go | 9 +- internal/service/short/shortservice.go | 57 +++- internal/service/token/tokenservice.go | 28 +- internal/service/user/userservice.go | 2 +- internal/storage/bun/errors.go | 8 + internal/storage/bun/models.go | 111 +++++--- internal/storage/bun/storage.go | 357 ++++++++++++++++++------- internal/storage/memory/memory.go | 278 ++++++++++++++++--- internal/storage/models/shared.go | 12 + internal/storage/models/short.go | 12 +- internal/storage/models/shortlog.go | 18 ++ internal/storage/models/token.go | 4 +- internal/storage/models/user.go | 7 +- internal/storage/sqlite/sqlite.go | 68 ++++- internal/storage/storage.go | 11 + 33 files changed, 1167 insertions(+), 269 deletions(-) create mode 100644 frontend/src/hooks/useStats.ts create mode 100644 internal/storage/bun/errors.go create mode 100644 internal/storage/models/shared.go create mode 100644 internal/storage/models/shortlog.go diff --git a/.dockerignore b/.dockerignore index 7b8df1d..3abf069 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,5 +1,7 @@ # ---> DB *.db +*.db-shm +*.db-wal # ---> Go # If you prefer the allow list template instead of the deny list, see community template: diff --git a/.gitignore b/.gitignore index adee2d7..0e4390f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # ---> DB *.db +*.db-shm +*.db-wal # ---> Go # If you prefer the allow list template instead of the deny list, see community template: @@ -157,4 +159,4 @@ dist .yarn/unplugged .yarn/build-state.yml .yarn/install-state.gz -.pnp.* \ No newline at end of file +.pnp.* diff --git a/cmd/shared/shared.go b/cmd/shared/shared.go index 3424e70..e91c74f 100644 --- a/cmd/shared/shared.go +++ b/cmd/shared/shared.go @@ -9,7 +9,6 @@ import ( "git.maronato.dev/maronato/goshort/internal/config" "git.maronato.dev/maronato/goshort/internal/storage" - memorystorage "git.maronato.dev/maronato/goshort/internal/storage/memory" sqlitestorage "git.maronato.dev/maronato/goshort/internal/storage/sqlite" "github.com/peterbourgon/ff/v3" ) @@ -60,7 +59,8 @@ func RegisterServerFlags(fs *flag.FlagSet, cfg *config.Config) { func InitStorage(cfg *config.Config) storage.Storage { switch cfg.DBType { case config.DBTypeMemory: - return memorystorage.NewMemoryStorage() + cfg.DBURL = ":memory:" + return sqlitestorage.NewSQLiteStorage(cfg) case config.DBTypeSQLite: return sqlitestorage.NewSQLiteStorage(cfg) default: diff --git a/frontend/src/components/ItemList.tsx b/frontend/src/components/ItemList.tsx index 6ab6f53..2c1631c 100644 --- a/frontend/src/components/ItemList.tsx +++ b/frontend/src/components/ItemList.tsx @@ -10,14 +10,22 @@ const ItemList = , K extends keyof T>({ idKey: K Item: FunctionComponent }) => { + console.log("list") return ( -
    - {items.map((item) => ( - - ))} -
+ <> +
    + {items.map((item) => ( + + ))} +
+ {items.length === 0 && ( +
+ Nothing here yet +
+ )} + ) } diff --git a/frontend/src/components/ShortItem.tsx b/frontend/src/components/ShortItem.tsx index 51aa2f7..8e8c97e 100644 --- a/frontend/src/components/ShortItem.tsx +++ b/frontend/src/components/ShortItem.tsx @@ -1,4 +1,4 @@ -import { FunctionComponent, useCallback } from "react" +import { FunctionComponent, memo, useCallback } from "react" import { ArrowRightIcon, @@ -6,6 +6,7 @@ import { } from "@heroicons/react/24/outline" import { useDelete } from "../hooks/useCRUD" +import { useShortLogs, useVisitMetrics } from "../hooks/useStats" import { Short } from "../types" import ItemBase from "./ItemBase" @@ -24,17 +25,14 @@ const ShortItem: FunctionComponent = ({ ...short }) => { // Handle deletion const [deleting, del] = useDelete() - const doDelete = useCallback( - () => del({ name: short.name }), - [del, short.name] - ) + const doDelete = useCallback(() => del({ id: short.id }), [del, short.id]) return ( + detailsPage={`/sht/${short.id}`}>
-

- 1000 views -

-

- Last viewed -

+
) } -export default ShortItem +export default memo(ShortItem, (prev, next) => prev.id === next.id) + +const ShortItemStats: FunctionComponent> = ({ id }) => { + const { loading, error, logs } = useShortLogs(id) + const { lastVisit, visits } = useVisitMetrics(logs) + + if (error) { + return ( +

+ Failed to load stats +

+ ) + } + + const lastViewedISO = lastVisit + ? new Date(lastVisit).toISOString() + : undefined + const lastViewedRelative = getRelativeTimeString(lastVisit) + + return ( + <> +

+ {loading && "Loading..."} + {loading || ( + <> + {visits} views + + )} +

+

+ {loading && "Loading..."} + {loading || ( + <> + Last viewed: + + + + + )} +

+ + ) +} + +const ShortItemStatsMemo = memo( + ShortItemStats, + (prev, next) => prev.id === next.id +) + +/** + * Convert a date to a relative time string, such as + * "a minute ago", "in 2 hours", "yesterday", "3 months ago", etc. + * using Intl.RelativeTimeFormat + * https://gist.github.com/LewisJEllis + */ +function getRelativeTimeString(date: Date | number | null): string { + if (!date) return "never" + // Allow dates or times to be passed + const timeMs = typeof date === "number" ? date : date.getTime() + + // Get the amount of seconds between the given date and now + const deltaSeconds = Math.round((timeMs - Date.now()) / 1000) + + // Array reprsenting one minute, hour, day, week, month, etc in seconds + const cutoffs = [ + 60, + 3600, + 86400, + 86400 * 7, + 86400 * 30, + 86400 * 365, + Infinity, + ] + + // Array equivalent to the above but in the string representation of the units + const units: Intl.RelativeTimeFormatUnit[] = [ + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + ] + + // Grab the ideal cutoff unit + const unitIndex = cutoffs.findIndex( + (cutoff) => cutoff > Math.abs(deltaSeconds) + ) + + // Get the divisor to divide from the seconds. E.g. if our unit is "day" our divisor + // is one day in seconds, so we can divide our seconds by this to get the # of days + const divisor = unitIndex ? cutoffs[unitIndex - 1] : 1 + + // Intl.RelativeTimeFormat do its magic + const rtf = new Intl.RelativeTimeFormat("en", { + numeric: "auto", + style: "narrow", + }) + return rtf.format(Math.floor(deltaSeconds / divisor), units[unitIndex]) +} diff --git a/frontend/src/hooks/useLoaderItems.tsx b/frontend/src/hooks/useLoaderItems.tsx index 4481180..e1e8e4c 100644 --- a/frontend/src/hooks/useLoaderItems.tsx +++ b/frontend/src/hooks/useLoaderItems.tsx @@ -6,7 +6,8 @@ type Item = Record export const useLoadedItems = () => { const sourceDefault = useMemo(() => [], []) - const items = (useLoaderData() ?? sourceDefault) as T[] + const data = useLoaderData() ?? sourceDefault + const items = useMemo(() => (Array.isArray(data) ? data : []), [data]) return items } diff --git a/frontend/src/hooks/useStats.ts b/frontend/src/hooks/useStats.ts new file mode 100644 index 0000000..d2a03e2 --- /dev/null +++ b/frontend/src/hooks/useStats.ts @@ -0,0 +1,168 @@ +import { useEffect, useMemo, useState } from "react" + +import { ShortLog } from "../types" +import fetchAPI from "../util/fetchAPI" + +export const useShortLogs = (shortID: string) => { + const [loading, setLoading] = useState(false) + const [error, setError] = useState("") + const [logs, setLogs] = useState(useMemo(() => [], [])) + + // Fetch logs on mount + useEffect(() => { + const fetchLogs = async () => { + // Reset logs and error + setLogs([]) + setError("") + // Fetch logs + const resp = await fetchAPI(`/shorts/${shortID}/logs`) + if (resp.ok) { + // If response is ok, set logs + if (resp.data.length > 0) { + // Sort logs by date, descending + resp.data.sort((a, b) => { + return b.createdAt.localeCompare(a.createdAt) + }) + // Set sorted logs + setLogs(resp.data) + } + } else { + // If response is not ok, set error + setError(resp.error) + } + // Now we're done loading + setLoading(false) + } + + if (shortID) { + // If we have an ID, fetch logs + setLoading(true) + fetchLogs() + } + }, [shortID]) + + return { loading, error, logs } +} + +const singleFilterOptions = ["equal", "greater", "lower", "notEqual"] as const +export type ShortLogSingleFilterOptions = (typeof singleFilterOptions)[number] + +const listFilterOptions = ["in", "notIn"] as const +export type ShortLogListFilterOptions = (typeof listFilterOptions)[number] + +type ShortLogKeys = keyof ShortLog + +type SingleFilter = { + [key in ShortLogSingleFilterOptions]?: ShortLog[T] | null +} +type ListFilter = { + [key in ShortLogListFilterOptions]?: (ShortLog[T] | null)[] +} + +export type ShortLogFilter = { + [key in ShortLogKeys]?: SingleFilter | ListFilter +} + +export const useFilteredShortLogs = ( + logs: ShortLog[], + filterList: ShortLogFilter[] +) => { + return useMemo(() => { + const start = performance.now() + // If there are no filters, return all logs + if (filterList.length === 0) return logs + // For each log + const result = logs.filter((log) => { + // If it matches any of the filters, return true + return filterList.some((filterSet) => { + // For each key in the filterSet + return Object.entries(filterSet).every(([key, filterMap]) => { + // If the key is not in the log, return false + if (!(key in log)) return false + // For each filter in the filterMap + return Object.entries(filterMap).every( + ([filterType, filterValue]) => { + const logValue = log[key as keyof ShortLog] + // If filterType is single + if ( + singleFilterOptions.includes( + filterType as unknown as ShortLogSingleFilterOptions + ) + ) { + switch (filterType) { + case "equal": + return logValue === filterValue + case "greater": + return logValue > filterValue + case "lower": + return logValue < filterValue + case "notEqual": + return logValue !== filterValue + } + } else if ( + listFilterOptions.includes( + filterType as unknown as ShortLogListFilterOptions + ) + ) { + // If filterType is list + switch (filterType) { + case "in": + return filterValue.includes(logValue) + case "notIn": + return !filterValue.includes(logValue) + } + } + // If we get here, + // the filterType is not valid + return false + } + ) + }) + }) + }) + console.debug( + `Filtered ${logs.length} logs in ${performance.now() - start}ms` + ) + + return result + }, [logs, filterList]) +} + +export const useVisitMetrics = (logs: ShortLog[]) => { + const [visits, setVisits] = useState(0) + const [uniqueVisits, setUniqueVisits] = useState(0) + const [lastVisit, setLastVisit] = useState(null) + + useEffect(() => { + // Reset visits + setUniqueVisits(0) + setLastVisit(null) + + // Set number of visits + setVisits(logs.length) + + // Skip work if there are no logs + if (logs.length === 0) return + + // Keep track of unique IPs + const uniqueIPs = new Set() + + let last = logs[0].createdAt + logs.forEach((log) => { + // If the log is older, update + if (log.createdAt.localeCompare(last) < 0) { + last = log.createdAt + } + // Add IP to set + uniqueIPs.add(log.ipAddress) + }) + + // Set unique visits + setUniqueVisits(uniqueIPs.size) + + // Set last visit + setLastVisit(new Date(last)) + }, [logs]) + + return { visits, uniqueVisits, lastVisit } +} diff --git a/frontend/src/pages/Index.tsx b/frontend/src/pages/Index.tsx index 9ce9c75..14d03c9 100644 --- a/frontend/src/pages/Index.tsx +++ b/frontend/src/pages/Index.tsx @@ -54,10 +54,10 @@ const useSubmitActions = ({ useEffect(() => { if (onDelete && formData && actionData && actionData.ok) { const data = actionData.data - const name = formData.get("name") - if (typeof data === "string" && name && typeof name === "string") { + const id = formData.get("id") + if (typeof data === "string" && id && typeof id === "string") { // Deleting - onDelete(name) + onDelete(id) } } }, [formData, actionData, onDelete]) @@ -136,8 +136,8 @@ const useRecentShorts = () => { onCreate: useCallback((short: Short) => { setRecentShorts((prev) => [short, ...prev]) }, []), - onDelete: useCallback((name: string) => { - setRecentShorts((prev) => prev.filter((short) => short.name !== name)) + onDelete: useCallback((id: string) => { + setRecentShorts((prev) => prev.filter((short) => short.id !== id)) }, []), }) @@ -213,7 +213,7 @@ export const Component: FunctionComponent = () => {
- +
) diff --git a/frontend/src/pages/ShortDetails.tsx b/frontend/src/pages/ShortDetails.tsx index 125c39d..680edcc 100644 --- a/frontend/src/pages/ShortDetails.tsx +++ b/frontend/src/pages/ShortDetails.tsx @@ -22,7 +22,7 @@ export const loader: LoaderFunction = async (args) => { const resp = await protectedLoader(args) if (resp) return resp - const data = await fetchAPI(`/shorts/${args.params.name}`) + const data = await fetchAPI(`/shorts/${args.params.id}`) if (data.ok) { return data.data } diff --git a/frontend/src/pages/Shorts.tsx b/frontend/src/pages/Shorts.tsx index aee8ce2..19ed33b 100644 --- a/frontend/src/pages/Shorts.tsx +++ b/frontend/src/pages/Shorts.tsx @@ -1,4 +1,4 @@ -import { FunctionComponent, useCallback } from "react" +import { useCallback } from "react" import { LoaderFunction, redirect } from "react-router-dom" @@ -16,19 +16,10 @@ export function Component() { useCallback((a, b) => a.name.localeCompare(b.name), []) ) - const Shorts: FunctionComponent = () => { - return - } - const NoShorts = () => { - return ( -
No shorts yet
- ) - } - return ( <>
- {items.length > 0 ? : } + ) } @@ -59,13 +50,13 @@ export const action = crudAction({ }) }, DELETE: async (formData) => { - const name = formData.get("name") as string | null + const id = formData.get("id") as string | null - if (!name) { + if (!id) { return { ok: false, error: "Invalid request" } } - return fetchAPI(`/shorts/${name}`, { + return fetchAPI(`/shorts/${id}`, { method: "DELETE", }) }, diff --git a/frontend/src/router.tsx b/frontend/src/router.tsx index b249d89..3345c24 100644 --- a/frontend/src/router.tsx +++ b/frontend/src/router.tsx @@ -52,7 +52,7 @@ export default createBrowserRouter([ }, { id: "shortDetails", - path: "/sht/:name", + path: "/sht/:id", lazy: () => import("./pages/ShortDetails"), }, { diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 07dbbfe..0c481c2 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -2,16 +2,28 @@ export type GenericItem = Record export type User = { username: string + createdAt: string } export type Short = { + id: string name: string url: string + createdAt: string +} + +export type ShortLog = { + id: string + shortID: string + ipAddress: string + userAgent: string + referer: string + createdAt: string } export type Session = { id: string - ip: string + ipAddress: string userAgent: string lastActivity: string createdAt: string diff --git a/internal/config/config.go b/internal/config/config.go index b1c0d30..96c1091 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,12 +83,12 @@ type Config struct { func Validate(cfg *Config) error { // Host and port have to be valid. if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil { - return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid host and/or port: %s", err)) + return errs.Errorf(fmt.Sprintf("invalid host and/or port: %s", err), errs.ErrInvalidConfig) } // UI port has to be valid. if cfg.UIPort != "" { if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.UIPort)); err != nil { - return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid UI port: %s", err)) + return errs.Errorf(fmt.Sprintf("invalid UI port: %s", err), errs.ErrInvalidConfig) } } if cfg.DBType != "" { @@ -101,7 +101,7 @@ func Validate(cfg *Config) error { } } if !valid { - return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid database type: %s", cfg.DBType)) + return errs.Errorf(fmt.Sprintf("invalid database type: %s", cfg.DBType), errs.ErrInvalidConfig) } } diff --git a/internal/errs/errors.go b/internal/errs/errors.go index 1bc825a..91bb96f 100644 --- a/internal/errs/errors.go +++ b/internal/errs/errors.go @@ -40,8 +40,10 @@ var ( ErrInvalidTokenID = errors.New("invalid token ID") // ErrInvalidTokenName ErrInvalidTokenName = errors.New("invalid token name") + // ErrDatabaseError + ErrDatabaseError = errors.New("database error") ) -func Error(err error, msg string) error { - return fmt.Errorf("%w: %s", err, msg) +func Errorf(msg string, err error) error { + return fmt.Errorf("%s: %w", msg, err) } diff --git a/internal/server/api/handler.go b/internal/server/api/handler.go index 09633f8..22ac1cd 100644 --- a/internal/server/api/handler.go +++ b/internal/server/api/handler.go @@ -194,7 +194,7 @@ func (h *APIHandler) CreateShort(w http.ResponseWriter, r *http.Request) { } // Set the user - short.User = user + short.UserID = &user.ID // Shorten URL newShort, err := h.shorts.Shorten(ctx, short) @@ -283,6 +283,28 @@ func (h *APIHandler) DeleteShort(w http.ResponseWriter, r *http.Request) { render.NoContent(w, r) } +func (h *APIHandler) ListShortLogs(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Find own short or respond + short, ok := h.findShortOrRespond(w, r) + if !ok { + return + } + + // Get logs + logs, err := h.shorts.ListLogs(ctx, short) + if err != nil { + server.RenderServerError(w, r, err) + + return + } + + // Render the response + render.Status(r, http.StatusOK) + render.JSON(w, r, logs) +} + func (h *APIHandler) ListSessions(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -451,11 +473,11 @@ func (h *APIHandler) findUserOrRespond(w http.ResponseWriter, r *http.Request) ( func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request) (short *models.Short, ok bool) { ctx := r.Context() - // Get short name from request - name := chi.URLParam(r, "short") + // Get short id from request + id := chi.URLParam(r, "id") // Find short in storage - short, err := h.shorts.FindShort(ctx, name) + short, err := h.shorts.FindShortByID(ctx, id) if err != nil { // If the short doesn't exist or is invalid, return not found if errors.Is(err, errs.ErrShortDoesNotExist) || errors.Is(err, errs.ErrInvalidShort) { @@ -475,7 +497,7 @@ func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request) // If the session user does not match the short's user, // return forbidden. - if user.Username != short.User.Username { + if user.ID != *short.UserID { server.RenderForbidden(w, r) return nil, false @@ -514,7 +536,7 @@ func (h *APIHandler) findTokenOrRespond(w http.ResponseWriter, r *http.Request) // If the session user does not match the token's user, // return NotFound. - if user.Username != token.User.Username { + if user.ID != *token.UserID { server.RenderNotFound(w, r) return nil, false diff --git a/internal/server/api/router.go b/internal/server/api/router.go index 84fe6f5..2105c6b 100644 --- a/internal/server/api/router.go +++ b/internal/server/api/router.go @@ -26,8 +26,9 @@ func NewAPIRouter(h *APIHandler) http.Handler { // Shorts routes r.Get("/shorts", h.ListShorts) r.Post("/shorts", h.CreateShort) - r.Get("/shorts/{short}", h.FindShort) - r.Delete("/shorts/{short}", h.DeleteShort) + r.Get("/shorts/{id}", h.FindShort) + r.Delete("/shorts/{id}", h.DeleteShort) + r.Get("/shorts/{id}/logs", h.ListShortLogs) // Sessions routes r.Get("/sessions", h.ListSessions) diff --git a/internal/server/middleware/auth.go b/internal/server/middleware/auth.go index af83600..9ab7594 100644 --- a/internal/server/middleware/auth.go +++ b/internal/server/middleware/auth.go @@ -15,11 +15,11 @@ import ( ) const ( - sessionUserKey = "user" + sessionUserKey = "u" sessionIPKey = "ip" - sessionUserAgentKey = "user_agent" - sessionLastActivityKey = "last_activity" - sessionCreatedAtKey = "created_at" + sessionUserAgentKey = "ua" + sessionLastActivityKey = "la" + sessionCreatedAtKey = "ca" tokenHeader = "Authorization" ) @@ -28,8 +28,8 @@ type userContextKey struct{} type AuthSessionData struct { // Username is the username of the user to whom the session belongs. Username string `json:"username"` - // IP is the last IP address used by with the session. - IP string `json:"ip"` + // IPAddress is the last IP address used by with the session. + IPAddress string `json:"ipAddress"` // UserAgent is the last User-Agent used with the session. UserAgent string `json:"userAgent"` // LastActivity is the last time the session was used. @@ -97,7 +97,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) { func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context) *AuthSessionData { // Get data from session username := manager.GetString(sessionCtx, sessionUserKey) - ip := manager.GetString(sessionCtx, sessionIPKey) + ipAddress := manager.GetString(sessionCtx, sessionIPKey) userAgent := manager.GetString(sessionCtx, sessionUserAgentKey) lastActivity, err := time.Parse(time.RFC3339, manager.GetString(sessionCtx, sessionLastActivityKey)) if err != nil { @@ -111,7 +111,7 @@ func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context) // Create new session data sessionData := &AuthSessionData{ Username: username, - IP: ip, + IPAddress: ipAddress, UserAgent: userAgent, LastActivity: lastActivity, CreatedAt: createdAt, @@ -124,12 +124,12 @@ func UpdateSession(ctx context.Context, r *http.Request) { manager := SessionManagerFromCtx(ctx) // Get data from request - ip := r.RemoteAddr + ipAddress := r.RemoteAddr userAgent := r.UserAgent() lastActivity := time.Now().Format(time.RFC3339) // Update session - manager.Put(ctx, sessionIPKey, ip) + manager.Put(ctx, sessionIPKey, ipAddress) manager.Put(ctx, sessionUserAgentKey, userAgent) manager.Put(ctx, sessionLastActivityKey, lastActivity) } @@ -298,14 +298,11 @@ func authenticateViaToken(r *http.Request, tokenService *tokenservice.TokenServi return nil, errs.ErrTokenMissing } - // Get token from storage - token, err := tokenService.FindToken(ctx, value) + // Get the token's user from storage + user, err = tokenService.FindTokenUserFromValue(ctx, value) if err != nil { - return nil, fmt.Errorf("error authenticating via token: %w", err) + return nil, errs.Errorf("could not authenticate user", err) } - // Get user from token - user = token.User - return user, nil } diff --git a/internal/server/server.go b/internal/server/server.go index 2a8d9fb..46b3f75 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -26,12 +26,13 @@ func NewServer(cfg *config.Config) *Server { mux := chi.NewRouter() // Register default middlewares + mux.Use(middleware.Recoverer) mux.Use(middleware.RequestID) mux.Use(middleware.RealIP) mux.Use(middleware.Logger) - mux.Use(middleware.Recoverer) mux.Use(servermiddleware.SessionManager(cfg)) mux.Use(middleware.Timeout(config.RequestTimeout)) + mux.Use(middleware.Compress(5, "application/json")) // Create the server srv := &http.Server{ diff --git a/internal/server/short/handler.go b/internal/server/short/handler.go index cb93e00..13598b0 100644 --- a/internal/server/short/handler.go +++ b/internal/server/short/handler.go @@ -2,6 +2,7 @@ package shortserver import ( "errors" + "fmt" "net/http" "git.maronato.dev/maronato/goshort/internal/errs" @@ -32,7 +33,13 @@ func (h *ShortHandler) FindShort(w http.ResponseWriter, r *http.Request) { switch { case err == nil: - // If there's no error, redirect to the URL + // If there's no error, log the access and redirect to the URL + err = h.service.LogShortAccess(ctx, short, r) + if err != nil { + // If there was an error logging the access, print a message and + // continue. + fmt.Printf("failed to log short access: %v\n", err) + } http.Redirect(w, r, short.URL, http.StatusSeeOther) case errors.Is(err, errs.ErrInvalidShort): // If the short name is invalid, do nothing and let the static handler diff --git a/internal/service/short/shortservice.go b/internal/service/short/shortservice.go index 1f2517e..a08a074 100644 --- a/internal/service/short/shortservice.go +++ b/internal/service/short/shortservice.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "net/url" "regexp" @@ -20,6 +21,8 @@ const ( MinShortLength = 4 // MaxShortLength is the maximum length of the short URL. MaxShortLength = 20 + // ShortIDLength is the length of the short ID. + ShortIDLength = 16 ) type ShortService struct { @@ -30,6 +33,33 @@ func NewShortService(db storage.Storage) *ShortService { return &ShortService{db: db} } +func (s *ShortService) LogShortAccess(ctx context.Context, short *models.Short, r *http.Request) error { + // Log the access + shortLog := &models.ShortLog{ + ShortID: short.ID, + IPAddress: r.RemoteAddr, + UserAgent: r.UserAgent(), + Referer: r.Referer(), + } + + err := s.db.CreateShortLog(ctx, shortLog) + if err != nil { + return fmt.Errorf("failed to log short access: %w", err) + } + + return nil +} + +func (s *ShortService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) { + // Get the logs from storage + logs, err := s.db.ListShortLogs(ctx, short) + if err != nil { + return nil, fmt.Errorf("failed to get short logs from storage: %w", err) + } + + return logs, nil +} + func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Short, error) { // Check if the short is valid err := ShortNameIsValid(name) @@ -37,11 +67,26 @@ func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Shor return &models.Short{}, fmt.Errorf("could not validate short: %w", err) } - // Get the URL from storage + // Get the short from storage short, err := s.db.FindShort(ctx, name) if err != nil { - return short, fmt.Errorf("could not get URL from storage: %w", err) + return short, fmt.Errorf("could not get short from storage: %w", err) + } + + return short, nil +} + +func (s *ShortService) FindShortByID(ctx context.Context, id string) (*models.Short, error) { + // Check if the ID is valid + if len(id) != ShortIDLength { + return &models.Short{}, errs.ErrInvalidShort + } + + // Get the short from storage + short, err := s.db.FindShortByID(ctx, id) + if err != nil { + return short, fmt.Errorf("could not get short from storage: %w", err) } return short, nil @@ -116,13 +161,13 @@ var shortRegex = regexp.MustCompile(fmt.Sprintf("^%s$", shortPattern)) func ShortNameIsValid(name string) error { if !shortRegex.MatchString(name) { - return errs.Error( - errs.ErrInvalidShort, + return errs.Errorf( fmt.Sprintf( "short must use only letters, numbers, underscores and dashes, and be between %d and %d characters long", MinShortLength, MaxShortLength, ), + errs.ErrInvalidShort, ) } @@ -132,11 +177,11 @@ func ShortNameIsValid(name string) error { func ShortURLIsValid(shortURL string) error { parsedURL, err := url.ParseRequestURI(shortURL) if err != nil { - return errs.Error(errs.ErrInvalidShort, "invalid URL") + return errs.Errorf("invalid URL", errs.ErrInvalidShort) } if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return errs.Error(errs.ErrInvalidShort, "invalid URL scheme") + return errs.Errorf("invalid URL scheme", errs.ErrInvalidShort) } return nil diff --git a/internal/service/token/tokenservice.go b/internal/service/token/tokenservice.go index 43e3eed..e86c86b 100644 --- a/internal/service/token/tokenservice.go +++ b/internal/service/token/tokenservice.go @@ -8,7 +8,6 @@ import ( "git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage/models" - shortutil "git.maronato.dev/maronato/goshort/internal/util/short" tokenutil "git.maronato.dev/maronato/goshort/internal/util/token" ) @@ -60,7 +59,7 @@ func (s *TokenService) FindToken(ctx context.Context, value string) (*models.Tok // FindTokenByID finds a token in the storage using its ID. func (s *TokenService) FindTokenByID(ctx context.Context, id string) (*models.Token, error) { // Check if the ID is valid - if len(id) != TokenIDLength { + if len(id) != models.IDLength { return &models.Token{}, errs.ErrInvalidTokenID } @@ -86,12 +85,10 @@ func (s *TokenService) ListTokens(ctx context.Context, user *models.User) ([]*mo // CreateToken creates a new token for a user. func (s *TokenService) CreateToken(ctx context.Context, user *models.User) (*models.Token, error) { // Generate a new token - id := shortutil.GenerateRandomShort(TokenIDLength) token := &models.Token{ - ID: id, - Name: fmt.Sprintf("%s's token #%s", user.Username, id[:5]), - Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2), - User: user, + Name: fmt.Sprintf("%s's token", user.Username), + Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2), + UserID: &user.ID, } // Create the token in storage @@ -128,3 +125,20 @@ func (s *TokenService) ChangeTokenName(ctx context.Context, token *models.Token, return newToken, nil } + +// FindTokenUserFromValue finds the user that owns a token. +func (s *TokenService) FindTokenUserFromValue(ctx context.Context, value string) (*models.User, error) { + // Get the token from storage + token, err := s.FindToken(ctx, value) + if err != nil { + return &models.User{}, errs.Errorf("could not get token from storage", err) + } + + // Get the user from storage + user, err := s.db.FindUserByID(ctx, *token.UserID) + if err != nil { + return &models.User{}, errs.Errorf("could not get token user from storage", err) + } + + return user, nil +} diff --git a/internal/service/user/userservice.go b/internal/service/user/userservice.go index 1707d4f..1fc964b 100644 --- a/internal/service/user/userservice.go +++ b/internal/service/user/userservice.go @@ -131,7 +131,7 @@ func (s *UserService) SetPassword(ctx context.Context, user *models.User, newPas func UsernameIsValid(username string) error { if !UsernameRegex.MatchString(username) { - return errs.Error(errs.ErrInvalidUser, fmt.Sprintf("username must match %s", UsernameRegex)) + return errs.Errorf(fmt.Sprintf("username must match %s", UsernameRegex), errs.ErrInvalidUser) } return nil diff --git a/internal/storage/bun/errors.go b/internal/storage/bun/errors.go new file mode 100644 index 0000000..aa1a0da --- /dev/null +++ b/internal/storage/bun/errors.go @@ -0,0 +1,8 @@ +package bunstorage + +import "errors" + +var ( + // ErrDatabaseError is the error returned when a generic database error occurs. + ErrDatabaseError = errors.New("database error") +) diff --git a/internal/storage/bun/models.go b/internal/storage/bun/models.go index 762963b..72b4d29 100644 --- a/internal/storage/bun/models.go +++ b/internal/storage/bun/models.go @@ -7,42 +7,11 @@ import ( "github.com/uptrace/bun" ) -type ShortModel struct { - bun.BaseModel `bun:"table:shorts,alias:s" json:"-"` - - ID int64 `bun:",pk,autoincrement" json:"id"` - // Name is the short's name, which is also the path to access it - Name string `bun:",unique,notnull" json:"name"` - // URL is the URL that the short will redirect to - URL string `bun:",notnull" json:"url"` - // Deleted is whether the short is soft-deleted or not - Deleted bool `bun:",notnull,default:false" json:"-"` - // CreatedAt is when the short was created (initialized by the storage) - CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"` - // DeletedAt is when the short was deleted - DeletedAt time.Time `bun:",null" json:"-"` - - // UserID is the ID of the user that created the short - // This can be null if the short was deleted - UserID *int64 `json:"-"` - // User is the user that created the short - User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"` -} - -func (s *ShortModel) toShort() *models.Short { - return &models.Short{ - Name: s.Name, - URL: s.URL, - CreatedAt: s.CreatedAt, - User: s.User.toUser(), - } -} - type UserModel struct { bun.BaseModel `bun:"table:users,alias:u"` // ID is the primary key - ID int64 `bun:",pk,autoincrement" json:"id"` + ID string `bun:",pk,unique,notnull" json:"id"` // Username is the user's username Username string `bun:",unique,notnull" json:"username"` // Password is the user's password @@ -58,27 +27,60 @@ type UserModel struct { func (u *UserModel) toUser() *models.User { return models.NewAuthenticatableUser(&models.User{ + ID: u.ID, Username: u.Username, CreatedAt: u.CreatedAt, }, u.Password) } +type ShortModel struct { + bun.BaseModel `bun:"table:shorts,alias:s" json:"-"` + + ID string `bun:",pk,unique,notnull" json:"id"` + // Name is the short's name, which is also the path to access it + Name string `bun:",unique,notnull" json:"name"` + // URL is the URL that the short will redirect to + URL string `bun:",notnull" json:"url"` + // Deleted is whether the short is soft-deleted or not + Deleted bool `bun:",notnull,default:false" json:"-"` + // CreatedAt is when the short was created (initialized by the storage) + CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"` + // DeletedAt is when the short was deleted + DeletedAt time.Time `bun:",null" json:"-"` + + // UserID is the ID of the user that created the short + // This can be null if the short was deleted + UserID *string `json:"-"` + // User is the user that created the short + User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"` +} + +func (s *ShortModel) toShort() *models.Short { + return &models.Short{ + ID: s.ID, + Name: s.Name, + URL: s.URL, + CreatedAt: s.CreatedAt, + UserID: s.UserID, + } +} + type TokenModel struct { bun.BaseModel `bun:"table:tokens,alias:t"` // ID is the primary key - ID string `bun:",pk" json:"id"` + ID string `bun:",pk,unique,notnull" json:"id"` // Name is the user-friendly name of the token Name string `bun:",notnull" json:"name"` // Value is the actual token Value string `bun:",unique,notnull" json:"value"` - // UserID is the ID of the user that created the token - UserID *int64 `bun:",notnull" json:"-"` - // User is the user that created the token - User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"` - // CreatedAt is when the token was created (initialized by the storage) CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"` + + // UserID is the ID of the user that created the token + UserID *string `bun:",notnull" json:"-"` + // User is the user that created the token + User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"` } func (t *TokenModel) toToken() *models.Token { @@ -87,6 +89,37 @@ func (t *TokenModel) toToken() *models.Token { Name: t.Name, Value: t.Value, CreatedAt: t.CreatedAt, - User: t.User.toUser(), + UserID: t.UserID, + } +} + +type ShortLogModel struct { + bun.BaseModel `bun:"table:short_logs,alias:sl"` + + // ID is the primary key + ID string `bun:",pk,unique,notnull" json:"id"` + // IPAddress is the IP address of the client that accessed the short + IPAddress string `bun:"," json:"ipAddress"` + // UserAgent is the User-Agent of the client that accessed the short + UserAgent string `bun:"," json:"userAgent"` + // Referer is the referer of the client that accessed the short + Referer string `bun:"," json:"referer"` + // CreatedAt is when the short was accessed + CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"` + + // ShortID is the ID of the short that was accessed + ShortID string `bun:",notnull" json:"-"` + // Short is the short that was accessed + Short *ShortModel `bun:"rel:belongs-to,join:short_id=id" json:"short,omitempty"` +} + +func (sl *ShortLogModel) toShortLog() *models.ShortLog { + return &models.ShortLog{ + ID: sl.ID, + IPAddress: sl.IPAddress, + UserAgent: sl.UserAgent, + Referer: sl.Referer, + CreatedAt: sl.CreatedAt, + ShortID: sl.ShortID, } } diff --git a/internal/storage/bun/storage.go b/internal/storage/bun/storage.go index d68689e..bf21ccd 100644 --- a/internal/storage/bun/storage.go +++ b/internal/storage/bun/storage.go @@ -3,12 +3,10 @@ package bunstorage import ( "context" "database/sql" - "fmt" "time" "git.maronato.dev/maronato/goshort/internal/config" "git.maronato.dev/maronato/goshort/internal/errs" - "git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage/models" "github.com/uptrace/bun" @@ -16,11 +14,14 @@ import ( ) type BunStorage struct { - storage.Storage db *bun.DB + // StartHooks are ran after the database is connected and the tables are created. + StartHooks []func(ctx context.Context, db *bun.DB) error + // StopHooks are ran before the database is closed. + StopHook []func(ctx context.Context, db *bun.DB) error } -func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage { +func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage { if cfg.Debug { db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true))) } @@ -30,8 +31,18 @@ func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage { } } +// RegisterStartHook registers a hook to be ran after the database is connected and the tables are created. +func (s *BunStorage) RegisterStartHook(hook func(ctx context.Context, db *bun.DB) error) { + s.StartHooks = append(s.StartHooks, hook) +} + +// RegisterStopHook registers a hook to be ran before the database is closed. +func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB) error) { + s.StopHook = append(s.StopHook, hook) +} + func (s *BunStorage) Start(ctx context.Context) error { - return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { _, err := tx.NewCreateTable(). IfNotExists(). @@ -39,7 +50,7 @@ func (s *BunStorage) Start(ctx context.Context) error { WithForeignKeys(). Exec(ctx) if err != nil { - return fmt.Errorf("failed to create users table: %w", err) + return errs.Errorf("failed to create users table", err) } _, err = tx.NewCreateTable(). @@ -48,7 +59,7 @@ func (s *BunStorage) Start(ctx context.Context) error { ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`). Exec(ctx) if err != nil { - return fmt.Errorf("failed to create shorts table: %w", err) + return errs.Errorf("failed to create shorts table", err) } _, err = tx.NewCreateTable(). @@ -56,9 +67,17 @@ func (s *BunStorage) Start(ctx context.Context) error { Model((*TokenModel)(nil)). ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`). Exec(ctx) - if err != nil { - return fmt.Errorf("failed to create tokens table: %w", err) + return errs.Errorf("failed to create tokens table", err) + } + + _, err = tx.NewCreateTable(). + IfNotExists(). + Model((*ShortLogModel)(nil)). + ForeignKey(`("short_id") REFERENCES "shorts" ("id") ON DELETE CASCADE`). + Exec(ctx) + if err != nil { + return errs.Errorf("failed to create short logs table", err) } // shorts user_id index @@ -69,7 +88,7 @@ func (s *BunStorage) Start(ctx context.Context) error { Column("user_id"). Exec(ctx) if err != nil { - return fmt.Errorf("failed to create shorts user_id index: %w", err) + return errs.Errorf("failed to create shorts user_id index", err) } // tokens user_id index @@ -80,14 +99,47 @@ func (s *BunStorage) Start(ctx context.Context) error { Column("user_id"). Exec(ctx) if err != nil { - return fmt.Errorf("failed to create tokens user_id index: %w", err) + return errs.Errorf("failed to create tokens user_id index", err) + } + + // short_logs short_id index + _, err = tx.NewCreateIndex(). + IfNotExists(). + Model((*ShortLogModel)(nil)). + Index("idx_short_logs_short_id"). + Column("short_id"). + Exec(ctx) + if err != nil { + return errs.Errorf("failed to create short_logs short_id index", err) } return nil }) + + if err != nil { + return errs.Errorf("failed to start storage", err) + } + + // Run start hooks + for _, hook := range s.StartHooks { + err := hook(ctx, s.db) + if err != nil { + return errs.Errorf("failed to start storage", err) + } + } + + return nil } func (s *BunStorage) Stop(ctx context.Context) error { + // Run stop hooks + for _, hook := range s.StopHook { + err := hook(ctx, s.db) + if err != nil { + return errs.Errorf("failed to stop storage", err) + } + } + return s.db.Close() } @@ -97,71 +149,95 @@ func (s *BunStorage) FindShort(ctx context.Context, name string) (*models.Short, err := s.db.NewSelect(). Model(shortModel). Where("name = ? and deleted = false", name). - Relation("User"). Scan(ctx) if err != nil { if err == sql.ErrNoRows { - return nil, errs.ErrShortDoesNotExist + err = errs.ErrShortDoesNotExist } - return nil, fmt.Errorf("failed to find short: %w", err) + + return nil, errs.Errorf("failed to find short", err) + } + + return shortModel.toShort(), nil +} + +func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) { + shortModel := new(ShortModel) + err := s.db.NewSelect(). + Model(shortModel). + Where("id = ? and deleted = false", id). + Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + err = errs.ErrShortDoesNotExist + } + + return nil, errs.Errorf("failed to find short", err) } return shortModel.toShort(), nil } func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) { - // Get user from username - user, err := findUser(ctx, s.db, short.User.Username) + // Generate an ID that does not exist + tableName := s.db.NewSelect(). + Model((*ShortModel)(nil)). + GetTableName() + newID, err := createNewID(ctx, s.db, tableName, "id") if err != nil { - return nil, err + return nil, errs.Errorf("failed to create short", err) } shortModel := &ShortModel{ + ID: newID, Name: short.Name, URL: short.URL, - UserID: &user.ID, - User: user, + UserID: short.UserID, } _, err = s.db.NewInsert(). Model(shortModel). Exec(ctx) if err != nil { - return nil, errs.ErrShortExists + return nil, errs.Errorf("failed to create short", err) } return shortModel.toShort(), nil } func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error { - _, err := withShortDeleteUpdates( - s.db.NewUpdate(). - Model((*ShortModel)(nil)). - Where("name = ?", short.Name), - ).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to delete short: %w", err) - } + return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Delete short logs + _, err := tx.NewDelete(). + Model((*ShortLogModel)(nil)). + Where("short_id = ?", short.ID). + Exec(ctx) + if err != nil { + return errs.Errorf("failed to delete short", err) + } - return nil + // Delete short + _, err = withShortDeleteUpdates( + tx.NewUpdate(). + Model((*ShortModel)(nil)). + Where("id = ?", short.ID), + ).Exec(ctx) + if err != nil { + return errs.Errorf("failed to delete short", err) + } + + return nil + }) } func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) { - - // Get user ID from username - userID, err := findUserIDFromUsername(ctx, s.db, user.Username) - if err != nil { - return nil, err - } - shortModels := []*ShortModel{} - err = s.db.NewSelect(). + err := s.db.NewSelect(). Model(&shortModels). - Where("user_id = ? and deleted = false", userID). - Relation("User"). + Where("user_id = ? and deleted = false", user.ID). Scan(ctx) if err != nil { - return nil, fmt.Errorf("failed to list shorts: %w", err) + return nil, errs.Errorf("failed to list shorts", err) } shorts := []*models.Short{} @@ -171,6 +247,51 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode return shorts, nil } +func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error { + // Create new ID + tableName := s.db.NewSelect(). + Model((*ShortLogModel)(nil)). + GetTableName() + newID, err := createNewID(ctx, s.db, tableName, "id") + if err != nil { + return errs.Errorf("failed to create short log", err) + } + + shortLogModel := &ShortLogModel{ + ID: newID, + ShortID: shortLog.ShortID, + IPAddress: shortLog.IPAddress, + UserAgent: shortLog.UserAgent, + Referer: shortLog.Referer, + } + + _, err = s.db.NewInsert(). + Model(shortLogModel). + Exec(ctx) + if err != nil { + return errs.Errorf("failed to create short log", err) + } + + return nil +} + +func (s *BunStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) { + shortLogModels := []*ShortLogModel{} + err := s.db.NewSelect(). + Model(&shortLogModels). + Where("short_id = ?", short.ID). + Scan(ctx) + if err != nil { + return nil, errs.Errorf("failed to list short logs", err) + } + + shortLogs := []*models.ShortLog{} + for _, shortLogModel := range shortLogModels { + shortLogs = append(shortLogs, shortLogModel.toShortLog()) + } + return shortLogs, nil +} + func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) { user, err := findUser(ctx, s.db, username) if err != nil { @@ -180,21 +301,39 @@ func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.Use return user.toUser(), nil } +func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) { + user, err := findUserByID(ctx, s.db, id) + if err != nil { + return nil, err + } + + return user.toUser(), nil +} + func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) { + // Create a new ID + tableName := s.db.NewSelect(). + Model((*UserModel)(nil)). + GetTableName() + newID, err := createNewID(ctx, s.db, tableName, "id") + if err != nil { + return nil, errs.Errorf("failed to create user", err) + } + userModel := &UserModel{ + ID: newID, Username: user.Username, Password: user.GetPasswordHash(), } - _, err := s.db.NewInsert(). + _, err = s.db.NewInsert(). Model(userModel). Exec(ctx) if err != nil { - return nil, errs.ErrUserExists + return nil, errs.Errorf("failed to create user", err) } return userModel.toUser(), nil - } func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error { @@ -203,15 +342,15 @@ func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error { // Delete user shorts err := deleteUserShorts(ctx, tx, user) if err != nil { - return fmt.Errorf("failed to delete user: %w", err) + return errs.Errorf("failed to delete user", err) } // Delete user _, err = tx.NewDelete(). Model(user). - Where("username = ?", user.Username). + Where("id = ?", user.ID). Exec(ctx) if err != nil { - return fmt.Errorf("failed to delete user: %w", err) + return errs.Errorf("failed to delete user", err) } return nil @@ -224,13 +363,13 @@ func (s *BunStorage) FindToken(ctx context.Context, value string) (*models.Token err := s.db.NewSelect(). Model(tokenModel). Where("value = ?", value). - Relation("User"). Scan(ctx) if err != nil { if err == sql.ErrNoRows { - return nil, errs.ErrTokenDoesNotExist + err = errs.ErrTokenDoesNotExist } - return nil, fmt.Errorf("failed to find token: %w", err) + + return nil, errs.Errorf("failed to find token", err) } return tokenModel.toToken(), nil @@ -242,14 +381,14 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke err := s.db.NewSelect(). Model(tokenModel). - Where("t.id = ?", id). - Relation("User"). + Where("id = ?", id). Scan(ctx) if err != nil { if err == sql.ErrNoRows { - return nil, errs.ErrTokenDoesNotExist + err = errs.ErrTokenDoesNotExist } - return nil, fmt.Errorf("failed to find token: %w", err) + + return nil, errs.Errorf("failed to find token", err) } return tokenModel.toToken(), nil @@ -259,19 +398,12 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) { tokenModels := []*TokenModel{} - // Get user ID from username - userID, err := findUserIDFromUsername(ctx, s.db, user.Username) - if err != nil { - return nil, err - } - - err = s.db.NewSelect(). + err := s.db.NewSelect(). Model(&tokenModels). - Where("user_id = ?", userID). - Relation("User"). + Where("user_id = ?", user.ID). Scan(ctx) if err != nil { - return nil, fmt.Errorf("failed to list tokens: %w", err) + return nil, errs.Errorf("failed to list tokens", err) } tokens := []*models.Token{} @@ -283,26 +415,27 @@ func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*mode } func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) { - - // Get user ID from username - user, err := findUser(ctx, s.db, token.User.Username) + // Create a new ID + tableName := s.db.NewSelect(). + Model((*TokenModel)(nil)). + GetTableName() + newID, err := createNewID(ctx, s.db, tableName, "id") if err != nil { - return nil, err + return nil, errs.Errorf("failed to create token", err) } tokenModel := &TokenModel{ - ID: token.ID, + ID: newID, Name: token.Name, Value: token.Value, - UserID: &user.ID, - User: user, + UserID: token.UserID, } _, err = s.db.NewInsert(). Model(tokenModel). Exec(ctx) if err != nil { - return nil, errs.ErrTokenExists + return nil, errs.Errorf("failed to create token", err) } return tokenModel.toToken(), nil @@ -314,12 +447,27 @@ func (s *BunStorage) DeleteToken(ctx context.Context, token *models.Token) error Where("id = ?", token.ID). Exec(ctx) if err != nil { - return fmt.Errorf("failed to delete token: %w", err) + return errs.Errorf("failed to delete token", err) } return nil } +func (s *BunStorage) ChangeTokenName(ctx context.Context, token *models.Token, name string) (*models.Token, error) { + newToken := new(TokenModel) + + _, err := s.db.NewUpdate(). + Model((*TokenModel)(nil)). + Set("name = ?", name). + Where("id = ?", token.ID). + Exec(ctx, newToken) + if err != nil { + return nil, errs.Errorf("failed to change token name", err) + } + + return newToken.toToken(), nil +} + func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) { userModel := new(UserModel) @@ -329,31 +477,31 @@ func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, err Scan(ctx) if err != nil { if err == sql.ErrNoRows { - return nil, errs.ErrUserDoesNotExist + err = errs.ErrUserDoesNotExist } - return nil, fmt.Errorf("failed to find user: %w", err) + + return nil, errs.Errorf("failed to find user", err) } return userModel, nil } -func findUserIDFromUsername(ctx context.Context, db bun.IDB, username string) (*int64, error) { - var userID int64 +func findUserByID(ctx context.Context, db bun.IDB, id string) (*UserModel, error) { + userModel := new(UserModel) err := db.NewSelect(). - Table("users"). - Column("id"). - Where("username = ?", username). - Scan(ctx, &userID) + Model(userModel). + Where("id = ?", id). + Scan(ctx) if err != nil { if err == sql.ErrNoRows { - return nil, errs.ErrUserDoesNotExist + err = errs.ErrUserDoesNotExist } - return nil, fmt.Errorf("failed to get user ID: %w", err) + return nil, errs.Errorf("failed to find user", err) } - return &userID, nil + return userModel, nil } func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery { @@ -363,19 +511,46 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery { } func deleteUserShorts(ctx context.Context, db bun.IDB, user *models.User) error { - userID, err := findUserIDFromUsername(ctx, db, user.Username) - if err != nil { - return fmt.Errorf("failed to delete user shorts: %w", err) - } - - _, err = withShortDeleteUpdates( + _, err := withShortDeleteUpdates( db.NewUpdate(). Model((*ShortModel)(nil)). - Where("user_id = ?", userID), + Where("user_id = ?", user.ID), ).Exec(ctx) if err != nil { - return fmt.Errorf("failed to delete user shorts: %w", err) + return errs.Errorf("failed to delete user shorts", err) } return nil } + +func createNewID(ctx context.Context, db bun.IDB, table, col string) (string, error) { + var newID string + + // Generate an ID that does not exist + maxIters := 10 + for { + // Make sure we don't get stuck in an infinite loop + maxIters-- + if maxIters <= 0 { + return "", errs.Errorf("failed to create unique ID", ErrDatabaseError) + } + + // Create a new ID and check if it exists + newID = models.NewID() + count, err := db.NewSelect(). + Table(table). + Column(col). + Where(col+" = ?", newID). + Count(ctx) + if err != nil { + return "", errs.Errorf("failed to create unique ID", err) + } + + // If the ID does not exist, break the loop + if count == 0 { + break + } + } + + return newID, nil +} diff --git a/internal/storage/memory/memory.go b/internal/storage/memory/memory.go index 2b9bd7e..8e6f53e 100644 --- a/internal/storage/memory/memory.go +++ b/internal/storage/memory/memory.go @@ -3,9 +3,11 @@ package memorystorage import ( "context" "fmt" + "runtime" "sync" "time" + "git.maronato.dev/maronato/goshort/internal/config" "git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/storage" "git.maronato.dev/maronato/goshort/internal/storage/models" @@ -13,87 +15,170 @@ import ( // MemoryStorage is a storage that stores everything in memory. type MemoryStorage struct { + debug bool storage.Storage - shortMu sync.RWMutex - userMu sync.RWMutex - tokenMu sync.RWMutex - shortMap map[string]*models.Short - userMap map[string]*models.User - tokenMap map[string]*models.Token - tokenIDMap map[string]*models.Token + shortMu sync.RWMutex + shortIDMu sync.RWMutex + shortLogMu sync.RWMutex + userMu sync.RWMutex + userIDMu sync.RWMutex + tokenMu sync.RWMutex + tokenIDMu sync.RWMutex + shortMap map[string]*models.Short + shortIDMap map[string]*models.Short + shortLogMap map[string][]*models.ShortLog + userMap map[string]*models.User + userIDMap map[string]*models.User + tokenMap map[string]*models.Token + tokenIDMap map[string]*models.Token } // NewMemoryStorage creates a new MemoryStorage. -func NewMemoryStorage() storage.Storage { +func NewMemoryStorage(cfg *config.Config) *MemoryStorage { return &MemoryStorage{ - shortMap: make(map[string]*models.Short), - userMap: make(map[string]*models.User), - tokenMap: make(map[string]*models.Token), - tokenIDMap: make(map[string]*models.Token), + debug: cfg.Debug, + shortMap: make(map[string]*models.Short), + shortIDMap: make(map[string]*models.Short), + userMap: make(map[string]*models.User), + userIDMap: make(map[string]*models.User), + tokenMap: make(map[string]*models.Token), + tokenIDMap: make(map[string]*models.Token), + shortLogMap: make(map[string][]*models.ShortLog), + } +} + +// logPerformance is a helper function to log the performance of a function. +func (s *MemoryStorage) logPerformance() func() { + start := time.Now() + + return func() { + elapsed := time.Since(start) + if s.debug { + pc, _, _, ok := runtime.Caller(1) + if !ok { + return + } + + method := runtime.FuncForPC(pc).Name() + fmt.Printf("%s took %s\n", method, elapsed) + } } } // Start starts the storage. func (s *MemoryStorage) Start(ctx context.Context) error { + logPerf := s.logPerformance() + defer logPerf() + return nil } // Stop stops the storage. func (s *MemoryStorage) Stop(ctx context.Context) error { + logPerf := s.logPerformance() + defer logPerf() + return nil } // FindShort finds a short in the storage. func (s *MemoryStorage) FindShort(ctx context.Context, name string) (*models.Short, error) { + logPerf := s.logPerformance() + defer logPerf() + s.shortMu.RLock() defer s.shortMu.RUnlock() short, ok := s.shortMap[name] if !ok { - return short, errs.ErrShortDoesNotExist + return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist) + } + + return short, nil +} + +// FindShortByID finds a short in the storage. +func (s *MemoryStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) { + logPerf := s.logPerformance() + defer logPerf() + + s.shortIDMu.RLock() + defer s.shortIDMu.RUnlock() + + short, ok := s.shortIDMap[id] + if !ok { + return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist) } return short, nil } func (s *MemoryStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) { + logPerf := s.logPerformance() + defer logPerf() + s.shortMu.Lock() defer s.shortMu.Unlock() + s.shortIDMu.Lock() + defer s.shortIDMu.Unlock() _, ok := s.shortMap[short.Name] if ok { - return &models.Short{}, errs.ErrShortExists + return nil, errs.Errorf("failed to create short", errs.ErrShortExists) } short.CreatedAt = time.Now() + // Create a new ID + for { + short.ID = models.NewID() + + _, ok := s.shortIDMap[short.ID] + if !ok { + break + } + } + s.shortMap[short.Name] = short + s.shortIDMap[short.ID] = short return short, nil } func (s *MemoryStorage) DeleteShort(ctx context.Context, short *models.Short) error { + logPerf := s.logPerformance() + defer logPerf() + s.shortMu.Lock() defer s.shortMu.Unlock() + s.shortIDMu.Lock() + defer s.shortIDMu.Unlock() + s.shortLogMu.Lock() + defer s.shortLogMu.Unlock() - _, ok := s.shortMap[short.Name] + _, ok := s.shortIDMap[short.ID] if !ok { - return errs.ErrShortDoesNotExist + return errs.Errorf("failed to delete short", errs.ErrShortDoesNotExist) } delete(s.shortMap, short.Name) + delete(s.shortIDMap, short.ID) + delete(s.shortLogMap, short.ID) return nil } func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) { - s.shortMu.RLock() - defer s.shortMu.RUnlock() + logPerf := s.logPerformance() + defer logPerf() + + s.shortIDMu.RLock() + defer s.shortIDMu.RUnlock() shorts := []*models.Short{} - for _, short := range s.shortMap { - if short.User != nil && short.User.Username == user.Username { + for _, short := range s.shortIDMap { + if short.UserID != nil && *short.UserID == user.ID { shorts = append(shorts, short) } } @@ -101,108 +186,203 @@ func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*m return shorts, nil } +func (s *MemoryStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error { + logPerf := s.logPerformance() + defer logPerf() + + s.shortLogMu.Lock() + defer s.shortLogMu.Unlock() + + // Check if short exists + short, err := s.FindShortByID(ctx, shortLog.ShortID) + if err != nil { + return fmt.Errorf("could not create short log: %w", err) + } + + // Create + _, ok := s.shortLogMap[short.ID] + if !ok { + s.shortLogMap[shortLog.ShortID] = []*models.ShortLog{} + } + + shortLog.CreatedAt = time.Now() + + // Create a new ID (we can't really know if it's unique, but it's unlikely to happen) + shortLog.ID = models.NewID() + + s.shortLogMap[shortLog.ShortID] = append(s.shortLogMap[shortLog.ShortID], shortLog) + + return nil +} + +func (s *MemoryStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) { + logPerf := s.logPerformance() + defer logPerf() + + s.shortLogMu.RLock() + defer s.shortLogMu.RUnlock() + + shortLogs, ok := s.shortLogMap[short.ID] + if !ok { + return []*models.ShortLog{}, nil + } + + return shortLogs, nil +} + func (s *MemoryStorage) FindUser(ctx context.Context, username string) (*models.User, error) { + logPerf := s.logPerformance() + defer logPerf() + s.userMu.RLock() defer s.userMu.RUnlock() user, ok := s.userMap[username] if !ok { - return user, errs.ErrUserDoesNotExist + return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist) + } + + return user, nil +} + +func (s *MemoryStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) { + logPerf := s.logPerformance() + defer logPerf() + + s.userIDMu.RLock() + defer s.userIDMu.RUnlock() + + user, ok := s.userIDMap[id] + if !ok { + return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist) } return user, nil } func (s *MemoryStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) { + logPerf := s.logPerformance() + defer logPerf() + s.userMu.Lock() defer s.userMu.Unlock() + s.userIDMu.Lock() + defer s.userIDMu.Unlock() _, ok := s.userMap[user.Username] if ok { - return &models.User{}, errs.ErrUserExists + return nil, errs.Errorf("failed to create user", errs.ErrUserExists) } user.CreatedAt = time.Now() + // Create a new ID + for { + user.ID = models.NewID() + + _, ok := s.userIDMap[user.ID] + if !ok { + break + } + } + s.userMap[user.Username] = user + s.userIDMap[user.ID] = user return user, nil } func (s *MemoryStorage) DeleteUser(ctx context.Context, user *models.User) error { + logPerf := s.logPerformance() + defer logPerf() + s.userMu.Lock() defer s.userMu.Unlock() + s.userIDMu.Lock() + defer s.userIDMu.Unlock() + + errormsg := "failed to delete user" _, ok := s.userMap[user.Username] if !ok { - return errs.ErrUserDoesNotExist + return errs.Errorf(errormsg, errs.ErrUserDoesNotExist) } // Find all user shorts shorts, err := s.ListShorts(ctx, user) if err != nil { - return fmt.Errorf("could not list user shorts: %w", err) + return errs.Errorf(errormsg, errs.Errorf("could not list user shorts", err)) } // Delete all user shorts for _, short := range shorts { err = s.DeleteShort(ctx, short) if err != nil { - return fmt.Errorf("could not delete user short: %w", err) + return errs.Errorf(errormsg, errs.Errorf("could not delete user short", err)) } } // Find all user tokens tokens, err := s.ListTokens(ctx, user) if err != nil { - return fmt.Errorf("could not list user tokens: %w", err) + return errs.Errorf(errormsg, errs.Errorf("could not list user tokens", err)) } // Delete all user tokens for _, token := range tokens { err = s.DeleteToken(ctx, token) if err != nil { - return fmt.Errorf("could not delete user token: %w", err) + return errs.Errorf(errormsg, errs.Errorf("could not delete user token", err)) } } delete(s.userMap, user.Username) + delete(s.userIDMap, user.ID) return nil } func (s *MemoryStorage) FindToken(ctx context.Context, value string) (*models.Token, error) { + logPerf := s.logPerformance() + defer logPerf() + s.tokenMu.RLock() defer s.tokenMu.RUnlock() token, ok := s.tokenMap[value] if !ok { - return token, errs.ErrTokenDoesNotExist + return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist) } return token, nil } func (s *MemoryStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) { - s.tokenMu.RLock() - defer s.tokenMu.RUnlock() + logPerf := s.logPerformance() + defer logPerf() + + s.tokenIDMu.RLock() + defer s.tokenIDMu.RUnlock() token, ok := s.tokenIDMap[id] if !ok { - return token, errs.ErrTokenDoesNotExist + return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist) } return token, nil } func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) { - s.tokenMu.RLock() - defer s.tokenMu.RUnlock() + logPerf := s.logPerformance() + defer logPerf() + + s.tokenIDMu.RLock() + defer s.tokenIDMu.RUnlock() tokens := []*models.Token{} - for _, token := range s.tokenMap { - if token.User != nil && token.User.Username == user.Username { + for _, token := range s.tokenIDMap { + if token.UserID != nil && *token.UserID == user.ID { tokens = append(tokens, token) } } @@ -211,20 +391,35 @@ func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*m } func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) { + logPerf := s.logPerformance() + defer logPerf() + s.tokenMu.Lock() defer s.tokenMu.Unlock() + s.tokenIDMu.Lock() + defer s.tokenIDMu.Unlock() _, ok := s.tokenMap[token.Value] if ok { - return &models.Token{}, errs.ErrTokenExists + return nil, errs.Errorf("failed to create token", errs.ErrTokenExists) } _, ok = s.tokenIDMap[token.ID] if ok { - return &models.Token{}, errs.ErrTokenExists + return nil, errs.Errorf("failed to create token", errs.ErrTokenExists) } token.CreatedAt = time.Now() + // Create a new ID + for { + token.ID = models.NewID() + + _, ok := s.tokenIDMap[token.ID] + if !ok { + break + } + } + s.tokenMap[token.Value] = token s.tokenIDMap[token.ID] = token @@ -232,16 +427,21 @@ func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (* } func (s *MemoryStorage) DeleteToken(ctx context.Context, token *models.Token) error { + logPerf := s.logPerformance() + defer logPerf() + s.tokenMu.Lock() defer s.tokenMu.Unlock() + s.tokenIDMu.Lock() + defer s.tokenIDMu.Unlock() _, ok := s.tokenMap[token.Value] if !ok { - return errs.ErrTokenDoesNotExist + return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist) } _, ok = s.tokenIDMap[token.ID] if !ok { - return errs.ErrTokenDoesNotExist + return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist) } delete(s.tokenMap, token.Value) diff --git a/internal/storage/models/shared.go b/internal/storage/models/shared.go new file mode 100644 index 0000000..9897ea6 --- /dev/null +++ b/internal/storage/models/shared.go @@ -0,0 +1,12 @@ +package models + +import tokenutil "git.maronato.dev/maronato/goshort/internal/util/token" + +const ( + // IDLength is the length of IDs. + IDLength = 16 +) + +func NewID() string { + return tokenutil.GenerateSecureToken(IDLength / 2) +} diff --git a/internal/storage/models/short.go b/internal/storage/models/short.go index 03ad268..a00888a 100644 --- a/internal/storage/models/short.go +++ b/internal/storage/models/short.go @@ -2,14 +2,20 @@ package models import "time" +const ( + ShortIDLength = 16 +) + type Short struct { + // ID is the unique identifier of the short. + ID string `json:"id"` // Name is the shortened name of the URL. Name string `json:"name,omitempty"` // URL is the URL that is shortened. URL string `json:"url"` // CreatedAt is the time the short was created. - CreatedAt time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` - // User is the user that created the short. - User *User `json:"-"` + // UserID of the user that created the short. + UserID *string `json:"-"` } diff --git a/internal/storage/models/shortlog.go b/internal/storage/models/shortlog.go new file mode 100644 index 0000000..acc908b --- /dev/null +++ b/internal/storage/models/shortlog.go @@ -0,0 +1,18 @@ +package models + +import "time" + +type ShortLog struct { + // ID is the unique identifier of the short log. + ID string `json:"id"` + // ShortID is the ID of the short that was accessed. + ShortID string `json:"shortID"` + // IPAddress is the IP address of the client that accessed the short. + IPAddress string `json:"ipAddress"` + // UserAgent is the User-Agent of the client that accessed the short. + UserAgent string `json:"userAgent"` + // Referer is the referer of the client that accessed the short. + Referer string `json:"referer"` + // CreatedAt is the time the short was accessed. + CreatedAt time.Time `json:"createdAt,omitempty"` +} diff --git a/internal/storage/models/token.go b/internal/storage/models/token.go index cab0855..bbe9116 100644 --- a/internal/storage/models/token.go +++ b/internal/storage/models/token.go @@ -12,6 +12,6 @@ type Token struct { // CreatedAt is when the token was created (initialized by the storage) CreatedAt time.Time `json:"createdAt"` - // User is the user that created the token. - User *User `json:"-"` + // UserID of the user that created the token. + UserID *string `json:"-"` } diff --git a/internal/storage/models/user.go b/internal/storage/models/user.go index da946b6..bf13029 100644 --- a/internal/storage/models/user.go +++ b/internal/storage/models/user.go @@ -8,11 +8,15 @@ import ( ) type User struct { + // ID is the user's ID. + ID string `json:"id"` + // Username is the user's username. Username string `json:"username"` + // password is the user's password. password string `json:"-"` // CreatedAt is the time the user was created. - CreatedAt time.Time `json:"created_at,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` } // NewAuthenticatableUser is a elper function for storages that takes @@ -20,6 +24,7 @@ type User struct { func NewAuthenticatableUser(user *User, hashedPass string) *User { return &User{ + ID: user.ID, Username: user.Username, CreatedAt: user.CreatedAt, password: hashedPass, diff --git a/internal/storage/sqlite/sqlite.go b/internal/storage/sqlite/sqlite.go index ef7a503..3f62662 100644 --- a/internal/storage/sqlite/sqlite.go +++ b/internal/storage/sqlite/sqlite.go @@ -1,9 +1,12 @@ package sqlitestorage import ( + "context" "database/sql" + "time" "git.maronato.dev/maronato/goshort/internal/config" + "git.maronato.dev/maronato/goshort/internal/errs" "git.maronato.dev/maronato/goshort/internal/storage" bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun" "github.com/uptrace/bun" @@ -13,13 +16,74 @@ import ( // NewSQLiteStorage creates a new SQLite storage. func NewSQLiteStorage(cfg *config.Config) storage.Storage { - sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL) + // Create a new SQLite database with the following pragmas enabled: + // - journal_mode=WAL: Enables Write-Ahead Logging, which allows for concurrent reads and writes. + // - foreign_keys=ON: Enables foreign key constraints. + // - synchronous=NORMAL: Enables synchronous mode NORMAL + sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL+"?_pragma=foreign_keys(ON)&_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)") if err != nil { panic(err) } + // If running the DB in memory, make sure + // database/sql does not close idle connections. + // Otherwise, the database will be lost. + if cfg.DBType == config.DBTypeMemory { + sqldb.SetMaxIdleConns(1000) + sqldb.SetConnMaxLifetime(0) + } + db := bun.NewDB(sqldb, sqlitedialect.New()) // Use Bun as the storage implementation. - return bunstorage.NewBunStorage(cfg, db) + storage := bunstorage.NewBunStorage(cfg, db) + + // Config vacuuming of the database every hour. + ticker := time.NewTicker(time.Hour) + tickerCtx, tickerCancel := context.WithCancel(context.Background()) + // Start vacuuming the database when the storage is started. + storage.RegisterStartHook(func(ctx context.Context, activeDB *bun.DB) error { + // Vacuum now + _, err := activeDB.NewRaw("VACUUM").Exec(ctx) + if err != nil { + return errs.Errorf("failed to vacuum database: %w", err) + } + + // Vacuum every hour, until the context is canceled. + go func() { + for { + select { + case <-ticker.C: + _, err := activeDB.NewRaw("VACUUM").Exec(ctx) + if err != nil { + panic(errs.Errorf("failed to vacuum database: %w", err)) + } + case <-ctx.Done(): + return + case <-tickerCtx.Done(): + return + } + } + }() + + return nil + }) + // So some cleaning up and optimization when the storage is stopped. + storage.RegisterStopHook(func(ctx context.Context, activeDB *bun.DB) error { + // Stop the ticker and cancel the ticker context. + ticker.Stop() + tickerCancel() + // Run PRAGMA optimize to optimize the database. + // By this point, the passed context is likely cancelled, so we use the + // background context since otherwise the query will fail. + // This will delay the shutdown of the storage, but it'll be for a good cause. + // The user can force the shutdown by using SIGINT or SIGTERM. + _, err := activeDB.NewRaw("PRAGMA optimize").Exec(context.Background()) + if err != nil { + return errs.Errorf("failed to optimize database: %w", err) + } + return nil + }) + + return storage } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 40b961b..ba753f4 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -14,6 +14,8 @@ type Storage interface { // FindShort finds a short in the storage using its name. FindShort(ctx context.Context, name string) (*models.Short, error) + // FindShortByID finds a short in the storage using its ID. + FindShortByID(ctx context.Context, id string) (*models.Short, error) // FindShorts finds all shorts in the storage that belong to a user. ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) // CreateShort creates a short in the storage. @@ -21,10 +23,19 @@ type Storage interface { // DeleteShort deletes a short from the storage. DeleteShort(ctx context.Context, short *models.Short) error + // ShortLog Storage + + // CreateShortLog creates a short log in the storage. + CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error + // ListShortLogs finds all short logs in the storage that belong to a short. + ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) + // User Storage // FindUser finds a user in the storage using its username. FindUser(ctx context.Context, username string) (*models.User, error) + // FindUserByID finds a user in the storage using its ID. + FindUserByID(ctx context.Context, id string) (*models.User, error) // CreateUser creates a user in the storage. CreateUser(ctx context.Context, user *models.User) (*models.User, error) // DeleteUser deletes a user and all their shorts from the storage.