Clean up code reuse, consistency, and efficiency issues

Merge readExistingSerial and readExistingContent into a single
readExisting function to eliminate duplicate file I/O. Extract dateBase
helper to deduplicate serial formula between defaultSerial and
bumpSerial. Cache hash results during collision check to avoid
recomputing per member. Normalize error prefixes (remove "error:" from
fmt.Errorf, add uniformly at print sites). Use filepath.Join instead of
manual "/" concatenation. Replace trivial containsStr wrapper with
strings.Contains. Simplify tokenize to a single return. Use writeTestFile
and fixedTime helpers consistently in tests.
This commit is contained in:
2026-03-01 17:38:26 -08:00
parent 1f2f39f40c
commit 0a460b975d
7 changed files with 75 additions and 117 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -29,12 +30,13 @@ func hashZoneName(zone string) string {
h := fnv.New32a() h := fnv.New32a()
h.Write([]byte(zone)) h.Write([]byte(zone))
sum := h.Sum32() sum := h.Sum32()
b := make([]byte, 4) b := [4]byte{
b[0] = byte(sum >> 24) byte(sum >> 24),
b[1] = byte(sum >> 16) byte(sum >> 16),
b[2] = byte(sum >> 8) byte(sum >> 8),
b[3] = byte(sum) byte(sum),
return b32.EncodeToString(b) }
return b32.EncodeToString(b[:])
} }
// generateCatalogZone builds the zone file content for a single catalog. // generateCatalogZone builds the zone file content for a single catalog.
@@ -76,20 +78,22 @@ func generateCatalogZone(catName string, cfg *Config, members []ZoneEntry, seria
return sorted[i].Zone < sorted[j].Zone return sorted[i].Zone < sorted[j].Zone
}) })
// Check for hash collisions // Check for hash collisions and cache results
hashToZone := make(map[string]string) hashToZone := make(map[string]string, len(sorted))
zoneHash := make(map[string]string, len(sorted))
for _, entry := range sorted { for _, entry := range sorted {
h := hashZoneName(entry.Zone) h := hashZoneName(entry.Zone)
if existing, ok := hashToZone[h]; ok && existing != entry.Zone { if existing, ok := hashToZone[h]; ok && existing != entry.Zone {
return "", fmt.Errorf("error: %s:%d: hash collision between %s and %s in catalog %q", return "", fmt.Errorf("%s:%d: hash collision between %s and %s in catalog %q",
entry.File, entry.Line, existing, entry.Zone, catName) entry.File, entry.Line, existing, entry.Zone, catName)
} }
hashToZone[h] = entry.Zone hashToZone[h] = entry.Zone
zoneHash[entry.Zone] = h
} }
// Member records // Member records
for _, entry := range sorted { for _, entry := range sorted {
h := hashZoneName(entry.Zone) h := zoneHash[entry.Zone]
// PTR record // PTR record
ptrOwner := fmt.Sprintf("%s.zones.%s", h, origin) ptrOwner := fmt.Sprintf("%s.zones.%s", h, origin)
@@ -126,58 +130,48 @@ func generateCatalogZone(catName string, cfg *Config, members []ZoneEntry, seria
return strings.Join(records, "\n") + "\n", nil return strings.Join(records, "\n") + "\n", nil
} }
// readExistingSerial reads an existing zone file and extracts the SOA serial. // readExisting reads an existing zone file and returns its content and SOA serial.
// Returns 0, nil if the file doesn't exist. // Returns ("", 0, nil) if the file doesn't exist.
func readExistingSerial(path string) (uint32, error) { func readExisting(path string) (string, uint32, error) {
f, err := os.Open(path) data, err := os.ReadFile(path)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return 0, nil return "", 0, nil
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("reading existing zone %s: %w", path, err) return "", 0, fmt.Errorf("reading existing zone %s: %w", path, err)
} }
defer f.Close() content := string(data)
zp := dns.NewZoneParser(f, "", path) zp := dns.NewZoneParser(strings.NewReader(content), "", path)
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
if soa, ok := rr.(*dns.SOA); ok { if soa, ok := rr.(*dns.SOA); ok {
return soa.Serial, nil return content, soa.Serial, nil
} }
} }
if err := zp.Err(); err != nil { if err := zp.Err(); err != nil {
return 0, fmt.Errorf("parsing existing zone %s: %w", path, err) return "", 0, fmt.Errorf("parsing existing zone %s: %w", path, err)
} }
return 0, nil return content, 0, nil
} }
// readExistingContent reads the full content of an existing zone file. // dateBase returns the YYYYMMDD00 serial base for the given time.
// Returns empty string if file doesn't exist. func dateBase(now time.Time) uint32 {
func readExistingContent(path string) (string, error) { return uint32(now.Year())*1000000 +
data, err := os.ReadFile(path) uint32(now.Month())*10000 +
if os.IsNotExist(err) { uint32(now.Day())*100
return "", nil
}
if err != nil {
return "", err
}
return string(data), nil
} }
// defaultSerial returns a serial for today with sequence 01: YYYYMMDD01. // defaultSerial returns a serial for today with sequence 01: YYYYMMDD01.
func defaultSerial(now time.Time) uint32 { func defaultSerial(now time.Time) uint32 {
return uint32(now.Year())*1000000 + return dateBase(now) + 1
uint32(now.Month())*10000 +
uint32(now.Day())*100 + 1
} }
// bumpSerial increments a serial. If same date, bumps the sequence number. // bumpSerial increments a serial. If same date, bumps the sequence number.
// If different date, starts at YYYYMMDD01. // If different date, starts at YYYYMMDD01.
// Returns error if sequence reaches 99 and needs another bump. // Returns error if sequence reaches 99 and needs another bump.
func bumpSerial(old uint32, now time.Time) (uint32, error) { func bumpSerial(old uint32, now time.Time) (uint32, error) {
todayBase := uint32(now.Year())*1000000 + todayBase := dateBase(now)
uint32(now.Month())*10000 +
uint32(now.Day())*100
if old >= todayBase && old < todayBase+100 { if old >= todayBase && old < todayBase+100 {
// Same date, bump sequence // Same date, bump sequence
@@ -196,10 +190,10 @@ func bumpSerial(old uint32, now time.Time) (uint32, error) {
// Returns true if the file was written (changed), false if unchanged. // Returns true if the file was written (changed), false if unchanged.
func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir string, now time.Time) (bool, error) { func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir string, now time.Time) (bool, error) {
catCfg := cfg.Catalogs[catName] catCfg := cfg.Catalogs[catName]
outputPath := outputDir + "/" + catCfg.Zone + "zone" outputPath := filepath.Join(outputDir, catCfg.Zone+"zone")
// Read existing serial // Read existing file (content + serial in one pass)
oldSerial, err := readExistingSerial(outputPath) existing, oldSerial, err := readExisting(outputPath)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -217,11 +211,6 @@ func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir
} }
// Compare with existing file // Compare with existing file
existing, err := readExistingContent(outputPath)
if err != nil {
return false, err
}
if content == existing { if content == existing {
return false, nil return false, nil
} }

View File

@@ -1,6 +1,8 @@
package main package main
import ( import (
"path/filepath"
"strings"
"testing" "testing"
) )
@@ -159,12 +161,12 @@ func TestGenerateCatalogZone(t *testing.T) {
aIdx := -1 aIdx := -1
bIdx := -1 bIdx := -1
for i, line := range lines { for i, line := range lines {
if containsStr(line, "a.example.com.") { if strings.Contains(line, "a.example.com.") {
if aIdx == -1 { if aIdx == -1 {
aIdx = i aIdx = i
} }
} }
if containsStr(line, "b.example.com.") { if strings.Contains(line, "b.example.com.") {
if bIdx == -1 { if bIdx == -1 {
bIdx = i bIdx = i
} }
@@ -180,7 +182,7 @@ func TestGenerateCatalogZone(t *testing.T) {
// a.example.com should have a group TXT // a.example.com should have a group TXT
foundGroup := false foundGroup := false
for _, line := range lines { for _, line := range lines {
if containsStr(line, "group.") && containsStr(line, "\"mygroup\"") { if strings.Contains(line, "group.") && strings.Contains(line, "\"mygroup\"") {
foundGroup = true foundGroup = true
} }
} }
@@ -211,7 +213,7 @@ func TestGenerateCatalogZoneCOO(t *testing.T) {
foundCOO := false foundCOO := false
for _, line := range splitLines(content) { for _, line := range splitLines(content) {
if containsStr(line, "coo.") && containsStr(line, "old.example.com.") { if strings.Contains(line, "coo.") && strings.Contains(line, "old.example.com.") {
foundCOO = true foundCOO = true
} }
} }
@@ -220,69 +222,51 @@ func TestGenerateCatalogZoneCOO(t *testing.T) {
} }
} }
func TestReadExistingSerial(t *testing.T) { func TestReadExisting(t *testing.T) {
t.Run("file does not exist", func(t *testing.T) { t.Run("file does not exist", func(t *testing.T) {
serial, err := readExistingSerial("/nonexistent/path.zone") content, serial, err := readExisting("/nonexistent/path.zone")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if serial != 0 { if serial != 0 {
t.Errorf("serial = %d, want 0", serial) t.Errorf("serial = %d, want 0", serial)
} }
if content != "" {
t.Errorf("content = %q, want empty", content)
}
}) })
t.Run("valid zone file", func(t *testing.T) { t.Run("valid zone file", func(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
path := dir + "/test.zone" zoneContent := "example.com. 0 IN SOA ns1.example.com. hostmaster.example.com. 2026030205 900 600 2147483646 0\nexample.com. 0 IN NS invalid.\n"
content := "example.com. 0 IN SOA ns1.example.com. hostmaster.example.com. 2026030205 900 600 2147483646 0\nexample.com. 0 IN NS invalid.\n" writeTestFile(t, dir, "test.zone", zoneContent)
writeTestFile(t, dir, "test.zone", content)
serial, err := readExistingSerial(path) content, serial, err := readExisting(filepath.Join(dir, "test.zone"))
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if serial != 2026030205 { if serial != 2026030205 {
t.Errorf("serial = %d, want 2026030205", serial) t.Errorf("serial = %d, want 2026030205", serial)
} }
if content != zoneContent {
t.Errorf("content mismatch: got %q", content)
}
}) })
t.Run("zone file with no SOA", func(t *testing.T) { t.Run("zone file with no SOA", func(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
path := dir + "/test.zone" zoneContent := "example.com. 0 IN NS invalid.\n"
content := "example.com. 0 IN NS invalid.\n" writeTestFile(t, dir, "test.zone", zoneContent)
writeTestFile(t, dir, "test.zone", content)
serial, err := readExistingSerial(path) content, serial, err := readExisting(filepath.Join(dir, "test.zone"))
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if serial != 0 { if serial != 0 {
t.Errorf("serial = %d, want 0 (no SOA found)", serial) t.Errorf("serial = %d, want 0 (no SOA found)", serial)
} }
}) if content != zoneContent {
} t.Errorf("content mismatch: got %q", content)
func TestReadExistingContent(t *testing.T) {
t.Run("file does not exist", func(t *testing.T) {
content, err := readExistingContent("/nonexistent/path.zone")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "" {
t.Errorf("content = %q, want empty", content)
}
})
t.Run("file exists", func(t *testing.T) {
dir := t.TempDir()
writeTestFile(t, dir, "test.zone", "hello\n")
content, err := readExistingContent(dir + "/test.zone")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "hello\n" {
t.Errorf("content = %q, want %q", content, "hello\n")
} }
}) })
} }

View File

@@ -24,7 +24,3 @@ func assertContains(t *testing.T, s, substr string) {
t.Errorf("expected %q to contain %q", s, substr) t.Errorf("expected %q to contain %q", s, substr)
} }
} }
func containsStr(s, substr string) bool {
return strings.Contains(s, substr)
}

View File

@@ -54,7 +54,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
// Split on whitespace first, then handle comma separation within tokens // Split on whitespace first, then handle comma separation within tokens
tokens := tokenize(line) tokens := tokenize(line)
if len(tokens) < 2 { if len(tokens) < 2 {
return ZoneEntry{}, fmt.Errorf("error: %s:%d: expected zone name followed by at least one catalog name", file, lineNum) return ZoneEntry{}, fmt.Errorf("%s:%d: expected zone name followed by at least one catalog name", file, lineNum)
} }
entry := ZoneEntry{ entry := ZoneEntry{
@@ -69,16 +69,16 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
switch strings.ToLower(key) { switch strings.ToLower(key) {
case "group": case "group":
if value == "" { if value == "" {
return ZoneEntry{}, fmt.Errorf("error: %s:%d: empty group value", file, lineNum) return ZoneEntry{}, fmt.Errorf("%s:%d: empty group value", file, lineNum)
} }
entry.Group = value entry.Group = value
case "coo": case "coo":
if value == "" { if value == "" {
return ZoneEntry{}, fmt.Errorf("error: %s:%d: empty coo value", file, lineNum) return ZoneEntry{}, fmt.Errorf("%s:%d: empty coo value", file, lineNum)
} }
entry.COO = normalizeFQDN(value) entry.COO = normalizeFQDN(value)
default: default:
return ZoneEntry{}, fmt.Errorf("error: %s:%d: unknown property %q", file, lineNum, key) return ZoneEntry{}, fmt.Errorf("%s:%d: unknown property %q", file, lineNum, key)
} }
} else { } else {
// Bare name = catalog assignment // Bare name = catalog assignment
@@ -87,7 +87,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
} }
if len(entry.Catalogs) == 0 { if len(entry.Catalogs) == 0 {
return ZoneEntry{}, fmt.Errorf("error: %s:%d: no catalog assignment for zone %s", file, lineNum, entry.Zone) return ZoneEntry{}, fmt.Errorf("%s:%d: no catalog assignment for zone %s", file, lineNum, entry.Zone)
} }
return entry, nil return entry, nil
@@ -95,10 +95,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
// tokenize splits a line on whitespace and commas, stripping empty tokens. // tokenize splits a line on whitespace and commas, stripping empty tokens.
func tokenize(line string) []string { func tokenize(line string) []string {
// Replace commas with spaces, then split on whitespace return strings.Fields(strings.ReplaceAll(line, ",", " "))
line = strings.ReplaceAll(line, ",", " ")
fields := strings.Fields(line)
return fields
} }
func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, error) { func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, error) {
@@ -110,14 +107,14 @@ func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, erro
for _, entry := range entries { for _, entry := range entries {
for _, catName := range entry.Catalogs { for _, catName := range entry.Catalogs {
if _, ok := cfg.Catalogs[catName]; !ok { if _, ok := cfg.Catalogs[catName]; !ok {
return nil, fmt.Errorf("error: %s:%d: unknown catalog %q", entry.File, entry.Line, catName) return nil, fmt.Errorf("%s:%d: unknown catalog %q", entry.File, entry.Line, catName)
} }
if seen[catName] == nil { if seen[catName] == nil {
seen[catName] = make(map[string]int) seen[catName] = make(map[string]int)
} }
if prevLine, dup := seen[catName][entry.Zone]; dup { if prevLine, dup := seen[catName][entry.Zone]; dup {
return nil, fmt.Errorf("error: %s:%d: zone %s already assigned to catalog %q at line %d", return nil, fmt.Errorf("%s:%d: zone %s already assigned to catalog %q at line %d",
entry.File, entry.Line, entry.Zone, catName, prevLine) entry.File, entry.Line, entry.Zone, catName, prevLine)
} }
seen[catName][entry.Zone] = entry.Line seen[catName][entry.Zone] = entry.Line

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
) )
@@ -206,9 +205,7 @@ zone.example.org catalog1, catalog2
test.example.net catalog1, group=internal test.example.net catalog1, group=internal
` `
if err := os.WriteFile(inputPath, []byte(content), 0o644); err != nil { writeTestFile(t, dir, "zones.txt", content)
t.Fatal(err)
}
members, err := parseInput(inputPath, cfg) members, err := parseInput(inputPath, cfg)
if err != nil { if err != nil {
@@ -240,9 +237,7 @@ func TestParseInputErrors(t *testing.T) {
t.Run("invalid line in input", func(t *testing.T) { t.Run("invalid line in input", func(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
path := filepath.Join(dir, "zones.txt") path := filepath.Join(dir, "zones.txt")
if err := os.WriteFile(path, []byte("zone-with-no-catalog\n"), 0o644); err != nil { writeTestFile(t, dir, "zones.txt", "zone-with-no-catalog\n")
t.Fatal(err)
}
_, err := parseInput(path, cfg) _, err := parseInput(path, cfg)
if err == nil { if err == nil {
t.Fatal("expected error for invalid line") t.Fatal("expected error for invalid line")

10
main.go
View File

@@ -41,7 +41,7 @@ func main() {
members, err := parseInput(inputFile, cfg) members, err := parseInput(inputFile, cfg)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1) os.Exit(1)
} }
@@ -58,15 +58,15 @@ func main() {
for _, catName := range catNames { for _, catName := range catNames {
changed, err := processCatalog(catName, cfg, members[catName], *outputDir, now) changed, err := processCatalog(catName, cfg, members[catName], *outputDir, now)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "error: %s\n", err)
hasErrors = true hasErrors = true
continue continue
} }
catZone := cfg.Catalogs[catName].Zone zoneFile := cfg.Catalogs[catName].Zone + "zone"
if changed { if changed {
fmt.Fprintf(os.Stderr, "%s%s: updated\n", catZone, "zone") fmt.Fprintf(os.Stderr, "%s: updated\n", zoneFile)
} else { } else {
fmt.Fprintf(os.Stderr, "%s%s: unchanged\n", catZone, "zone") fmt.Fprintf(os.Stderr, "%s: unchanged\n", zoneFile)
} }
} }

View File

@@ -6,7 +6,6 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"time"
) )
func TestIntegrationEndToEnd(t *testing.T) { func TestIntegrationEndToEnd(t *testing.T) {
@@ -213,7 +212,7 @@ soa:
// Find PTR lines and verify order // Find PTR lines and verify order
var ptrZones []string var ptrZones []string
for _, line := range lines { for _, line := range lines {
if containsStr(line, "\tPTR\t") && !containsStr(line, "coo.") { if strings.Contains(line, "\tPTR\t") && !strings.Contains(line, "coo.") {
// Extract the PTR target // Extract the PTR target
parts := strings.Split(line, "\t") parts := strings.Split(line, "\t")
ptrZones = append(ptrZones, parts[len(parts)-1]) ptrZones = append(ptrZones, parts[len(parts)-1])
@@ -313,9 +312,7 @@ soa:
if err == nil { if err == nil {
t.Error("expected error for unknown catalog") t.Error("expected error for unknown catalog")
} }
if !containsStr(string(out), "unknown catalog") { assertContains(t, string(out), "unknown catalog")
t.Errorf("expected 'unknown catalog' in error output, got: %s", out)
}
}) })
} }
@@ -405,7 +402,7 @@ soa:
assertContains(t, content2, "2026011601") assertContains(t, content2, "2026011601")
// Simulate not using the tool for a while, running on a much later date // Simulate not using the tool for a while, running on a much later date
day3 := time.Date(2026, 6, 15, 0, 0, 0, 0, time.UTC) day3 := fixedTime(2026, 6, 15)
members3 := []ZoneEntry{ members3 := []ZoneEntry{
{Zone: "a.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 1}, {Zone: "a.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 1},
{Zone: "b.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 2}, {Zone: "b.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 2},