109 lines
2.8 KiB
Go
109 lines
2.8 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
|
|
"git.maronato.dev/maronato/goshort/internal/config"
|
|
servermiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware"
|
|
sessionmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/session"
|
|
tracingmiddleware "git.maronato.dev/maronato/goshort/internal/server/middleware/tracing"
|
|
"git.maronato.dev/maronato/goshort/internal/util/logging"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
const compressionRatio = 5
|
|
|
|
type Server struct {
|
|
srv *http.Server
|
|
Mux *chi.Mux
|
|
}
|
|
|
|
func NewServer(cfg *config.Config) *Server {
|
|
// Parse the address
|
|
addr := net.JoinHostPort(cfg.Host, cfg.Port)
|
|
|
|
// Create the mux
|
|
mux := chi.NewRouter()
|
|
|
|
// Register default middlewares
|
|
mux.Use(middleware.RequestID)
|
|
mux.Use(middleware.RealIP)
|
|
mux.Use(middleware.StripSlashes)
|
|
|
|
// Register tracing middleware
|
|
mux.Use(tracingmiddleware.Tracer())
|
|
|
|
// Register logging middleware
|
|
requestLogger := servermiddleware.NewLogFormatter(cfg)
|
|
mux.Use(middleware.RequestLogger(requestLogger))
|
|
|
|
// Register secondary middlewares
|
|
mux.Use(middleware.Recoverer)
|
|
mux.Use(sessionmiddleware.SessionManager(cfg))
|
|
mux.Use(middleware.Timeout(config.RequestTimeout))
|
|
mux.Use(middleware.Compress(compressionRatio, "application/json"))
|
|
|
|
// Create the server
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: mux,
|
|
ReadHeaderTimeout: config.ReadHeaderTimeout,
|
|
ReadTimeout: config.ReadTimeout,
|
|
WriteTimeout: config.WriteTimeout,
|
|
IdleTimeout: config.IdleTimeout,
|
|
}
|
|
|
|
return &Server{
|
|
srv: srv,
|
|
Mux: mux,
|
|
}
|
|
}
|
|
|
|
func (s *Server) ListenAndServe(ctx context.Context) error {
|
|
l := logging.FromCtx(ctx)
|
|
|
|
// Create the errorgroup that will manage the server execution
|
|
eg, egCtx := errgroup.WithContext(ctx)
|
|
|
|
// Start the server
|
|
eg.Go(func() error {
|
|
l.Info("Starting server", slog.String("addr", s.srv.Addr))
|
|
|
|
// Use the global context for the server
|
|
s.srv.BaseContext = func(_ net.Listener) context.Context {
|
|
return egCtx
|
|
}
|
|
|
|
return s.srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup
|
|
})
|
|
// Gracefully shutdown the server when the context is done
|
|
eg.Go(func() error {
|
|
// Wait for the context to be done
|
|
<-egCtx.Done()
|
|
|
|
l.Info("Shutting down server")
|
|
// Disable the cancel since we don't wan't to force
|
|
// the server to shutdown if the context is canceled.
|
|
noCancelCtx := context.WithoutCancel(egCtx)
|
|
|
|
return s.srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup
|
|
})
|
|
|
|
s.srv.RegisterOnShutdown(func() {
|
|
l.Info("Server shutdown complete")
|
|
})
|
|
|
|
// Ignore the error if the context was canceled
|
|
if err := eg.Wait(); err != nil && ctx.Err() == nil {
|
|
return fmt.Errorf("server exited with error: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|