Clean up code reuse, consistency, and efficiency issues
Merge readExistingSerial and readExistingContent into a single readExisting function to eliminate duplicate file I/O. Extract dateBase helper to deduplicate serial formula between defaultSerial and bumpSerial. Cache hash results during collision check to avoid recomputing per member. Normalize error prefixes (remove "error:" from fmt.Errorf, add uniformly at print sites). Use filepath.Join instead of manual "/" concatenation. Replace trivial containsStr wrapper with strings.Contains. Simplify tokenize to a single return. Use writeTestFile and fixedTime helpers consistently in tests.
This commit is contained in:
81
catalog.go
81
catalog.go
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -29,12 +30,13 @@ func hashZoneName(zone string) string {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(zone))
|
||||
sum := h.Sum32()
|
||||
b := make([]byte, 4)
|
||||
b[0] = byte(sum >> 24)
|
||||
b[1] = byte(sum >> 16)
|
||||
b[2] = byte(sum >> 8)
|
||||
b[3] = byte(sum)
|
||||
return b32.EncodeToString(b)
|
||||
b := [4]byte{
|
||||
byte(sum >> 24),
|
||||
byte(sum >> 16),
|
||||
byte(sum >> 8),
|
||||
byte(sum),
|
||||
}
|
||||
return b32.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// generateCatalogZone builds the zone file content for a single catalog.
|
||||
@@ -76,20 +78,22 @@ func generateCatalogZone(catName string, cfg *Config, members []ZoneEntry, seria
|
||||
return sorted[i].Zone < sorted[j].Zone
|
||||
})
|
||||
|
||||
// Check for hash collisions
|
||||
hashToZone := make(map[string]string)
|
||||
// Check for hash collisions and cache results
|
||||
hashToZone := make(map[string]string, len(sorted))
|
||||
zoneHash := make(map[string]string, len(sorted))
|
||||
for _, entry := range sorted {
|
||||
h := hashZoneName(entry.Zone)
|
||||
if existing, ok := hashToZone[h]; ok && existing != entry.Zone {
|
||||
return "", fmt.Errorf("error: %s:%d: hash collision between %s and %s in catalog %q",
|
||||
return "", fmt.Errorf("%s:%d: hash collision between %s and %s in catalog %q",
|
||||
entry.File, entry.Line, existing, entry.Zone, catName)
|
||||
}
|
||||
hashToZone[h] = entry.Zone
|
||||
zoneHash[entry.Zone] = h
|
||||
}
|
||||
|
||||
// Member records
|
||||
for _, entry := range sorted {
|
||||
h := hashZoneName(entry.Zone)
|
||||
h := zoneHash[entry.Zone]
|
||||
|
||||
// PTR record
|
||||
ptrOwner := fmt.Sprintf("%s.zones.%s", h, origin)
|
||||
@@ -126,58 +130,48 @@ func generateCatalogZone(catName string, cfg *Config, members []ZoneEntry, seria
|
||||
return strings.Join(records, "\n") + "\n", nil
|
||||
}
|
||||
|
||||
// readExistingSerial reads an existing zone file and extracts the SOA serial.
|
||||
// Returns 0, nil if the file doesn't exist.
|
||||
func readExistingSerial(path string) (uint32, error) {
|
||||
f, err := os.Open(path)
|
||||
// readExisting reads an existing zone file and returns its content and SOA serial.
|
||||
// Returns ("", 0, nil) if the file doesn't exist.
|
||||
func readExisting(path string) (string, uint32, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
return "", 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("reading existing zone %s: %w", path, err)
|
||||
return "", 0, fmt.Errorf("reading existing zone %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
content := string(data)
|
||||
|
||||
zp := dns.NewZoneParser(f, "", path)
|
||||
zp := dns.NewZoneParser(strings.NewReader(content), "", path)
|
||||
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
|
||||
if soa, ok := rr.(*dns.SOA); ok {
|
||||
return soa.Serial, nil
|
||||
return content, soa.Serial, nil
|
||||
}
|
||||
}
|
||||
if err := zp.Err(); err != nil {
|
||||
return 0, fmt.Errorf("parsing existing zone %s: %w", path, err)
|
||||
return "", 0, fmt.Errorf("parsing existing zone %s: %w", path, err)
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
return content, 0, nil
|
||||
}
|
||||
|
||||
// readExistingContent reads the full content of an existing zone file.
|
||||
// Returns empty string if file doesn't exist.
|
||||
func readExistingContent(path string) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
// dateBase returns the YYYYMMDD00 serial base for the given time.
|
||||
func dateBase(now time.Time) uint32 {
|
||||
return uint32(now.Year())*1000000 +
|
||||
uint32(now.Month())*10000 +
|
||||
uint32(now.Day())*100
|
||||
}
|
||||
|
||||
// defaultSerial returns a serial for today with sequence 01: YYYYMMDD01.
|
||||
func defaultSerial(now time.Time) uint32 {
|
||||
return uint32(now.Year())*1000000 +
|
||||
uint32(now.Month())*10000 +
|
||||
uint32(now.Day())*100 + 1
|
||||
return dateBase(now) + 1
|
||||
}
|
||||
|
||||
// bumpSerial increments a serial. If same date, bumps the sequence number.
|
||||
// If different date, starts at YYYYMMDD01.
|
||||
// Returns error if sequence reaches 99 and needs another bump.
|
||||
func bumpSerial(old uint32, now time.Time) (uint32, error) {
|
||||
todayBase := uint32(now.Year())*1000000 +
|
||||
uint32(now.Month())*10000 +
|
||||
uint32(now.Day())*100
|
||||
todayBase := dateBase(now)
|
||||
|
||||
if old >= todayBase && old < todayBase+100 {
|
||||
// Same date, bump sequence
|
||||
@@ -196,10 +190,10 @@ func bumpSerial(old uint32, now time.Time) (uint32, error) {
|
||||
// Returns true if the file was written (changed), false if unchanged.
|
||||
func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir string, now time.Time) (bool, error) {
|
||||
catCfg := cfg.Catalogs[catName]
|
||||
outputPath := outputDir + "/" + catCfg.Zone + "zone"
|
||||
outputPath := filepath.Join(outputDir, catCfg.Zone+"zone")
|
||||
|
||||
// Read existing serial
|
||||
oldSerial, err := readExistingSerial(outputPath)
|
||||
// Read existing file (content + serial in one pass)
|
||||
existing, oldSerial, err := readExisting(outputPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -217,11 +211,6 @@ func processCatalog(catName string, cfg *Config, members []ZoneEntry, outputDir
|
||||
}
|
||||
|
||||
// Compare with existing file
|
||||
existing, err := readExistingContent(outputPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if content == existing {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -159,12 +161,12 @@ func TestGenerateCatalogZone(t *testing.T) {
|
||||
aIdx := -1
|
||||
bIdx := -1
|
||||
for i, line := range lines {
|
||||
if containsStr(line, "a.example.com.") {
|
||||
if strings.Contains(line, "a.example.com.") {
|
||||
if aIdx == -1 {
|
||||
aIdx = i
|
||||
}
|
||||
}
|
||||
if containsStr(line, "b.example.com.") {
|
||||
if strings.Contains(line, "b.example.com.") {
|
||||
if bIdx == -1 {
|
||||
bIdx = i
|
||||
}
|
||||
@@ -180,7 +182,7 @@ func TestGenerateCatalogZone(t *testing.T) {
|
||||
// a.example.com should have a group TXT
|
||||
foundGroup := false
|
||||
for _, line := range lines {
|
||||
if containsStr(line, "group.") && containsStr(line, "\"mygroup\"") {
|
||||
if strings.Contains(line, "group.") && strings.Contains(line, "\"mygroup\"") {
|
||||
foundGroup = true
|
||||
}
|
||||
}
|
||||
@@ -211,7 +213,7 @@ func TestGenerateCatalogZoneCOO(t *testing.T) {
|
||||
|
||||
foundCOO := false
|
||||
for _, line := range splitLines(content) {
|
||||
if containsStr(line, "coo.") && containsStr(line, "old.example.com.") {
|
||||
if strings.Contains(line, "coo.") && strings.Contains(line, "old.example.com.") {
|
||||
foundCOO = true
|
||||
}
|
||||
}
|
||||
@@ -220,69 +222,51 @@ func TestGenerateCatalogZoneCOO(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadExistingSerial(t *testing.T) {
|
||||
func TestReadExisting(t *testing.T) {
|
||||
t.Run("file does not exist", func(t *testing.T) {
|
||||
serial, err := readExistingSerial("/nonexistent/path.zone")
|
||||
content, serial, err := readExisting("/nonexistent/path.zone")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if serial != 0 {
|
||||
t.Errorf("serial = %d, want 0", serial)
|
||||
}
|
||||
if content != "" {
|
||||
t.Errorf("content = %q, want empty", content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid zone file", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := dir + "/test.zone"
|
||||
content := "example.com. 0 IN SOA ns1.example.com. hostmaster.example.com. 2026030205 900 600 2147483646 0\nexample.com. 0 IN NS invalid.\n"
|
||||
writeTestFile(t, dir, "test.zone", content)
|
||||
zoneContent := "example.com. 0 IN SOA ns1.example.com. hostmaster.example.com. 2026030205 900 600 2147483646 0\nexample.com. 0 IN NS invalid.\n"
|
||||
writeTestFile(t, dir, "test.zone", zoneContent)
|
||||
|
||||
serial, err := readExistingSerial(path)
|
||||
content, serial, err := readExisting(filepath.Join(dir, "test.zone"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if serial != 2026030205 {
|
||||
t.Errorf("serial = %d, want 2026030205", serial)
|
||||
}
|
||||
if content != zoneContent {
|
||||
t.Errorf("content mismatch: got %q", content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zone file with no SOA", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := dir + "/test.zone"
|
||||
content := "example.com. 0 IN NS invalid.\n"
|
||||
writeTestFile(t, dir, "test.zone", content)
|
||||
zoneContent := "example.com. 0 IN NS invalid.\n"
|
||||
writeTestFile(t, dir, "test.zone", zoneContent)
|
||||
|
||||
serial, err := readExistingSerial(path)
|
||||
content, serial, err := readExisting(filepath.Join(dir, "test.zone"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if serial != 0 {
|
||||
t.Errorf("serial = %d, want 0 (no SOA found)", serial)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadExistingContent(t *testing.T) {
|
||||
t.Run("file does not exist", func(t *testing.T) {
|
||||
content, err := readExistingContent("/nonexistent/path.zone")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != "" {
|
||||
t.Errorf("content = %q, want empty", content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("file exists", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeTestFile(t, dir, "test.zone", "hello\n")
|
||||
|
||||
content, err := readExistingContent(dir + "/test.zone")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != "hello\n" {
|
||||
t.Errorf("content = %q, want %q", content, "hello\n")
|
||||
if content != zoneContent {
|
||||
t.Errorf("content mismatch: got %q", content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,7 +24,3 @@ func assertContains(t *testing.T, s, substr string) {
|
||||
t.Errorf("expected %q to contain %q", s, substr)
|
||||
}
|
||||
}
|
||||
|
||||
func containsStr(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
|
||||
19
input.go
19
input.go
@@ -54,7 +54,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
|
||||
// Split on whitespace first, then handle comma separation within tokens
|
||||
tokens := tokenize(line)
|
||||
if len(tokens) < 2 {
|
||||
return ZoneEntry{}, fmt.Errorf("error: %s:%d: expected zone name followed by at least one catalog name", file, lineNum)
|
||||
return ZoneEntry{}, fmt.Errorf("%s:%d: expected zone name followed by at least one catalog name", file, lineNum)
|
||||
}
|
||||
|
||||
entry := ZoneEntry{
|
||||
@@ -69,16 +69,16 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
|
||||
switch strings.ToLower(key) {
|
||||
case "group":
|
||||
if value == "" {
|
||||
return ZoneEntry{}, fmt.Errorf("error: %s:%d: empty group value", file, lineNum)
|
||||
return ZoneEntry{}, fmt.Errorf("%s:%d: empty group value", file, lineNum)
|
||||
}
|
||||
entry.Group = value
|
||||
case "coo":
|
||||
if value == "" {
|
||||
return ZoneEntry{}, fmt.Errorf("error: %s:%d: empty coo value", file, lineNum)
|
||||
return ZoneEntry{}, fmt.Errorf("%s:%d: empty coo value", file, lineNum)
|
||||
}
|
||||
entry.COO = normalizeFQDN(value)
|
||||
default:
|
||||
return ZoneEntry{}, fmt.Errorf("error: %s:%d: unknown property %q", file, lineNum, key)
|
||||
return ZoneEntry{}, fmt.Errorf("%s:%d: unknown property %q", file, lineNum, key)
|
||||
}
|
||||
} else {
|
||||
// Bare name = catalog assignment
|
||||
@@ -87,7 +87,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
|
||||
}
|
||||
|
||||
if len(entry.Catalogs) == 0 {
|
||||
return ZoneEntry{}, fmt.Errorf("error: %s:%d: no catalog assignment for zone %s", file, lineNum, entry.Zone)
|
||||
return ZoneEntry{}, fmt.Errorf("%s:%d: no catalog assignment for zone %s", file, lineNum, entry.Zone)
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
@@ -95,10 +95,7 @@ func parseLine(line, file string, lineNum int) (ZoneEntry, error) {
|
||||
|
||||
// tokenize splits a line on whitespace and commas, stripping empty tokens.
|
||||
func tokenize(line string) []string {
|
||||
// Replace commas with spaces, then split on whitespace
|
||||
line = strings.ReplaceAll(line, ",", " ")
|
||||
fields := strings.Fields(line)
|
||||
return fields
|
||||
return strings.Fields(strings.ReplaceAll(line, ",", " "))
|
||||
}
|
||||
|
||||
func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, error) {
|
||||
@@ -110,14 +107,14 @@ func buildCatalogMembers(entries []ZoneEntry, cfg *Config) (CatalogMembers, erro
|
||||
for _, entry := range entries {
|
||||
for _, catName := range entry.Catalogs {
|
||||
if _, ok := cfg.Catalogs[catName]; !ok {
|
||||
return nil, fmt.Errorf("error: %s:%d: unknown catalog %q", entry.File, entry.Line, catName)
|
||||
return nil, fmt.Errorf("%s:%d: unknown catalog %q", entry.File, entry.Line, catName)
|
||||
}
|
||||
|
||||
if seen[catName] == nil {
|
||||
seen[catName] = make(map[string]int)
|
||||
}
|
||||
if prevLine, dup := seen[catName][entry.Zone]; dup {
|
||||
return nil, fmt.Errorf("error: %s:%d: zone %s already assigned to catalog %q at line %d",
|
||||
return nil, fmt.Errorf("%s:%d: zone %s already assigned to catalog %q at line %d",
|
||||
entry.File, entry.Line, entry.Zone, catName, prevLine)
|
||||
}
|
||||
seen[catName][entry.Zone] = entry.Line
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
@@ -206,9 +205,7 @@ zone.example.org catalog1, catalog2
|
||||
|
||||
test.example.net catalog1, group=internal
|
||||
`
|
||||
if err := os.WriteFile(inputPath, []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
writeTestFile(t, dir, "zones.txt", content)
|
||||
|
||||
members, err := parseInput(inputPath, cfg)
|
||||
if err != nil {
|
||||
@@ -240,9 +237,7 @@ func TestParseInputErrors(t *testing.T) {
|
||||
t.Run("invalid line in input", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "zones.txt")
|
||||
if err := os.WriteFile(path, []byte("zone-with-no-catalog\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
writeTestFile(t, dir, "zones.txt", "zone-with-no-catalog\n")
|
||||
_, err := parseInput(path, cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid line")
|
||||
|
||||
10
main.go
10
main.go
@@ -41,7 +41,7 @@ func main() {
|
||||
|
||||
members, err := parseInput(inputFile, cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
fmt.Fprintf(os.Stderr, "error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -58,15 +58,15 @@ func main() {
|
||||
for _, catName := range catNames {
|
||||
changed, err := processCatalog(catName, cfg, members[catName], *outputDir, now)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
fmt.Fprintf(os.Stderr, "error: %s\n", err)
|
||||
hasErrors = true
|
||||
continue
|
||||
}
|
||||
catZone := cfg.Catalogs[catName].Zone
|
||||
zoneFile := cfg.Catalogs[catName].Zone + "zone"
|
||||
if changed {
|
||||
fmt.Fprintf(os.Stderr, "%s%s: updated\n", catZone, "zone")
|
||||
fmt.Fprintf(os.Stderr, "%s: updated\n", zoneFile)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s%s: unchanged\n", catZone, "zone")
|
||||
fmt.Fprintf(os.Stderr, "%s: unchanged\n", zoneFile)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIntegrationEndToEnd(t *testing.T) {
|
||||
@@ -213,7 +212,7 @@ soa:
|
||||
// Find PTR lines and verify order
|
||||
var ptrZones []string
|
||||
for _, line := range lines {
|
||||
if containsStr(line, "\tPTR\t") && !containsStr(line, "coo.") {
|
||||
if strings.Contains(line, "\tPTR\t") && !strings.Contains(line, "coo.") {
|
||||
// Extract the PTR target
|
||||
parts := strings.Split(line, "\t")
|
||||
ptrZones = append(ptrZones, parts[len(parts)-1])
|
||||
@@ -313,9 +312,7 @@ soa:
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown catalog")
|
||||
}
|
||||
if !containsStr(string(out), "unknown catalog") {
|
||||
t.Errorf("expected 'unknown catalog' in error output, got: %s", out)
|
||||
}
|
||||
assertContains(t, string(out), "unknown catalog")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -405,7 +402,7 @@ soa:
|
||||
assertContains(t, content2, "2026011601")
|
||||
|
||||
// Simulate not using the tool for a while, running on a much later date
|
||||
day3 := time.Date(2026, 6, 15, 0, 0, 0, 0, time.UTC)
|
||||
day3 := fixedTime(2026, 6, 15)
|
||||
members3 := []ZoneEntry{
|
||||
{Zone: "a.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 1},
|
||||
{Zone: "b.example.com.", Catalogs: []string{"cat1"}, File: "test", Line: 2},
|
||||
|
||||
Reference in New Issue
Block a user