refactor and add tests
Some checks failed
Go / checks (push) Failing after 52s

This commit is contained in:
Gustavo Maronato 2023-09-20 01:26:38 -03:00
parent 6bbfbad1d0
commit f96dda4af2
Signed by: maronato
SSH Key Fingerprint: SHA256:2Gw7kwMz/As+2UkR1qQ/qYYhn+WNh3FGv6ozhoRrLcs
26 changed files with 2180 additions and 613 deletions

148
README.md
View File

@ -0,0 +1,148 @@
# Finger
Webfinger server written in Go.
## Features
- 🍰 Easy YAML configuration
- 🪶 Single 8MB binary / 0% idle CPU / 4MB idle RAM
- ⚡️ Sub millisecond responses at 10,000 request per second
- 🐳 10MB Docker image
## Install
Via `go install`:
```bash
go install git.maronato.dev/maronato/finger@latest
```
Via Docker:
```bash
docker run --name finger /
-p 8080:8080 /
git.maronato.dev/maronato/finger
```
## Usage
If you installed it using `go install`, run
```bash
finger serve
```
To start the server on port `8080`. Your resources will be queryable via `locahost:8080/.well-known/webfinger?resource=<your-resource>`
If you're using Docker, the use the same command in the install section.
By default, no resources will be exposed. You can create resources via a `fingers.yml` file. It should contain a collection of resources as keys and their attributes as their objects.
Some default URN aliases are provided via the built-in mapping ([`urns.yml`](./urns.yml)). You can replace that with your own or use URNs directly in the `fingers.yml` file.
Here's an example:
```yaml
# fingers.yml
# Resources go in the root of the file. Email address will have the acct:
# prefix added automatically.
alice@example.com:
# "avatar" is an alias of "http://webfinger.net/rel/avatar"
# (see urns.yml for more)
avatar: "https://example.com/alice-pic"
# If the value is a URI, it'll be exposed as a webfinger link
openid: "https://sso.example.com/"
# If the value of the attribute is not a URI, it will be exposed as a
# webfinger property
name: "Alice Doe"
# You can also specify URN's directly instead of the aliases
http://webfinger.net/rel/profile-page: "https://example.com/user/alice"
bob@example.com:
name: Bob Foo
openid: "https://sso.example.com/"
# Resources can also be URIs
https://example.com/user/charlie:
name: Charlie Baz
profile: https://example.com/user/charlie
```
### Example queries
<details>
<summary><b>Query Alice</b><pre>GET http://localhost:8080/.well-known/webfinger?resource=acct:alice@example.com</pre></summary>
```json
{
"subject": "acct:alice@example.com",
"links": [
{
"rel": "avatar",
"href": "https://example.com/alice-pic"
},
{
"rel": "openid",
"href": "https://sso.example.com/"
},
{
"rel": "name",
"href": "Alice Doe"
},
{
"rel": "http://webfinger.net/rel/profile-page",
"href": "https://example.com/user/alice"
}
]
}
```
</details>
<details>
<summary><b>Query Bob</b><pre>GET http://localhost:8080/.well-known/webfinger?resource=acct:bob@example.com</pre></summary>
```json
{
"subject": "acct:bob@example.com",
"links": [
{
"rel": "name",
"href": "Bob Foo"
},
{
"rel": "openid",
"href": "https://sso.example.com/"
}
]
}
```
</details>
<details>
<summary><b>Query Charlie</b><pre>GET http://localhost:8080/.well-known/webfinger?resource=https://example.com/user/charlie</pre></summary>
```JSON
{
"subject": "https://example.com/user/charlie",
"links": [
{
"rel": "name",
"href": "Charlie Baz"
},
{
"rel": "profile",
"href": "https://example.com/user/charlie"
}
]
}
```
</details>
## Configs
Here are the config options available. You can change them via command line flags or environment variables:
| CLI flag | Env variable | Default | Description |
| -------- | ------------ | ------- | ----------- |
| fdsfds | gsfgfs | fgfsdgf | gdfsgdf |

98
cmd/cmd.go Normal file
View File

@ -0,0 +1,98 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"syscall"
"git.maronato.dev/maronato/finger/internal/config"
"github.com/peterbourgon/ff/v4"
"github.com/peterbourgon/ff/v4/ffhelp"
)
func Run(version string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Allow graceful shutdown
trapSignalsCrossPlatform(cancel)
cfg := &config.Config{}
// Create a new root command
subcommands := []*ff.Command{
newServerCmd(cfg),
newHealthcheckCmd(cfg),
}
cmd := newRootCmd(version, cfg, subcommands)
// Parse and run
if err := cmd.ParseAndRun(ctx, os.Args[1:], ff.WithEnvVarPrefix("WF")); err != nil {
if errors.Is(err, ff.ErrHelp) || errors.Is(err, ff.ErrNoExec) {
fmt.Fprintf(os.Stderr, "\n%s\n", ffhelp.Command(cmd))
return nil
}
return fmt.Errorf("error running command: %w", err)
}
return nil
}
// https://github.com/caddyserver/caddy/blob/fbb0ecfa322aa7710a3448453fd3ae40f037b8d1/sigtrap.go#L37
// trapSignalsCrossPlatform captures SIGINT or interrupt (depending
// on the OS), which initiates a graceful shutdown. A second SIGINT
// or interrupt will forcefully exit the process immediately.
func trapSignalsCrossPlatform(cancel context.CancelFunc) {
go func() {
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGINT)
for i := 0; true; i++ {
<-shutdown
if i > 0 {
fmt.Printf("\nForce quit\n") //nolint:forbidigo // We want to print to stdout
os.Exit(1)
}
fmt.Printf("\nGracefully shutting down. Press Ctrl+C again to force quit\n") //nolint:forbidigo // We want to print to stdout
cancel()
}
}()
}
// NewRootCmd parses the command line flags and returns a config.Config struct.
func newRootCmd(version string, cfg *config.Config, subcommands []*ff.Command) *ff.Command {
fs := ff.NewFlagSet(appName)
for _, cmd := range subcommands {
cmd.Flags = ff.NewFlagSet(cmd.Name).SetParent(fs)
}
cmd := &ff.Command{
Name: appName,
Usage: fmt.Sprintf("%s <command> [flags]", appName),
ShortHelp: fmt.Sprintf("(%s) A webfinger server", version),
Flags: fs,
Subcommands: subcommands,
}
// Use 0.0.0.0 as the default host if on docker
defaultHost := "localhost"
if os.Getenv("ENV_DOCKER") == "true" {
defaultHost = "0.0.0.0"
}
fs.BoolVar(&cfg.Debug, 'd', "debug", "Enable debug logging")
fs.StringVar(&cfg.Host, 'h', "host", defaultHost, "Host to listen on")
fs.StringVar(&cfg.Port, 'p', "port", "8080", "Port to listen on")
fs.StringVar(&cfg.URNPath, 'u', "urn-file", "urns.yml", "Path to the URNs file")
fs.StringVar(&cfg.FingerPath, 'f', "finger-file", "fingers.yml", "Path to the fingers file")
return cmd
}

53
cmd/healthcheck.go Normal file
View File

@ -0,0 +1,53 @@
package cmd
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
"git.maronato.dev/maronato/finger/internal/config"
"github.com/peterbourgon/ff/v4"
)
func newHealthcheckCmd(cfg *config.Config) *ff.Command {
return &ff.Command{
Name: "healthcheck",
Usage: "healthcheck [flags]",
ShortHelp: "Check if the server is running",
Exec: func(ctx context.Context, args []string) error {
// Create a new client
client := &http.Client{
Timeout: 5 * time.Second, //nolint:gomnd // We want to use a constant
}
// Create a new request
reqURL := url.URL{
Scheme: "http",
Host: cfg.GetAddr(),
Path: "/healthz",
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody)
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
// Send the request
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("error sending request: %w", err)
}
defer resp.Body.Close()
// Check the response
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server returned status %d", resp.StatusCode) //nolint:goerr113 // We want to return an error
}
return nil
},
}
}

49
cmd/serve.go Normal file
View File

@ -0,0 +1,49 @@
package cmd
import (
"context"
"fmt"
"os"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/server"
"git.maronato.dev/maronato/finger/internal/webfinger"
"github.com/peterbourgon/ff/v4"
)
const appName = "finger"
func newServerCmd(cfg *config.Config) *ff.Command {
return &ff.Command{
Name: "serve",
Usage: "serve [flags]",
ShortHelp: "Start the webfinger server",
Exec: func(ctx context.Context, args []string) error {
// Create a logger and add it to the context
l := log.NewLogger(os.Stderr, cfg)
ctx = log.WithLogger(ctx, l)
// Read the webfinger files
r := webfinger.NewFingerReader()
err := r.ReadFiles(cfg)
if err != nil {
return fmt.Errorf("error reading finger files: %w", err)
}
webfingers, err := r.ReadFingerFile(ctx)
if err != nil {
return fmt.Errorf("error parsing finger files: %w", err)
}
l.Info(fmt.Sprintf("Loaded %d webfingers", len(webfingers)))
// Start the server
if err := server.StartServer(ctx, cfg, webfingers); err != nil {
return fmt.Errorf("error running server: %w", err)
}
return nil
},
}
}

1
go.mod
View File

@ -4,7 +4,6 @@ go 1.21.0
require (
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3
golang.org/x/exp v0.0.0-20230905200255-921286631fa9
golang.org/x/sync v0.3.0
gopkg.in/yaml.v3 v3.0.1
)

2
go.sum
View File

@ -2,8 +2,6 @@ github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3 h1:fpyiFVEJvxIFljxM4l5ANSk/UGlM1gyU+hPAr9jhB7M=
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3/go.mod h1:H/13DK46DKXy7EaIxPhk2Y0EC8aubKm35nBjBe8AAGc=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

67
internal/config/config.go Normal file
View File

@ -0,0 +1,67 @@
package config
import (
"errors"
"fmt"
"net"
"net/url"
)
const (
// DefaultHost is the default host to listen on.
DefaultHost = "localhost"
// DefaultPort is the default port to listen on.
DefaultPort = "8080"
// DefaultURNPath is the default file path to the URN alias file.
DefaultURNPath = "urns.yml"
// DefaultFingerPath is the default file path to the webfinger definition file.
DefaultFingerPath = "finger.yml"
)
// ErrInvalidConfig is returned when the config is invalid.
var ErrInvalidConfig = errors.New("invalid config")
type Config struct {
Debug bool
Host string
Port string
URNPath string
FingerPath string
}
func NewConfig() *Config {
return &Config{
Host: DefaultHost,
Port: DefaultPort,
URNPath: DefaultURNPath,
FingerPath: DefaultFingerPath,
}
}
func (c *Config) GetAddr() string {
return net.JoinHostPort(c.Host, c.Port)
}
func (c *Config) Validate() error {
if c.Host == "" {
return fmt.Errorf("%w: host is empty", ErrInvalidConfig)
}
if c.Port == "" {
return fmt.Errorf("%w: port is empty", ErrInvalidConfig)
}
if _, err := url.Parse(c.GetAddr()); err != nil {
return fmt.Errorf("%w: %w", ErrInvalidConfig, err)
}
if c.URNPath == "" {
return fmt.Errorf("%w: urn path is empty", ErrInvalidConfig)
}
if c.FingerPath == "" {
return fmt.Errorf("%w: finger path is empty", ErrInvalidConfig)
}
return nil
}

View File

@ -0,0 +1,124 @@
package config_test
import (
"testing"
"git.maronato.dev/maronato/finger/internal/config"
)
func TestConfig_GetAddr(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg *config.Config
want string
}{
{
name: "default",
cfg: config.NewConfig(),
want: "localhost:8080",
},
{
name: "custom",
cfg: &config.Config{
Host: "example.com",
Port: "1234",
},
want: "example.com:1234",
},
}
for _, tt := range tests {
tc := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := tc.cfg.GetAddr()
if got != tc.want {
t.Errorf("Config.GetAddr() = %v, want %v", got, tc.want)
}
})
}
}
func TestConfig_Validate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg *config.Config
wantErr bool
}{
{
name: "default",
cfg: config.NewConfig(),
wantErr: false,
},
{
name: "empty host",
cfg: &config.Config{
Host: "",
Port: "1234",
},
wantErr: true,
},
{
name: "empty port",
cfg: &config.Config{
Host: "example.com",
Port: "",
},
wantErr: true,
},
{
name: "invalid addr",
cfg: &config.Config{
Host: "example.com",
Port: "invalid",
},
wantErr: true,
},
{
name: "empty urn path",
cfg: &config.Config{
Host: "example.com",
Port: "1234",
URNPath: "",
},
wantErr: true,
},
{
name: "empty finger path",
cfg: &config.Config{
Host: "example.com",
Port: "1234",
URNPath: "urns.yml",
FingerPath: "",
},
wantErr: true,
},
{
name: "valid",
cfg: &config.Config{
Host: "example.com",
Port: "1234",
URNPath: "urns.yml",
FingerPath: "finger.yml",
},
wantErr: false,
},
}
for _, tt := range tests {
tc := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tc.cfg.Validate()
if (err != nil) != tc.wantErr {
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tc.wantErr)
}
})
}
}

42
internal/log/log.go Normal file
View File

@ -0,0 +1,42 @@
package log
import (
"context"
"io"
"log/slog"
"git.maronato.dev/maronato/finger/internal/config"
)
type loggerCtxKey struct{}
// NewLogger creates a new logger with the given debug level.
func NewLogger(w io.Writer, cfg *config.Config) *slog.Logger {
level := slog.LevelInfo
addSource := false
if cfg.Debug {
level = slog.LevelDebug
addSource = true
}
return slog.New(
slog.NewJSONHandler(w, &slog.HandlerOptions{
Level: level,
AddSource: addSource,
}),
)
}
func FromContext(ctx context.Context) *slog.Logger {
l, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger)
if !ok {
panic("logger not found in context")
}
return l
}
func WithLogger(ctx context.Context, l *slog.Logger) context.Context {
return context.WithValue(ctx, loggerCtxKey{}, l)
}

95
internal/log/log_test.go Normal file
View File

@ -0,0 +1,95 @@
package log_test
import (
"context"
"strings"
"testing"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
)
func assertPanic(t *testing.T, f func()) {
t.Helper()
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
// Call the function
f()
}
func TestNewLogger(t *testing.T) {
t.Parallel()
t.Run("defaults to info level", func(t *testing.T) {
t.Parallel()
cfg := config.NewConfig()
w := &strings.Builder{}
l := log.NewLogger(w, cfg)
// It shouldn't log debug messages
l.Debug("test")
if w.String() != "" {
t.Error("logger logged debug message")
}
// It should log info messages
l.Info("test")
if w.String() == "" {
t.Error("logger did not log info message")
}
})
t.Run("logs debug messages if debug is enabled", func(t *testing.T) {
t.Parallel()
cfg := config.NewConfig()
cfg.Debug = true
w := &strings.Builder{}
l := log.NewLogger(w, cfg)
// It should log debug messages
l.Debug("test")
if w.String() == "" {
t.Error("logger did not log debug message")
}
})
}
func TestFromContext(t *testing.T) {
t.Parallel()
ctx := context.Background()
cfg := config.NewConfig()
l := log.NewLogger(nil, cfg)
t.Run("panics if no logger in context", func(t *testing.T) {
t.Parallel()
assertPanic(t, func() {
log.FromContext(ctx)
})
})
t.Run("returns logger from context", func(t *testing.T) {
t.Parallel()
ctx = log.WithLogger(ctx, l)
l2 := log.FromContext(ctx)
if l2 == nil {
t.Error("logger is nil")
}
})
}

View File

@ -0,0 +1,44 @@
package middleware
import (
"log/slog"
"net/http"
"time"
"git.maronato.dev/maronato/finger/internal/log"
)
func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := log.FromContext(ctx)
start := time.Now()
// Wrap the response writer
wrapped := WrapResponseWriter(w)
// Call the next handler
next.ServeHTTP(wrapped, r)
status := wrapped.Status()
// Log the request
lg := l.With(
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", status),
slog.String("remote", r.RemoteAddr),
slog.Duration("duration", time.Since(start)),
)
switch {
case status >= http.StatusInternalServerError:
lg.Error("Server error")
case status >= http.StatusBadRequest:
lg.Info("Client error")
default:
lg.Info("Request completed")
}
})
}

View File

@ -0,0 +1,44 @@
package middleware_test
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/middleware"
)
func TestRequestLogger(t *testing.T) {
t.Parallel()
ctx := context.Background()
cfg := config.NewConfig()
stdout := &strings.Builder{}
l := log.NewLogger(stdout, cfg)
ctx = log.WithLogger(ctx, l)
w := httptest.NewRecorder()
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody)
if stdout.String() != "" {
t.Error("logger logged before request")
}
middleware.RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Error("status is not 200")
}
if stdout.String() == "" {
t.Error("logger did not log request")
}
}

View File

@ -0,0 +1,27 @@
package middleware
import (
"log/slog"
"net/http"
"git.maronato.dev/maronato/finger/internal/log"
)
func Recoverer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := log.FromContext(ctx)
defer func() {
err := recover()
if err != nil {
l.Error("Panic", slog.Any("error", err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}()
next.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,76 @@
package middleware_test
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/middleware"
)
func assertNoPanic(t *testing.T, f func()) {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Error("function panicked")
}
}()
f()
}
func TestRecoverer(t *testing.T) {
t.Parallel()
ctx := context.Background()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
t.Run("handles panics", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody)
h := middleware.Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("test")
}))
assertNoPanic(t, func() {
h.ServeHTTP(w, r)
})
if w.Code != http.StatusInternalServerError {
t.Error("status is not 500")
}
if w.Body.String() != "Internal Server Error\n" {
t.Error("response body is not 'Internal Server Error'")
}
})
t.Run("handles successful requests", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody)
h := middleware.Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
assertNoPanic(t, func() {
h.ServeHTTP(w, r)
})
if w.Code != http.StatusOK {
t.Error("status is not 200")
}
})
}

View File

@ -0,0 +1,42 @@
package middleware
import (
"fmt"
"net/http"
)
type ResponseWrapper struct {
http.ResponseWriter
status int
}
func WrapResponseWriter(w http.ResponseWriter) *ResponseWrapper {
return &ResponseWrapper{w, 0}
}
func (w *ResponseWrapper) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
func (w *ResponseWrapper) Status() int {
return w.status
}
func (w *ResponseWrapper) Write(b []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
size, err := w.ResponseWriter.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing response: %w", err)
}
return size, nil
}
func (w *ResponseWrapper) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@ -0,0 +1,97 @@
package middleware_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"git.maronato.dev/maronato/finger/internal/middleware"
)
func TestWrapResponseWriter(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
wrapped := middleware.WrapResponseWriter(w)
if wrapped == nil {
t.Error("wrapper is nil")
}
}
func TestResponseWrapper_Status(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
wrapped := middleware.WrapResponseWriter(w)
if wrapped.Status() != 0 {
t.Error("status is not 0")
}
wrapped.WriteHeader(http.StatusOK)
if wrapped.Status() != http.StatusOK {
t.Error("status is not 200")
}
}
type FailWriter struct{}
func (w *FailWriter) Write(b []byte) (int, error) {
return 0, fmt.Errorf("error")
}
func (w *FailWriter) Header() http.Header {
return http.Header{}
}
func (w *FailWriter) WriteHeader(_ int) {}
func TestResponseWrapper_Write(t *testing.T) {
t.Parallel()
t.Run("writes success messages", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
wrapped := middleware.WrapResponseWriter(w)
size, err := wrapped.Write([]byte("test"))
if err != nil {
t.Errorf("error writing response: %v", err)
}
if size != 4 {
t.Error("size is not 4")
}
if wrapped.Status() != http.StatusOK {
t.Error("status is not 200")
}
})
t.Run("returns error on fail write", func(t *testing.T) {
t.Parallel()
w := &FailWriter{}
wrapped := middleware.WrapResponseWriter(w)
_, err := wrapped.Write([]byte("test"))
if err == nil {
t.Error("error is nil")
}
})
}
func TestResponseWrapper_Unwrap(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
wrapped := middleware.WrapResponseWriter(w)
if wrapped.Unwrap() != w {
t.Error("unwrapped response is not the same")
}
}

View File

@ -0,0 +1,13 @@
package server
import (
"net/http"
"git.maronato.dev/maronato/finger/internal/config"
)
func HealthCheckHandler(_ *config.Config) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
}

View File

@ -0,0 +1,40 @@
package server_test
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/server"
)
func TestHealthcheckHandler(t *testing.T) {
t.Parallel()
ctx := context.Background()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
// Create a new request
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/healthz", http.NoBody)
// Create a new recorder
rec := httptest.NewRecorder()
// Create a new handler
h := server.HealthCheckHandler(cfg)
// Serve the request
h.ServeHTTP(rec, req)
// Check the status code
if rec.Code != http.StatusOK {
t.Errorf("expected status code %d, got %d", http.StatusOK, rec.Code)
}
}

100
internal/server/server.go Normal file
View File

@ -0,0 +1,100 @@
package server
import (
"context"
"fmt"
"log/slog"
"net"
"net/http"
"time"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/middleware"
"git.maronato.dev/maronato/finger/internal/webfinger"
"golang.org/x/sync/errgroup"
)
const (
// ReadTimeout is the maximum duration for reading the entire
// request, including the body.
ReadTimeout = 5 * time.Second
// WriteTimeout is the maximum duration before timing out
// writes of the response.
WriteTimeout = 10 * time.Second
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alives are enabled.
IdleTimeout = 30 * time.Second
// ReadHeaderTimeout is the amount of time allowed to read
// request headers.
ReadHeaderTimeout = 2 * time.Second
// RequestTimeout is the maximum duration for the entire
// request.
RequestTimeout = 7 * 24 * time.Hour
)
func StartServer(ctx context.Context, cfg *config.Config, webfingers webfinger.WebFingers) error {
l := log.FromContext(ctx)
// Create the server mux
mux := http.NewServeMux()
mux.Handle("/.well-known/webfinger", WebfingerHandler(cfg, webfingers))
mux.Handle("/healthz", HealthCheckHandler(cfg))
// Create a new server
srv := &http.Server{
Addr: cfg.GetAddr(),
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
Handler: middleware.RequestLogger(
middleware.Recoverer(
http.TimeoutHandler(mux, RequestTimeout, "request timed out"),
),
),
ReadHeaderTimeout: ReadHeaderTimeout,
ReadTimeout: ReadTimeout,
WriteTimeout: WriteTimeout,
IdleTimeout: IdleTimeout,
}
// Create the errorgroup that will manage the server execution
eg, egCtx := errgroup.WithContext(ctx)
// Start the server
eg.Go(func() error {
l.Info("Starting server", slog.String("addr", srv.Addr))
// Use the global context for the server
srv.BaseContext = func(_ net.Listener) context.Context {
return egCtx
}
return srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup
})
// Gracefully shutdown the server when the context is done
eg.Go(func() error {
// Wait for the context to be done
<-egCtx.Done()
l.Info("Shutting down server")
// Disable the cancel since we don't wan't to force
// the server to shutdown if the context is canceled.
noCancelCtx := context.WithoutCancel(egCtx)
return srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup
})
// Log when the server is fully shutdown
srv.RegisterOnShutdown(func() {
l.Info("Server shutdown complete")
})
// Wait for the server to exit and check for errors that
// are not caused by the context being canceled.
if err := eg.Wait(); err != nil && ctx.Err() == nil {
return fmt.Errorf("server exited with error: %w", err)
}
return nil
}

View File

@ -0,0 +1,206 @@
package server_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"testing"
"time"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/server"
"git.maronato.dev/maronato/finger/internal/webfinger"
)
func getPortGenerator() func() int {
lock := &sync.Mutex{}
port := 8080
return func() int {
lock.Lock()
defer lock.Unlock()
port++
return port
}
}
func TestStartServer(t *testing.T) {
t.Parallel()
portGenerator := getPortGenerator()
t.Run("starts and shuts down", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
// Use a new port
cfg.Port = fmt.Sprint(portGenerator())
// Start the server
err := server.StartServer(ctx, cfg, nil)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
})
t.Run("fails to start", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
// Use a new port
cfg.Port = fmt.Sprint(portGenerator())
// Use invalid host
cfg.Host = "google.com"
// Start the server
err := server.StartServer(ctx, cfg, nil)
if err == nil {
t.Errorf("expected error, got nil")
}
})
t.Run("serves webfinger", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
// Use a new port
cfg.Port = fmt.Sprint(portGenerator())
resource := "acct:user@example.com"
webfingers := webfinger.WebFingers{
resource: &webfinger.WebFinger{
Subject: resource,
Properties: map[string]string{
"http://webfinger.net/rel/name": "John Doe",
},
},
}
go func() {
// Start the server
err := server.StartServer(ctx, cfg, webfingers)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
}()
// Wait for the server to start
time.Sleep(time.Millisecond * 50)
// Create a new client
c := http.Client{}
// Create a new request
r, _ := http.NewRequestWithContext(ctx,
http.MethodGet,
"http://"+cfg.GetAddr()+"/.well-known/webfinger?resource=acct:user@example.com",
http.NoBody,
)
// Send the request
resp, err := c.Do(r)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
defer resp.Body.Close()
// Check the status code
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
// Check the response body
fingerGot := &webfinger.WebFinger{}
// Decode the response body
if err := json.NewDecoder(resp.Body).Decode(fingerGot); err != nil {
t.Errorf("error decoding json: %v", err)
}
// Check the response body
fingerWant := webfingers[resource]
if !reflect.DeepEqual(fingerGot, fingerWant) {
t.Errorf("expected %v, got %v", fingerWant, fingerGot)
}
})
t.Run("serves healthcheck", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
// Use a new port
cfg.Port = fmt.Sprint(portGenerator())
go func() {
// Start the server
err := server.StartServer(ctx, cfg, nil)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
}()
// Wait for the server to start
time.Sleep(time.Millisecond * 50)
// Create a new client
c := http.Client{}
// Create a new request
r, _ := http.NewRequestWithContext(ctx,
http.MethodGet,
"http://"+cfg.GetAddr()+"/healthz",
http.NoBody,
)
// Send the request
resp, err := c.Do(r)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
defer resp.Body.Close()
// Check the status code
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
})
}

View File

@ -0,0 +1,59 @@
package server
import (
"encoding/json"
"net/http"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/webfinger"
)
func WebfingerHandler(_ *config.Config, webfingers webfinger.WebFingers) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := log.FromContext(ctx)
// Only handle GET requests
if r.Method != http.MethodGet {
l.Debug("Method not allowed")
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get the query params
q := r.URL.Query()
// Get the resource
resource := q.Get("resource")
if resource == "" {
l.Debug("No resource provided")
http.Error(w, "No resource provided", http.StatusBadRequest)
return
}
// Get and validate resource
finger, ok := webfingers[resource]
if !ok {
l.Debug("Resource not found")
http.Error(w, "Resource not found", http.StatusNotFound)
return
}
// Set the content type
w.Header().Set("Content-Type", "application/jrd+json")
// Write the response
if err := json.NewEncoder(w).Encode(finger); err != nil {
l.Debug("Error encoding json")
http.Error(w, "Error encoding json", http.StatusInternalServerError)
return
}
l.Debug("Webfinger request successful")
})
}

View File

@ -0,0 +1,149 @@
package server_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"sort"
"strings"
"testing"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/server"
"git.maronato.dev/maronato/finger/internal/webfinger"
)
func TestWebfingerHandler(t *testing.T) {
t.Parallel()
webfingers := webfinger.WebFingers{
"acct:user@example.com": {
Subject: "acct:user@example.com",
Links: []webfinger.Link{
{
Rel: "http://webfinger.net/rel/profile-page",
Href: "https://example.com/user",
},
},
Properties: map[string]string{
"http://webfinger.net/rel/name": "John Doe",
},
},
"acct:other@example.com": {
Subject: "acct:other@example.com",
Properties: map[string]string{
"http://webfinger.net/rel/name": "Jane Doe",
},
},
"https://example.com/user": {
Subject: "https://example.com/user",
Properties: map[string]string{
"http://webfinger.net/rel/name": "John Baz",
},
},
}
tests := []struct {
name string
resource string
wantCode int
alternateMethod string
}{
{
name: "valid resource",
resource: "acct:user@example.com",
wantCode: http.StatusOK,
},
{
name: "other valid resource",
resource: "acct:other@example.com",
wantCode: http.StatusOK,
},
{
name: "url resource",
resource: "https://example.com/user",
wantCode: http.StatusOK,
},
{
name: "resource missing acct:",
resource: "user@example.com",
wantCode: http.StatusNotFound,
},
{
name: "resource missing",
resource: "",
wantCode: http.StatusBadRequest,
},
{
name: "invalid method",
resource: "acct:user@example.com",
wantCode: http.StatusMethodNotAllowed,
alternateMethod: http.MethodPost,
},
}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
// Create a new request
r, _ := http.NewRequestWithContext(ctx, tc.alternateMethod, "/.well-known/webfinger?resource="+tc.resource, http.NoBody)
// Create a new response
w := httptest.NewRecorder()
// Create a new handler
h := server.WebfingerHandler(cfg, webfingers)
// Serve the request
h.ServeHTTP(w, r)
// Check the status code
if w.Code != tc.wantCode {
t.Errorf("expected status code %d, got %d", tc.wantCode, w.Code)
}
// If the status code is 200, check the response body
if tc.wantCode == http.StatusOK {
// Check the content type
if w.Header().Get("Content-Type") != "application/jrd+json" {
t.Errorf("expected content type %s, got %s", "application/jrd+json", w.Header().Get("Content-Type"))
}
fingerWant := webfingers[tc.resource]
fingerGot := &webfinger.WebFinger{}
// Decode the response body
if err := json.NewDecoder(w.Body).Decode(fingerGot); err != nil {
t.Errorf("error decoding json: %v", err)
}
// Sort links
sort.Slice(fingerGot.Links, func(i, j int) bool {
return fingerGot.Links[i].Rel < fingerGot.Links[j].Rel
})
sort.Slice(fingerWant.Links, func(i, j int) bool {
return fingerWant.Links[i].Rel < fingerWant.Links[j].Rel
})
// Check the response body
if !reflect.DeepEqual(fingerGot, fingerWant) {
t.Errorf("expected body %v, got %v", fingerWant, fingerGot)
}
}
})
}
}

View File

@ -0,0 +1,160 @@
package webfinger
import (
"context"
"fmt"
"log/slog"
"net/mail"
"net/url"
"os"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"gopkg.in/yaml.v3"
)
type Link struct {
Rel string `json:"rel"`
Href string `json:"href,omitempty"`
}
type WebFinger struct {
Subject string `json:"subject"`
Links []Link `json:"links,omitempty"`
Properties map[string]string `json:"properties,omitempty"`
}
type WebFingers map[string]*WebFinger
type (
URNMap = map[string]string
RawFingersMap = map[string]map[string]string
)
type FingerReader struct {
URNSFile []byte
FingersFile []byte
}
func NewFingerReader() *FingerReader {
return &FingerReader{}
}
func (f *FingerReader) ReadFiles(cfg *config.Config) error {
// Read URNs file
file, err := os.ReadFile(cfg.URNPath)
if err != nil {
return fmt.Errorf("error opening URNs file: %w", err)
}
f.URNSFile = file
// Read fingers file
file, err = os.ReadFile(cfg.FingerPath)
if err != nil {
return fmt.Errorf("error opening fingers file: %w", err)
}
f.FingersFile = file
return nil
}
func (f *FingerReader) ParseFingers(ctx context.Context, urns URNMap, rawFingers RawFingersMap) (WebFingers, error) {
l := log.FromContext(ctx)
webfingers := make(WebFingers)
// Parse the webfinger file
for k, v := range rawFingers {
resource := k
// Remove leading acct: if present
if len(k) > 5 && resource[:5] == "acct:" {
resource = resource[5:]
}
// The key must be a URL or email address
if _, err := mail.ParseAddress(resource); err != nil {
if _, err := url.ParseRequestURI(resource); err != nil {
return nil, fmt.Errorf("error parsing webfinger key (%s): %w", k, err)
}
} else {
// Add acct: back to the key if it is an email address
resource = fmt.Sprintf("acct:%s", resource)
}
// Create a new webfinger
webfinger := &WebFinger{
Subject: resource,
}
// Parse the fields
for field, value := range v {
fieldUrn := field
// If the key is present in the URNs file, use the value
if _, ok := urns[field]; ok {
fieldUrn = urns[field]
}
// If the value is a valid URI, add it to the links
if _, err := url.ParseRequestURI(value); err == nil {
webfinger.Links = append(webfinger.Links, Link{
Rel: fieldUrn,
Href: value,
})
} else {
// Otherwise add it to the properties
if webfinger.Properties == nil {
webfinger.Properties = make(map[string]string)
}
webfinger.Properties[fieldUrn] = value
}
}
// Add the webfinger to the map
webfingers[resource] = webfinger
}
l.Debug("Webfinger map built successfully", slog.Int("number", len(webfingers)), slog.Any("data", webfingers))
return webfingers, nil
}
func (f *FingerReader) ReadFingerFile(ctx context.Context) (WebFingers, error) {
l := log.FromContext(ctx)
urnMap := make(URNMap)
fingerData := make(RawFingersMap)
// Parse the URNs file
if err := yaml.Unmarshal(f.URNSFile, &urnMap); err != nil {
return nil, fmt.Errorf("error unmarshalling URNs file: %w", err)
}
// The URNs file must be a map of strings to valid URLs
for _, v := range urnMap {
if _, err := url.ParseRequestURI(v); err != nil {
return nil, fmt.Errorf("error parsing URN URIs: %w", err)
}
}
l.Debug("URNs file parsed successfully", slog.Int("number", len(urnMap)), slog.Any("data", urnMap))
// Parse the fingers file
if err := yaml.Unmarshal(f.FingersFile, &fingerData); err != nil {
return nil, fmt.Errorf("error unmarshalling fingers file: %w", err)
}
l.Debug("Fingers file parsed successfully", slog.Int("number", len(fingerData)), slog.Any("data", fingerData))
// Parse raw data
webfingers, err := f.ParseFingers(ctx, urnMap, fingerData)
if err != nil {
return nil, fmt.Errorf("error parsing raw fingers: %w", err)
}
return webfingers, nil
}

View File

@ -0,0 +1,444 @@
package webfinger_test
import (
"context"
"encoding/json"
"os"
"reflect"
"sort"
"strings"
"testing"
"git.maronato.dev/maronato/finger/internal/config"
"git.maronato.dev/maronato/finger/internal/log"
"git.maronato.dev/maronato/finger/internal/webfinger"
)
func newTempFile(t *testing.T, content string) (name string, remove func()) {
t.Helper()
f, err := os.CreateTemp("", "finger-test")
if err != nil {
t.Fatalf("error creating temp file: %v", err)
}
_, err = f.WriteString(content)
if err != nil {
t.Fatalf("error writing to temp file: %v", err)
}
return f.Name(), func() {
err = os.Remove(f.Name())
if err != nil {
t.Fatalf("error removing temp file: %v", err)
}
}
}
func TestNewFingerReader(t *testing.T) {
t.Parallel()
f := webfinger.NewFingerReader()
if f == nil {
t.Errorf("NewFingerReader() = %v, want: %v", f, nil)
}
}
func TestFingerReader_ReadFiles(t *testing.T) {
t.Parallel()
tests := []struct {
name string
urnsContent string
fingersContent string
useURNFile bool
useFingerFile bool
wantErr bool
}{
{
name: "reads files",
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
fingersContent: "user@example.com:\n name: John Doe",
useURNFile: true,
useFingerFile: true,
wantErr: false,
},
{
name: "errors on missing URNs file",
urnsContent: "invalid",
fingersContent: "user@example.com:\n name: John Doe",
useURNFile: false,
useFingerFile: true,
wantErr: true,
},
{
name: "errors on missing fingers file",
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
fingersContent: "invalid",
useFingerFile: false,
useURNFile: true,
wantErr: true,
},
}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg := config.NewConfig()
urnsFileName, urnsCleanup := newTempFile(t, tc.urnsContent)
defer urnsCleanup()
fingersFileName, fingersCleanup := newTempFile(t, tc.fingersContent)
defer fingersCleanup()
if !tc.useURNFile {
cfg.URNPath = "invalid"
} else {
cfg.URNPath = urnsFileName
}
if !tc.useFingerFile {
cfg.FingerPath = "invalid"
} else {
cfg.FingerPath = fingersFileName
}
f := webfinger.NewFingerReader()
err := f.ReadFiles(cfg)
if err != nil {
if !tc.wantErr {
t.Errorf("ReadFiles() error = %v", err)
}
return
} else if tc.wantErr {
t.Errorf("ReadFiles() error = %v, wantErr %v", err, tc.wantErr)
}
if !reflect.DeepEqual(f.URNSFile, []byte(tc.urnsContent)) {
t.Errorf("ReadFiles() URNsFile = %v, want: %v", f.URNSFile, tc.urnsContent)
}
if !reflect.DeepEqual(f.FingersFile, []byte(tc.fingersContent)) {
t.Errorf("ReadFiles() FingersFile = %v, want: %v", f.FingersFile, tc.fingersContent)
}
})
}
}
func TestParseFingers(t *testing.T) {
t.Parallel()
tests := []struct {
name string
rawFingers webfinger.RawFingersMap
want webfinger.WebFingers
wantErr bool
}{
{
name: "parses links",
rawFingers: webfinger.RawFingersMap{
"user@example.com": {
"profile": "https://example.com/profile",
"invalidalias": "https://example.com/invalidalias",
"https://something": "https://somethingelse",
},
},
want: webfinger.WebFingers{
"acct:user@example.com": {
Subject: "acct:user@example.com",
Links: []webfinger.Link{
{
Rel: "https://schema/profile",
Href: "https://example.com/profile",
},
{
Rel: "invalidalias",
Href: "https://example.com/invalidalias",
},
{
Rel: "https://something",
Href: "https://somethingelse",
},
},
},
},
wantErr: false,
},
{
name: "parses properties",
rawFingers: webfinger.RawFingersMap{
"user@example.com": {
"name": "John Doe",
"invalidalias": "value1",
"https://mylink": "value2",
},
},
want: webfinger.WebFingers{
"acct:user@example.com": {
Subject: "acct:user@example.com",
Properties: map[string]string{
"https://schema/name": "John Doe",
"invalidalias": "value1",
"https://mylink": "value2",
},
},
},
wantErr: false,
},
{
name: "accepts acct: prefix",
rawFingers: webfinger.RawFingersMap{
"acct:user@example.com": {
"name": "John Doe",
},
},
want: webfinger.WebFingers{
"acct:user@example.com": {
Subject: "acct:user@example.com",
Properties: map[string]string{
"https://schema/name": "John Doe",
},
},
},
wantErr: false,
},
{
name: "accepts urls as resource",
rawFingers: webfinger.RawFingersMap{
"https://example.com": {
"name": "John Doe",
},
},
want: webfinger.WebFingers{
"https://example.com": {
Subject: "https://example.com",
Properties: map[string]string{
"https://schema/name": "John Doe",
},
},
},
wantErr: false,
},
{
name: "accepts multiple resources",
rawFingers: webfinger.RawFingersMap{
"user@example.com": {
"name": "John Doe",
},
"other@example.com": {
"name": "Jane Doe",
},
},
want: webfinger.WebFingers{
"acct:user@example.com": {
Subject: "acct:user@example.com",
Properties: map[string]string{
"https://schema/name": "John Doe",
},
},
"acct:other@example.com": {
Subject: "acct:other@example.com",
Properties: map[string]string{
"https://schema/name": "Jane Doe",
},
},
},
wantErr: false,
},
{
name: "errors on invalid resource",
rawFingers: webfinger.RawFingersMap{
"invalid": {
"name": "John Doe",
},
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create a urn map
urns := webfinger.URNMap{
"name": "https://schema/name",
"profile": "https://schema/profile",
}
ctx := context.Background()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
f := webfinger.NewFingerReader()
got, err := f.ParseFingers(ctx, urns, tc.rawFingers)
if (err != nil) != tc.wantErr {
t.Errorf("ParseFingers() error = %v, wantErr %v", err, tc.wantErr)
return
}
// Sort links to make it easier to compare
for _, v := range got {
for range v.Links {
sort.Slice(v.Links, func(i, j int) bool {
return v.Links[i].Rel < v.Links[j].Rel
})
}
}
for _, v := range tc.want {
for range v.Links {
sort.Slice(v.Links, func(i, j int) bool {
return v.Links[i].Rel < v.Links[j].Rel
})
}
}
if !reflect.DeepEqual(got, tc.want) {
// Unmarshal the structs to JSON to make it easier to print
gotstr := &strings.Builder{}
gotenc := json.NewEncoder(gotstr)
wantstr := &strings.Builder{}
wantenc := json.NewEncoder(wantstr)
_ = gotenc.Encode(got)
_ = wantenc.Encode(tc.want)
t.Errorf("ParseFingers() got = \n%s want: \n%s", gotstr.String(), wantstr.String())
}
})
}
}
func TestReadFingerFile(t *testing.T) {
t.Parallel()
tests := []struct {
name string
urnsContent string
fingersContent string
wantURN webfinger.URNMap
wantFinger webfinger.RawFingersMap
returns *webfinger.WebFingers
wantErr bool
}{
{
name: "reads files",
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
fingersContent: "user@example.com:\n name: John Doe",
wantURN: webfinger.URNMap{
"name": "https://schema/name",
"profile": "https://schema/profile",
},
wantFinger: webfinger.RawFingersMap{
"user@example.com": {
"name": "John Doe",
},
},
returns: &webfinger.WebFingers{
"acct:user@example.com": {
Subject: "acct:user@example.com",
Properties: map[string]string{
"https://schema/name": "John Doe",
},
},
},
wantErr: false,
},
{
name: "uses custom URNs",
urnsContent: "favorite_food: https://schema/favorite_food",
fingersContent: "user@example.com:\n favorite_food: Apple",
wantURN: webfinger.URNMap{
"favorite_food": "https://schema/favorite_food",
},
wantFinger: webfinger.RawFingersMap{
"user@example.com": {
"https://schema/favorite_food": "Apple",
},
},
wantErr: false,
},
{
name: "errors on invalid URNs file",
urnsContent: "invalid",
fingersContent: "user@example.com:\n name: John Doe",
wantURN: webfinger.URNMap{},
wantFinger: webfinger.RawFingersMap{},
wantErr: true,
},
{
name: "errors on invalid fingers file",
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
fingersContent: "invalid",
wantURN: webfinger.URNMap{},
wantFinger: webfinger.RawFingersMap{},
wantErr: true,
},
{
name: "errors on invalid URNs values",
urnsContent: "name: invalid",
fingersContent: "user@example.com:\n name: John Doe",
wantURN: webfinger.URNMap{},
wantFinger: webfinger.RawFingersMap{},
wantErr: true,
},
{
name: "errors on invalid fingers values",
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
fingersContent: "invalid:\n name: John Doe",
wantURN: webfinger.URNMap{},
wantFinger: webfinger.RawFingersMap{},
wantErr: true,
},
}
for _, tt := range tests {
tc := tt
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
cfg := config.NewConfig()
l := log.NewLogger(&strings.Builder{}, cfg)
ctx = log.WithLogger(ctx, l)
f := webfinger.NewFingerReader()
f.FingersFile = []byte(tc.fingersContent)
f.URNSFile = []byte(tc.urnsContent)
got, err := f.ReadFingerFile(ctx)
if err != nil {
if !tc.wantErr {
t.Errorf("ReadFingerFile() error = %v", err)
}
return
} else if tc.wantErr {
t.Errorf("ReadFingerFile() error = %v, wantErr %v", err, tc.wantErr)
}
if tc.returns != nil && !reflect.DeepEqual(got, *tc.returns) {
t.Errorf("ReadFingerFile() got = %v, want: %v", got, *tc.returns)
}
})
}
}

553
main.go
View File

@ -1,566 +1,19 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/mail"
"net/url"
"os"
"os/signal"
"syscall"
"time"
"github.com/peterbourgon/ff/v4"
"github.com/peterbourgon/ff/v4/ffhelp"
"golang.org/x/exp/slog"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v3"
"git.maronato.dev/maronato/finger/cmd"
)
const appName = "finger"
// Version of the application.
// Version of the app.
var version = "dev"
func main() {
// Run the server
if err := Run(); err != nil {
if err := cmd.Run(version); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}
func Run() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Allow graceful shutdown
trapSignalsCrossPlatform(cancel)
cfg := &Config{}
// Create a new root command
subcommands := []*ff.Command{
NewServerCmd(cfg),
NewHealthcheckCmd(cfg),
}
cmd := NewRootCmd(cfg, subcommands)
// Parse and run
if err := cmd.ParseAndRun(ctx, os.Args[1:], ff.WithEnvVarPrefix("WF")); err != nil {
if errors.Is(err, ff.ErrHelp) || errors.Is(err, ff.ErrNoExec) {
fmt.Fprintf(os.Stderr, "\n%s\n", ffhelp.Command(cmd))
return nil
}
return fmt.Errorf("error running command: %w", err)
}
return nil
}
func NewServerCmd(cfg *Config) *ff.Command {
return &ff.Command{
Name: "serve",
Usage: "serve [flags]",
ShortHelp: "Start the webfinger server",
Exec: func(ctx context.Context, args []string) error {
// Create a logger and add it to the context
l := NewLogger(cfg)
ctx = WithLogger(ctx, l)
// Parse the webfinger files
fingermap, err := ParseFingerFile(ctx, cfg)
if err != nil {
return fmt.Errorf("error parsing finger files: %w", err)
}
l.Info(fmt.Sprintf("Loaded %d webfingers", len(fingermap)))
// Start the server
if err := StartServer(ctx, cfg, fingermap); err != nil {
return fmt.Errorf("error running server: %w", err)
}
return nil
},
}
}
func NewHealthcheckCmd(cfg *Config) *ff.Command {
return &ff.Command{
Name: "healthcheck",
Usage: "healthcheck [flags]",
ShortHelp: "Check if the server is running",
Exec: func(ctx context.Context, args []string) error {
// Create a new client
client := &http.Client{
Timeout: 5 * time.Second, //nolint:gomnd // We want to use a constant
}
// Create a new request
reqURL := url.URL{
Scheme: "http",
Host: net.JoinHostPort(cfg.Host, cfg.Port),
Path: "/healthz",
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody)
if err != nil {
return fmt.Errorf("error creating request: %w", err)
}
// Send the request
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("error sending request: %w", err)
}
defer resp.Body.Close()
// Check the response
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server returned status %d", resp.StatusCode) //nolint:goerr113 // We want to return an error
}
return nil
},
}
}
type loggerCtxKey struct{}
// NewLogger creates a new logger with the given debug level.
func NewLogger(cfg *Config) *slog.Logger {
level := slog.LevelInfo
addSource := false
if cfg.Debug {
level = slog.LevelDebug
addSource = true
}
return slog.New(
slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: addSource,
}),
)
}
func LoggerFromContext(ctx context.Context) *slog.Logger {
l, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger)
if !ok {
panic("logger not found in context")
}
return l
}
func WithLogger(ctx context.Context, l *slog.Logger) context.Context {
return context.WithValue(ctx, loggerCtxKey{}, l)
}
// https://github.com/caddyserver/caddy/blob/fbb0ecfa322aa7710a3448453fd3ae40f037b8d1/sigtrap.go#L37
// trapSignalsCrossPlatform captures SIGINT or interrupt (depending
// on the OS), which initiates a graceful shutdown. A second SIGINT
// or interrupt will forcefully exit the process immediately.
func trapSignalsCrossPlatform(cancel context.CancelFunc) {
go func() {
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGINT)
for i := 0; true; i++ {
<-shutdown
if i > 0 {
fmt.Printf("\nForce quit\n") //nolint:forbidigo // We want to print to stdout
os.Exit(1)
}
fmt.Printf("\nGracefully shutting down. Press Ctrl+C again to force quit\n") //nolint:forbidigo // We want to print to stdout
cancel()
}
}()
}
type Config struct {
Debug bool
Host string
Port string
urnPath string
fingerPath string
}
// NewRootCmd parses the command line flags and returns a Config struct.
func NewRootCmd(cfg *Config, subcommands []*ff.Command) *ff.Command {
fs := ff.NewFlagSet(appName)
for _, cmd := range subcommands {
cmd.Flags = ff.NewFlagSet(cmd.Name).SetParent(fs)
}
cmd := &ff.Command{
Name: appName,
Usage: fmt.Sprintf("%s <command> [flags]", appName),
ShortHelp: fmt.Sprintf("(%s) A webfinger server", version),
Flags: fs,
Subcommands: subcommands,
}
// Use 0.0.0.0 as the default host if on docker
defaultHost := "localhost"
if os.Getenv("ENV_DOCKER") == "true" {
defaultHost = "0.0.0.0"
}
fs.BoolVar(&cfg.Debug, 'd', "debug", "Enable debug logging")
fs.StringVar(&cfg.Host, 'h', "host", defaultHost, "Host to listen on")
fs.StringVar(&cfg.Port, 'p', "port", "8080", "Port to listen on")
fs.StringVar(&cfg.urnPath, 'u', "urn-file", "urns.yml", "Path to the URNs file")
fs.StringVar(&cfg.fingerPath, 'f', "finger-file", "fingers.yml", "Path to the fingers file")
return cmd
}
type Link struct {
Rel string `json:"rel"`
Href string `json:"href,omitempty"`
}
type WebFinger struct {
Subject string `json:"subject"`
Links []Link `json:"links,omitempty"`
Properties map[string]string `json:"properties,omitempty"`
}
type WebFingerMap map[string]*WebFinger
func ParseFingerFile(ctx context.Context, cfg *Config) (WebFingerMap, error) {
l := LoggerFromContext(ctx)
urnMap := make(map[string]string)
fingerData := make(map[string]map[string]string)
fingermap := make(WebFingerMap)
// Read URNs file
file, err := os.ReadFile(cfg.urnPath)
if err != nil {
return nil, fmt.Errorf("error opening URNs file: %w", err)
}
if err := yaml.Unmarshal(file, &urnMap); err != nil {
return nil, fmt.Errorf("error unmarshalling URNs file: %w", err)
}
// The URNs file must be a map of strings to valid URLs
for _, v := range urnMap {
if _, err := url.Parse(v); err != nil {
return nil, fmt.Errorf("error parsing URN URIs: %w", err)
}
}
l.Debug("URNs file parsed successfully", slog.Int("number", len(urnMap)), slog.Any("data", urnMap))
// Read webfingers file
file, err = os.ReadFile(cfg.fingerPath)
if err != nil {
return nil, fmt.Errorf("error opening fingers file: %w", err)
}
if err := yaml.Unmarshal(file, &fingerData); err != nil {
return nil, fmt.Errorf("error unmarshalling fingers file: %w", err)
}
l.Debug("Fingers file parsed successfully", slog.Int("number", len(fingerData)), slog.Any("data", fingerData))
// Parse the webfinger file
for k, v := range fingerData {
resource := k
// Remove leading acct: if present
if len(k) > 5 && resource[:5] == "acct:" {
resource = resource[5:]
}
// The key must be a URL or email address
if _, err := mail.ParseAddress(resource); err != nil {
if _, err := url.Parse(resource); err != nil {
return nil, fmt.Errorf("error parsing webfinger key (%s): %w", k, err)
}
} else {
// Add acct: back to the key if it is an email address
resource = fmt.Sprintf("acct:%s", resource)
}
// Create a new webfinger
webfinger := &WebFinger{
Subject: resource,
}
// Parse the fields
for field, value := range v {
fieldUrn := field
// If the key is present in the URNs file, use the value
if _, ok := urnMap[field]; ok {
fieldUrn = urnMap[field]
}
// If the value is a valid URI, add it to the links
if _, err := url.Parse(value); err == nil {
webfinger.Links = append(webfinger.Links, Link{
Rel: fieldUrn,
Href: value,
})
} else {
// Otherwise add it to the properties
if webfinger.Properties == nil {
webfinger.Properties = make(map[string]string)
}
webfinger.Properties[fieldUrn] = value
}
}
// Add the webfinger to the map
fingermap[resource] = webfinger
}
l.Debug("Webfinger map built successfully", slog.Int("number", len(fingermap)), slog.Any("data", fingermap))
return fingermap, nil
}
func WebfingerHandler(_ *Config, webmap WebFingerMap) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := LoggerFromContext(ctx)
// Only handle GET requests
if r.Method != http.MethodGet {
l.Debug("Method not allowed")
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get the query params
q := r.URL.Query()
// Get the resource
resource := q.Get("resource")
if resource == "" {
l.Debug("No resource provided")
http.Error(w, "No resource provided", http.StatusBadRequest)
return
}
// Get and validate resource
webfinger, ok := webmap[resource]
if !ok {
l.Debug("Resource not found")
http.Error(w, "Resource not found", http.StatusNotFound)
return
}
// Set the content type
w.Header().Set("Content-Type", "application/jrd+json")
// Write the response
if err := json.NewEncoder(w).Encode(webfinger); err != nil {
l.Debug("Error encoding json")
http.Error(w, "Error encoding json", http.StatusInternalServerError)
return
}
l.Debug("Webfinger request successful")
})
}
func HealthCheckHandler(_ *Config) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
}
type ResponseWrapper struct {
http.ResponseWriter
status int
}
func WrapResponseWriter(w http.ResponseWriter) *ResponseWrapper {
return &ResponseWrapper{w, 0}
}
func (w *ResponseWrapper) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
func (w *ResponseWrapper) Status() int {
return w.status
}
func (w *ResponseWrapper) Write(b []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
size, err := w.ResponseWriter.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing response: %w", err)
}
return size, nil
}
func (w *ResponseWrapper) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := LoggerFromContext(ctx)
start := time.Now()
// Wrap the response writer
wrapped := WrapResponseWriter(w)
// Call the next handler
next.ServeHTTP(wrapped, r)
status := wrapped.Status()
// Log the request
lg := l.With(
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", status),
slog.String("remote", r.RemoteAddr),
slog.Duration("duration", time.Since(start)),
)
switch {
case status >= http.StatusInternalServerError:
lg.Error("Server error")
case status >= http.StatusBadRequest:
lg.Info("Client error")
default:
lg.Info("Request completed")
}
})
}
const (
// ReadTimeout is the maximum duration for reading the entire
// request, including the body.
ReadTimeout = 5 * time.Second
// WriteTimeout is the maximum duration before timing out
// writes of the response.
WriteTimeout = 10 * time.Second
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alives are enabled.
IdleTimeout = 30 * time.Second
// ReadHeaderTimeout is the amount of time allowed to read
// request headers.
ReadHeaderTimeout = 2 * time.Second
// RequestTimeout is the maximum duration for the entire
// request.
RequestTimeout = 7 * 24 * time.Hour
)
func StartServer(ctx context.Context, cfg *Config, webmap WebFingerMap) error {
l := LoggerFromContext(ctx)
// Create the server mux
mux := http.NewServeMux()
mux.Handle("/.well-known/webfinger", WebfingerHandler(cfg, webmap))
mux.Handle("/healthz", HealthCheckHandler(cfg))
// Create a new server
srv := &http.Server{
Addr: net.JoinHostPort(cfg.Host, cfg.Port),
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
Handler: LoggingMiddleware(
RecoveryHandler(
http.TimeoutHandler(mux, RequestTimeout, "request timed out"),
),
),
ReadHeaderTimeout: ReadHeaderTimeout,
ReadTimeout: ReadTimeout,
WriteTimeout: WriteTimeout,
IdleTimeout: IdleTimeout,
}
// Create the errorgroup that will manage the server execution
eg, egCtx := errgroup.WithContext(ctx)
// Start the server
eg.Go(func() error {
l.Info("Starting server", slog.String("addr", srv.Addr))
// Use the global context for the server
srv.BaseContext = func(_ net.Listener) context.Context {
return egCtx
}
return srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup
})
// Gracefully shutdown the server when the context is done
eg.Go(func() error {
// Wait for the context to be done
<-egCtx.Done()
l.Info("Shutting down server")
// Disable the cancel since we don't wan't to force
// the server to shutdown if the context is canceled.
noCancelCtx := context.WithoutCancel(egCtx)
return srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup
})
srv.RegisterOnShutdown(func() {
l.Info("Server shutdown complete")
})
// Ignore the error if the context was canceled
if err := eg.Wait(); err != nil && ctx.Err() == nil {
return fmt.Errorf("server exited with error: %w", err)
}
return nil
}
func RecoveryHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := LoggerFromContext(ctx)
defer func() {
err := recover()
if err != nil {
l.Error("Panic", slog.Any("error", err))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}()
next.ServeHTTP(w, r)
})
}

View File

@ -1,60 +0,0 @@
package main_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
finger "git.maronato.dev/maronato/finger"
)
func BenchmarkGetWebfinger(b *testing.B) {
ctx := context.Background()
cfg := &finger.Config{}
l := finger.NewLogger(cfg)
ctx = finger.WithLogger(ctx, l)
resource := "acct:user@example.com"
webmap := finger.WebFingerMap{
resource: {
Subject: resource,
Links: []finger.Link{
{
Rel: "http://webfinger.net/rel/avatar",
Href: "https://example.com/avatar.png",
},
},
Properties: map[string]string{
"example": "value",
},
},
"acct:other": {
Subject: "acct:other",
Links: []finger.Link{
{
Rel: "http://webfinger.net/rel/avatar",
Href: "https://example.com/avatar.png",
},
},
Properties: map[string]string{
"example": "value",
},
},
}
handler := finger.WebfingerHandler(&finger.Config{}, webmap)
r, _ := http.NewRequestWithContext(
ctx,
http.MethodGet,
fmt.Sprintf("/.well-known/webfinger?resource=%s", resource),
http.NoBody,
)
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
}
}