diff --git a/health/health_server.go b/health/health_server.go index e6bd8ec..8377ead 100644 --- a/health/health_server.go +++ b/health/health_server.go @@ -6,20 +6,39 @@ import ( "strconv" "time" + "go.ntppool.org/common/logger" "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" ) -// HealthCheckListener runs simple http server on the specified port for -// health check probes -func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error { - log.Info("Starting health listener", "port", port) +type Server struct { + log *slog.Logger + healthFn http.HandlerFunc +} + +func NewServer(healthFn http.HandlerFunc) *Server { + if healthFn == nil { + healthFn = basicHealth + } + srv := &Server{ + log: logger.Setup(), + healthFn: healthFn, + } + return srv +} + +func (srv *Server) SetLogger(log *slog.Logger) { + srv.log = log +} + +func (srv *Server) Listen(ctx context.Context, port int) error { + srv.log.Info("Starting health listener", "port", port) serveMux := http.NewServeMux() - serveMux.HandleFunc("/__health", basicHealth) + serveMux.HandleFunc("/__health", srv.healthFn) - srv := &http.Server{ + hsrv := &http.Server{ Addr: ":" + strconv.Itoa(port), ReadTimeout: 10 * time.Second, WriteTimeout: 20 * time.Second, @@ -30,9 +49,9 @@ func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - err := srv.ListenAndServe() + err := hsrv.ListenAndServe() if err != http.ErrServerClosed { - log.Warn("health check server done listening", "err", err) + srv.log.Warn("health check server done listening", "err", err) return err } return nil @@ -44,8 +63,8 @@ func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error defer cancel() g.Go(func() error { - if err := srv.Shutdown(ctx); err != nil { - log.Error("health check server shutdown failed", "err", err) + if err := hsrv.Shutdown(ctx); err != nil { + srv.log.Error("health check server shutdown failed", "err", err) return err } return nil @@ -54,6 +73,14 @@ func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error return g.Wait() } +// HealthCheckListener runs simple http server on the specified port for +// health check probes +func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error { + srv := NewServer(nil) + srv.SetLogger(log) + return srv.Listen(ctx, port) +} + func basicHealth(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte("ok"))