diff options
Diffstat (limited to 'vendor/github.com/miekg/dns/client.go')
-rw-r--r-- | vendor/github.com/miekg/dns/client.go | 306 |
1 files changed, 125 insertions, 181 deletions
diff --git a/vendor/github.com/miekg/dns/client.go b/vendor/github.com/miekg/dns/client.go index 856b169..aa2c49d 100644 --- a/vendor/github.com/miekg/dns/client.go +++ b/vendor/github.com/miekg/dns/client.go @@ -3,26 +3,27 @@ package dns // A client implementation. import ( - "bytes" "context" "crypto/tls" "encoding/binary" + "fmt" "io" "net" "strings" "time" ) -const dnsTimeout time.Duration = 2 * time.Second -const tcpIdleTimeout time.Duration = 8 * time.Second +const ( + dnsTimeout time.Duration = 2 * time.Second + tcpIdleTimeout time.Duration = 8 * time.Second +) // A Conn represents a connection to a DNS server. type Conn struct { net.Conn // a net.Conn holding the connection UDPSize uint16 // minimum receive buffer for UDP messages TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) - rtt time.Duration - t time.Time + TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. tsigRequestMAC string } @@ -34,12 +35,13 @@ type Client struct { Dialer *net.Dialer // a net.Dialer used to set local address, timeouts and more // Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout, // WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and - // Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext) + // Client.Dialer) or context.Context.Deadline (see ExchangeContext) Timeout time.Duration DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight } @@ -83,33 +85,22 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { // create a new dialer with the appropriate timeout var d net.Dialer if c.Dialer == nil { - d = net.Dialer{} + d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())} } else { - d = net.Dialer(*c.Dialer) - } - d.Timeout = c.getTimeoutForRequest(c.writeTimeout()) - - network := "udp" - useTLS := false - - switch c.Net { - case "tcp-tls": - network = "tcp" - useTLS = true - case "tcp4-tls": - network = "tcp4" - useTLS = true - case "tcp6-tls": - network = "tcp6" - useTLS = true - default: - if c.Net != "" { - network = c.Net - } + d = *c.Dialer + } + + network := c.Net + if network == "" { + network = "udp" } + useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls") + conn = new(Conn) if useTLS { + network = strings.TrimSuffix(network, "-tls") + conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig) } else { conn.Conn, err = d.Dial(network, address) @@ -117,6 +108,7 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { if err != nil { return nil, err } + conn.UDPSize = c.UDPSize return conn, nil } @@ -135,36 +127,45 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { // To specify a local address or a timeout, the caller has to set the `Client.Dialer` // attribute appropriately func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) { - if !c.SingleInflight { - return c.exchange(m, address) - } + co, err := c.Dial(address) - t := "nop" - if t1, ok := TypeToString[m.Question[0].Qtype]; ok { - t = t1 + if err != nil { + return nil, 0, err } - cl := "nop" - if cl1, ok := ClassToString[m.Question[0].Qclass]; ok { - cl = cl1 + defer co.Close() + return c.ExchangeWithConn(m, co) +} + +// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection +// that will be used instead of creating a new one. +// Usage pattern with a *dns.Client: +// c := new(dns.Client) +// // connection management logic goes here +// +// conn := c.Dial(address) +// in, rtt, err := c.ExchangeWithConn(message, conn) +// +// This allows users of the library to implement their own connection management, +// as opposed to Exchange, which will always use new connections and incur the added overhead +// that entails when using "tcp" and especially "tcp-tls" clients. +func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { + if !c.SingleInflight { + return c.exchange(m, conn) } - r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { - return c.exchange(m, address) + + q := m.Question[0] + key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass) + r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) { + return c.exchange(m, conn) }) if r != nil && shared { r = r.Copy() } + return r, rtt, err } -func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - var co *Conn - - co, err = c.Dial(a) - - if err != nil { - return nil, 0, err - } - defer co.Close() +func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { opt := m.IsEdns0() // If EDNS0 is used use that for size. @@ -176,19 +177,32 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro co.UDPSize = c.UDPSize } - co.TsigSecret = c.TsigSecret + co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider + t := time.Now() // write with the appropriate write timeout - co.SetWriteDeadline(time.Now().Add(c.getTimeoutForRequest(c.writeTimeout()))) + co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) if err = co.WriteMsg(m); err != nil { return nil, 0, err } co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) - r, err = co.ReadMsg() - if err == nil && r.Id != m.Id { - err = ErrId + if _, ok := co.Conn.(net.PacketConn); ok { + for { + r, err = co.ReadMsg() + // Ignore replies with mismatched IDs because they might be + // responses to earlier queries that timed out. + if err != nil || r.Id == m.Id { + break + } + } + } else { + r, err = co.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } } - return r, co.rtt, err + rtt = time.Since(t) + return r, rtt, err } // ReadMsg reads a message from the connection co. @@ -210,11 +224,15 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, err } if t := m.IsTsig(); t != nil { - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return m, ErrSecret + if co.TsigProvider != nil { + err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false) + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } return m, err } @@ -229,26 +247,21 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { err error ) - switch t := co.Conn.(type) { - case *net.TCPConn, *tls.Conn: - r := t.(io.Reader) - - // First two bytes specify the length of the entire message. - l, err := tcpMsgLen(r) - if err != nil { - return nil, err - } - p = make([]byte, l) - n, err = tcpRead(r, p) - co.rtt = time.Since(co.t) - default: + if _, ok := co.Conn.(net.PacketConn); ok { if co.UDPSize > MinMsgSize { p = make([]byte, co.UDPSize) } else { p = make([]byte, MinMsgSize) } n, err = co.Read(p) - co.rtt = time.Since(co.t) + } else { + var length uint16 + if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { + return nil, err + } + + p = make([]byte, length) + n, err = io.ReadFull(co.Conn, p) } if err != nil { @@ -268,78 +281,26 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { return p, err } -// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length. -func tcpMsgLen(t io.Reader) (int, error) { - p := []byte{0, 0} - n, err := t.Read(p) - if err != nil { - return 0, err - } - - // As seen with my local router/switch, returns 1 byte on the above read, - // resulting a a ShortRead. Just write it out (instead of loop) and read the - // other byte. - if n == 1 { - n1, err := t.Read(p[1:]) - if err != nil { - return 0, err - } - n += n1 - } - - if n != 2 { - return 0, ErrShortRead - } - l := binary.BigEndian.Uint16(p) - if l == 0 { - return 0, ErrShortRead - } - return int(l), nil -} - -// tcpRead calls TCPConn.Read enough times to fill allocated buffer. -func tcpRead(t io.Reader, p []byte) (int, error) { - n, err := t.Read(p) - if err != nil { - return n, err - } - for n < len(p) { - j, err := t.Read(p[n:]) - if err != nil { - return n, err - } - n += j - } - return n, err -} - // Read implements the net.Conn read method. func (co *Conn) Read(p []byte) (n int, err error) { if co.Conn == nil { return 0, ErrConnEmpty } - if len(p) < 2 { - return 0, io.ErrShortBuffer + + if _, ok := co.Conn.(net.PacketConn); ok { + // UDP connection + return co.Conn.Read(p) } - switch t := co.Conn.(type) { - case *net.TCPConn, *tls.Conn: - r := t.(io.Reader) - l, err := tcpMsgLen(r) - if err != nil { - return 0, err - } - if l > len(p) { - return int(l), io.ErrShortBuffer - } - return tcpRead(r, p[:l]) + var length uint16 + if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { + return 0, err } - // UDP connection - n, err = co.Conn.Read(p) - if err != nil { - return n, err + if int(length) > len(p) { + return 0, io.ErrShortBuffer } - return n, err + + return io.ReadFull(co.Conn, p[:length]) } // WriteMsg sends a message through the connection co. @@ -349,10 +310,14 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { mac := "" - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return ErrSecret + if co.TsigProvider != nil { + out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false) + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) // Set for the next read, although only used in zone transfers co.tsigRequestMAC = mac } else { @@ -361,34 +326,25 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { if err != nil { return err } - co.t = time.Now() - if _, err = co.Write(out); err != nil { - return err - } - return nil + _, err = co.Write(out) + return err } // Write implements the net.Conn Write method. -func (co *Conn) Write(p []byte) (n int, err error) { - switch t := co.Conn.(type) { - case *net.TCPConn, *tls.Conn: - w := t.(io.Writer) - - lp := len(p) - if lp < 2 { - return 0, io.ErrShortBuffer - } - if lp > MaxMsgSize { - return 0, &Error{err: "message too large"} - } - l := make([]byte, 2, lp+2) - binary.BigEndian.PutUint16(l, uint16(lp)) - p = append(l, p...) - n, err := io.Copy(w, bytes.NewReader(p)) - return int(n), err - } - n, err = co.Conn.Write(p) - return n, err +func (co *Conn) Write(p []byte) (int, error) { + if len(p) > MaxMsgSize { + return 0, &Error{err: "message too large"} + } + + if _, ok := co.Conn.(net.PacketConn); ok { + return co.Conn.Write(p) + } + + l := make([]byte, 2) + binary.BigEndian.PutUint16(l, uint16(len(p))) + + n, err := (&net.Buffers{l, p}).WriteTo(co.Conn) + return int(n), err } // Return the appropriate timeout for a specific request @@ -431,7 +387,7 @@ func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) // ExchangeConn performs a synchronous query. It sends the message m via the connection // c and waits for a reply. The connection c is not closed by ExchangeConn. -// This function is going away, but can easily be mimicked: +// Deprecated: This function is going away, but can easily be mimicked: // // co := &dns.Conn{Conn: c} // c is your net.Conn // co.WriteMsg(m) @@ -455,11 +411,7 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { // DialTimeout acts like Dial but takes a timeout. func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} - conn, err = client.Dial(address) - if err != nil { - return nil, err - } - return conn, nil + return client.Dial(address) } // DialWithTLS connects to the address on the named network with TLS. @@ -468,12 +420,7 @@ func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, er network += "-tls" } client := Client{Net: network, TLSConfig: tlsConfig} - conn, err = client.Dial(address) - - if err != nil { - return nil, err - } - return conn, nil + return client.Dial(address) } // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. @@ -482,11 +429,7 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout network += "-tls" } client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} - conn, err = client.Dial(address) - if err != nil { - return nil, err - } - return conn, nil + return client.Dial(address) } // ExchangeContext acts like Exchange, but honors the deadline on the provided @@ -497,10 +440,11 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, if deadline, ok := ctx.Deadline(); !ok { timeout = 0 } else { - timeout = deadline.Sub(time.Now()) + timeout = time.Until(deadline) } // not passing the context to the underlying calls, as the API does not support // context. For timeouts you should set up Client.Dialer and call Client.Exchange. + // TODO(tmthrgd,miekg): this is a race condition. c.Dialer = &net.Dialer{Timeout: timeout} return c.Exchange(m, a) } |