Added testing, restructured error handling.
This commit is contained in:
parent
6191b1368e
commit
4c5495f2a7
89
main.go
89
main.go
@ -20,22 +20,30 @@ var conf = flag.String("config", "/etc/dnssec-checks", "Configuration file")
|
||||
var resolver = flag.String("resolver", "8.8.8.8:53", "Resolver to use")
|
||||
var timeout = flag.Duration("timeout", 10*time.Second, "Timeout for network operations")
|
||||
|
||||
var dnsClient *dns.Client
|
||||
|
||||
type Records struct {
|
||||
Zone string
|
||||
Record string
|
||||
Type string
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Print(v ...interface{})
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
type Exporter struct {
|
||||
Records []Records
|
||||
|
||||
records *prometheus.GaugeVec
|
||||
valid *prometheus.GaugeVec
|
||||
|
||||
resolver string
|
||||
dnsClient *dns.Client
|
||||
|
||||
logger Logger
|
||||
}
|
||||
|
||||
func NewDNSSECExporter() *Exporter {
|
||||
func NewDNSSECExporter(dnsClient *dns.Client, resolver string, logger Logger) *Exporter {
|
||||
return &Exporter{
|
||||
records: prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
@ -63,6 +71,9 @@ func NewDNSSECExporter() *Exporter {
|
||||
"type",
|
||||
},
|
||||
),
|
||||
dnsClient: dnsClient,
|
||||
resolver: resolver,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +93,17 @@ func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
|
||||
rec := rec
|
||||
|
||||
go func() {
|
||||
e.collectRecord(rec.Zone, rec.Record, rec.Type)
|
||||
|
||||
valid, exp := e.collectRecord(rec.Zone, rec.Record, rec.Type)
|
||||
|
||||
e.valid.WithLabelValues(
|
||||
rec.Zone, rec.Record, rec.Type,
|
||||
).Set(map[bool]float64{true: 1}[valid])
|
||||
|
||||
e.records.WithLabelValues(
|
||||
rec.Zone, rec.Record, rec.Type,
|
||||
).Set(float64(time.Until(exp)/time.Hour) / 24)
|
||||
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
@ -95,16 +116,16 @@ func (e *Exporter) Collect(ch chan<- prometheus.Metric) {
|
||||
|
||||
}
|
||||
|
||||
func (e *Exporter) collectRecord(zone, record, recordType string) {
|
||||
func (e *Exporter) collectRecord(zone, record, recordType string) (valid bool, exp time.Time) {
|
||||
|
||||
// Start by finding the DNSKEY
|
||||
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(fmt.Sprintf("%s.", zone), dns.TypeDNSKEY)
|
||||
|
||||
response, _, err := dnsClient.Exchange(msg, *resolver)
|
||||
response, _, err := e.dnsClient.Exchange(msg, e.resolver)
|
||||
if err != nil {
|
||||
log.Printf("while looking up DNSKEY for %v: %v", zone, err)
|
||||
e.logger.Printf("while looking up DNSKEY for %v: %v", zone, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -118,7 +139,7 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
log.Printf("didn't find DNSKEY for %v", zone)
|
||||
e.logger.Printf("didn't find DNSKEY for %v", zone)
|
||||
}
|
||||
|
||||
// Now lookup the signature
|
||||
@ -126,9 +147,9 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
|
||||
msg = &dns.Msg{}
|
||||
msg.SetQuestion(hostname(zone, record), dns.TypeRRSIG)
|
||||
|
||||
response, _, err = dnsClient.Exchange(msg, *resolver)
|
||||
response, _, err = e.dnsClient.Exchange(msg, e.resolver)
|
||||
if err != nil {
|
||||
log.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err)
|
||||
e.logger.Printf("while looking up RRSIG for %v: %v", hostname(zone, record), err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -136,35 +157,28 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
|
||||
var key *dns.DNSKEY
|
||||
|
||||
for _, rr := range response.Answer {
|
||||
if rrsig, ok := rr.(*dns.RRSIG); ok {
|
||||
if rrsig, ok := rr.(*dns.RRSIG); ok &&
|
||||
rrsig.TypeCovered == dns.StringToType[recordType] &&
|
||||
keys[rrsig.KeyTag] != nil {
|
||||
|
||||
if rrsig.TypeCovered == dns.StringToType[recordType] &&
|
||||
keys[rrsig.KeyTag] != nil {
|
||||
|
||||
sig = rrsig
|
||||
key = keys[rrsig.KeyTag]
|
||||
break
|
||||
|
||||
}
|
||||
sig = rrsig
|
||||
key = keys[rrsig.KeyTag]
|
||||
break
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if sig == nil {
|
||||
log.Printf("didn't find RRSIG for %v covering type %v", hostname(zone, record), recordType)
|
||||
e.logger.Printf("didn't find RRSIG for %v covering type %v matching a tag of a DNSKEY", hostname(zone, record), recordType)
|
||||
return
|
||||
}
|
||||
|
||||
exp := time.Unix(int64(sig.Expiration), 0)
|
||||
exp = time.Unix(int64(sig.Expiration), 0)
|
||||
if exp.IsZero() {
|
||||
log.Print("zero exp")
|
||||
e.logger.Printf("zero exp for RRSIG for %v covering type %v", hostname(zone, record), recordType)
|
||||
return
|
||||
}
|
||||
|
||||
e.records.WithLabelValues(
|
||||
zone, record, recordType,
|
||||
).Set(float64(time.Until(exp)/time.Hour) / 24)
|
||||
|
||||
// Finally, lookup the records to validate
|
||||
|
||||
if key == nil {
|
||||
@ -175,19 +189,20 @@ func (e *Exporter) collectRecord(zone, record, recordType string) {
|
||||
msg = &dns.Msg{}
|
||||
msg.SetQuestion(hostname(zone, record), dns.StringToType[recordType])
|
||||
|
||||
response, _, err = dnsClient.Exchange(msg, *resolver)
|
||||
response, _, err = e.dnsClient.Exchange(msg, e.resolver)
|
||||
if err != nil {
|
||||
log.Printf("while looking up RRSet for %v type %v: %v", hostname(zone, record), recordType, err)
|
||||
e.logger.Printf("while looking up RRSet for %v type %v: %v", hostname(zone, record), recordType, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sig.Verify(key, response.Answer); err == nil {
|
||||
e.valid.WithLabelValues(zone, record, recordType).Set(1)
|
||||
valid = true
|
||||
} else {
|
||||
log.Printf("verify error for %v type %v): %v", hostname(zone, record), recordType, err)
|
||||
e.valid.WithLabelValues(zone, record, recordType).Set(0)
|
||||
e.logger.Printf("verify error for %v type %v): %v", hostname(zone, record), recordType, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func hostname(zone, record string) string {
|
||||
@ -204,17 +219,17 @@ func main() {
|
||||
|
||||
flag.Parse()
|
||||
|
||||
dnsClient = &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: *timeout,
|
||||
}
|
||||
|
||||
f, err := os.Open(*conf)
|
||||
if err != nil {
|
||||
log.Fatalf("couldn't open configuration file: %v", err)
|
||||
}
|
||||
|
||||
exporter := NewDNSSECExporter()
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
|
||||
exporter := NewDNSSECExporter(&dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: *timeout,
|
||||
}, *resolver, logger)
|
||||
|
||||
if err := toml.NewDecoder(f).Decode(exporter); err != nil {
|
||||
log.Fatalf("couldn't parse configuration file: %v", err)
|
||||
|
235
main_test.go
Normal file
235
main_test.go
Normal file
@ -0,0 +1,235 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type opts struct {
|
||||
signed time.Time
|
||||
expires time.Time
|
||||
privkey crypto.PrivateKey
|
||||
}
|
||||
|
||||
func nullLogger() *log.Logger {
|
||||
return log.New(ioutil.Discard, "", log.LstdFlags)
|
||||
}
|
||||
|
||||
func runServer(t *testing.T, opts opts) (string, func()) {
|
||||
|
||||
if opts.signed.IsZero() {
|
||||
opts.signed = time.Now().Add(-time.Hour)
|
||||
}
|
||||
|
||||
if opts.expires.IsZero() {
|
||||
opts.expires = time.Now().Add(14 * 24 * time.Hour)
|
||||
}
|
||||
|
||||
dnskey := &dns.DNSKEY{
|
||||
Algorithm: dns.ECDSAP256SHA256,
|
||||
Flags: dns.ZONE,
|
||||
Protocol: 3,
|
||||
}
|
||||
|
||||
privkey, err := dnskey.Generate(256)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't generate private key: %v", err)
|
||||
}
|
||||
|
||||
if opts.privkey != nil {
|
||||
privkey = opts.privkey
|
||||
}
|
||||
|
||||
h := dns.NewServeMux()
|
||||
h.HandleFunc("example.org.", func(rw dns.ResponseWriter, msg *dns.Msg) {
|
||||
|
||||
q := msg.Question[0]
|
||||
|
||||
soa := &dns.SOA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
},
|
||||
Ns: "ns1.example.org.",
|
||||
Mbox: "test.example.org.",
|
||||
Serial: 1,
|
||||
Refresh: 14400,
|
||||
Retry: 3600,
|
||||
Expire: 7200,
|
||||
Minttl: 60,
|
||||
}
|
||||
|
||||
switch q.Qtype {
|
||||
|
||||
case dns.TypeDNSKEY:
|
||||
|
||||
rrHeader := dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeDNSKEY,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
}
|
||||
|
||||
answer := &dns.DNSKEY{
|
||||
Hdr: rrHeader,
|
||||
Algorithm: dnskey.Algorithm,
|
||||
Flags: dnskey.Flags,
|
||||
Protocol: dnskey.Protocol,
|
||||
PublicKey: dnskey.PublicKey,
|
||||
}
|
||||
|
||||
msg.Answer = append(msg.Answer, answer)
|
||||
|
||||
case dns.TypeRRSIG:
|
||||
|
||||
rrHeader := dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeRRSIG,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
}
|
||||
|
||||
answer := &dns.RRSIG{
|
||||
Hdr: rrHeader,
|
||||
TypeCovered: dns.TypeSOA,
|
||||
Algorithm: dnskey.Algorithm,
|
||||
Labels: uint8(dns.CountLabel(q.Name)),
|
||||
OrigTtl: 3600,
|
||||
Expiration: uint32(opts.expires.Unix()),
|
||||
Inception: uint32(opts.signed.Unix()),
|
||||
KeyTag: dnskey.KeyTag(),
|
||||
SignerName: q.Name,
|
||||
}
|
||||
|
||||
if err := answer.Sign(privkey.(*ecdsa.PrivateKey), []dns.RR{soa}); err != nil {
|
||||
t.Fatalf("couldn't sign SOA record: %v", err)
|
||||
}
|
||||
|
||||
msg.Answer = append(msg.Answer, answer)
|
||||
|
||||
case dns.TypeSOA:
|
||||
|
||||
msg.Answer = append(msg.Answer, soa)
|
||||
|
||||
}
|
||||
|
||||
rw.WriteMsg(msg)
|
||||
|
||||
})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen failed: %v", err)
|
||||
}
|
||||
|
||||
server := &dns.Server{
|
||||
Listener: ln,
|
||||
Handler: h,
|
||||
}
|
||||
|
||||
go func() {
|
||||
server.ActivateAndServe()
|
||||
}()
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
<-done
|
||||
server.Shutdown()
|
||||
ln.Close()
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), func() {
|
||||
done <- true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCollectionOK(t *testing.T) {
|
||||
|
||||
addr, cancel := runServer(t, opts{})
|
||||
defer cancel()
|
||||
|
||||
e := NewDNSSECExporter(&dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 1 * time.Second,
|
||||
}, addr, nullLogger())
|
||||
|
||||
valid, exp := e.collectRecord("example.org", "@", "SOA")
|
||||
|
||||
if !valid {
|
||||
t.Fatal("expected record to be valid")
|
||||
}
|
||||
|
||||
if exp.Before(time.Now()) {
|
||||
t.Fatalf("expected expiration to be in the future, was: %v", exp)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCollectionExpired(t *testing.T) {
|
||||
|
||||
addr, cancel := runServer(t, opts{
|
||||
signed: time.Now().Add(14 * 24 * time.Hour),
|
||||
expires: time.Now().Add(-time.Hour),
|
||||
})
|
||||
|
||||
defer cancel()
|
||||
|
||||
e := NewDNSSECExporter(&dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 1 * time.Second,
|
||||
}, addr, nullLogger())
|
||||
|
||||
valid, exp := e.collectRecord("example.org", "@", "SOA")
|
||||
|
||||
if !valid {
|
||||
t.Fatal("expected record to be valid")
|
||||
}
|
||||
|
||||
if exp.After(time.Now()) {
|
||||
t.Fatalf("expected expiration to be in the past, was: %v", exp)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCollectionInvalid(t *testing.T) {
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't generate fake private key: %v", err)
|
||||
}
|
||||
|
||||
addr, cancel := runServer(t, opts{
|
||||
privkey: priv,
|
||||
})
|
||||
|
||||
defer cancel()
|
||||
|
||||
e := NewDNSSECExporter(&dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: 1 * time.Second,
|
||||
}, addr, nullLogger())
|
||||
|
||||
valid, exp := e.collectRecord("example.org", "@", "SOA")
|
||||
|
||||
if valid {
|
||||
t.Fatal("expected record to be invalid")
|
||||
}
|
||||
|
||||
if exp.Before(time.Now()) {
|
||||
t.Fatalf("expected expiration to be in the future, was: %v", exp)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user