diff options
Diffstat (limited to 'vendor/github.com/caddyserver/certmagic/dnsutil.go')
| -rw-r--r-- | vendor/github.com/caddyserver/certmagic/dnsutil.go | 68 |
1 files changed, 42 insertions, 26 deletions
diff --git a/vendor/github.com/caddyserver/certmagic/dnsutil.go b/vendor/github.com/caddyserver/certmagic/dnsutil.go index 81fc192..bd008b0 100644 --- a/vendor/github.com/caddyserver/certmagic/dnsutil.go +++ b/vendor/github.com/caddyserver/certmagic/dnsutil.go @@ -1,6 +1,7 @@ package certmagic import ( + "context" "errors" "fmt" "net" @@ -18,21 +19,24 @@ import ( // // It has been modified. -// findZoneByFQDN determines the zone apex for the given fqdn by recursing -// up the domain labels until the nameserver returns a SOA record in the -// answer section. The logger must be non-nil. -func findZoneByFQDN(logger *zap.Logger, fqdn string, nameservers []string) (string, error) { +// FindZoneByFQDN determines the zone apex for the given fully-qualified +// domain name (FQDN) by recursing up the domain labels until the nameserver +// returns a SOA record in the answer section. The logger must be non-nil. +// +// EXPERIMENTAL: This API was previously unexported, and may be changed or +// unexported again in the future. Do not rely on it at this time. +func FindZoneByFQDN(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (string, error) { if !strings.HasSuffix(fqdn, ".") { fqdn += "." } - soa, err := lookupSoaByFqdn(logger, fqdn, nameservers) + soa, err := lookupSoaByFqdn(ctx, logger, fqdn, nameservers) if err != nil { return "", err } return soa.zone, nil } -func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) { +func lookupSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) { logger = logger.Named("soa_lookup") if !strings.HasSuffix(fqdn, ".") { @@ -42,13 +46,17 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so fqdnSOACacheMu.Lock() defer fqdnSOACacheMu.Unlock() + if err := ctx.Err(); err != nil { + return nil, err + } + // prefer cached version if fresh if ent := fqdnSOACache[fqdn]; ent != nil && !ent.isExpired() { logger.Debug("using cached SOA result", zap.String("entry", ent.zone)) return ent, nil } - ent, err := fetchSoaByFqdn(logger, fqdn, nameservers) + ent, err := fetchSoaByFqdn(ctx, logger, fqdn, nameservers) if err != nil { return nil, err } @@ -66,15 +74,19 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so return ent, nil } -func fetchSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) { +func fetchSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) { var err error var in *dns.Msg labelIndexes := dns.Split(fqdn) for _, index := range labelIndexes { + if err := ctx.Err(); err != nil { + return nil, err + } + domain := fqdn[index:] - in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true) + in, err = dnsQuery(ctx, domain, dns.TypeSOA, nameservers, true) if err != nil { continue } @@ -122,12 +134,12 @@ func dnsMsgContainsCNAME(msg *dns.Msg) bool { return false } -func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) { +func dnsQuery(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) { m := createDNSMsg(fqdn, rtype, recursive) var in *dns.Msg var err error for _, ns := range nameservers { - in, err = sendDNSQuery(m, ns) + in, err = sendDNSQuery(ctx, m, ns) if err == nil && len(in.Answer) > 0 { break } @@ -147,16 +159,16 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg { return m } -func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) { +func sendDNSQuery(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) { udp := &dns.Client{Net: "udp", Timeout: dnsTimeout} - in, _, err := udp.Exchange(m, ns) + in, _, err := udp.ExchangeContext(ctx, m, ns) // two kinds of errors we can handle by retrying with TCP: // truncation and timeout; see https://github.com/caddyserver/caddy/issues/3639 truncated := in != nil && in.Truncated timeoutErr := err != nil && strings.Contains(err.Error(), "timeout") if truncated || timeoutErr { tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout} - in, _, err = tcp.Exchange(m, ns) + in, _, err = tcp.ExchangeContext(ctx, m, ns) } return in, err } @@ -205,7 +217,8 @@ func systemOrDefaultNameservers(path string, defaults []string) []string { return config.Servers } -// populateNameserverPorts ensures that all nameservers have a port number. +// populateNameserverPorts ensures that all nameservers have a port number +// If not, the the default DNS server port of 53 will be appended. func populateNameserverPorts(servers []string) { for i := range servers { _, port, _ := net.SplitHostPort(servers[i]) @@ -216,7 +229,7 @@ func populateNameserverPorts(servers []string) { } // checkDNSPropagation checks if the expected record has been propagated to all authoritative nameservers. -func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) { +func checkDNSPropagation(ctx context.Context, logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) { logger = logger.Named("propagation") if !strings.HasSuffix(fqdn, ".") { @@ -227,7 +240,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect // dereference (follow) a CNAME record if we are targeting a CNAME record // itself if recType != dns.TypeCNAME { - r, err := dnsQuery(fqdn, recType, resolvers, true) + r, err := dnsQuery(ctx, fqdn, recType, resolvers, true) if err != nil { return false, fmt.Errorf("CNAME dns query: %v", err) } @@ -237,7 +250,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect } if checkAuthoritativeServers { - authoritativeServers, err := lookupNameservers(logger, fqdn, resolvers) + authoritativeServers, err := lookupNameservers(ctx, logger, fqdn, resolvers) if err != nil { return false, fmt.Errorf("looking up authoritative nameservers: %v", err) } @@ -246,13 +259,13 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect } logger.Debug("checking authoritative nameservers", zap.Strings("resolvers", resolvers)) - return checkAuthoritativeNss(fqdn, recType, expectedValue, resolvers) + return checkAuthoritativeNss(ctx, fqdn, recType, expectedValue, resolvers) } // checkAuthoritativeNss queries each of the given nameservers for the expected record. -func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) { +func checkAuthoritativeNss(ctx context.Context, fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) { for _, ns := range nameservers { - r, err := dnsQuery(fqdn, recType, []string{ns}, true) + r, err := dnsQuery(ctx, fqdn, recType, []string{ns}, true) if err != nil { return false, fmt.Errorf("querying authoritative nameservers: %v", err) } @@ -293,15 +306,15 @@ func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, na } // lookupNameservers returns the authoritative nameservers for the given fqdn. -func lookupNameservers(logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) { +func lookupNameservers(ctx context.Context, logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) { var authoritativeNss []string - zone, err := findZoneByFQDN(logger, fqdn, resolvers) + zone, err := FindZoneByFQDN(ctx, logger, fqdn, resolvers) if err != nil { return nil, fmt.Errorf("could not determine the zone for '%s': %w", fqdn, err) } - r, err := dnsQuery(zone, dns.TypeNS, resolvers, true) + r, err := dnsQuery(ctx, zone, dns.TypeNS, resolvers, true) if err != nil { return nil, fmt.Errorf("querying NS resolver for zone '%s' recursively: %v", zone, err) } @@ -330,11 +343,14 @@ func updateDomainWithCName(r *dns.Msg, fqdn string) string { return fqdn } -// recursiveNameservers are used to pre-check DNS propagation. It +// RecursiveNameservers are used to pre-check DNS propagation. It // picks user-configured nameservers (custom) OR the defaults // obtained from resolv.conf and defaultNameservers if none is // configured and ensures that all server addresses have a port value. -func recursiveNameservers(custom []string) []string { +// +// EXPERIMENTAL: This API was previously unexported, and may be +// be unexported again in the future. Do not rely on it at this time. +func RecursiveNameservers(custom []string) []string { var servers []string if len(custom) == 0 { servers = systemOrDefaultNameservers(defaultResolvConf, defaultNameservers) |
