aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/client.go
diff options
context:
space:
mode:
authorTyler Davis <tydavis@gmail.com>2022-01-06 04:50:20 +0000
committerTyler Davis <tydavis@gmail.com>2022-01-06 04:50:20 +0000
commitb6385e524ebbfe242f9d36eb4361f8568e6bf895 (patch)
tree90d136278e520af555f64479973c287f39719361 /vendor/github.com/miekg/dns/client.go
parenta687ebabb6589ebb36a9c385f583a19ac462b831 (diff)
downloaddnstracker-b6385e524ebbfe242f9d36eb4361f8568e6bf895.tar.gz
dnstracker-b6385e524ebbfe242f9d36eb4361f8568e6bf895.zip
Update module to 1.17 and update all deps
Diffstat (limited to 'vendor/github.com/miekg/dns/client.go')
-rw-r--r--vendor/github.com/miekg/dns/client.go86
1 files changed, 58 insertions, 28 deletions
diff --git a/vendor/github.com/miekg/dns/client.go b/vendor/github.com/miekg/dns/client.go
index aa2c49d..6bae3a1 100644
--- a/vendor/github.com/miekg/dns/client.go
+++ b/vendor/github.com/miekg/dns/client.go
@@ -82,6 +82,12 @@ func (c *Client) writeTimeout() time.Duration {
// Dial connects to the address on the named network.
func (c *Client) Dial(address string) (conn *Conn, err error) {
+ return c.DialContext(context.Background(), address)
+}
+
+// DialContext connects to the address on the named network, with a context.Context.
+// For TLS over TCP (DoT) the context isn't used yet. This will be enabled when Go 1.18 is released.
+func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
// create a new dialer with the appropriate timeout
var d net.Dialer
if c.Dialer == nil {
@@ -101,9 +107,17 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
if useTLS {
network = strings.TrimSuffix(network, "-tls")
+ // TODO(miekg): Enable after Go 1.18 is released, to be able to support two prev. releases.
+ /*
+ tlsDialer := tls.Dialer{
+ NetDialer: &d,
+ Config: c.TLSConfig,
+ }
+ conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
+ */
conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
} else {
- conn.Conn, err = d.Dial(network, address)
+ conn.Conn, err = d.DialContext(ctx, network, address)
}
if err != nil {
return nil, err
@@ -139,24 +153,34 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
// that will be used instead of creating a new one.
// Usage pattern with a *dns.Client:
+//
// c := new(dns.Client)
// // connection management logic goes here
//
// conn := c.Dial(address)
// in, rtt, err := c.ExchangeWithConn(message, conn)
//
-// This allows users of the library to implement their own connection management,
-// as opposed to Exchange, which will always use new connections and incur the added overhead
-// that entails when using "tcp" and especially "tcp-tls" clients.
+// This allows users of the library to implement their own connection management,
+// as opposed to Exchange, which will always use new connections and incur the added overhead
+// that entails when using "tcp" and especially "tcp-tls" clients.
+//
+// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to
+// prevent one cancelation from canceling all outstanding requests.
func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
+ return c.exchangeWithConnContext(context.Background(), m, conn)
+}
+
+func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight {
- return c.exchange(m, conn)
+ return c.exchangeContext(ctx, m, conn)
}
q := m.Question[0]
key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) {
- return c.exchange(m, conn)
+ // When we're doing singleflight we don't want one context cancelation, cancel _all_ outstanding queries.
+ // Hence we ignore the context and use Background().
+ return c.exchangeContext(context.Background(), m, conn)
})
if r != nil && shared {
r = r.Copy()
@@ -165,8 +189,7 @@ func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration
return r, rtt, err
}
-func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
-
+func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
@@ -177,15 +200,27 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro
co.UDPSize = c.UDPSize
}
- co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
- t := time.Now()
// write with the appropriate write timeout
- co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout())))
+ t := time.Now()
+ writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
+ readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
+ if deadline, ok := ctx.Deadline(); ok {
+ if deadline.Before(writeDeadline) {
+ writeDeadline = deadline
+ }
+ if deadline.Before(readDeadline) {
+ readDeadline = deadline
+ }
+ }
+ co.SetWriteDeadline(writeDeadline)
+ co.SetReadDeadline(readDeadline)
+
+ co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
+
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
- co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout())))
if _, ok := co.Conn.(net.PacketConn); ok {
for {
r, err = co.ReadMsg()
@@ -340,11 +375,10 @@ func (co *Conn) Write(p []byte) (int, error) {
return co.Conn.Write(p)
}
- l := make([]byte, 2)
- binary.BigEndian.PutUint16(l, uint16(len(p)))
-
- n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
- return int(n), err
+ msg := make([]byte, 2+len(p))
+ binary.BigEndian.PutUint16(msg, uint16(len(p)))
+ copy(msg[2:], p)
+ return co.Conn.Write(msg)
}
// Return the appropriate timeout for a specific request
@@ -380,7 +414,7 @@ func Dial(network, address string) (conn *Conn, err error) {
func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
client := Client{Net: "udp"}
r, _, err = client.ExchangeContext(ctx, m, a)
- // ignorint rtt to leave the original ExchangeContext API unchanged, but
+ // ignoring rtt to leave the original ExchangeContext API unchanged, but
// this function will go away
return r, err
}
@@ -436,15 +470,11 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
// context, if present. If there is both a context deadline and a configured
// timeout on the client, the earliest of the two takes effect.
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
- var timeout time.Duration
- if deadline, ok := ctx.Deadline(); !ok {
- timeout = 0
- } else {
- timeout = time.Until(deadline)
+ conn, err := c.DialContext(ctx, a)
+ if err != nil {
+ return nil, 0, err
}
- // not passing the context to the underlying calls, as the API does not support
- // context. For timeouts you should set up Client.Dialer and call Client.Exchange.
- // TODO(tmthrgd,miekg): this is a race condition.
- c.Dialer = &net.Dialer{Timeout: timeout}
- return c.Exchange(m, a)
+ defer conn.Close()
+
+ return c.exchangeWithConnContext(ctx, m, conn)
}