diff --git a/cmd/servercmd.go b/cmd/servercmd.go index e2be3a9..26cc480 100644 --- a/cmd/servercmd.go +++ b/cmd/servercmd.go @@ -2,10 +2,16 @@ package cmd import ( "context" + "errors" "fmt" - "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" "github.com/spf13/cobra" + "go.ntppool.org/common/logger" "go.ntppool.org/data-api/server" "golang.org/x/sync/errgroup" ) @@ -25,25 +31,36 @@ func (cli *CLI) serverCmd() *cobra.Command { } func (cli *CLI) serverCLI(cmd *cobra.Command, args []string) error { + log := logger.Setup() // cfg := cli.Config - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + g, ctx := errgroup.WithContext(ctx) + srv, err := server.NewServer(ctx, cfgFile) + if err != nil { + return fmt.Errorf("srv setup: %s", err) + } + g.Go(func() error { - srv, err := server.NewServer(ctx, cfgFile) - if err != nil { - return fmt.Errorf("srv setup: %s", err) - } return srv.Run() }) - err := g.Wait() - if err != nil { - log.Printf("server error: %s", err) + g.Go(func() error { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return srv.Shutdown(shutdownCtx) + }) + + err = g.Wait() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Error("server error", "err", err) } - cancel() - return err - + // don't tell cobra something went wrong as it'll just + // print usage information + return nil } diff --git a/server/server.go b/server/server.go index 780121e..02a9165 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "database/sql" + "errors" "fmt" "log/slog" "net/http" @@ -14,7 +15,6 @@ import ( slogecho "github.com/samber/slog-echo" "go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho" - otrace "go.opentelemetry.io/otel/trace" "go.ntppool.org/common/health" "go.ntppool.org/common/logger" @@ -31,8 +31,8 @@ type Server struct { ctx context.Context - metrics *metricsserver.Metrics - tracer otrace.Tracer + metrics *metricsserver.Metrics + tpShutdown []tracing.TpShutdownFunc } func NewServer(ctx context.Context, configFile string) (*Server, error) { @@ -52,7 +52,7 @@ func NewServer(ctx context.Context, configFile string) (*Server, error) { metrics: metricsserver.New(), } - err = tracing.InitTracer(ctx, &tracing.TracerConfig{ + tpShutdown, err := tracing.InitTracer(ctx, &tracing.TracerConfig{ ServiceName: "data-api", Environment: "", }) @@ -60,7 +60,8 @@ func NewServer(ctx context.Context, configFile string) (*Server, error) { return nil, err } - srv.tracer = tracing.Tracer() + srv.tpShutdown = append(srv.tpShutdown, tpShutdown) + // srv.tracer = tracing.Tracer() return srv, nil } @@ -84,6 +85,8 @@ func (srv *Server) Run() error { e.Use(otelecho.Middleware("data-api")) e.Use(slogecho.New(log)) + srv.tpShutdown = append(srv.tpShutdown, e.Shutdown) + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{ "http://localhost", "http://localhost:5173", "http://localhost:8080", @@ -114,6 +117,18 @@ func (srv *Server) Run() error { return g.Wait() } +func (srv *Server) Shutdown(ctx context.Context) error { + logger.Setup().Info("Shutting down") + errs := []error{} + for _, fn := range srv.tpShutdown { + err := fn(ctx) + if err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + func (srv *Server) userCountryData(c echo.Context) error { ctx := c.Request().Context()