diff --git a/catalog.go b/catalog.go index 90c2c95..301eb42 100644 --- a/catalog.go +++ b/catalog.go @@ -5,6 +5,7 @@ import ( "fmt" "hash/fnv" "os" + "path/filepath" "sort" "strings" "time" @@ -29,12 +30,13 @@ func hashZoneName(zone string) string { h := fnv.New32a() h.Write([]byte(zone)) sum := h.Sum32() - b := make([]byte, 4) - b[0] = byte(sum >> 24) - b[1] = byte(sum >> 16) - b[2] = byte(sum >> 8) - b[3] = byte(sum) - return b32.EncodeToString(b) + b := [4]byte{ + byte(sum >> 24), + byte(sum >> 16), + byte(sum >> 8), + byte(sum), + } + return b32.EncodeToString(b[:]) } // generateCatalogZone builds the zone file content for a single catalog. @@ -76,20 +78,22 @@ func generateCatalogZone(catName string, cfg *Config, members []ZoneEntry, seria return sorted[i].Zone < sorted[j].Zone }) - // Check for hash collisions - hashToZone := make(map[string]string) + // Check for hash collisions and cache results + hashToZone := make(map[string]string, len(sorted)) + zoneHash := make(map[string]string, len(sorted)) for _, entry := range sorted { h := hashZoneName(entry.Zone) if existing, ok := hashToZone[h]; ok && existing != entry.Zone { - return "", fmt.Errorf("error: %s:%d: hash collision between %s and %s in catalog %q", + return "", fmt.Errorf("%s:%d: hash collision between %s and %s in catalog %q", entry.File, entry.Line, existing, entry.Zone, catName) } hashToZone[h] = entry.Zone + zoneHash[entry.Zone] = h } // Member records for _, entry := range sorted { - h := hashZoneName(entry.Zone) + h := zoneHash[entry.Zone] // PTR record ptrOwner := fmt.Sprintf("%s.zones.%s", h, origin) @@ -126,58 +130,48 @@ func generateCatalogZone(catName string, cfg *Config, members []ZoneEntry, seria return strings.Join(records, "\n") + "\n", nil } -// readExistingSerial reads an existing zone file and extracts the SOA serial. -// Returns 0, nil if the file doesn't exist. -func readExistingSerial(path string) (uint32, error) { - f, err := os.Open(path) +// readExisting reads an existing zone file and returns its content and SOA serial. +// Returns ("", 0, nil) if the file doesn't exist. +func readExisting(path string) (string, uint32, error) { + data, err := os.ReadFile(path) if os.IsNotExist(err) { - return 0, nil + return "", 0, nil } if err != nil { - return 0, fmt.Errorf("reading existing zone %s: %w", path, err) + return "", 0, fmt.Errorf("reading existing zone %s: %w", path, err) } - defer f.Close() + content := string(data) - zp := dns.NewZoneParser(f, "", path) + zp := dns.NewZoneParser(strings.NewReader(content), "", path) for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { if soa, ok := rr.(*dns.SOA); ok { - return soa.Serial, nil + return content, soa.Serial, nil } } if err := zp.Err(); err != nil { - return 0, fmt.Errorf("parsing existing zone %s: %w", path, err) + return "", 0, fmt.Errorf("parsing existing zone %s: %w", path, err) } - return 0, nil + return content, 0, nil } -// readExistingContent reads the full content of an existing zone file. -// Returns empty string if file doesn't exist. -func readExistingContent(path string) (string, error) { - data, err := os.ReadFile(path) - if os.IsNotExist(err) { - return "", nil - } - if err != nil { - return "", err - } - return string(data), nil +// dateBase returns the YYYYMMDD00 serial base for the given time. +func dateBase(now time.Time) uint32 { + return uint32(now.Year())*1000000 + + uint32(now.Month())*10000 + + uint32(now.Day())*100 } // defaultSerial returns a serial for today with sequence 01: YYYYMMDD01. func defaultSerial(now time.Time) uint32 { - return uint32(now.Year())*1000000 + - uint32(now.Month())*10000 + - uint32(now.Day())*100 + 1 + return dateBase(now) + 1 } // bumpSerial increments a serial. If same date, bumps the sequence number. // If different date, starts at YYYYMMDD01. // Returns error if sequence reaches 99 and needs another bump. func bumpSerial(old uint32, now time.Time) (uint32, error) { - todayBase := uint32(now.Year())*1000000 + - uint32(now.Month())*10000 + - uint32(now.Day())*100 + todayBase := dateBase(now) if old >= todayBase && old < todayBase+100 { // Same date, bump sequence @@ -196,10 +190,10 @@ func bumpSerial(old uint32, now time.Time) (uint32, error) { // Returns true if the file was written (changed), false if unchanged. func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir string, now time.Time) (bool, error) { catCfg := cfg.Catalogs[catName] - outputPath := outputDir + "/" + catCfg.Zone + "zone" + outputPath := filepath.Join(outputDir, catCfg.Zone+"zone") - // Read existing serial - oldSerial, err := readExistingSerial(outputPath) + // Read existing file (content + serial in one pass) + existing, oldSerial, err := readExisting(outputPath) if err != nil { return false, err } @@ -217,11 +211,6 @@ func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir } // Compare with existing file - existing, err := readExistingContent(outputPath) - if err != nil { - return false, err - } - if content == existing { return false, nil } diff --git a/catalog_test.go b/catalog_test.go index 72414e9..5daf8a1 100644 --- a/catalog_test.go +++ b/catalog_test.go @@ -1,6 +1,8 @@ package main import ( + "path/filepath" + "strings" "testing" ) @@ -159,12 +161,12 @@ func TestGenerateCatalogZone(t *testing.T) { aIdx := -1 bIdx := -1 for i, line := range lines { - if containsStr(line, "a.example.com.") { + if strings.Contains(line, "a.example.com.") { if aIdx == -1 { aIdx = i } } - if containsStr(line, "b.example.com.") { + if strings.Contains(line, "b.example.com.") { if bIdx == -1 { bIdx = i } @@ -180,7 +182,7 @@ func TestGenerateCatalogZone(t *testing.T) { // a.example.com should have a group TXT foundGroup := false for _, line := range lines { - if containsStr(line, "group.") && containsStr(line, "\"mygroup\"") { + if strings.Contains(line, "group.") && strings.Contains(line, "\"mygroup\"") { foundGroup = true } } @@ -211,7 +213,7 @@ func TestGenerateCatalogZoneCOO(t *testing.T) { foundCOO := false for _, line := range splitLines(content) { - if containsStr(line, "coo.") && containsStr(line, "old.example.com.") { + if strings.Contains(line, "coo.") && strings.Contains(line, "old.example.com.") { foundCOO = true } } @@ -220,69 +222,51 @@ func TestGenerateCatalogZoneCOO(t *testing.T) { } } -func TestReadExistingSerial(t *testing.T) { +func TestReadExisting(t *testing.T) { t.Run("file does not exist", func(t *testing.T) { - serial, err := readExistingSerial("/nonexistent/path.zone") + content, serial, err := readExisting("/nonexistent/path.zone") if err != nil { t.Fatalf("unexpected error: %v", err) } if serial != 0 { t.Errorf("serial = %d, want 0", serial) } + if content != "" { + t.Errorf("content = %q, want empty", content) + } }) t.Run("valid zone file", func(t *testing.T) { dir := t.TempDir() - path := dir + "/test.zone" - content := "example.com. 0 IN SOA ns1.example.com. hostmaster.example.com. 2026030205 900 600 2147483646 0\nexample.com. 0 IN NS invalid.\n" - writeTestFile(t, dir, "test.zone", content) + zoneContent := "example.com. 0 IN SOA ns1.example.com. hostmaster.example.com. 2026030205 900 600 2147483646 0\nexample.com. 0 IN NS invalid.\n" + writeTestFile(t, dir, "test.zone", zoneContent) - serial, err := readExistingSerial(path) + content, serial, err := readExisting(filepath.Join(dir, "test.zone")) if err != nil { t.Fatalf("unexpected error: %v", err) } if serial != 2026030205 { t.Errorf("serial = %d, want 2026030205", serial) } + if content != zoneContent { + t.Errorf("content mismatch: got %q", content) + } }) t.Run("zone file with no SOA", func(t *testing.T) { dir := t.TempDir() - path := dir + "/test.zone" - content := "example.com. 0 IN NS invalid.\n" - writeTestFile(t, dir, "test.zone", content) + zoneContent := "example.com. 0 IN NS invalid.\n" + writeTestFile(t, dir, "test.zone", zoneContent) - serial, err := readExistingSerial(path) + content, serial, err := readExisting(filepath.Join(dir, "test.zone")) if err != nil { t.Fatalf("unexpected error: %v", err) } if serial != 0 { t.Errorf("serial = %d, want 0 (no SOA found)", serial) } - }) -} - -func TestReadExistingContent(t *testing.T) { - t.Run("file does not exist", func(t *testing.T) { - content, err := readExistingContent("/nonexistent/path.zone") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if content != "" { - t.Errorf("content = %q, want empty", content) - } - }) - - t.Run("file exists", func(t *testing.T) { - dir := t.TempDir() - writeTestFile(t, dir, "test.zone", "hello\n") - - content, err := readExistingContent(dir + "/test.zone") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if content != "hello\n" { - t.Errorf("content = %q, want %q", content, "hello\n") + if content != zoneContent { + t.Errorf("content mismatch: got %q", content) } }) } diff --git a/helpers_test.go b/helpers_test.go index a5bcc6a..d9ff4e6 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -24,7 +24,3 @@ func assertContains(t *testing.T, s, substr string) { t.Errorf("expected %q to contain %q", s, substr) } } - -func containsStr(s, substr string) bool { - return strings.Contains(s, substr) -} diff --git a/input.go b/input.go index 8e5d7a4..650e5ec 100644 --- a/input.go +++ b/input.go @@ -54,7 +54,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) { // Split on whitespace first, then handle comma separation within tokens tokens := tokenize(line) if len(tokens) < 2 { - return ZoneEntry{}, fmt.Errorf("error: %s:%d: expected zone name followed by at least one catalog name", file, lineNum) + return ZoneEntry{}, fmt.Errorf("%s:%d: expected zone name followed by at least one catalog name", file, lineNum) } entry := ZoneEntry{ @@ -69,16 +69,16 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) { switch strings.ToLower(key) { case "group": if value == "" { - return ZoneEntry{}, fmt.Errorf("error: %s:%d: empty group value", file, lineNum) + return ZoneEntry{}, fmt.Errorf("%s:%d: empty group value", file, lineNum) } entry.Group = value case "coo": if value == "" { - return ZoneEntry{}, fmt.Errorf("error: %s:%d: empty coo value", file, lineNum) + return ZoneEntry{}, fmt.Errorf("%s:%d: empty coo value", file, lineNum) } entry.COO = normalizeFQDN(value) default: - return ZoneEntry{}, fmt.Errorf("error: %s:%d: unknown property %q", file, lineNum, key) + return ZoneEntry{}, fmt.Errorf("%s:%d: unknown property %q", file, lineNum, key) } } else { // Bare name = catalog assignment @@ -87,7 +87,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) { } if len(entry.Catalogs) == 0 { - return ZoneEntry{}, fmt.Errorf("error: %s:%d: no catalog assignment for zone %s", file, lineNum, entry.Zone) + return ZoneEntry{}, fmt.Errorf("%s:%d: no catalog assignment for zone %s", file, lineNum, entry.Zone) } return entry, nil @@ -95,10 +95,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) { // tokenize splits a line on whitespace and commas, stripping empty tokens. func tokenize(line string) []string { - // Replace commas with spaces, then split on whitespace - line = strings.ReplaceAll(line, ",", " ") - fields := strings.Fields(line) - return fields + return strings.Fields(strings.ReplaceAll(line, ",", " ")) } func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, error) { @@ -110,14 +107,14 @@ func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, erro for _, entry := range entries { for _, catName := range entry.Catalogs { if _, ok := cfg.Catalogs[catName]; !ok { - return nil, fmt.Errorf("error: %s:%d: unknown catalog %q", entry.File, entry.Line, catName) + return nil, fmt.Errorf("%s:%d: unknown catalog %q", entry.File, entry.Line, catName) } if seen[catName] == nil { seen[catName] = make(map[string]int) } if prevLine, dup := seen[catName][entry.Zone]; dup { - return nil, fmt.Errorf("error: %s:%d: zone %s already assigned to catalog %q at line %d", + return nil, fmt.Errorf("%s:%d: zone %s already assigned to catalog %q at line %d", entry.File, entry.Line, entry.Zone, catName, prevLine) } seen[catName][entry.Zone] = entry.Line diff --git a/input_test.go b/input_test.go index 5a46e3e..a0c3ca7 100644 --- a/input_test.go +++ b/input_test.go @@ -1,7 +1,6 @@ package main import ( - "os" "path/filepath" "testing" ) @@ -206,9 +205,7 @@ zone.example.org catalog1, catalog2 test.example.net catalog1, group=internal ` - if err := os.WriteFile(inputPath, []byte(content), 0o644); err != nil { - t.Fatal(err) - } + writeTestFile(t, dir, "zones.txt", content) members, err := parseInput(inputPath, cfg) if err != nil { @@ -240,9 +237,7 @@ func TestParseInputErrors(t *testing.T) { t.Run("invalid line in input", func(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "zones.txt") - if err := os.WriteFile(path, []byte("zone-with-no-catalog\n"), 0o644); err != nil { - t.Fatal(err) - } + writeTestFile(t, dir, "zones.txt", "zone-with-no-catalog\n") _, err := parseInput(path, cfg) if err == nil { t.Fatal("expected error for invalid line") diff --git a/main.go b/main.go index 44376be..f279a92 100644 --- a/main.go +++ b/main.go @@ -41,7 +41,7 @@ func main() { members, err := parseInput(inputFile, cfg) if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err) + fmt.Fprintf(os.Stderr, "error: %s\n", err) os.Exit(1) } @@ -58,15 +58,15 @@ func main() { for _, catName := range catNames { changed, err := processCatalog(catName, cfg, members[catName], *outputDir, now) if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err) + fmt.Fprintf(os.Stderr, "error: %s\n", err) hasErrors = true continue } - catZone := cfg.Catalogs[catName].Zone + zoneFile := cfg.Catalogs[catName].Zone + "zone" if changed { - fmt.Fprintf(os.Stderr, "%s%s: updated\n", catZone, "zone") + fmt.Fprintf(os.Stderr, "%s: updated\n", zoneFile) } else { - fmt.Fprintf(os.Stderr, "%s%s: unchanged\n", catZone, "zone") + fmt.Fprintf(os.Stderr, "%s: unchanged\n", zoneFile) } } diff --git a/main_test.go b/main_test.go index 074741c..82ff143 100644 --- a/main_test.go +++ b/main_test.go @@ -6,7 +6,6 @@ import ( "path/filepath" "strings" "testing" - "time" ) func TestIntegrationEndToEnd(t *testing.T) { @@ -213,7 +212,7 @@ soa: // Find PTR lines and verify order var ptrZones []string for _, line := range lines { - if containsStr(line, "\tPTR\t") && !containsStr(line, "coo.") { + if strings.Contains(line, "\tPTR\t") && !strings.Contains(line, "coo.") { // Extract the PTR target parts := strings.Split(line, "\t") ptrZones = append(ptrZones, parts[len(parts)-1]) @@ -313,9 +312,7 @@ soa: if err == nil { t.Error("expected error for unknown catalog") } - if !containsStr(string(out), "unknown catalog") { - t.Errorf("expected 'unknown catalog' in error output, got: %s", out) - } + assertContains(t, string(out), "unknown catalog") }) } @@ -405,7 +402,7 @@ soa: assertContains(t, content2, "2026011601") // Simulate not using the tool for a while, running on a much later date - day3 := time.Date(2026, 6, 15, 0, 0, 0, 0, time.UTC) + day3 := fixedTime(2026, 6, 15) members3 := []ZoneEntry{ {Zone: "a.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 1}, {Zone: "b.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 2},