common/ekko/ekko.go

159 lines
3.6 KiB
Go

package ekko
import (
"context"
"fmt"
"net"
"net/http"
"time"
"github.com/labstack/echo-contrib/echoprometheus"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
slogecho "github.com/samber/slog-echo"
"go.ntppool.org/common/logger"
"go.ntppool.org/common/version"
"go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
)
func New(name string, options ...func(*Ekko)) (*Ekko, error) {
ek := &Ekko{
writeTimeout: 60 * time.Second,
readHeaderTimeout: 30 * time.Second,
}
for _, o := range options {
o(ek)
}
return ek, nil
}
// Setup Echo; only intended for testing
func (ek *Ekko) SetupEcho(ctx context.Context) (*echo.Echo, error) {
return ek.setup(ctx)
}
// Setup Echo and start the server. Will return if the http server
// returns or the context is done.
func (ek *Ekko) Start(ctx context.Context) error {
log := logger.Setup()
e, err := ek.setup(ctx)
if err != nil {
return err
}
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
e.Server.Addr = fmt.Sprintf(":%d", ek.port)
log.Info("server starting", "port", ek.port)
err := e.Server.ListenAndServe()
if err == http.ErrServerClosed {
return nil
}
return err
})
g.Go(func() error {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return e.Shutdown(shutdownCtx)
})
return g.Wait()
}
func (ek *Ekko) setup(ctx context.Context) (*echo.Echo, error) {
log := logger.Setup()
e := echo.New()
e.Server.ReadHeaderTimeout = ek.readHeaderTimeout
e.Server.WriteTimeout = ek.writeTimeout
e.Server.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
trustOptions := []echo.TrustOption{
echo.TrustLoopback(true),
echo.TrustLinkLocal(false),
echo.TrustPrivateNet(true),
}
e.IPExtractor = echo.ExtractIPFromXFFHeader(trustOptions...)
if ek.otelmiddleware == nil {
e.Use(otelecho.Middleware(ek.name))
} else {
e.Use(ek.otelmiddleware)
}
e.Use(middleware.RecoverWithConfig(middleware.RecoverConfig{
LogErrorFunc: func(c echo.Context, err error, stack []byte) error {
log.ErrorContext(c.Request().Context(), err.Error(), "stack", string(stack))
fmt.Println(string(stack))
return err
},
}))
e.Use(slogecho.NewWithConfig(log,
slogecho.Config{
WithTraceID: false, // done by logger already
Filters: ek.logFilters,
},
))
if ek.prom != nil {
e.Use(echoprometheus.NewMiddlewareWithConfig(echoprometheus.MiddlewareConfig{
Subsystem: ek.name,
Registerer: ek.prom,
}))
}
e.Use(middleware.Gzip())
e.Use(middleware.Secure())
e.Use(
func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
request := c.Request()
span := trace.SpanFromContext(request.Context())
if span.IsRecording() {
span.SetAttributes(attribute.String("http.real_ip", c.RealIP()))
span.SetAttributes(attribute.String("url.path", c.Request().RequestURI))
if q := c.QueryString(); len(q) > 0 {
span.SetAttributes(attribute.String("url.query", q))
}
c.Response().Header().Set("Traceparent", span.SpanContext().TraceID().String())
}
return next(c)
}
},
)
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
vinfo := version.VersionInfo()
v := ek.name + "/" + vinfo.Version + "+" + vinfo.GitRevShort
return func(c echo.Context) error {
c.Response().Header().Set(echo.HeaderServer, v)
return next(c)
}
})
if ek.routeFn != nil {
err := ek.routeFn(e)
if err != nil {
return nil, err
}
}
return e, nil
}