package logger

import (
	"bytes"
	"context"
	"log/slog"
	"slices"
	"strings"
	"sync"
)

type logfmt struct {
	buf  *bytes.Buffer
	txt  slog.Handler
	next slog.Handler
	mu   sync.Mutex
}

func newLogFmtHandler(next slog.Handler) slog.Handler {

	buf := bytes.NewBuffer([]byte{})

	h := &logfmt{
		buf:  buf,
		next: next,
		txt: slog.NewTextHandler(buf, &slog.HandlerOptions{
			ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
				if a.Key == slog.TimeKey && len(groups) == 0 {
					return slog.Attr{}
				}
				if a.Key == slog.LevelKey && len(groups) == 0 {
					return slog.Attr{}
				}
				return a
			},
		}),
	}

	return h
}

func (h *logfmt) Enabled(ctx context.Context, lvl slog.Level) bool {
	return h.next.Enabled(ctx, lvl)
}

func (h *logfmt) WithAttrs(attrs []slog.Attr) slog.Handler {
	return &logfmt{
		buf:  bytes.NewBuffer([]byte{}),
		next: h.next.WithAttrs(slices.Clone(attrs)),
		txt:  h.txt.WithAttrs(slices.Clone(attrs)),
	}
}

func (h *logfmt) WithGroup(g string) slog.Handler {
	if g == "" {
		return h
	}
	return &logfmt{
		buf:  bytes.NewBuffer([]byte{}),
		next: h.next.WithGroup(g),
		txt:  h.txt.WithGroup(g),
	}
}

func (h *logfmt) Handle(ctx context.Context, r slog.Record) error {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.buf.Len() > 0 {
		panic("buffer wasn't empty")
	}

	h.txt.Handle(ctx, r)
	r.Message = h.buf.String()
	r.Message = strings.TrimSuffix(r.Message, "\n")
	h.buf.Reset()

	return h.next.Handle(ctx, r)
}