summaryrefslogtreecommitdiff
path: root/vendor/github.com/caddyserver/certmagic/dnsutil.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/caddyserver/certmagic/dnsutil.go')
-rw-r--r--vendor/github.com/caddyserver/certmagic/dnsutil.go68
1 files changed, 42 insertions, 26 deletions
diff --git a/vendor/github.com/caddyserver/certmagic/dnsutil.go b/vendor/github.com/caddyserver/certmagic/dnsutil.go
index 81fc192..bd008b0 100644
--- a/vendor/github.com/caddyserver/certmagic/dnsutil.go
+++ b/vendor/github.com/caddyserver/certmagic/dnsutil.go
@@ -1,6 +1,7 @@
package certmagic
import (
+ "context"
"errors"
"fmt"
"net"
@@ -18,21 +19,24 @@ import (
//
// It has been modified.
-// findZoneByFQDN determines the zone apex for the given fqdn by recursing
-// up the domain labels until the nameserver returns a SOA record in the
-// answer section. The logger must be non-nil.
-func findZoneByFQDN(logger *zap.Logger, fqdn string, nameservers []string) (string, error) {
+// FindZoneByFQDN determines the zone apex for the given fully-qualified
+// domain name (FQDN) by recursing up the domain labels until the nameserver
+// returns a SOA record in the answer section. The logger must be non-nil.
+//
+// EXPERIMENTAL: This API was previously unexported, and may be changed or
+// unexported again in the future. Do not rely on it at this time.
+func FindZoneByFQDN(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (string, error) {
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
- soa, err := lookupSoaByFqdn(logger, fqdn, nameservers)
+ soa, err := lookupSoaByFqdn(ctx, logger, fqdn, nameservers)
if err != nil {
return "", err
}
return soa.zone, nil
}
-func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
+func lookupSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
logger = logger.Named("soa_lookup")
if !strings.HasSuffix(fqdn, ".") {
@@ -42,13 +46,17 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so
fqdnSOACacheMu.Lock()
defer fqdnSOACacheMu.Unlock()
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
// prefer cached version if fresh
if ent := fqdnSOACache[fqdn]; ent != nil && !ent.isExpired() {
logger.Debug("using cached SOA result", zap.String("entry", ent.zone))
return ent, nil
}
- ent, err := fetchSoaByFqdn(logger, fqdn, nameservers)
+ ent, err := fetchSoaByFqdn(ctx, logger, fqdn, nameservers)
if err != nil {
return nil, err
}
@@ -66,15 +74,19 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so
return ent, nil
}
-func fetchSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
+func fetchSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
var err error
var in *dns.Msg
labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
domain := fqdn[index:]
- in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
+ in, err = dnsQuery(ctx, domain, dns.TypeSOA, nameservers, true)
if err != nil {
continue
}
@@ -122,12 +134,12 @@ func dnsMsgContainsCNAME(msg *dns.Msg) bool {
return false
}
-func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
+func dnsQuery(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
var in *dns.Msg
var err error
for _, ns := range nameservers {
- in, err = sendDNSQuery(m, ns)
+ in, err = sendDNSQuery(ctx, m, ns)
if err == nil && len(in.Answer) > 0 {
break
}
@@ -147,16 +159,16 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
return m
}
-func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
+func sendDNSQuery(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
- in, _, err := udp.Exchange(m, ns)
+ in, _, err := udp.ExchangeContext(ctx, m, ns)
// two kinds of errors we can handle by retrying with TCP:
// truncation and timeout; see https://github.com/caddyserver/caddy/issues/3639
truncated := in != nil && in.Truncated
timeoutErr := err != nil && strings.Contains(err.Error(), "timeout")
if truncated || timeoutErr {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
- in, _, err = tcp.Exchange(m, ns)
+ in, _, err = tcp.ExchangeContext(ctx, m, ns)
}
return in, err
}
@@ -205,7 +217,8 @@ func systemOrDefaultNameservers(path string, defaults []string) []string {
return config.Servers
}
-// populateNameserverPorts ensures that all nameservers have a port number.
+// populateNameserverPorts ensures that all nameservers have a port number
+// If not, the the default DNS server port of 53 will be appended.
func populateNameserverPorts(servers []string) {
for i := range servers {
_, port, _ := net.SplitHostPort(servers[i])
@@ -216,7 +229,7 @@ func populateNameserverPorts(servers []string) {
}
// checkDNSPropagation checks if the expected record has been propagated to all authoritative nameservers.
-func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) {
+func checkDNSPropagation(ctx context.Context, logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) {
logger = logger.Named("propagation")
if !strings.HasSuffix(fqdn, ".") {
@@ -227,7 +240,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
// dereference (follow) a CNAME record if we are targeting a CNAME record
// itself
if recType != dns.TypeCNAME {
- r, err := dnsQuery(fqdn, recType, resolvers, true)
+ r, err := dnsQuery(ctx, fqdn, recType, resolvers, true)
if err != nil {
return false, fmt.Errorf("CNAME dns query: %v", err)
}
@@ -237,7 +250,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
}
if checkAuthoritativeServers {
- authoritativeServers, err := lookupNameservers(logger, fqdn, resolvers)
+ authoritativeServers, err := lookupNameservers(ctx, logger, fqdn, resolvers)
if err != nil {
return false, fmt.Errorf("looking up authoritative nameservers: %v", err)
}
@@ -246,13 +259,13 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
}
logger.Debug("checking authoritative nameservers", zap.Strings("resolvers", resolvers))
- return checkAuthoritativeNss(fqdn, recType, expectedValue, resolvers)
+ return checkAuthoritativeNss(ctx, fqdn, recType, expectedValue, resolvers)
}
// checkAuthoritativeNss queries each of the given nameservers for the expected record.
-func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) {
+func checkAuthoritativeNss(ctx context.Context, fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
- r, err := dnsQuery(fqdn, recType, []string{ns}, true)
+ r, err := dnsQuery(ctx, fqdn, recType, []string{ns}, true)
if err != nil {
return false, fmt.Errorf("querying authoritative nameservers: %v", err)
}
@@ -293,15 +306,15 @@ func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, na
}
// lookupNameservers returns the authoritative nameservers for the given fqdn.
-func lookupNameservers(logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) {
+func lookupNameservers(ctx context.Context, logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) {
var authoritativeNss []string
- zone, err := findZoneByFQDN(logger, fqdn, resolvers)
+ zone, err := FindZoneByFQDN(ctx, logger, fqdn, resolvers)
if err != nil {
return nil, fmt.Errorf("could not determine the zone for '%s': %w", fqdn, err)
}
- r, err := dnsQuery(zone, dns.TypeNS, resolvers, true)
+ r, err := dnsQuery(ctx, zone, dns.TypeNS, resolvers, true)
if err != nil {
return nil, fmt.Errorf("querying NS resolver for zone '%s' recursively: %v", zone, err)
}
@@ -330,11 +343,14 @@ func updateDomainWithCName(r *dns.Msg, fqdn string) string {
return fqdn
}
-// recursiveNameservers are used to pre-check DNS propagation. It
+// RecursiveNameservers are used to pre-check DNS propagation. It
// picks user-configured nameservers (custom) OR the defaults
// obtained from resolv.conf and defaultNameservers if none is
// configured and ensures that all server addresses have a port value.
-func recursiveNameservers(custom []string) []string {
+//
+// EXPERIMENTAL: This API was previously unexported, and may be
+// be unexported again in the future. Do not rely on it at this time.
+func RecursiveNameservers(custom []string) []string {
var servers []string
if len(custom) == 0 {
servers = systemOrDefaultNameservers(defaultResolvConf, defaultNameservers)