diff --git a/main.go b/main.go index 23a3531..391d8d5 100644 --- a/main.go +++ b/main.go @@ -20,22 +20,30 @@ var conf = flag.String("config", "/etc/dnssec-checks", "Configuration file") var resolver = flag.String("resolver", "8.8.8.8:53", "Resolver to use") var timeout = flag.Duration("timeout", 10*time.Second, "Timeout for network operations") -var dnsClient *dns.Client - type Records struct { Zone string Record string Type string } +type Logger interface { + Print(v ...interface{}) + Printf(format string, v ...interface{}) +} + type Exporter struct { Records []Records records *prometheus.GaugeVec valid *prometheus.GaugeVec + + resolver string + dnsClient *dns.Client + + logger Logger } -func NewDNSSECExporter() *Exporter { +func NewDNSSECExporter(dnsClient *dns.Client, resolver string, logger Logger) *Exporter { return &Exporter{ records: prometheus.NewGaugeVec( prometheus.GaugeOpts{ @@ -63,6 +71,9 @@ func NewDNSSECExporter() *Exporter { "type", }, ), + dnsClient: dnsClient, + resolver: resolver, + logger: logger, } } @@ -82,7 +93,17 @@ func (e *Exporter) Collect(ch chan<- prometheus.Metric) { rec := rec go func() { - e.collectRecord(rec.Zone, rec.Record, rec.Type) + + valid, exp := e.collectRecord(rec.Zone, rec.Record, rec.Type) + + e.valid.WithLabelValues( + rec.Zone, rec.Record, rec.Type, + ).Set(map[bool]float64{true: 1}[valid]) + + e.records.WithLabelValues( + rec.Zone, rec.Record, rec.Type, + ).Set(float64(time.Until(exp)/time.Hour) / 24) + wg.Done() }() @@ -95,16 +116,16 @@ func (e *Exporter) Collect(ch chan<- prometheus.Metric) { } -func (e *Exporter) collectRecord(zone, record, recordType string) { +func (e *Exporter) collectRecord(zone, record, recordType string) (valid bool, exp time.Time) { // Start by finding the DNSKEY msg := &dns.Msg{} msg.SetQuestion(fmt.Sprintf("%s.", zone), dns.TypeDNSKEY) - response, _, err := dnsClient.Exchange(msg, *resolver) + response, _, err := e.dnsClient.Exchange(msg, e.resolver) if err != nil { - log.Printf("while looking up DNSKEY for %v: %v", zone, err) + e.logger.Printf("while looking up DNSKEY for %v: %v", zone, err) return } @@ -118,7 +139,7 @@ func (e *Exporter) collectRecord(zone, record, recordType string) { } if len(keys) == 0 { - log.Printf("didn't find DNSKEY for %v", zone) + e.logger.Printf("didn't find DNSKEY for %v", zone) } // Now lookup the signature @@ -126,9 +147,9 @@ func (e *Exporter) collectRecord(zone, record, recordType string) { msg = &dns.Msg{} msg.SetQuestion(hostname(zone, record), dns.TypeRRSIG) - response, _, err = dnsClient.Exchange(msg, *resolver) + response, _, err = e.dnsClient.Exchange(msg, e.resolver) if err != nil { - log.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err) + e.logger.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err) return } @@ -136,35 +157,28 @@ func (e *Exporter) collectRecord(zone, record, recordType string) { var key *dns.DNSKEY for _, rr := range response.Answer { - if rrsig, ok := rr.(*dns.RRSIG); ok { + if rrsig, ok := rr.(*dns.RRSIG); ok && + rrsig.TypeCovered == dns.StringToType[recordType] && + keys[rrsig.KeyTag] != nil { - if rrsig.TypeCovered == dns.StringToType[recordType] && - keys[rrsig.KeyTag] != nil { - - sig = rrsig - key = keys[rrsig.KeyTag] - break - - } + sig = rrsig + key = keys[rrsig.KeyTag] + break } } if sig == nil { - log.Printf("didn't find RRSIG for %v covering type %v", hostname(zone, record), recordType) + e.logger.Printf("didn't find RRSIG for %v covering type %v matching a tag of a DNSKEY", hostname(zone, record), recordType) return } - exp := time.Unix(int64(sig.Expiration), 0) + exp = time.Unix(int64(sig.Expiration), 0) if exp.IsZero() { - log.Print("zero exp") + e.logger.Printf("zero exp for RRSIG for %v covering type %v", hostname(zone, record), recordType) return } - e.records.WithLabelValues( - zone, record, recordType, - ).Set(float64(time.Until(exp)/time.Hour) / 24) - // Finally, lookup the records to validate if key == nil { @@ -175,19 +189,20 @@ func (e *Exporter) collectRecord(zone, record, recordType string) { msg = &dns.Msg{} msg.SetQuestion(hostname(zone, record), dns.StringToType[recordType]) - response, _, err = dnsClient.Exchange(msg, *resolver) + response, _, err = e.dnsClient.Exchange(msg, e.resolver) if err != nil { - log.Printf("while looking up RRSet for %v type %v: %v", hostname(zone, record), recordType, err) + e.logger.Printf("while looking up RRSet for %v type %v: %v", hostname(zone, record), recordType, err) return } if err := sig.Verify(key, response.Answer); err == nil { - e.valid.WithLabelValues(zone, record, recordType).Set(1) + valid = true } else { - log.Printf("verify error for %v type %v): %v", hostname(zone, record), recordType, err) - e.valid.WithLabelValues(zone, record, recordType).Set(0) + e.logger.Printf("verify error for %v type %v): %v", hostname(zone, record), recordType, err) } + return + } func hostname(zone, record string) string { @@ -204,17 +219,17 @@ func main() { flag.Parse() - dnsClient = &dns.Client{ - Net: "tcp", - Timeout: *timeout, - } - f, err := os.Open(*conf) if err != nil { log.Fatalf("couldn't open configuration file: %v", err) } - exporter := NewDNSSECExporter() + logger := log.New(os.Stderr, "", log.LstdFlags) + + exporter := NewDNSSECExporter(&dns.Client{ + Net: "tcp", + Timeout: *timeout, + }, *resolver, logger) if err := toml.NewDecoder(f).Decode(exporter); err != nil { log.Fatalf("couldn't parse configuration file: %v", err) diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..919a9b3 --- /dev/null +++ b/main_test.go @@ -0,0 +1,235 @@ +package main + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io/ioutil" + "log" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +type opts struct { + signed time.Time + expires time.Time + privkey crypto.PrivateKey +} + +func nullLogger() *log.Logger { + return log.New(ioutil.Discard, "", log.LstdFlags) +} + +func runServer(t *testing.T, opts opts) (string, func()) { + + if opts.signed.IsZero() { + opts.signed = time.Now().Add(-time.Hour) + } + + if opts.expires.IsZero() { + opts.expires = time.Now().Add(14 * 24 * time.Hour) + } + + dnskey := &dns.DNSKEY{ + Algorithm: dns.ECDSAP256SHA256, + Flags: dns.ZONE, + Protocol: 3, + } + + privkey, err := dnskey.Generate(256) + if err != nil { + t.Fatalf("couldn't generate private key: %v", err) + } + + if opts.privkey != nil { + privkey = opts.privkey + } + + h := dns.NewServeMux() + h.HandleFunc("example.org.", func(rw dns.ResponseWriter, msg *dns.Msg) { + + q := msg.Question[0] + + soa := &dns.SOA{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: 3600, + }, + Ns: "ns1.example.org.", + Mbox: "test.example.org.", + Serial: 1, + Refresh: 14400, + Retry: 3600, + Expire: 7200, + Minttl: 60, + } + + switch q.Qtype { + + case dns.TypeDNSKEY: + + rrHeader := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeDNSKEY, + Class: dns.ClassINET, + Ttl: 3600, + } + + answer := &dns.DNSKEY{ + Hdr: rrHeader, + Algorithm: dnskey.Algorithm, + Flags: dnskey.Flags, + Protocol: dnskey.Protocol, + PublicKey: dnskey.PublicKey, + } + + msg.Answer = append(msg.Answer, answer) + + case dns.TypeRRSIG: + + rrHeader := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeRRSIG, + Class: dns.ClassINET, + Ttl: 3600, + } + + answer := &dns.RRSIG{ + Hdr: rrHeader, + TypeCovered: dns.TypeSOA, + Algorithm: dnskey.Algorithm, + Labels: uint8(dns.CountLabel(q.Name)), + OrigTtl: 3600, + Expiration: uint32(opts.expires.Unix()), + Inception: uint32(opts.signed.Unix()), + KeyTag: dnskey.KeyTag(), + SignerName: q.Name, + } + + if err := answer.Sign(privkey.(*ecdsa.PrivateKey), []dns.RR{soa}); err != nil { + t.Fatalf("couldn't sign SOA record: %v", err) + } + + msg.Answer = append(msg.Answer, answer) + + case dns.TypeSOA: + + msg.Answer = append(msg.Answer, soa) + + } + + rw.WriteMsg(msg) + + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + + server := &dns.Server{ + Listener: ln, + Handler: h, + } + + go func() { + server.ActivateAndServe() + }() + + done := make(chan bool) + + go func() { + <-done + server.Shutdown() + ln.Close() + }() + + return ln.Addr().String(), func() { + done <- true + } + +} + +func TestCollectionOK(t *testing.T) { + + addr, cancel := runServer(t, opts{}) + defer cancel() + + e := NewDNSSECExporter(&dns.Client{ + Net: "tcp", + Timeout: 1 * time.Second, + }, addr, nullLogger()) + + valid, exp := e.collectRecord("example.org", "@", "SOA") + + if !valid { + t.Fatal("expected record to be valid") + } + + if exp.Before(time.Now()) { + t.Fatalf("expected expiration to be in the future, was: %v", exp) + } + +} + +func TestCollectionExpired(t *testing.T) { + + addr, cancel := runServer(t, opts{ + signed: time.Now().Add(14 * 24 * time.Hour), + expires: time.Now().Add(-time.Hour), + }) + + defer cancel() + + e := NewDNSSECExporter(&dns.Client{ + Net: "tcp", + Timeout: 1 * time.Second, + }, addr, nullLogger()) + + valid, exp := e.collectRecord("example.org", "@", "SOA") + + if !valid { + t.Fatal("expected record to be valid") + } + + if exp.After(time.Now()) { + t.Fatalf("expected expiration to be in the past, was: %v", exp) + } + +} + +func TestCollectionInvalid(t *testing.T) { + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("couldn't generate fake private key: %v", err) + } + + addr, cancel := runServer(t, opts{ + privkey: priv, + }) + + defer cancel() + + e := NewDNSSECExporter(&dns.Client{ + Net: "tcp", + Timeout: 1 * time.Second, + }, addr, nullLogger()) + + valid, exp := e.collectRecord("example.org", "@", "SOA") + + if valid { + t.Fatal("expected record to be invalid") + } + + if exp.Before(time.Now()) { + t.Fatalf("expected expiration to be in the future, was: %v", exp) + } + +}