Added testing, restructured error handling.

This commit is contained in:
Christian Joergensen 2018-10-05 11:47:11 +02:00
parent 6191b1368e
commit 4c5495f2a7
2 changed files with 287 additions and 37 deletions

89
main.go
View File

@ -20,22 +20,30 @@ var conf = flag.String("config", "/etc/dnssec-checks", "Configuration file")
var resolver = flag.String("resolver", "8.8.8.8:53", "Resolver to use")
var timeout = flag.Duration("timeout", 10*time.Second, "Timeout for network operations")
var dnsClient *dns.Client
type Records struct {
Zone string
Record string
Type string
}
type Logger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
}
type Exporter struct {
Records []Records
records *prometheus.GaugeVec
valid *prometheus.GaugeVec
resolver string
dnsClient *dns.Client
logger Logger
}
func NewDNSSECExporter() *Exporter {
func NewDNSSECExporter(dnsClient *dns.Client, resolver string, logger Logger) *Exporter {
return &Exporter{
records: prometheus.NewGaugeVec(
prometheus.GaugeOpts{
@ -63,6 +71,9 @@ func NewDNSSECExporter() *Exporter {
"type",
},
),
dnsClient: dnsClient,
resolver: resolver,
logger: logger,
}
}
@ -82,7 +93,17 @@ func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
rec := rec
go func() {
e.collectRecord(rec.Zone, rec.Record, rec.Type)
valid, exp := e.collectRecord(rec.Zone, rec.Record, rec.Type)
e.valid.WithLabelValues(
rec.Zone, rec.Record, rec.Type,
).Set(map[bool]float64{true: 1}[valid])
e.records.WithLabelValues(
rec.Zone, rec.Record, rec.Type,
).Set(float64(time.Until(exp)/time.Hour) / 24)
wg.Done()
}()
@ -95,16 +116,16 @@ func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
}
func (e *Exporter) collectRecord(zone, record, recordType string) {
func (e *Exporter) collectRecord(zone, record, recordType string) (valid bool, exp time.Time) {
// Start by finding the DNSKEY
msg := &dns.Msg{}
msg.SetQuestion(fmt.Sprintf("%s.", zone), dns.TypeDNSKEY)
response, _, err := dnsClient.Exchange(msg, *resolver)
response, _, err := e.dnsClient.Exchange(msg, e.resolver)
if err != nil {
log.Printf("while looking up DNSKEY for %v: %v", zone, err)
e.logger.Printf("while looking up DNSKEY for %v: %v", zone, err)
return
}
@ -118,7 +139,7 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
}
if len(keys) == 0 {
log.Printf("didn't find DNSKEY for %v", zone)
e.logger.Printf("didn't find DNSKEY for %v", zone)
}
// Now lookup the signature
@ -126,9 +147,9 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
msg = &dns.Msg{}
msg.SetQuestion(hostname(zone, record), dns.TypeRRSIG)
response, _, err = dnsClient.Exchange(msg, *resolver)
response, _, err = e.dnsClient.Exchange(msg, e.resolver)
if err != nil {
log.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err)
e.logger.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err)
return
}
@ -136,35 +157,28 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
var key *dns.DNSKEY
for _, rr := range response.Answer {
if rrsig, ok := rr.(*dns.RRSIG); ok {
if rrsig, ok := rr.(*dns.RRSIG); ok &&
rrsig.TypeCovered == dns.StringToType[recordType] &&
keys[rrsig.KeyTag] != nil {
if rrsig.TypeCovered == dns.StringToType[recordType] &&
keys[rrsig.KeyTag] != nil {
sig = rrsig
key = keys[rrsig.KeyTag]
break
}
sig = rrsig
key = keys[rrsig.KeyTag]
break
}
}
if sig == nil {
log.Printf("didn't find RRSIG for %v covering type %v", hostname(zone, record), recordType)
e.logger.Printf("didn't find RRSIG for %v covering type %v matching a tag of a DNSKEY", hostname(zone, record), recordType)
return
}
exp := time.Unix(int64(sig.Expiration), 0)
exp = time.Unix(int64(sig.Expiration), 0)
if exp.IsZero() {
log.Print("zero exp")
e.logger.Printf("zero exp for RRSIG for %v covering type %v", hostname(zone, record), recordType)
return
}
e.records.WithLabelValues(
zone, record, recordType,
).Set(float64(time.Until(exp)/time.Hour) / 24)
// Finally, lookup the records to validate
if key == nil {
@ -175,19 +189,20 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
msg = &dns.Msg{}
msg.SetQuestion(hostname(zone, record), dns.StringToType[recordType])
response, _, err = dnsClient.Exchange(msg, *resolver)
response, _, err = e.dnsClient.Exchange(msg, e.resolver)
if err != nil {
log.Printf("while looking up RRSet for %v type %v: %v", hostname(zone, record), recordType, err)
e.logger.Printf("while looking up RRSet for %v type %v: %v", hostname(zone, record), recordType, err)
return
}
if err := sig.Verify(key, response.Answer); err == nil {
e.valid.WithLabelValues(zone, record, recordType).Set(1)
valid = true
} else {
log.Printf("verify error for %v type %v): %v", hostname(zone, record), recordType, err)
e.valid.WithLabelValues(zone, record, recordType).Set(0)
e.logger.Printf("verify error for %v type %v): %v", hostname(zone, record), recordType, err)
}
return
}
func hostname(zone, record string) string {
@ -204,17 +219,17 @@ func main() {
flag.Parse()
dnsClient = &dns.Client{
Net: "tcp",
Timeout: *timeout,
}
f, err := os.Open(*conf)
if err != nil {
log.Fatalf("couldn't open configuration file: %v", err)
}
exporter := NewDNSSECExporter()
logger := log.New(os.Stderr, "", log.LstdFlags)
exporter := NewDNSSECExporter(&dns.Client{
Net: "tcp",
Timeout: *timeout,
}, *resolver, logger)
if err := toml.NewDecoder(f).Decode(exporter); err != nil {
log.Fatalf("couldn't parse configuration file: %v", err)

235
main_test.go Normal file
View File

@ -0,0 +1,235 @@
package main
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"io/ioutil"
"log"
"net"
"testing"
"time"
"github.com/miekg/dns"
)
type opts struct {
signed time.Time
expires time.Time
privkey crypto.PrivateKey
}
func nullLogger() *log.Logger {
return log.New(ioutil.Discard, "", log.LstdFlags)
}
func runServer(t *testing.T, opts opts) (string, func()) {
if opts.signed.IsZero() {
opts.signed = time.Now().Add(-time.Hour)
}
if opts.expires.IsZero() {
opts.expires = time.Now().Add(14 * 24 * time.Hour)
}
dnskey := &dns.DNSKEY{
Algorithm: dns.ECDSAP256SHA256,
Flags: dns.ZONE,
Protocol: 3,
}
privkey, err := dnskey.Generate(256)
if err != nil {
t.Fatalf("couldn't generate private key: %v", err)
}
if opts.privkey != nil {
privkey = opts.privkey
}
h := dns.NewServeMux()
h.HandleFunc("example.org.", func(rw dns.ResponseWriter, msg *dns.Msg) {
q := msg.Question[0]
soa := &dns.SOA{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 3600,
},
Ns: "ns1.example.org.",
Mbox: "test.example.org.",
Serial: 1,
Refresh: 14400,
Retry: 3600,
Expire: 7200,
Minttl: 60,
}
switch q.Qtype {
case dns.TypeDNSKEY:
rrHeader := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeDNSKEY,
Class: dns.ClassINET,
Ttl: 3600,
}
answer := &dns.DNSKEY{
Hdr: rrHeader,
Algorithm: dnskey.Algorithm,
Flags: dnskey.Flags,
Protocol: dnskey.Protocol,
PublicKey: dnskey.PublicKey,
}
msg.Answer = append(msg.Answer, answer)
case dns.TypeRRSIG:
rrHeader := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeRRSIG,
Class: dns.ClassINET,
Ttl: 3600,
}
answer := &dns.RRSIG{
Hdr: rrHeader,
TypeCovered: dns.TypeSOA,
Algorithm: dnskey.Algorithm,
Labels: uint8(dns.CountLabel(q.Name)),
OrigTtl: 3600,
Expiration: uint32(opts.expires.Unix()),
Inception: uint32(opts.signed.Unix()),
KeyTag: dnskey.KeyTag(),
SignerName: q.Name,
}
if err := answer.Sign(privkey.(*ecdsa.PrivateKey), []dns.RR{soa}); err != nil {
t.Fatalf("couldn't sign SOA record: %v", err)
}
msg.Answer = append(msg.Answer, answer)
case dns.TypeSOA:
msg.Answer = append(msg.Answer, soa)
}
rw.WriteMsg(msg)
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen failed: %v", err)
}
server := &dns.Server{
Listener: ln,
Handler: h,
}
go func() {
server.ActivateAndServe()
}()
done := make(chan bool)
go func() {
<-done
server.Shutdown()
ln.Close()
}()
return ln.Addr().String(), func() {
done <- true
}
}
func TestCollectionOK(t *testing.T) {
addr, cancel := runServer(t, opts{})
defer cancel()
e := NewDNSSECExporter(&dns.Client{
Net: "tcp",
Timeout: 1 * time.Second,
}, addr, nullLogger())
valid, exp := e.collectRecord("example.org", "@", "SOA")
if !valid {
t.Fatal("expected record to be valid")
}
if exp.Before(time.Now()) {
t.Fatalf("expected expiration to be in the future, was: %v", exp)
}
}
func TestCollectionExpired(t *testing.T) {
addr, cancel := runServer(t, opts{
signed: time.Now().Add(14 * 24 * time.Hour),
expires: time.Now().Add(-time.Hour),
})
defer cancel()
e := NewDNSSECExporter(&dns.Client{
Net: "tcp",
Timeout: 1 * time.Second,
}, addr, nullLogger())
valid, exp := e.collectRecord("example.org", "@", "SOA")
if !valid {
t.Fatal("expected record to be valid")
}
if exp.After(time.Now()) {
t.Fatalf("expected expiration to be in the past, was: %v", exp)
}
}
func TestCollectionInvalid(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("couldn't generate fake private key: %v", err)
}
addr, cancel := runServer(t, opts{
privkey: priv,
})
defer cancel()
e := NewDNSSECExporter(&dns.Client{
Net: "tcp",
Timeout: 1 * time.Second,
}, addr, nullLogger())
valid, exp := e.collectRecord("example.org", "@", "SOA")
if valid {
t.Fatal("expected record to be invalid")
}
if exp.Before(time.Now()) {
t.Fatalf("expected expiration to be in the future, was: %v", exp)
}
}