diff options
author | Tyler Davis <tydavis@gmail.com> | 2022-01-06 04:50:20 +0000 |
---|---|---|
committer | Tyler Davis <tydavis@gmail.com> | 2022-01-06 04:50:20 +0000 |
commit | b6385e524ebbfe242f9d36eb4361f8568e6bf895 (patch) | |
tree | 90d136278e520af555f64479973c287f39719361 /vendor/github.com/miekg/dns/client.go | |
parent | a687ebabb6589ebb36a9c385f583a19ac462b831 (diff) | |
download | dnstracker-b6385e524ebbfe242f9d36eb4361f8568e6bf895.tar.gz dnstracker-b6385e524ebbfe242f9d36eb4361f8568e6bf895.zip |
Update module to 1.17 and update all deps
Diffstat (limited to 'vendor/github.com/miekg/dns/client.go')
-rw-r--r-- | vendor/github.com/miekg/dns/client.go | 86 |
1 files changed, 58 insertions, 28 deletions
diff --git a/vendor/github.com/miekg/dns/client.go b/vendor/github.com/miekg/dns/client.go index aa2c49d..6bae3a1 100644 --- a/vendor/github.com/miekg/dns/client.go +++ b/vendor/github.com/miekg/dns/client.go @@ -82,6 +82,12 @@ func (c *Client) writeTimeout() time.Duration { // Dial connects to the address on the named network. func (c *Client) Dial(address string) (conn *Conn, err error) { + return c.DialContext(context.Background(), address) +} + +// DialContext connects to the address on the named network, with a context.Context. +// For TLS over TCP (DoT) the context isn't used yet. This will be enabled when Go 1.18 is released. +func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) { // create a new dialer with the appropriate timeout var d net.Dialer if c.Dialer == nil { @@ -101,9 +107,17 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { if useTLS { network = strings.TrimSuffix(network, "-tls") + // TODO(miekg): Enable after Go 1.18 is released, to be able to support two prev. releases. + /* + tlsDialer := tls.Dialer{ + NetDialer: &d, + Config: c.TLSConfig, + } + conn.Conn, err = tlsDialer.DialContext(ctx, network, address) + */ conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig) } else { - conn.Conn, err = d.Dial(network, address) + conn.Conn, err = d.DialContext(ctx, network, address) } if err != nil { return nil, err @@ -139,24 +153,34 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er // ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection // that will be used instead of creating a new one. // Usage pattern with a *dns.Client: +// // c := new(dns.Client) // // connection management logic goes here // // conn := c.Dial(address) // in, rtt, err := c.ExchangeWithConn(message, conn) // -// This allows users of the library to implement their own connection management, -// as opposed to Exchange, which will always use new connections and incur the added overhead -// that entails when using "tcp" and especially "tcp-tls" clients. +// This allows users of the library to implement their own connection management, +// as opposed to Exchange, which will always use new connections and incur the added overhead +// that entails when using "tcp" and especially "tcp-tls" clients. +// +// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to +// prevent one cancelation from canceling all outstanding requests. func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { + return c.exchangeWithConnContext(context.Background(), m, conn) +} + +func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { - return c.exchange(m, conn) + return c.exchangeContext(ctx, m, conn) } q := m.Question[0] key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass) r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) { - return c.exchange(m, conn) + // When we're doing singleflight we don't want one context cancelation, cancel _all_ outstanding queries. + // Hence we ignore the context and use Background(). + return c.exchangeContext(context.Background(), m, conn) }) if r != nil && shared { r = r.Copy() @@ -165,8 +189,7 @@ func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration return r, rtt, err } -func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { - +func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { opt := m.IsEdns0() // If EDNS0 is used use that for size. if opt != nil && opt.UDPSize() >= MinMsgSize { @@ -177,15 +200,27 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro co.UDPSize = c.UDPSize } - co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider - t := time.Now() // write with the appropriate write timeout - co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) + t := time.Now() + writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout())) + readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout())) + if deadline, ok := ctx.Deadline(); ok { + if deadline.Before(writeDeadline) { + writeDeadline = deadline + } + if deadline.Before(readDeadline) { + readDeadline = deadline + } + } + co.SetWriteDeadline(writeDeadline) + co.SetReadDeadline(readDeadline) + + co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider + if err = co.WriteMsg(m); err != nil { return nil, 0, err } - co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) if _, ok := co.Conn.(net.PacketConn); ok { for { r, err = co.ReadMsg() @@ -340,11 +375,10 @@ func (co *Conn) Write(p []byte) (int, error) { return co.Conn.Write(p) } - l := make([]byte, 2) - binary.BigEndian.PutUint16(l, uint16(len(p))) - - n, err := (&net.Buffers{l, p}).WriteTo(co.Conn) - return int(n), err + msg := make([]byte, 2+len(p)) + binary.BigEndian.PutUint16(msg, uint16(len(p))) + copy(msg[2:], p) + return co.Conn.Write(msg) } // Return the appropriate timeout for a specific request @@ -380,7 +414,7 @@ func Dial(network, address string) (conn *Conn, err error) { func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) { client := Client{Net: "udp"} r, _, err = client.ExchangeContext(ctx, m, a) - // ignorint rtt to leave the original ExchangeContext API unchanged, but + // ignoring rtt to leave the original ExchangeContext API unchanged, but // this function will go away return r, err } @@ -436,15 +470,11 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout // context, if present. If there is both a context deadline and a configured // timeout on the client, the earliest of the two takes effect. func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - var timeout time.Duration - if deadline, ok := ctx.Deadline(); !ok { - timeout = 0 - } else { - timeout = time.Until(deadline) + conn, err := c.DialContext(ctx, a) + if err != nil { + return nil, 0, err } - // not passing the context to the underlying calls, as the API does not support - // context. For timeouts you should set up Client.Dialer and call Client.Exchange. - // TODO(tmthrgd,miekg): this is a race condition. - c.Dialer = &net.Dialer{Timeout: timeout} - return c.Exchange(m, a) + defer conn.Close() + + return c.exchangeWithConnContext(ctx, m, conn) } |