short logs and better database access
This commit is contained in:
parent
b143221181
commit
16bb373c60
|
@ -1,5 +1,7 @@
|
|||
# ---> DB
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
|
||||
# ---> Go
|
||||
# If you prefer the allow list template instead of the deny list, see community template:
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# ---> DB
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
|
||||
# ---> Go
|
||||
# If you prefer the allow list template instead of the deny list, see community template:
|
||||
|
@ -157,4 +159,4 @@ dist
|
|||
.yarn/unplugged
|
||||
.yarn/build-state.yml
|
||||
.yarn/install-state.gz
|
||||
.pnp.*
|
||||
.pnp.*
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
memorystorage "git.maronato.dev/maronato/goshort/internal/storage/memory"
|
||||
sqlitestorage "git.maronato.dev/maronato/goshort/internal/storage/sqlite"
|
||||
"github.com/peterbourgon/ff/v3"
|
||||
)
|
||||
|
@ -60,7 +59,8 @@ func RegisterServerFlags(fs *flag.FlagSet, cfg *config.Config) {
|
|||
func InitStorage(cfg *config.Config) storage.Storage {
|
||||
switch cfg.DBType {
|
||||
case config.DBTypeMemory:
|
||||
return memorystorage.NewMemoryStorage()
|
||||
cfg.DBURL = ":memory:"
|
||||
return sqlitestorage.NewSQLiteStorage(cfg)
|
||||
case config.DBTypeSQLite:
|
||||
return sqlitestorage.NewSQLiteStorage(cfg)
|
||||
default:
|
||||
|
|
|
@ -10,14 +10,22 @@ const ItemList = <T extends Record<string, unknown>, K extends keyof T>({
|
|||
idKey: K
|
||||
Item: FunctionComponent<T>
|
||||
}) => {
|
||||
console.log("list")
|
||||
return (
|
||||
<ul
|
||||
role="list"
|
||||
className="divide-y divide-gray-100 rounded-lg shadow-md overflow-hidden">
|
||||
{items.map((item) => (
|
||||
<Item {...item} key={item[idKey] as string} />
|
||||
))}
|
||||
</ul>
|
||||
<>
|
||||
<ul
|
||||
role="list"
|
||||
className="divide-y divide-gray-100 rounded-lg shadow-md overflow-hidden">
|
||||
{items.map((item) => (
|
||||
<Item {...item} key={item[idKey] as string} />
|
||||
))}
|
||||
</ul>
|
||||
{items.length === 0 && (
|
||||
<div className="text-center pt-5 text-xl font-light">
|
||||
Nothing here yet
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { FunctionComponent, useCallback } from "react"
|
||||
import { FunctionComponent, memo, useCallback } from "react"
|
||||
|
||||
import {
|
||||
ArrowRightIcon,
|
||||
|
@ -6,6 +6,7 @@ import {
|
|||
} from "@heroicons/react/24/outline"
|
||||
|
||||
import { useDelete } from "../hooks/useCRUD"
|
||||
import { useShortLogs, useVisitMetrics } from "../hooks/useStats"
|
||||
import { Short } from "../types"
|
||||
|
||||
import ItemBase from "./ItemBase"
|
||||
|
@ -24,17 +25,14 @@ const ShortItem: FunctionComponent<Short> = ({ ...short }) => {
|
|||
|
||||
// Handle deletion
|
||||
const [deleting, del] = useDelete()
|
||||
const doDelete = useCallback(
|
||||
() => del({ name: short.name }),
|
||||
[del, short.name]
|
||||
)
|
||||
const doDelete = useCallback(() => del({ id: short.id }), [del, short.id])
|
||||
|
||||
return (
|
||||
<ItemBase
|
||||
copyString={shortNameURL}
|
||||
doDelete={doDelete}
|
||||
deleting={deleting}
|
||||
detailsPage={`/sht/${short.name}`}>
|
||||
detailsPage={`/sht/${short.id}`}>
|
||||
<div className="grid md:grid-cols-12 md:grid-rows-1 w-full grid-flow-row md:grid-flow-col">
|
||||
<div className="col-span-5 md:col-span-3 my-auto flex flex-col order-1">
|
||||
<a
|
||||
|
@ -69,16 +67,111 @@ const ShortItem: FunctionComponent<Short> = ({ ...short }) => {
|
|||
</a>
|
||||
</div>
|
||||
<div className="col-span-5 md:col-span-2 shrink-0 flex flex-col items-end my-auto order-2 md:order-4">
|
||||
<p className="text-sm font-bold leading-6 text-gray-900">
|
||||
<span className="text-green-500">1000</span> views
|
||||
</p>
|
||||
<p className="mt-1 text-xs leading-5 text-slate-400">
|
||||
Last viewed <time dateTime="2023-01-23T13:23Z">3h ago</time>
|
||||
</p>
|
||||
<ShortItemStatsMemo key={short.id} id={short.id} />
|
||||
</div>
|
||||
</div>
|
||||
</ItemBase>
|
||||
)
|
||||
}
|
||||
|
||||
export default ShortItem
|
||||
export default memo(ShortItem, (prev, next) => prev.id === next.id)
|
||||
|
||||
const ShortItemStats: FunctionComponent<Pick<Short, "id">> = ({ id }) => {
|
||||
const { loading, error, logs } = useShortLogs(id)
|
||||
const { lastVisit, visits } = useVisitMetrics(logs)
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<p className="text-red-500 text-center font-medium">
|
||||
Failed to load stats
|
||||
</p>
|
||||
)
|
||||
}
|
||||
|
||||
const lastViewedISO = lastVisit
|
||||
? new Date(lastVisit).toISOString()
|
||||
: undefined
|
||||
const lastViewedRelative = getRelativeTimeString(lastVisit)
|
||||
|
||||
return (
|
||||
<>
|
||||
<p className="text-sm font-bold leading-6 text-gray-900">
|
||||
{loading && "Loading..."}
|
||||
{loading || (
|
||||
<>
|
||||
<span className="text-green-500">{visits}</span> views
|
||||
</>
|
||||
)}
|
||||
</p>
|
||||
<p className="mt-1 text-xs leading-5 text-slate-400 flex flex-col text-right">
|
||||
{loading && "Loading..."}
|
||||
{loading || (
|
||||
<>
|
||||
<span className="-mb-1">Last viewed:</span>
|
||||
<span className="font-medium">
|
||||
<time dateTime={lastViewedISO}>{lastViewedRelative}</time>
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</p>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const ShortItemStatsMemo = memo(
|
||||
ShortItemStats,
|
||||
(prev, next) => prev.id === next.id
|
||||
)
|
||||
|
||||
/**
|
||||
* Convert a date to a relative time string, such as
|
||||
* "a minute ago", "in 2 hours", "yesterday", "3 months ago", etc.
|
||||
* using Intl.RelativeTimeFormat
|
||||
* https://gist.github.com/LewisJEllis
|
||||
*/
|
||||
function getRelativeTimeString(date: Date | number | null): string {
|
||||
if (!date) return "never"
|
||||
// Allow dates or times to be passed
|
||||
const timeMs = typeof date === "number" ? date : date.getTime()
|
||||
|
||||
// Get the amount of seconds between the given date and now
|
||||
const deltaSeconds = Math.round((timeMs - Date.now()) / 1000)
|
||||
|
||||
// Array reprsenting one minute, hour, day, week, month, etc in seconds
|
||||
const cutoffs = [
|
||||
60,
|
||||
3600,
|
||||
86400,
|
||||
86400 * 7,
|
||||
86400 * 30,
|
||||
86400 * 365,
|
||||
Infinity,
|
||||
]
|
||||
|
||||
// Array equivalent to the above but in the string representation of the units
|
||||
const units: Intl.RelativeTimeFormatUnit[] = [
|
||||
"second",
|
||||
"minute",
|
||||
"hour",
|
||||
"day",
|
||||
"week",
|
||||
"month",
|
||||
"year",
|
||||
]
|
||||
|
||||
// Grab the ideal cutoff unit
|
||||
const unitIndex = cutoffs.findIndex(
|
||||
(cutoff) => cutoff > Math.abs(deltaSeconds)
|
||||
)
|
||||
|
||||
// Get the divisor to divide from the seconds. E.g. if our unit is "day" our divisor
|
||||
// is one day in seconds, so we can divide our seconds by this to get the # of days
|
||||
const divisor = unitIndex ? cutoffs[unitIndex - 1] : 1
|
||||
|
||||
// Intl.RelativeTimeFormat do its magic
|
||||
const rtf = new Intl.RelativeTimeFormat("en", {
|
||||
numeric: "auto",
|
||||
style: "narrow",
|
||||
})
|
||||
return rtf.format(Math.floor(deltaSeconds / divisor), units[unitIndex])
|
||||
}
|
||||
|
|
|
@ -6,7 +6,8 @@ type Item = Record<string, unknown>
|
|||
|
||||
export const useLoadedItems = <T extends Item>() => {
|
||||
const sourceDefault = useMemo(() => [], [])
|
||||
const items = (useLoaderData() ?? sourceDefault) as T[]
|
||||
const data = useLoaderData() ?? sourceDefault
|
||||
const items = useMemo<T[]>(() => (Array.isArray(data) ? data : []), [data])
|
||||
|
||||
return items
|
||||
}
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
import { useEffect, useMemo, useState } from "react"
|
||||
|
||||
import { ShortLog } from "../types"
|
||||
import fetchAPI from "../util/fetchAPI"
|
||||
|
||||
export const useShortLogs = (shortID: string) => {
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [error, setError] = useState("")
|
||||
const [logs, setLogs] = useState<ShortLog[]>(useMemo(() => [], []))
|
||||
|
||||
// Fetch logs on mount
|
||||
useEffect(() => {
|
||||
const fetchLogs = async () => {
|
||||
// Reset logs and error
|
||||
setLogs([])
|
||||
setError("")
|
||||
// Fetch logs
|
||||
const resp = await fetchAPI<ShortLog[]>(`/shorts/${shortID}/logs`)
|
||||
if (resp.ok) {
|
||||
// If response is ok, set logs
|
||||
if (resp.data.length > 0) {
|
||||
// Sort logs by date, descending
|
||||
resp.data.sort((a, b) => {
|
||||
return b.createdAt.localeCompare(a.createdAt)
|
||||
})
|
||||
// Set sorted logs
|
||||
setLogs(resp.data)
|
||||
}
|
||||
} else {
|
||||
// If response is not ok, set error
|
||||
setError(resp.error)
|
||||
}
|
||||
// Now we're done loading
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
if (shortID) {
|
||||
// If we have an ID, fetch logs
|
||||
setLoading(true)
|
||||
fetchLogs()
|
||||
}
|
||||
}, [shortID])
|
||||
|
||||
return { loading, error, logs }
|
||||
}
|
||||
|
||||
const singleFilterOptions = ["equal", "greater", "lower", "notEqual"] as const
|
||||
export type ShortLogSingleFilterOptions = (typeof singleFilterOptions)[number]
|
||||
|
||||
const listFilterOptions = ["in", "notIn"] as const
|
||||
export type ShortLogListFilterOptions = (typeof listFilterOptions)[number]
|
||||
|
||||
type ShortLogKeys = keyof ShortLog
|
||||
|
||||
type SingleFilter<T extends keyof ShortLog> = {
|
||||
[key in ShortLogSingleFilterOptions]?: ShortLog[T] | null
|
||||
}
|
||||
type ListFilter<T extends keyof ShortLog> = {
|
||||
[key in ShortLogListFilterOptions]?: (ShortLog[T] | null)[]
|
||||
}
|
||||
|
||||
export type ShortLogFilter = {
|
||||
[key in ShortLogKeys]?: SingleFilter<key> | ListFilter<key>
|
||||
}
|
||||
|
||||
export const useFilteredShortLogs = (
|
||||
logs: ShortLog[],
|
||||
filterList: ShortLogFilter[]
|
||||
) => {
|
||||
return useMemo(() => {
|
||||
const start = performance.now()
|
||||
// If there are no filters, return all logs
|
||||
if (filterList.length === 0) return logs
|
||||
// For each log
|
||||
const result = logs.filter((log) => {
|
||||
// If it matches any of the filters, return true
|
||||
return filterList.some((filterSet) => {
|
||||
// For each key in the filterSet
|
||||
return Object.entries(filterSet).every(([key, filterMap]) => {
|
||||
// If the key is not in the log, return false
|
||||
if (!(key in log)) return false
|
||||
// For each filter in the filterMap
|
||||
return Object.entries(filterMap).every(
|
||||
([filterType, filterValue]) => {
|
||||
const logValue = log[key as keyof ShortLog]
|
||||
// If filterType is single
|
||||
if (
|
||||
singleFilterOptions.includes(
|
||||
filterType as unknown as ShortLogSingleFilterOptions
|
||||
)
|
||||
) {
|
||||
switch (filterType) {
|
||||
case "equal":
|
||||
return logValue === filterValue
|
||||
case "greater":
|
||||
return logValue > filterValue
|
||||
case "lower":
|
||||
return logValue < filterValue
|
||||
case "notEqual":
|
||||
return logValue !== filterValue
|
||||
}
|
||||
} else if (
|
||||
listFilterOptions.includes(
|
||||
filterType as unknown as ShortLogListFilterOptions
|
||||
)
|
||||
) {
|
||||
// If filterType is list
|
||||
switch (filterType) {
|
||||
case "in":
|
||||
return filterValue.includes(logValue)
|
||||
case "notIn":
|
||||
return !filterValue.includes(logValue)
|
||||
}
|
||||
}
|
||||
// If we get here,
|
||||
// the filterType is not valid
|
||||
return false
|
||||
}
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
console.debug(
|
||||
`Filtered ${logs.length} logs in ${performance.now() - start}ms`
|
||||
)
|
||||
|
||||
return result
|
||||
}, [logs, filterList])
|
||||
}
|
||||
|
||||
export const useVisitMetrics = (logs: ShortLog[]) => {
|
||||
const [visits, setVisits] = useState(0)
|
||||
const [uniqueVisits, setUniqueVisits] = useState(0)
|
||||
const [lastVisit, setLastVisit] = useState<Date | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
// Reset visits
|
||||
setUniqueVisits(0)
|
||||
setLastVisit(null)
|
||||
|
||||
// Set number of visits
|
||||
setVisits(logs.length)
|
||||
|
||||
// Skip work if there are no logs
|
||||
if (logs.length === 0) return
|
||||
|
||||
// Keep track of unique IPs
|
||||
const uniqueIPs = new Set()
|
||||
|
||||
let last = logs[0].createdAt
|
||||
logs.forEach((log) => {
|
||||
// If the log is older, update
|
||||
if (log.createdAt.localeCompare(last) < 0) {
|
||||
last = log.createdAt
|
||||
}
|
||||
// Add IP to set
|
||||
uniqueIPs.add(log.ipAddress)
|
||||
})
|
||||
|
||||
// Set unique visits
|
||||
setUniqueVisits(uniqueIPs.size)
|
||||
|
||||
// Set last visit
|
||||
setLastVisit(new Date(last))
|
||||
}, [logs])
|
||||
|
||||
return { visits, uniqueVisits, lastVisit }
|
||||
}
|
|
@ -54,10 +54,10 @@ const useSubmitActions = ({
|
|||
useEffect(() => {
|
||||
if (onDelete && formData && actionData && actionData.ok) {
|
||||
const data = actionData.data
|
||||
const name = formData.get("name")
|
||||
if (typeof data === "string" && name && typeof name === "string") {
|
||||
const id = formData.get("id")
|
||||
if (typeof data === "string" && id && typeof id === "string") {
|
||||
// Deleting
|
||||
onDelete(name)
|
||||
onDelete(id)
|
||||
}
|
||||
}
|
||||
}, [formData, actionData, onDelete])
|
||||
|
@ -136,8 +136,8 @@ const useRecentShorts = () => {
|
|||
onCreate: useCallback((short: Short) => {
|
||||
setRecentShorts((prev) => [short, ...prev])
|
||||
}, []),
|
||||
onDelete: useCallback((name: string) => {
|
||||
setRecentShorts((prev) => prev.filter((short) => short.name !== name))
|
||||
onDelete: useCallback((id: string) => {
|
||||
setRecentShorts((prev) => prev.filter((short) => short.id !== id))
|
||||
}, []),
|
||||
})
|
||||
|
||||
|
@ -213,7 +213,7 @@ export const Component: FunctionComponent = () => {
|
|||
</Button>
|
||||
</form>
|
||||
<div className="mt-10">
|
||||
<ItemList items={recentShorts} Item={ShortItem} idKey="name" />
|
||||
<ItemList items={recentShorts} Item={ShortItem} idKey="id" />
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ export const loader: LoaderFunction = async (args) => {
|
|||
const resp = await protectedLoader(args)
|
||||
if (resp) return resp
|
||||
|
||||
const data = await fetchAPI<Short>(`/shorts/${args.params.name}`)
|
||||
const data = await fetchAPI<Short>(`/shorts/${args.params.id}`)
|
||||
if (data.ok) {
|
||||
return data.data
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { FunctionComponent, useCallback } from "react"
|
||||
import { useCallback } from "react"
|
||||
|
||||
import { LoaderFunction, redirect } from "react-router-dom"
|
||||
|
||||
|
@ -16,19 +16,10 @@ export function Component() {
|
|||
useCallback((a, b) => a.name.localeCompare(b.name), [])
|
||||
)
|
||||
|
||||
const Shorts: FunctionComponent = () => {
|
||||
return <ItemList items={items} Item={ShortItem} idKey="name" />
|
||||
}
|
||||
const NoShorts = () => {
|
||||
return (
|
||||
<div className="text-center pt-5 text-xl font-light">No shorts yet</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Header title="Shorts" />
|
||||
{items.length > 0 ? <Shorts /> : <NoShorts />}
|
||||
<ItemList items={items} Item={ShortItem} idKey="id" key="list" />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
@ -59,13 +50,13 @@ export const action = crudAction({
|
|||
})
|
||||
},
|
||||
DELETE: async (formData) => {
|
||||
const name = formData.get("name") as string | null
|
||||
const id = formData.get("id") as string | null
|
||||
|
||||
if (!name) {
|
||||
if (!id) {
|
||||
return { ok: false, error: "Invalid request" }
|
||||
}
|
||||
|
||||
return fetchAPI<Short>(`/shorts/${name}`, {
|
||||
return fetchAPI<Short>(`/shorts/${id}`, {
|
||||
method: "DELETE",
|
||||
})
|
||||
},
|
||||
|
|
|
@ -52,7 +52,7 @@ export default createBrowserRouter([
|
|||
},
|
||||
{
|
||||
id: "shortDetails",
|
||||
path: "/sht/:name",
|
||||
path: "/sht/:id",
|
||||
lazy: () => import("./pages/ShortDetails"),
|
||||
},
|
||||
{
|
||||
|
|
|
@ -2,16 +2,28 @@ export type GenericItem = Record<string, unknown>
|
|||
|
||||
export type User = {
|
||||
username: string
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
export type Short = {
|
||||
id: string
|
||||
name: string
|
||||
url: string
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
export type ShortLog = {
|
||||
id: string
|
||||
shortID: string
|
||||
ipAddress: string
|
||||
userAgent: string
|
||||
referer: string
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
export type Session = {
|
||||
id: string
|
||||
ip: string
|
||||
ipAddress: string
|
||||
userAgent: string
|
||||
lastActivity: string
|
||||
createdAt: string
|
||||
|
|
|
@ -83,12 +83,12 @@ type Config struct {
|
|||
func Validate(cfg *Config) error {
|
||||
// Host and port have to be valid.
|
||||
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.Port)); err != nil {
|
||||
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid host and/or port: %s", err))
|
||||
return errs.Errorf(fmt.Sprintf("invalid host and/or port: %s", err), errs.ErrInvalidConfig)
|
||||
}
|
||||
// UI port has to be valid.
|
||||
if cfg.UIPort != "" {
|
||||
if _, err := url.ParseRequestURI("http://" + net.JoinHostPort(cfg.Host, cfg.UIPort)); err != nil {
|
||||
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid UI port: %s", err))
|
||||
return errs.Errorf(fmt.Sprintf("invalid UI port: %s", err), errs.ErrInvalidConfig)
|
||||
}
|
||||
}
|
||||
if cfg.DBType != "" {
|
||||
|
@ -101,7 +101,7 @@ func Validate(cfg *Config) error {
|
|||
}
|
||||
}
|
||||
if !valid {
|
||||
return errs.Error(errs.ErrInvalidConfig, fmt.Sprintf("invalid database type: %s", cfg.DBType))
|
||||
return errs.Errorf(fmt.Sprintf("invalid database type: %s", cfg.DBType), errs.ErrInvalidConfig)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,8 +40,10 @@ var (
|
|||
ErrInvalidTokenID = errors.New("invalid token ID")
|
||||
// ErrInvalidTokenName
|
||||
ErrInvalidTokenName = errors.New("invalid token name")
|
||||
// ErrDatabaseError
|
||||
ErrDatabaseError = errors.New("database error")
|
||||
)
|
||||
|
||||
func Error(err error, msg string) error {
|
||||
return fmt.Errorf("%w: %s", err, msg)
|
||||
func Errorf(msg string, err error) error {
|
||||
return fmt.Errorf("%s: %w", msg, err)
|
||||
}
|
||||
|
|
|
@ -194,7 +194,7 @@ func (h *APIHandler) CreateShort(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Set the user
|
||||
short.User = user
|
||||
short.UserID = &user.ID
|
||||
|
||||
// Shorten URL
|
||||
newShort, err := h.shorts.Shorten(ctx, short)
|
||||
|
@ -283,6 +283,28 @@ func (h *APIHandler) DeleteShort(w http.ResponseWriter, r *http.Request) {
|
|||
render.NoContent(w, r)
|
||||
}
|
||||
|
||||
func (h *APIHandler) ListShortLogs(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Find own short or respond
|
||||
short, ok := h.findShortOrRespond(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Get logs
|
||||
logs, err := h.shorts.ListLogs(ctx, short)
|
||||
if err != nil {
|
||||
server.RenderServerError(w, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Render the response
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(w, r, logs)
|
||||
}
|
||||
|
||||
func (h *APIHandler) ListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
|
@ -451,11 +473,11 @@ func (h *APIHandler) findUserOrRespond(w http.ResponseWriter, r *http.Request) (
|
|||
func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request) (short *models.Short, ok bool) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get short name from request
|
||||
name := chi.URLParam(r, "short")
|
||||
// Get short id from request
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
// Find short in storage
|
||||
short, err := h.shorts.FindShort(ctx, name)
|
||||
short, err := h.shorts.FindShortByID(ctx, id)
|
||||
if err != nil {
|
||||
// If the short doesn't exist or is invalid, return not found
|
||||
if errors.Is(err, errs.ErrShortDoesNotExist) || errors.Is(err, errs.ErrInvalidShort) {
|
||||
|
@ -475,7 +497,7 @@ func (h *APIHandler) findShortOrRespond(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
// If the session user does not match the short's user,
|
||||
// return forbidden.
|
||||
if user.Username != short.User.Username {
|
||||
if user.ID != *short.UserID {
|
||||
server.RenderForbidden(w, r)
|
||||
|
||||
return nil, false
|
||||
|
@ -514,7 +536,7 @@ func (h *APIHandler) findTokenOrRespond(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
// If the session user does not match the token's user,
|
||||
// return NotFound.
|
||||
if user.Username != token.User.Username {
|
||||
if user.ID != *token.UserID {
|
||||
server.RenderNotFound(w, r)
|
||||
|
||||
return nil, false
|
||||
|
|
|
@ -26,8 +26,9 @@ func NewAPIRouter(h *APIHandler) http.Handler {
|
|||
// Shorts routes
|
||||
r.Get("/shorts", h.ListShorts)
|
||||
r.Post("/shorts", h.CreateShort)
|
||||
r.Get("/shorts/{short}", h.FindShort)
|
||||
r.Delete("/shorts/{short}", h.DeleteShort)
|
||||
r.Get("/shorts/{id}", h.FindShort)
|
||||
r.Delete("/shorts/{id}", h.DeleteShort)
|
||||
r.Get("/shorts/{id}/logs", h.ListShortLogs)
|
||||
|
||||
// Sessions routes
|
||||
r.Get("/sessions", h.ListSessions)
|
||||
|
|
|
@ -15,11 +15,11 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
sessionUserKey = "user"
|
||||
sessionUserKey = "u"
|
||||
sessionIPKey = "ip"
|
||||
sessionUserAgentKey = "user_agent"
|
||||
sessionLastActivityKey = "last_activity"
|
||||
sessionCreatedAtKey = "created_at"
|
||||
sessionUserAgentKey = "ua"
|
||||
sessionLastActivityKey = "la"
|
||||
sessionCreatedAtKey = "ca"
|
||||
tokenHeader = "Authorization"
|
||||
)
|
||||
|
||||
|
@ -28,8 +28,8 @@ type userContextKey struct{}
|
|||
type AuthSessionData struct {
|
||||
// Username is the username of the user to whom the session belongs.
|
||||
Username string `json:"username"`
|
||||
// IP is the last IP address used by with the session.
|
||||
IP string `json:"ip"`
|
||||
// IPAddress is the last IP address used by with the session.
|
||||
IPAddress string `json:"ipAddress"`
|
||||
// UserAgent is the last User-Agent used with the session.
|
||||
UserAgent string `json:"userAgent"`
|
||||
// LastActivity is the last time the session was used.
|
||||
|
@ -97,7 +97,7 @@ func UserFromCtx(ctx context.Context) (*models.User, bool) {
|
|||
func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context) *AuthSessionData {
|
||||
// Get data from session
|
||||
username := manager.GetString(sessionCtx, sessionUserKey)
|
||||
ip := manager.GetString(sessionCtx, sessionIPKey)
|
||||
ipAddress := manager.GetString(sessionCtx, sessionIPKey)
|
||||
userAgent := manager.GetString(sessionCtx, sessionUserAgentKey)
|
||||
lastActivity, err := time.Parse(time.RFC3339, manager.GetString(sessionCtx, sessionLastActivityKey))
|
||||
if err != nil {
|
||||
|
@ -111,7 +111,7 @@ func SessionDataFromCtx(manager *scs.SessionManager, sessionCtx context.Context)
|
|||
// Create new session data
|
||||
sessionData := &AuthSessionData{
|
||||
Username: username,
|
||||
IP: ip,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LastActivity: lastActivity,
|
||||
CreatedAt: createdAt,
|
||||
|
@ -124,12 +124,12 @@ func UpdateSession(ctx context.Context, r *http.Request) {
|
|||
manager := SessionManagerFromCtx(ctx)
|
||||
|
||||
// Get data from request
|
||||
ip := r.RemoteAddr
|
||||
ipAddress := r.RemoteAddr
|
||||
userAgent := r.UserAgent()
|
||||
lastActivity := time.Now().Format(time.RFC3339)
|
||||
|
||||
// Update session
|
||||
manager.Put(ctx, sessionIPKey, ip)
|
||||
manager.Put(ctx, sessionIPKey, ipAddress)
|
||||
manager.Put(ctx, sessionUserAgentKey, userAgent)
|
||||
manager.Put(ctx, sessionLastActivityKey, lastActivity)
|
||||
}
|
||||
|
@ -298,14 +298,11 @@ func authenticateViaToken(r *http.Request, tokenService *tokenservice.TokenServi
|
|||
return nil, errs.ErrTokenMissing
|
||||
}
|
||||
|
||||
// Get token from storage
|
||||
token, err := tokenService.FindToken(ctx, value)
|
||||
// Get the token's user from storage
|
||||
user, err = tokenService.FindTokenUserFromValue(ctx, value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error authenticating via token: %w", err)
|
||||
return nil, errs.Errorf("could not authenticate user", err)
|
||||
}
|
||||
|
||||
// Get user from token
|
||||
user = token.User
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
|
|
@ -26,12 +26,13 @@ func NewServer(cfg *config.Config) *Server {
|
|||
mux := chi.NewRouter()
|
||||
|
||||
// Register default middlewares
|
||||
mux.Use(middleware.Recoverer)
|
||||
mux.Use(middleware.RequestID)
|
||||
mux.Use(middleware.RealIP)
|
||||
mux.Use(middleware.Logger)
|
||||
mux.Use(middleware.Recoverer)
|
||||
mux.Use(servermiddleware.SessionManager(cfg))
|
||||
mux.Use(middleware.Timeout(config.RequestTimeout))
|
||||
mux.Use(middleware.Compress(5, "application/json"))
|
||||
|
||||
// Create the server
|
||||
srv := &http.Server{
|
||||
|
|
|
@ -2,6 +2,7 @@ package shortserver
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
|
@ -32,7 +33,13 @@ func (h *ShortHandler) FindShort(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
switch {
|
||||
case err == nil:
|
||||
// If there's no error, redirect to the URL
|
||||
// If there's no error, log the access and redirect to the URL
|
||||
err = h.service.LogShortAccess(ctx, short, r)
|
||||
if err != nil {
|
||||
// If there was an error logging the access, print a message and
|
||||
// continue.
|
||||
fmt.Printf("failed to log short access: %v\n", err)
|
||||
}
|
||||
http.Redirect(w, r, short.URL, http.StatusSeeOther)
|
||||
case errors.Is(err, errs.ErrInvalidShort):
|
||||
// If the short name is invalid, do nothing and let the static handler
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
|
||||
|
@ -20,6 +21,8 @@ const (
|
|||
MinShortLength = 4
|
||||
// MaxShortLength is the maximum length of the short URL.
|
||||
MaxShortLength = 20
|
||||
// ShortIDLength is the length of the short ID.
|
||||
ShortIDLength = 16
|
||||
)
|
||||
|
||||
type ShortService struct {
|
||||
|
@ -30,6 +33,33 @@ func NewShortService(db storage.Storage) *ShortService {
|
|||
return &ShortService{db: db}
|
||||
}
|
||||
|
||||
func (s *ShortService) LogShortAccess(ctx context.Context, short *models.Short, r *http.Request) error {
|
||||
// Log the access
|
||||
shortLog := &models.ShortLog{
|
||||
ShortID: short.ID,
|
||||
IPAddress: r.RemoteAddr,
|
||||
UserAgent: r.UserAgent(),
|
||||
Referer: r.Referer(),
|
||||
}
|
||||
|
||||
err := s.db.CreateShortLog(ctx, shortLog)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to log short access: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ShortService) ListLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
|
||||
// Get the logs from storage
|
||||
logs, err := s.db.ListShortLogs(ctx, short)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get short logs from storage: %w", err)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Short, error) {
|
||||
// Check if the short is valid
|
||||
err := ShortNameIsValid(name)
|
||||
|
@ -37,11 +67,26 @@ func (s *ShortService) FindShort(ctx context.Context, name string) (*models.Shor
|
|||
return &models.Short{}, fmt.Errorf("could not validate short: %w", err)
|
||||
}
|
||||
|
||||
// Get the URL from storage
|
||||
// Get the short from storage
|
||||
short, err := s.db.FindShort(ctx, name)
|
||||
|
||||
if err != nil {
|
||||
return short, fmt.Errorf("could not get URL from storage: %w", err)
|
||||
return short, fmt.Errorf("could not get short from storage: %w", err)
|
||||
}
|
||||
|
||||
return short, nil
|
||||
}
|
||||
|
||||
func (s *ShortService) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
|
||||
// Check if the ID is valid
|
||||
if len(id) != ShortIDLength {
|
||||
return &models.Short{}, errs.ErrInvalidShort
|
||||
}
|
||||
|
||||
// Get the short from storage
|
||||
short, err := s.db.FindShortByID(ctx, id)
|
||||
if err != nil {
|
||||
return short, fmt.Errorf("could not get short from storage: %w", err)
|
||||
}
|
||||
|
||||
return short, nil
|
||||
|
@ -116,13 +161,13 @@ var shortRegex = regexp.MustCompile(fmt.Sprintf("^%s$", shortPattern))
|
|||
func ShortNameIsValid(name string) error {
|
||||
if !shortRegex.MatchString(name) {
|
||||
|
||||
return errs.Error(
|
||||
errs.ErrInvalidShort,
|
||||
return errs.Errorf(
|
||||
fmt.Sprintf(
|
||||
"short must use only letters, numbers, underscores and dashes, and be between %d and %d characters long",
|
||||
MinShortLength,
|
||||
MaxShortLength,
|
||||
),
|
||||
errs.ErrInvalidShort,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -132,11 +177,11 @@ func ShortNameIsValid(name string) error {
|
|||
func ShortURLIsValid(shortURL string) error {
|
||||
parsedURL, err := url.ParseRequestURI(shortURL)
|
||||
if err != nil {
|
||||
return errs.Error(errs.ErrInvalidShort, "invalid URL")
|
||||
return errs.Errorf("invalid URL", errs.ErrInvalidShort)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return errs.Error(errs.ErrInvalidShort, "invalid URL scheme")
|
||||
return errs.Errorf("invalid URL scheme", errs.ErrInvalidShort)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
shortutil "git.maronato.dev/maronato/goshort/internal/util/short"
|
||||
tokenutil "git.maronato.dev/maronato/goshort/internal/util/token"
|
||||
)
|
||||
|
||||
|
@ -60,7 +59,7 @@ func (s *TokenService) FindToken(ctx context.Context, value string) (*models.Tok
|
|||
// FindTokenByID finds a token in the storage using its ID.
|
||||
func (s *TokenService) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
|
||||
// Check if the ID is valid
|
||||
if len(id) != TokenIDLength {
|
||||
if len(id) != models.IDLength {
|
||||
return &models.Token{}, errs.ErrInvalidTokenID
|
||||
}
|
||||
|
||||
|
@ -86,12 +85,10 @@ func (s *TokenService) ListTokens(ctx context.Context, user *models.User) ([]*mo
|
|||
// CreateToken creates a new token for a user.
|
||||
func (s *TokenService) CreateToken(ctx context.Context, user *models.User) (*models.Token, error) {
|
||||
// Generate a new token
|
||||
id := shortutil.GenerateRandomShort(TokenIDLength)
|
||||
token := &models.Token{
|
||||
ID: id,
|
||||
Name: fmt.Sprintf("%s's token #%s", user.Username, id[:5]),
|
||||
Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2),
|
||||
User: user,
|
||||
Name: fmt.Sprintf("%s's token", user.Username),
|
||||
Value: TokenPrefix + tokenutil.GenerateSecureToken(TokenLength/2),
|
||||
UserID: &user.ID,
|
||||
}
|
||||
|
||||
// Create the token in storage
|
||||
|
@ -128,3 +125,20 @@ func (s *TokenService) ChangeTokenName(ctx context.Context, token *models.Token,
|
|||
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
// FindTokenUserFromValue finds the user that owns a token.
|
||||
func (s *TokenService) FindTokenUserFromValue(ctx context.Context, value string) (*models.User, error) {
|
||||
// Get the token from storage
|
||||
token, err := s.FindToken(ctx, value)
|
||||
if err != nil {
|
||||
return &models.User{}, errs.Errorf("could not get token from storage", err)
|
||||
}
|
||||
|
||||
// Get the user from storage
|
||||
user, err := s.db.FindUserByID(ctx, *token.UserID)
|
||||
if err != nil {
|
||||
return &models.User{}, errs.Errorf("could not get token user from storage", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
|
|
@ -131,7 +131,7 @@ func (s *UserService) SetPassword(ctx context.Context, user *models.User, newPas
|
|||
|
||||
func UsernameIsValid(username string) error {
|
||||
if !UsernameRegex.MatchString(username) {
|
||||
return errs.Error(errs.ErrInvalidUser, fmt.Sprintf("username must match %s", UsernameRegex))
|
||||
return errs.Errorf(fmt.Sprintf("username must match %s", UsernameRegex), errs.ErrInvalidUser)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
package bunstorage
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrDatabaseError is the error returned when a generic database error occurs.
|
||||
ErrDatabaseError = errors.New("database error")
|
||||
)
|
|
@ -7,42 +7,11 @@ import (
|
|||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type ShortModel struct {
|
||||
bun.BaseModel `bun:"table:shorts,alias:s" json:"-"`
|
||||
|
||||
ID int64 `bun:",pk,autoincrement" json:"id"`
|
||||
// Name is the short's name, which is also the path to access it
|
||||
Name string `bun:",unique,notnull" json:"name"`
|
||||
// URL is the URL that the short will redirect to
|
||||
URL string `bun:",notnull" json:"url"`
|
||||
// Deleted is whether the short is soft-deleted or not
|
||||
Deleted bool `bun:",notnull,default:false" json:"-"`
|
||||
// CreatedAt is when the short was created (initialized by the storage)
|
||||
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
|
||||
// DeletedAt is when the short was deleted
|
||||
DeletedAt time.Time `bun:",null" json:"-"`
|
||||
|
||||
// UserID is the ID of the user that created the short
|
||||
// This can be null if the short was deleted
|
||||
UserID *int64 `json:"-"`
|
||||
// User is the user that created the short
|
||||
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShortModel) toShort() *models.Short {
|
||||
return &models.Short{
|
||||
Name: s.Name,
|
||||
URL: s.URL,
|
||||
CreatedAt: s.CreatedAt,
|
||||
User: s.User.toUser(),
|
||||
}
|
||||
}
|
||||
|
||||
type UserModel struct {
|
||||
bun.BaseModel `bun:"table:users,alias:u"`
|
||||
|
||||
// ID is the primary key
|
||||
ID int64 `bun:",pk,autoincrement" json:"id"`
|
||||
ID string `bun:",pk,unique,notnull" json:"id"`
|
||||
// Username is the user's username
|
||||
Username string `bun:",unique,notnull" json:"username"`
|
||||
// Password is the user's password
|
||||
|
@ -58,27 +27,60 @@ type UserModel struct {
|
|||
|
||||
func (u *UserModel) toUser() *models.User {
|
||||
return models.NewAuthenticatableUser(&models.User{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
CreatedAt: u.CreatedAt,
|
||||
}, u.Password)
|
||||
}
|
||||
|
||||
type ShortModel struct {
|
||||
bun.BaseModel `bun:"table:shorts,alias:s" json:"-"`
|
||||
|
||||
ID string `bun:",pk,unique,notnull" json:"id"`
|
||||
// Name is the short's name, which is also the path to access it
|
||||
Name string `bun:",unique,notnull" json:"name"`
|
||||
// URL is the URL that the short will redirect to
|
||||
URL string `bun:",notnull" json:"url"`
|
||||
// Deleted is whether the short is soft-deleted or not
|
||||
Deleted bool `bun:",notnull,default:false" json:"-"`
|
||||
// CreatedAt is when the short was created (initialized by the storage)
|
||||
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
|
||||
// DeletedAt is when the short was deleted
|
||||
DeletedAt time.Time `bun:",null" json:"-"`
|
||||
|
||||
// UserID is the ID of the user that created the short
|
||||
// This can be null if the short was deleted
|
||||
UserID *string `json:"-"`
|
||||
// User is the user that created the short
|
||||
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShortModel) toShort() *models.Short {
|
||||
return &models.Short{
|
||||
ID: s.ID,
|
||||
Name: s.Name,
|
||||
URL: s.URL,
|
||||
CreatedAt: s.CreatedAt,
|
||||
UserID: s.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
type TokenModel struct {
|
||||
bun.BaseModel `bun:"table:tokens,alias:t"`
|
||||
|
||||
// ID is the primary key
|
||||
ID string `bun:",pk" json:"id"`
|
||||
ID string `bun:",pk,unique,notnull" json:"id"`
|
||||
// Name is the user-friendly name of the token
|
||||
Name string `bun:",notnull" json:"name"`
|
||||
// Value is the actual token
|
||||
Value string `bun:",unique,notnull" json:"value"`
|
||||
// UserID is the ID of the user that created the token
|
||||
UserID *int64 `bun:",notnull" json:"-"`
|
||||
// User is the user that created the token
|
||||
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
|
||||
|
||||
// CreatedAt is when the token was created (initialized by the storage)
|
||||
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
|
||||
|
||||
// UserID is the ID of the user that created the token
|
||||
UserID *string `bun:",notnull" json:"-"`
|
||||
// User is the user that created the token
|
||||
User *UserModel `bun:"rel:belongs-to,join:user_id=id" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (t *TokenModel) toToken() *models.Token {
|
||||
|
@ -87,6 +89,37 @@ func (t *TokenModel) toToken() *models.Token {
|
|||
Name: t.Name,
|
||||
Value: t.Value,
|
||||
CreatedAt: t.CreatedAt,
|
||||
User: t.User.toUser(),
|
||||
UserID: t.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
type ShortLogModel struct {
|
||||
bun.BaseModel `bun:"table:short_logs,alias:sl"`
|
||||
|
||||
// ID is the primary key
|
||||
ID string `bun:",pk,unique,notnull" json:"id"`
|
||||
// IPAddress is the IP address of the client that accessed the short
|
||||
IPAddress string `bun:"," json:"ipAddress"`
|
||||
// UserAgent is the User-Agent of the client that accessed the short
|
||||
UserAgent string `bun:"," json:"userAgent"`
|
||||
// Referer is the referer of the client that accessed the short
|
||||
Referer string `bun:"," json:"referer"`
|
||||
// CreatedAt is when the short was accessed
|
||||
CreatedAt time.Time `bun:",notnull,default:current_timestamp" json:"createdAt"`
|
||||
|
||||
// ShortID is the ID of the short that was accessed
|
||||
ShortID string `bun:",notnull" json:"-"`
|
||||
// Short is the short that was accessed
|
||||
Short *ShortModel `bun:"rel:belongs-to,join:short_id=id" json:"short,omitempty"`
|
||||
}
|
||||
|
||||
func (sl *ShortLogModel) toShortLog() *models.ShortLog {
|
||||
return &models.ShortLog{
|
||||
ID: sl.ID,
|
||||
IPAddress: sl.IPAddress,
|
||||
UserAgent: sl.UserAgent,
|
||||
Referer: sl.Referer,
|
||||
CreatedAt: sl.CreatedAt,
|
||||
ShortID: sl.ShortID,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,12 +3,10 @@ package bunstorage
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
|
@ -16,11 +14,14 @@ import (
|
|||
)
|
||||
|
||||
type BunStorage struct {
|
||||
storage.Storage
|
||||
db *bun.DB
|
||||
// StartHooks are ran after the database is connected and the tables are created.
|
||||
StartHooks []func(ctx context.Context, db *bun.DB) error
|
||||
// StopHooks are ran before the database is closed.
|
||||
StopHook []func(ctx context.Context, db *bun.DB) error
|
||||
}
|
||||
|
||||
func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage {
|
||||
func NewBunStorage(cfg *config.Config, db *bun.DB) *BunStorage {
|
||||
if cfg.Debug {
|
||||
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
|
||||
}
|
||||
|
@ -30,8 +31,18 @@ func NewBunStorage(cfg *config.Config, db *bun.DB) storage.Storage {
|
|||
}
|
||||
}
|
||||
|
||||
// RegisterStartHook registers a hook to be ran after the database is connected and the tables are created.
|
||||
func (s *BunStorage) RegisterStartHook(hook func(ctx context.Context, db *bun.DB) error) {
|
||||
s.StartHooks = append(s.StartHooks, hook)
|
||||
}
|
||||
|
||||
// RegisterStopHook registers a hook to be ran before the database is closed.
|
||||
func (s *BunStorage) RegisterStopHook(hook func(ctx context.Context, db *bun.DB) error) {
|
||||
s.StopHook = append(s.StopHook, hook)
|
||||
}
|
||||
|
||||
func (s *BunStorage) Start(ctx context.Context) error {
|
||||
return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
err := s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
|
||||
_, err := tx.NewCreateTable().
|
||||
IfNotExists().
|
||||
|
@ -39,7 +50,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
|
|||
WithForeignKeys().
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create users table: %w", err)
|
||||
return errs.Errorf("failed to create users table", err)
|
||||
}
|
||||
|
||||
_, err = tx.NewCreateTable().
|
||||
|
@ -48,7 +59,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
|
|||
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE SET NULL`).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create shorts table: %w", err)
|
||||
return errs.Errorf("failed to create shorts table", err)
|
||||
}
|
||||
|
||||
_, err = tx.NewCreateTable().
|
||||
|
@ -56,9 +67,17 @@ func (s *BunStorage) Start(ctx context.Context) error {
|
|||
Model((*TokenModel)(nil)).
|
||||
ForeignKey(`("user_id") REFERENCES "users" ("id") ON DELETE CASCADE`).
|
||||
Exec(ctx)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tokens table: %w", err)
|
||||
return errs.Errorf("failed to create tokens table", err)
|
||||
}
|
||||
|
||||
_, err = tx.NewCreateTable().
|
||||
IfNotExists().
|
||||
Model((*ShortLogModel)(nil)).
|
||||
ForeignKey(`("short_id") REFERENCES "shorts" ("id") ON DELETE CASCADE`).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to create short logs table", err)
|
||||
}
|
||||
|
||||
// shorts user_id index
|
||||
|
@ -69,7 +88,7 @@ func (s *BunStorage) Start(ctx context.Context) error {
|
|||
Column("user_id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create shorts user_id index: %w", err)
|
||||
return errs.Errorf("failed to create shorts user_id index", err)
|
||||
}
|
||||
|
||||
// tokens user_id index
|
||||
|
@ -80,14 +99,47 @@ func (s *BunStorage) Start(ctx context.Context) error {
|
|||
Column("user_id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tokens user_id index: %w", err)
|
||||
return errs.Errorf("failed to create tokens user_id index", err)
|
||||
}
|
||||
|
||||
// short_logs short_id index
|
||||
_, err = tx.NewCreateIndex().
|
||||
IfNotExists().
|
||||
Model((*ShortLogModel)(nil)).
|
||||
Index("idx_short_logs_short_id").
|
||||
Column("short_id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to create short_logs short_id index", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to start storage", err)
|
||||
}
|
||||
|
||||
// Run start hooks
|
||||
for _, hook := range s.StartHooks {
|
||||
err := hook(ctx, s.db)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to start storage", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) Stop(ctx context.Context) error {
|
||||
// Run stop hooks
|
||||
for _, hook := range s.StopHook {
|
||||
err := hook(ctx, s.db)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to stop storage", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
|
@ -97,71 +149,95 @@ func (s *BunStorage) FindShort(ctx context.Context, name string) (*models.Short,
|
|||
err := s.db.NewSelect().
|
||||
Model(shortModel).
|
||||
Where("name = ? and deleted = false", name).
|
||||
Relation("User").
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errs.ErrShortDoesNotExist
|
||||
err = errs.ErrShortDoesNotExist
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find short: %w", err)
|
||||
|
||||
return nil, errs.Errorf("failed to find short", err)
|
||||
}
|
||||
|
||||
return shortModel.toShort(), nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
|
||||
shortModel := new(ShortModel)
|
||||
err := s.db.NewSelect().
|
||||
Model(shortModel).
|
||||
Where("id = ? and deleted = false", id).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = errs.ErrShortDoesNotExist
|
||||
}
|
||||
|
||||
return nil, errs.Errorf("failed to find short", err)
|
||||
}
|
||||
|
||||
return shortModel.toShort(), nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
|
||||
// Get user from username
|
||||
user, err := findUser(ctx, s.db, short.User.Username)
|
||||
// Generate an ID that does not exist
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*ShortModel)(nil)).
|
||||
GetTableName()
|
||||
newID, err := createNewID(ctx, s.db, tableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Errorf("failed to create short", err)
|
||||
}
|
||||
|
||||
shortModel := &ShortModel{
|
||||
ID: newID,
|
||||
Name: short.Name,
|
||||
URL: short.URL,
|
||||
UserID: &user.ID,
|
||||
User: user,
|
||||
UserID: short.UserID,
|
||||
}
|
||||
|
||||
_, err = s.db.NewInsert().
|
||||
Model(shortModel).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errs.ErrShortExists
|
||||
return nil, errs.Errorf("failed to create short", err)
|
||||
}
|
||||
|
||||
return shortModel.toShort(), nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) DeleteShort(ctx context.Context, short *models.Short) error {
|
||||
_, err := withShortDeleteUpdates(
|
||||
s.db.NewUpdate().
|
||||
Model((*ShortModel)(nil)).
|
||||
Where("name = ?", short.Name),
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete short: %w", err)
|
||||
}
|
||||
return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Delete short logs
|
||||
_, err := tx.NewDelete().
|
||||
Model((*ShortLogModel)(nil)).
|
||||
Where("short_id = ?", short.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to delete short", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// Delete short
|
||||
_, err = withShortDeleteUpdates(
|
||||
tx.NewUpdate().
|
||||
Model((*ShortModel)(nil)).
|
||||
Where("id = ?", short.ID),
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to delete short", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
|
||||
|
||||
// Get user ID from username
|
||||
userID, err := findUserIDFromUsername(ctx, s.db, user.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shortModels := []*ShortModel{}
|
||||
err = s.db.NewSelect().
|
||||
err := s.db.NewSelect().
|
||||
Model(&shortModels).
|
||||
Where("user_id = ? and deleted = false", userID).
|
||||
Relation("User").
|
||||
Where("user_id = ? and deleted = false", user.ID).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list shorts: %w", err)
|
||||
return nil, errs.Errorf("failed to list shorts", err)
|
||||
}
|
||||
|
||||
shorts := []*models.Short{}
|
||||
|
@ -171,6 +247,51 @@ func (s *BunStorage) ListShorts(ctx context.Context, user *models.User) ([]*mode
|
|||
return shorts, nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
|
||||
// Create new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*ShortLogModel)(nil)).
|
||||
GetTableName()
|
||||
newID, err := createNewID(ctx, s.db, tableName, "id")
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to create short log", err)
|
||||
}
|
||||
|
||||
shortLogModel := &ShortLogModel{
|
||||
ID: newID,
|
||||
ShortID: shortLog.ShortID,
|
||||
IPAddress: shortLog.IPAddress,
|
||||
UserAgent: shortLog.UserAgent,
|
||||
Referer: shortLog.Referer,
|
||||
}
|
||||
|
||||
_, err = s.db.NewInsert().
|
||||
Model(shortLogModel).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to create short log", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
|
||||
shortLogModels := []*ShortLogModel{}
|
||||
err := s.db.NewSelect().
|
||||
Model(&shortLogModels).
|
||||
Where("short_id = ?", short.ID).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to list short logs", err)
|
||||
}
|
||||
|
||||
shortLogs := []*models.ShortLog{}
|
||||
for _, shortLogModel := range shortLogModels {
|
||||
shortLogs = append(shortLogs, shortLogModel.toShortLog())
|
||||
}
|
||||
return shortLogs, nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
|
||||
user, err := findUser(ctx, s.db, username)
|
||||
if err != nil {
|
||||
|
@ -180,21 +301,39 @@ func (s *BunStorage) FindUser(ctx context.Context, username string) (*models.Use
|
|||
return user.toUser(), nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
|
||||
user, err := findUserByID(ctx, s.db, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user.toUser(), nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
|
||||
// Create a new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*UserModel)(nil)).
|
||||
GetTableName()
|
||||
newID, err := createNewID(ctx, s.db, tableName, "id")
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to create user", err)
|
||||
}
|
||||
|
||||
userModel := &UserModel{
|
||||
ID: newID,
|
||||
Username: user.Username,
|
||||
Password: user.GetPasswordHash(),
|
||||
}
|
||||
|
||||
_, err := s.db.NewInsert().
|
||||
_, err = s.db.NewInsert().
|
||||
Model(userModel).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errs.ErrUserExists
|
||||
return nil, errs.Errorf("failed to create user", err)
|
||||
}
|
||||
|
||||
return userModel.toUser(), nil
|
||||
|
||||
}
|
||||
|
||||
func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
|
||||
|
@ -203,15 +342,15 @@ func (s *BunStorage) DeleteUser(ctx context.Context, user *models.User) error {
|
|||
// Delete user shorts
|
||||
err := deleteUserShorts(ctx, tx, user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
return errs.Errorf("failed to delete user", err)
|
||||
}
|
||||
// Delete user
|
||||
_, err = tx.NewDelete().
|
||||
Model(user).
|
||||
Where("username = ?", user.Username).
|
||||
Where("id = ?", user.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
return errs.Errorf("failed to delete user", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -224,13 +363,13 @@ func (s *BunStorage) FindToken(ctx context.Context, value string) (*models.Token
|
|||
err := s.db.NewSelect().
|
||||
Model(tokenModel).
|
||||
Where("value = ?", value).
|
||||
Relation("User").
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errs.ErrTokenDoesNotExist
|
||||
err = errs.ErrTokenDoesNotExist
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find token: %w", err)
|
||||
|
||||
return nil, errs.Errorf("failed to find token", err)
|
||||
}
|
||||
|
||||
return tokenModel.toToken(), nil
|
||||
|
@ -242,14 +381,14 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke
|
|||
|
||||
err := s.db.NewSelect().
|
||||
Model(tokenModel).
|
||||
Where("t.id = ?", id).
|
||||
Relation("User").
|
||||
Where("id = ?", id).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errs.ErrTokenDoesNotExist
|
||||
err = errs.ErrTokenDoesNotExist
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find token: %w", err)
|
||||
|
||||
return nil, errs.Errorf("failed to find token", err)
|
||||
}
|
||||
|
||||
return tokenModel.toToken(), nil
|
||||
|
@ -259,19 +398,12 @@ func (s *BunStorage) FindTokenByID(ctx context.Context, id string) (*models.Toke
|
|||
func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
|
||||
tokenModels := []*TokenModel{}
|
||||
|
||||
// Get user ID from username
|
||||
userID, err := findUserIDFromUsername(ctx, s.db, user.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.db.NewSelect().
|
||||
err := s.db.NewSelect().
|
||||
Model(&tokenModels).
|
||||
Where("user_id = ?", userID).
|
||||
Relation("User").
|
||||
Where("user_id = ?", user.ID).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list tokens: %w", err)
|
||||
return nil, errs.Errorf("failed to list tokens", err)
|
||||
}
|
||||
|
||||
tokens := []*models.Token{}
|
||||
|
@ -283,26 +415,27 @@ func (s *BunStorage) ListTokens(ctx context.Context, user *models.User) ([]*mode
|
|||
}
|
||||
|
||||
func (s *BunStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
|
||||
|
||||
// Get user ID from username
|
||||
user, err := findUser(ctx, s.db, token.User.Username)
|
||||
// Create a new ID
|
||||
tableName := s.db.NewSelect().
|
||||
Model((*TokenModel)(nil)).
|
||||
GetTableName()
|
||||
newID, err := createNewID(ctx, s.db, tableName, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.Errorf("failed to create token", err)
|
||||
}
|
||||
|
||||
tokenModel := &TokenModel{
|
||||
ID: token.ID,
|
||||
ID: newID,
|
||||
Name: token.Name,
|
||||
Value: token.Value,
|
||||
UserID: &user.ID,
|
||||
User: user,
|
||||
UserID: token.UserID,
|
||||
}
|
||||
|
||||
_, err = s.db.NewInsert().
|
||||
Model(tokenModel).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errs.ErrTokenExists
|
||||
return nil, errs.Errorf("failed to create token", err)
|
||||
}
|
||||
|
||||
return tokenModel.toToken(), nil
|
||||
|
@ -314,12 +447,27 @@ func (s *BunStorage) DeleteToken(ctx context.Context, token *models.Token) error
|
|||
Where("id = ?", token.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete token: %w", err)
|
||||
return errs.Errorf("failed to delete token", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BunStorage) ChangeTokenName(ctx context.Context, token *models.Token, name string) (*models.Token, error) {
|
||||
newToken := new(TokenModel)
|
||||
|
||||
_, err := s.db.NewUpdate().
|
||||
Model((*TokenModel)(nil)).
|
||||
Set("name = ?", name).
|
||||
Where("id = ?", token.ID).
|
||||
Exec(ctx, newToken)
|
||||
if err != nil {
|
||||
return nil, errs.Errorf("failed to change token name", err)
|
||||
}
|
||||
|
||||
return newToken.toToken(), nil
|
||||
}
|
||||
|
||||
func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, error) {
|
||||
userModel := new(UserModel)
|
||||
|
||||
|
@ -329,31 +477,31 @@ func findUser(ctx context.Context, db bun.IDB, username string) (*UserModel, err
|
|||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errs.ErrUserDoesNotExist
|
||||
err = errs.ErrUserDoesNotExist
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||
|
||||
return nil, errs.Errorf("failed to find user", err)
|
||||
}
|
||||
|
||||
return userModel, nil
|
||||
}
|
||||
|
||||
func findUserIDFromUsername(ctx context.Context, db bun.IDB, username string) (*int64, error) {
|
||||
var userID int64
|
||||
func findUserByID(ctx context.Context, db bun.IDB, id string) (*UserModel, error) {
|
||||
userModel := new(UserModel)
|
||||
|
||||
err := db.NewSelect().
|
||||
Table("users").
|
||||
Column("id").
|
||||
Where("username = ?", username).
|
||||
Scan(ctx, &userID)
|
||||
Model(userModel).
|
||||
Where("id = ?", id).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, errs.ErrUserDoesNotExist
|
||||
err = errs.ErrUserDoesNotExist
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get user ID: %w", err)
|
||||
return nil, errs.Errorf("failed to find user", err)
|
||||
}
|
||||
|
||||
return &userID, nil
|
||||
return userModel, nil
|
||||
}
|
||||
|
||||
func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
|
||||
|
@ -363,19 +511,46 @@ func withShortDeleteUpdates(q *bun.UpdateQuery) *bun.UpdateQuery {
|
|||
}
|
||||
|
||||
func deleteUserShorts(ctx context.Context, db bun.IDB, user *models.User) error {
|
||||
userID, err := findUserIDFromUsername(ctx, db, user.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user shorts: %w", err)
|
||||
}
|
||||
|
||||
_, err = withShortDeleteUpdates(
|
||||
_, err := withShortDeleteUpdates(
|
||||
db.NewUpdate().
|
||||
Model((*ShortModel)(nil)).
|
||||
Where("user_id = ?", userID),
|
||||
Where("user_id = ?", user.ID),
|
||||
).Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user shorts: %w", err)
|
||||
return errs.Errorf("failed to delete user shorts", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createNewID(ctx context.Context, db bun.IDB, table, col string) (string, error) {
|
||||
var newID string
|
||||
|
||||
// Generate an ID that does not exist
|
||||
maxIters := 10
|
||||
for {
|
||||
// Make sure we don't get stuck in an infinite loop
|
||||
maxIters--
|
||||
if maxIters <= 0 {
|
||||
return "", errs.Errorf("failed to create unique ID", ErrDatabaseError)
|
||||
}
|
||||
|
||||
// Create a new ID and check if it exists
|
||||
newID = models.NewID()
|
||||
count, err := db.NewSelect().
|
||||
Table(table).
|
||||
Column(col).
|
||||
Where(col+" = ?", newID).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return "", errs.Errorf("failed to create unique ID", err)
|
||||
}
|
||||
|
||||
// If the ID does not exist, break the loop
|
||||
if count == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return newID, nil
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ package memorystorage
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage/models"
|
||||
|
@ -13,87 +15,170 @@ import (
|
|||
|
||||
// MemoryStorage is a storage that stores everything in memory.
|
||||
type MemoryStorage struct {
|
||||
debug bool
|
||||
storage.Storage
|
||||
shortMu sync.RWMutex
|
||||
userMu sync.RWMutex
|
||||
tokenMu sync.RWMutex
|
||||
shortMap map[string]*models.Short
|
||||
userMap map[string]*models.User
|
||||
tokenMap map[string]*models.Token
|
||||
tokenIDMap map[string]*models.Token
|
||||
shortMu sync.RWMutex
|
||||
shortIDMu sync.RWMutex
|
||||
shortLogMu sync.RWMutex
|
||||
userMu sync.RWMutex
|
||||
userIDMu sync.RWMutex
|
||||
tokenMu sync.RWMutex
|
||||
tokenIDMu sync.RWMutex
|
||||
shortMap map[string]*models.Short
|
||||
shortIDMap map[string]*models.Short
|
||||
shortLogMap map[string][]*models.ShortLog
|
||||
userMap map[string]*models.User
|
||||
userIDMap map[string]*models.User
|
||||
tokenMap map[string]*models.Token
|
||||
tokenIDMap map[string]*models.Token
|
||||
}
|
||||
|
||||
// NewMemoryStorage creates a new MemoryStorage.
|
||||
func NewMemoryStorage() storage.Storage {
|
||||
func NewMemoryStorage(cfg *config.Config) *MemoryStorage {
|
||||
return &MemoryStorage{
|
||||
shortMap: make(map[string]*models.Short),
|
||||
userMap: make(map[string]*models.User),
|
||||
tokenMap: make(map[string]*models.Token),
|
||||
tokenIDMap: make(map[string]*models.Token),
|
||||
debug: cfg.Debug,
|
||||
shortMap: make(map[string]*models.Short),
|
||||
shortIDMap: make(map[string]*models.Short),
|
||||
userMap: make(map[string]*models.User),
|
||||
userIDMap: make(map[string]*models.User),
|
||||
tokenMap: make(map[string]*models.Token),
|
||||
tokenIDMap: make(map[string]*models.Token),
|
||||
shortLogMap: make(map[string][]*models.ShortLog),
|
||||
}
|
||||
}
|
||||
|
||||
// logPerformance is a helper function to log the performance of a function.
|
||||
func (s *MemoryStorage) logPerformance() func() {
|
||||
start := time.Now()
|
||||
|
||||
return func() {
|
||||
elapsed := time.Since(start)
|
||||
if s.debug {
|
||||
pc, _, _, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
method := runtime.FuncForPC(pc).Name()
|
||||
fmt.Printf("%s took %s\n", method, elapsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the storage.
|
||||
func (s *MemoryStorage) Start(ctx context.Context) error {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the storage.
|
||||
func (s *MemoryStorage) Stop(ctx context.Context) error {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindShort finds a short in the storage.
|
||||
func (s *MemoryStorage) FindShort(ctx context.Context, name string) (*models.Short, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortMu.RLock()
|
||||
defer s.shortMu.RUnlock()
|
||||
|
||||
short, ok := s.shortMap[name]
|
||||
if !ok {
|
||||
return short, errs.ErrShortDoesNotExist
|
||||
return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist)
|
||||
}
|
||||
|
||||
return short, nil
|
||||
}
|
||||
|
||||
// FindShortByID finds a short in the storage.
|
||||
func (s *MemoryStorage) FindShortByID(ctx context.Context, id string) (*models.Short, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortIDMu.RLock()
|
||||
defer s.shortIDMu.RUnlock()
|
||||
|
||||
short, ok := s.shortIDMap[id]
|
||||
if !ok {
|
||||
return nil, errs.Errorf("failed to find short", errs.ErrShortDoesNotExist)
|
||||
}
|
||||
|
||||
return short, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) CreateShort(ctx context.Context, short *models.Short) (*models.Short, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortMu.Lock()
|
||||
defer s.shortMu.Unlock()
|
||||
s.shortIDMu.Lock()
|
||||
defer s.shortIDMu.Unlock()
|
||||
|
||||
_, ok := s.shortMap[short.Name]
|
||||
if ok {
|
||||
return &models.Short{}, errs.ErrShortExists
|
||||
return nil, errs.Errorf("failed to create short", errs.ErrShortExists)
|
||||
}
|
||||
|
||||
short.CreatedAt = time.Now()
|
||||
|
||||
// Create a new ID
|
||||
for {
|
||||
short.ID = models.NewID()
|
||||
|
||||
_, ok := s.shortIDMap[short.ID]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.shortMap[short.Name] = short
|
||||
s.shortIDMap[short.ID] = short
|
||||
|
||||
return short, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) DeleteShort(ctx context.Context, short *models.Short) error {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortMu.Lock()
|
||||
defer s.shortMu.Unlock()
|
||||
s.shortIDMu.Lock()
|
||||
defer s.shortIDMu.Unlock()
|
||||
s.shortLogMu.Lock()
|
||||
defer s.shortLogMu.Unlock()
|
||||
|
||||
_, ok := s.shortMap[short.Name]
|
||||
_, ok := s.shortIDMap[short.ID]
|
||||
if !ok {
|
||||
return errs.ErrShortDoesNotExist
|
||||
return errs.Errorf("failed to delete short", errs.ErrShortDoesNotExist)
|
||||
}
|
||||
|
||||
delete(s.shortMap, short.Name)
|
||||
delete(s.shortIDMap, short.ID)
|
||||
delete(s.shortLogMap, short.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error) {
|
||||
s.shortMu.RLock()
|
||||
defer s.shortMu.RUnlock()
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortIDMu.RLock()
|
||||
defer s.shortIDMu.RUnlock()
|
||||
|
||||
shorts := []*models.Short{}
|
||||
|
||||
for _, short := range s.shortMap {
|
||||
if short.User != nil && short.User.Username == user.Username {
|
||||
for _, short := range s.shortIDMap {
|
||||
if short.UserID != nil && *short.UserID == user.ID {
|
||||
shorts = append(shorts, short)
|
||||
}
|
||||
}
|
||||
|
@ -101,108 +186,203 @@ func (s *MemoryStorage) ListShorts(ctx context.Context, user *models.User) ([]*m
|
|||
return shorts, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortLogMu.Lock()
|
||||
defer s.shortLogMu.Unlock()
|
||||
|
||||
// Check if short exists
|
||||
short, err := s.FindShortByID(ctx, shortLog.ShortID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create short log: %w", err)
|
||||
}
|
||||
|
||||
// Create
|
||||
_, ok := s.shortLogMap[short.ID]
|
||||
if !ok {
|
||||
s.shortLogMap[shortLog.ShortID] = []*models.ShortLog{}
|
||||
}
|
||||
|
||||
shortLog.CreatedAt = time.Now()
|
||||
|
||||
// Create a new ID (we can't really know if it's unique, but it's unlikely to happen)
|
||||
shortLog.ID = models.NewID()
|
||||
|
||||
s.shortLogMap[shortLog.ShortID] = append(s.shortLogMap[shortLog.ShortID], shortLog)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.shortLogMu.RLock()
|
||||
defer s.shortLogMu.RUnlock()
|
||||
|
||||
shortLogs, ok := s.shortLogMap[short.ID]
|
||||
if !ok {
|
||||
return []*models.ShortLog{}, nil
|
||||
}
|
||||
|
||||
return shortLogs, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) FindUser(ctx context.Context, username string) (*models.User, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.userMu.RLock()
|
||||
defer s.userMu.RUnlock()
|
||||
|
||||
user, ok := s.userMap[username]
|
||||
if !ok {
|
||||
return user, errs.ErrUserDoesNotExist
|
||||
return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) FindUserByID(ctx context.Context, id string) (*models.User, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.userIDMu.RLock()
|
||||
defer s.userIDMu.RUnlock()
|
||||
|
||||
user, ok := s.userIDMap[id]
|
||||
if !ok {
|
||||
return nil, errs.Errorf("failed to find user", errs.ErrUserDoesNotExist)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) CreateUser(ctx context.Context, user *models.User) (*models.User, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.userMu.Lock()
|
||||
defer s.userMu.Unlock()
|
||||
s.userIDMu.Lock()
|
||||
defer s.userIDMu.Unlock()
|
||||
|
||||
_, ok := s.userMap[user.Username]
|
||||
if ok {
|
||||
return &models.User{}, errs.ErrUserExists
|
||||
return nil, errs.Errorf("failed to create user", errs.ErrUserExists)
|
||||
}
|
||||
|
||||
user.CreatedAt = time.Now()
|
||||
|
||||
// Create a new ID
|
||||
for {
|
||||
user.ID = models.NewID()
|
||||
|
||||
_, ok := s.userIDMap[user.ID]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.userMap[user.Username] = user
|
||||
s.userIDMap[user.ID] = user
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) DeleteUser(ctx context.Context, user *models.User) error {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.userMu.Lock()
|
||||
defer s.userMu.Unlock()
|
||||
s.userIDMu.Lock()
|
||||
defer s.userIDMu.Unlock()
|
||||
|
||||
errormsg := "failed to delete user"
|
||||
|
||||
_, ok := s.userMap[user.Username]
|
||||
if !ok {
|
||||
return errs.ErrUserDoesNotExist
|
||||
return errs.Errorf(errormsg, errs.ErrUserDoesNotExist)
|
||||
}
|
||||
|
||||
// Find all user shorts
|
||||
shorts, err := s.ListShorts(ctx, user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not list user shorts: %w", err)
|
||||
return errs.Errorf(errormsg, errs.Errorf("could not list user shorts", err))
|
||||
}
|
||||
|
||||
// Delete all user shorts
|
||||
for _, short := range shorts {
|
||||
err = s.DeleteShort(ctx, short)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not delete user short: %w", err)
|
||||
return errs.Errorf(errormsg, errs.Errorf("could not delete user short", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Find all user tokens
|
||||
tokens, err := s.ListTokens(ctx, user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not list user tokens: %w", err)
|
||||
return errs.Errorf(errormsg, errs.Errorf("could not list user tokens", err))
|
||||
}
|
||||
|
||||
// Delete all user tokens
|
||||
for _, token := range tokens {
|
||||
err = s.DeleteToken(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not delete user token: %w", err)
|
||||
return errs.Errorf(errormsg, errs.Errorf("could not delete user token", err))
|
||||
}
|
||||
}
|
||||
|
||||
delete(s.userMap, user.Username)
|
||||
delete(s.userIDMap, user.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) FindToken(ctx context.Context, value string) (*models.Token, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.tokenMu.RLock()
|
||||
defer s.tokenMu.RUnlock()
|
||||
|
||||
token, ok := s.tokenMap[value]
|
||||
if !ok {
|
||||
return token, errs.ErrTokenDoesNotExist
|
||||
return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) FindTokenByID(ctx context.Context, id string) (*models.Token, error) {
|
||||
s.tokenMu.RLock()
|
||||
defer s.tokenMu.RUnlock()
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.tokenIDMu.RLock()
|
||||
defer s.tokenIDMu.RUnlock()
|
||||
|
||||
token, ok := s.tokenIDMap[id]
|
||||
if !ok {
|
||||
return token, errs.ErrTokenDoesNotExist
|
||||
return nil, errs.Errorf("failed to find token", errs.ErrTokenDoesNotExist)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*models.Token, error) {
|
||||
s.tokenMu.RLock()
|
||||
defer s.tokenMu.RUnlock()
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.tokenIDMu.RLock()
|
||||
defer s.tokenIDMu.RUnlock()
|
||||
|
||||
tokens := []*models.Token{}
|
||||
|
||||
for _, token := range s.tokenMap {
|
||||
if token.User != nil && token.User.Username == user.Username {
|
||||
for _, token := range s.tokenIDMap {
|
||||
if token.UserID != nil && *token.UserID == user.ID {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
|
@ -211,20 +391,35 @@ func (s *MemoryStorage) ListTokens(ctx context.Context, user *models.User) ([]*m
|
|||
}
|
||||
|
||||
func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*models.Token, error) {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.tokenMu.Lock()
|
||||
defer s.tokenMu.Unlock()
|
||||
s.tokenIDMu.Lock()
|
||||
defer s.tokenIDMu.Unlock()
|
||||
|
||||
_, ok := s.tokenMap[token.Value]
|
||||
if ok {
|
||||
return &models.Token{}, errs.ErrTokenExists
|
||||
return nil, errs.Errorf("failed to create token", errs.ErrTokenExists)
|
||||
}
|
||||
_, ok = s.tokenIDMap[token.ID]
|
||||
if ok {
|
||||
return &models.Token{}, errs.ErrTokenExists
|
||||
return nil, errs.Errorf("failed to create token", errs.ErrTokenExists)
|
||||
}
|
||||
|
||||
token.CreatedAt = time.Now()
|
||||
|
||||
// Create a new ID
|
||||
for {
|
||||
token.ID = models.NewID()
|
||||
|
||||
_, ok := s.tokenIDMap[token.ID]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.tokenMap[token.Value] = token
|
||||
s.tokenIDMap[token.ID] = token
|
||||
|
||||
|
@ -232,16 +427,21 @@ func (s *MemoryStorage) CreateToken(ctx context.Context, token *models.Token) (*
|
|||
}
|
||||
|
||||
func (s *MemoryStorage) DeleteToken(ctx context.Context, token *models.Token) error {
|
||||
logPerf := s.logPerformance()
|
||||
defer logPerf()
|
||||
|
||||
s.tokenMu.Lock()
|
||||
defer s.tokenMu.Unlock()
|
||||
s.tokenIDMu.Lock()
|
||||
defer s.tokenIDMu.Unlock()
|
||||
|
||||
_, ok := s.tokenMap[token.Value]
|
||||
if !ok {
|
||||
return errs.ErrTokenDoesNotExist
|
||||
return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist)
|
||||
}
|
||||
_, ok = s.tokenIDMap[token.ID]
|
||||
if !ok {
|
||||
return errs.ErrTokenDoesNotExist
|
||||
return errs.Errorf("failed to delete token", errs.ErrTokenDoesNotExist)
|
||||
}
|
||||
|
||||
delete(s.tokenMap, token.Value)
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
package models
|
||||
|
||||
import tokenutil "git.maronato.dev/maronato/goshort/internal/util/token"
|
||||
|
||||
const (
|
||||
// IDLength is the length of IDs.
|
||||
IDLength = 16
|
||||
)
|
||||
|
||||
func NewID() string {
|
||||
return tokenutil.GenerateSecureToken(IDLength / 2)
|
||||
}
|
|
@ -2,14 +2,20 @@ package models
|
|||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
ShortIDLength = 16
|
||||
)
|
||||
|
||||
type Short struct {
|
||||
// ID is the unique identifier of the short.
|
||||
ID string `json:"id"`
|
||||
// Name is the shortened name of the URL.
|
||||
Name string `json:"name,omitempty"`
|
||||
// URL is the URL that is shortened.
|
||||
URL string `json:"url"`
|
||||
// CreatedAt is the time the short was created.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt,omitempty"`
|
||||
|
||||
// User is the user that created the short.
|
||||
User *User `json:"-"`
|
||||
// UserID of the user that created the short.
|
||||
UserID *string `json:"-"`
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
type ShortLog struct {
|
||||
// ID is the unique identifier of the short log.
|
||||
ID string `json:"id"`
|
||||
// ShortID is the ID of the short that was accessed.
|
||||
ShortID string `json:"shortID"`
|
||||
// IPAddress is the IP address of the client that accessed the short.
|
||||
IPAddress string `json:"ipAddress"`
|
||||
// UserAgent is the User-Agent of the client that accessed the short.
|
||||
UserAgent string `json:"userAgent"`
|
||||
// Referer is the referer of the client that accessed the short.
|
||||
Referer string `json:"referer"`
|
||||
// CreatedAt is the time the short was accessed.
|
||||
CreatedAt time.Time `json:"createdAt,omitempty"`
|
||||
}
|
|
@ -12,6 +12,6 @@ type Token struct {
|
|||
// CreatedAt is when the token was created (initialized by the storage)
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
// User is the user that created the token.
|
||||
User *User `json:"-"`
|
||||
// UserID of the user that created the token.
|
||||
UserID *string `json:"-"`
|
||||
}
|
||||
|
|
|
@ -8,11 +8,15 @@ import (
|
|||
)
|
||||
|
||||
type User struct {
|
||||
// ID is the user's ID.
|
||||
ID string `json:"id"`
|
||||
// Username is the user's username.
|
||||
Username string `json:"username"`
|
||||
// password is the user's password.
|
||||
password string `json:"-"`
|
||||
|
||||
// CreatedAt is the time the user was created.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt,omitempty"`
|
||||
}
|
||||
|
||||
// NewAuthenticatableUser is a elper function for storages that takes
|
||||
|
@ -20,6 +24,7 @@ type User struct {
|
|||
func NewAuthenticatableUser(user *User, hashedPass string) *User {
|
||||
|
||||
return &User{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
CreatedAt: user.CreatedAt,
|
||||
password: hashedPass,
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package sqlitestorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"git.maronato.dev/maronato/goshort/internal/config"
|
||||
"git.maronato.dev/maronato/goshort/internal/errs"
|
||||
"git.maronato.dev/maronato/goshort/internal/storage"
|
||||
bunstorage "git.maronato.dev/maronato/goshort/internal/storage/bun"
|
||||
"github.com/uptrace/bun"
|
||||
|
@ -13,13 +16,74 @@ import (
|
|||
|
||||
// NewSQLiteStorage creates a new SQLite storage.
|
||||
func NewSQLiteStorage(cfg *config.Config) storage.Storage {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL)
|
||||
// Create a new SQLite database with the following pragmas enabled:
|
||||
// - journal_mode=WAL: Enables Write-Ahead Logging, which allows for concurrent reads and writes.
|
||||
// - foreign_keys=ON: Enables foreign key constraints.
|
||||
// - synchronous=NORMAL: Enables synchronous mode NORMAL
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, cfg.DBURL+"?_pragma=foreign_keys(ON)&_pragma=journal_mode(WAL)&_pragma=synchronous(NORMAL)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// If running the DB in memory, make sure
|
||||
// database/sql does not close idle connections.
|
||||
// Otherwise, the database will be lost.
|
||||
if cfg.DBType == config.DBTypeMemory {
|
||||
sqldb.SetMaxIdleConns(1000)
|
||||
sqldb.SetConnMaxLifetime(0)
|
||||
}
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
|
||||
// Use Bun as the storage implementation.
|
||||
return bunstorage.NewBunStorage(cfg, db)
|
||||
storage := bunstorage.NewBunStorage(cfg, db)
|
||||
|
||||
// Config vacuuming of the database every hour.
|
||||
ticker := time.NewTicker(time.Hour)
|
||||
tickerCtx, tickerCancel := context.WithCancel(context.Background())
|
||||
// Start vacuuming the database when the storage is started.
|
||||
storage.RegisterStartHook(func(ctx context.Context, activeDB *bun.DB) error {
|
||||
// Vacuum now
|
||||
_, err := activeDB.NewRaw("VACUUM").Exec(ctx)
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to vacuum database: %w", err)
|
||||
}
|
||||
|
||||
// Vacuum every hour, until the context is canceled.
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
_, err := activeDB.NewRaw("VACUUM").Exec(ctx)
|
||||
if err != nil {
|
||||
panic(errs.Errorf("failed to vacuum database: %w", err))
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-tickerCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
})
|
||||
// So some cleaning up and optimization when the storage is stopped.
|
||||
storage.RegisterStopHook(func(ctx context.Context, activeDB *bun.DB) error {
|
||||
// Stop the ticker and cancel the ticker context.
|
||||
ticker.Stop()
|
||||
tickerCancel()
|
||||
// Run PRAGMA optimize to optimize the database.
|
||||
// By this point, the passed context is likely cancelled, so we use the
|
||||
// background context since otherwise the query will fail.
|
||||
// This will delay the shutdown of the storage, but it'll be for a good cause.
|
||||
// The user can force the shutdown by using SIGINT or SIGTERM.
|
||||
_, err := activeDB.NewRaw("PRAGMA optimize").Exec(context.Background())
|
||||
if err != nil {
|
||||
return errs.Errorf("failed to optimize database: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return storage
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@ type Storage interface {
|
|||
|
||||
// FindShort finds a short in the storage using its name.
|
||||
FindShort(ctx context.Context, name string) (*models.Short, error)
|
||||
// FindShortByID finds a short in the storage using its ID.
|
||||
FindShortByID(ctx context.Context, id string) (*models.Short, error)
|
||||
// FindShorts finds all shorts in the storage that belong to a user.
|
||||
ListShorts(ctx context.Context, user *models.User) ([]*models.Short, error)
|
||||
// CreateShort creates a short in the storage.
|
||||
|
@ -21,10 +23,19 @@ type Storage interface {
|
|||
// DeleteShort deletes a short from the storage.
|
||||
DeleteShort(ctx context.Context, short *models.Short) error
|
||||
|
||||
// ShortLog Storage
|
||||
|
||||
// CreateShortLog creates a short log in the storage.
|
||||
CreateShortLog(ctx context.Context, shortLog *models.ShortLog) error
|
||||
// ListShortLogs finds all short logs in the storage that belong to a short.
|
||||
ListShortLogs(ctx context.Context, short *models.Short) ([]*models.ShortLog, error)
|
||||
|
||||
// User Storage
|
||||
|
||||
// FindUser finds a user in the storage using its username.
|
||||
FindUser(ctx context.Context, username string) (*models.User, error)
|
||||
// FindUserByID finds a user in the storage using its ID.
|
||||
FindUserByID(ctx context.Context, id string) (*models.User, error)
|
||||
// CreateUser creates a user in the storage.
|
||||
CreateUser(ctx context.Context, user *models.User) (*models.User, error)
|
||||
// DeleteUser deletes a user and all their shorts from the storage.
|
||||
|
|
Loading…
Reference in New Issue
Block a user