package health

import (
	"context"
	"log/slog"
	"net/http"
	"strconv"
	"time"

	"go.ntppool.org/common/logger"
	"golang.org/x/sync/errgroup"
)

type Server struct {
	log      *slog.Logger
	healthFn http.HandlerFunc
}

func NewServer(healthFn http.HandlerFunc) *Server {
	if healthFn == nil {
		healthFn = basicHealth
	}
	srv := &Server{
		log:      logger.Setup(),
		healthFn: healthFn,
	}
	return srv
}

func (srv *Server) SetLogger(log *slog.Logger) {
	srv.log = log
}

func (srv *Server) Listen(ctx context.Context, port int) error {
	srv.log.Info("Starting health listener", "port", port)

	serveMux := http.NewServeMux()

	serveMux.HandleFunc("/__health", srv.healthFn)

	hsrv := &http.Server{
		Addr:         ":" + strconv.Itoa(port),
		ReadTimeout:  10 * time.Second,
		WriteTimeout: 20 * time.Second,
		IdleTimeout:  120 * time.Second,
		Handler:      serveMux,
	}

	g, ctx := errgroup.WithContext(ctx)

	g.Go(func() error {
		err := hsrv.ListenAndServe()
		if err != http.ErrServerClosed {
			srv.log.Warn("health check server done listening", "err", err)
			return err
		}
		return nil
	})

	<-ctx.Done()

	ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
	defer cancel()

	g.Go(func() error {
		if err := hsrv.Shutdown(ctx); err != nil {
			srv.log.Error("health check server shutdown failed", "err", err)
			return err
		}
		return nil
	})

	return g.Wait()
}

// HealthCheckListener runs simple http server on the specified port for
// health check probes
func HealthCheckListener(ctx context.Context, port int, log *slog.Logger) error {
	srv := NewServer(nil)
	srv.SetLogger(log)
	return srv.Listen(ctx, port)
}

func basicHealth(w http.ResponseWriter, r *http.Request) {
	w.WriteHeader(200)
	w.Write([]byte("ok"))
}