diff options
Diffstat (limited to 'vendor/github.com/miekg/dns/server.go')
-rw-r--r-- | vendor/github.com/miekg/dns/server.go | 779 |
1 files changed, 444 insertions, 335 deletions
diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go index 685753f..30dfd41 100644 --- a/vendor/github.com/miekg/dns/server.go +++ b/vendor/github.com/miekg/dns/server.go @@ -3,23 +3,40 @@ package dns import ( - "bytes" + "context" "crypto/tls" "encoding/binary" + "errors" "io" "net" + "strings" "sync" "time" ) -// Maximum number of TCP queries before we close the socket. +// Default maximum number of TCP queries before we close the socket. const maxTCPQueries = 128 +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancelation of network operations. +var aLongTimeAgo = time.Unix(1, 0) + // Handler is implemented by any value that implements ServeDNS. type Handler interface { ServeDNS(w ResponseWriter, r *Msg) } +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as DNS handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler object that calls f. +type HandlerFunc func(ResponseWriter, *Msg) + +// ServeDNS calls f(w, r). +func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { + f(w, r) +} + // A ResponseWriter interface is used by an DNS handler to // construct an DNS response. type ResponseWriter interface { @@ -42,49 +59,35 @@ type ResponseWriter interface { Hijack() } +// A ConnectionStater interface is used by a DNS Handler to access TLS connection state +// when available. +type ConnectionStater interface { + ConnectionState() *tls.ConnectionState +} + type response struct { + closed bool // connection has been closed hijacked bool // connection has been hijacked by handler - tsigStatus error tsigTimersOnly bool + tsigStatus error tsigRequestMAC string tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used + udp net.PacketConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right - remoteAddr net.Addr // address of the client + pcSession net.Addr // address to use when writing to a generic net.PacketConn writer Writer // writer to output the raw DNS bits } -// ServeMux is an DNS request multiplexer. It matches the -// zone name of each incoming request against a list of -// registered patterns add calls the handler for the pattern -// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning -// that queries for the DS record are redirected to the parent zone (if that -// is also registered), otherwise the child gets the query. -// ServeMux is also safe for concurrent access from multiple goroutines. -type ServeMux struct { - z map[string]Handler - m *sync.RWMutex -} - -// NewServeMux allocates and returns a new ServeMux. -func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} } - -// DefaultServeMux is the default ServeMux used by Serve. -var DefaultServeMux = NewServeMux() - -// The HandlerFunc type is an adapter to allow the use of -// ordinary functions as DNS handlers. If f is a function -// with the appropriate signature, HandlerFunc(f) is a -// Handler object that calls f. -type HandlerFunc func(ResponseWriter, *Msg) - -// ServeDNS calls f(w, r). -func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { - f(w, r) +// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets. +func handleRefused(w ResponseWriter, r *Msg) { + m := new(Msg) + m.SetRcode(r, RcodeRefused) + w.WriteMsg(m) } // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets. +// Deprecated: This function is going away. func HandleFailed(w ResponseWriter, r *Msg) { m := new(Msg) m.SetRcode(r, RcodeServerFailure) @@ -92,8 +95,6 @@ func HandleFailed(w ResponseWriter, r *Msg) { w.WriteMsg(m) } -func failedHandler() Handler { return HandlerFunc(HandleFailed) } - // ListenAndServe Starts a server on address and network specified Invoke handler // for incoming queries. func ListenAndServe(addr string, network string, handler Handler) error { @@ -132,99 +133,6 @@ func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error { return server.ActivateAndServe() } -func (mux *ServeMux) match(q string, t uint16) Handler { - mux.m.RLock() - defer mux.m.RUnlock() - var handler Handler - b := make([]byte, len(q)) // worst case, one label of length q - off := 0 - end := false - for { - l := len(q[off:]) - for i := 0; i < l; i++ { - b[i] = q[off+i] - if b[i] >= 'A' && b[i] <= 'Z' { - b[i] |= ('a' - 'A') - } - } - if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key - if t != TypeDS { - return h - } - // Continue for DS to see if we have a parent too, if so delegeate to the parent - handler = h - } - off, end = NextLabel(q, off) - if end { - break - } - } - // Wildcard match, if we have found nothing try the root zone as a last resort. - if h, ok := mux.z["."]; ok { - return h - } - return handler -} - -// Handle adds a handler to the ServeMux for pattern. -func (mux *ServeMux) Handle(pattern string, handler Handler) { - if pattern == "" { - panic("dns: invalid pattern " + pattern) - } - mux.m.Lock() - mux.z[Fqdn(pattern)] = handler - mux.m.Unlock() -} - -// HandleFunc adds a handler function to the ServeMux for pattern. -func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { - mux.Handle(pattern, HandlerFunc(handler)) -} - -// HandleRemove deregistrars the handler specific for pattern from the ServeMux. -func (mux *ServeMux) HandleRemove(pattern string) { - if pattern == "" { - panic("dns: invalid pattern " + pattern) - } - mux.m.Lock() - delete(mux.z, Fqdn(pattern)) - mux.m.Unlock() -} - -// ServeDNS dispatches the request to the handler whose -// pattern most closely matches the request message. If DefaultServeMux -// is used the correct thing for DS queries is done: a possible parent -// is sought. -// If no handler is found a standard SERVFAIL message is returned -// If the request message does not have exactly one question in the -// question section a SERVFAIL is returned, unlesss Unsafe is true. -func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) { - var h Handler - if len(request.Question) < 1 { // allow more than one question - h = failedHandler() - } else { - if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil { - h = failedHandler() - } - } - h.ServeDNS(w, request) -} - -// Handle registers the handler with the given pattern -// in the DefaultServeMux. The documentation for -// ServeMux explains how patterns are matched. -func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } - -// HandleRemove deregisters the handle with the given pattern -// in the DefaultServeMux. -func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) } - -// HandleFunc registers the handler function with the given pattern -// in the DefaultServeMux. -func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { - DefaultServeMux.HandleFunc(pattern, handler) -} - // Writer writes raw DNS messages; each call to Write should send an entire message. type Writer interface { io.Writer @@ -240,22 +148,40 @@ type Reader interface { ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) } -// defaultReader is an adapter for the Server struct that implements the Reader interface -// using the readTCP and readUDP func of the embedded Server. +// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns. +type PacketConnReader interface { + Reader + + // ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may + // alter connection properties, for example the read-deadline. + ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) +} + +// defaultReader is an adapter for the Server struct that implements the Reader and +// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs +// of the embedded Server. type defaultReader struct { *Server } -func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { +var _ PacketConnReader = defaultReader{} + +func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { return dr.readTCP(conn, timeout) } -func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { +func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { return dr.readUDP(conn, timeout) } +func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + return dr.readPacketConn(conn, timeout) +} + // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. // Implementations should never return a nil Reader. +// Readers should also implement the optional PacketConnReader interface. +// PacketConnReader is required to use a generic net.PacketConn. type DecorateReader func(Reader) Reader // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. @@ -287,87 +213,120 @@ type Server struct { IdleTimeout func() time.Duration // Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). TsigSecret map[string]string - // Unsafe instructs the server to disregard any sanity checks and directly hand the message to - // the handler. It will specifically not check if the query has the QR bit not set. - Unsafe bool // If NotifyStartedFunc is set it is called once the server has started listening. NotifyStartedFunc func() // DecorateReader is optional, allows customization of the process that reads raw DNS messages. DecorateReader DecorateReader // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. DecorateWriter DecorateWriter + // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1). + MaxTCPQueries int + // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address. + // It is only supported on go1.11+ and when using ListenAndServe. + ReusePort bool + // AcceptMsgFunc will check the incoming message and will reject it early in the process. + // By default DefaultMsgAcceptFunc will be used. + MsgAcceptFunc MsgAcceptFunc // Shutdown handling - lock sync.RWMutex - started bool + lock sync.RWMutex + started bool + shutdown chan struct{} + conns map[net.Conn]struct{} + + // A pool for UDP message buffers. + udpPool sync.Pool +} + +func (srv *Server) isStarted() bool { + srv.lock.RLock() + started := srv.started + srv.lock.RUnlock() + return started +} + +func makeUDPBuffer(size int) func() interface{} { + return func() interface{} { + return make([]byte, size) + } +} + +func (srv *Server) init() { + srv.shutdown = make(chan struct{}) + srv.conns = make(map[net.Conn]struct{}) + + if srv.UDPSize == 0 { + srv.UDPSize = MinMsgSize + } + if srv.MsgAcceptFunc == nil { + srv.MsgAcceptFunc = DefaultMsgAcceptFunc + } + if srv.Handler == nil { + srv.Handler = DefaultServeMux + } + + srv.udpPool.New = makeUDPBuffer(srv.UDPSize) +} + +func unlockOnce(l sync.Locker) func() { + var once sync.Once + return func() { once.Do(l.Unlock) } } // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { + unlock := unlockOnce(&srv.lock) srv.lock.Lock() - defer srv.lock.Unlock() + defer unlock() + if srv.started { return &Error{err: "server already started"} } + addr := srv.Addr if addr == "" { addr = ":domain" } - if srv.UDPSize == 0 { - srv.UDPSize = MinMsgSize - } + + srv.init() + switch srv.Net { case "tcp", "tcp4", "tcp6": - a, err := net.ResolveTCPAddr(srv.Net, addr) - if err != nil { - return err - } - l, err := net.ListenTCP(srv.Net, a) + l, err := listenTCP(srv.Net, addr, srv.ReusePort) if err != nil { return err } srv.Listener = l srv.started = true - srv.lock.Unlock() - err = srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top - return err + unlock() + return srv.serveTCP(l) case "tcp-tls", "tcp4-tls", "tcp6-tls": - network := "tcp" - if srv.Net == "tcp4-tls" { - network = "tcp4" - } else if srv.Net == "tcp6-tls" { - network = "tcp6" + if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { + return errors.New("dns: neither Certificates nor GetCertificate set in Config") } - - l, err := tls.Listen(network, addr, srv.TLSConfig) + network := strings.TrimSuffix(srv.Net, "-tls") + l, err := listenTCP(network, addr, srv.ReusePort) if err != nil { return err } + l = tls.NewListener(l, srv.TLSConfig) srv.Listener = l srv.started = true - srv.lock.Unlock() - err = srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top - return err + unlock() + return srv.serveTCP(l) case "udp", "udp4", "udp6": - a, err := net.ResolveUDPAddr(srv.Net, addr) + l, err := listenUDP(srv.Net, addr, srv.ReusePort) if err != nil { return err } - l, err := net.ListenUDP(srv.Net, a) - if err != nil { - return err - } - if e := setUDPSocketOptions(l); e != nil { + u := l.(*net.UDPConn) + if e := setUDPSocketOptions(u); e != nil { return e } srv.PacketConn = l srv.started = true - srv.lock.Unlock() - err = srv.serveUDP(l) - srv.lock.Lock() // to satisfy the defer at the top - return err + unlock() + return srv.serveUDP(u) } return &Error{err: "bad network"} } @@ -375,36 +334,32 @@ func (srv *Server) ListenAndServe() error { // ActivateAndServe starts a nameserver with the PacketConn or Listener // configured in *Server. Its main use is to start a server from systemd. func (srv *Server) ActivateAndServe() error { + unlock := unlockOnce(&srv.lock) srv.lock.Lock() - defer srv.lock.Unlock() + defer unlock() + if srv.started { return &Error{err: "server already started"} } - pConn := srv.PacketConn - l := srv.Listener - if pConn != nil { - if srv.UDPSize == 0 { - srv.UDPSize = MinMsgSize - } + + srv.init() + + if srv.PacketConn != nil { // Check PacketConn interface's type is valid and value // is not nil - if t, ok := pConn.(*net.UDPConn); ok && t != nil { + if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil { if e := setUDPSocketOptions(t); e != nil { return e } - srv.started = true - srv.lock.Unlock() - e := srv.serveUDP(t) - srv.lock.Lock() // to satisfy the defer at the top - return e } + srv.started = true + unlock() + return srv.serveUDP(srv.PacketConn) } - if l != nil { + if srv.Listener != nil { srv.started = true - srv.lock.Unlock() - e := srv.serveTCP(l) - srv.lock.Lock() // to satisfy the defer at the top - return e + unlock() + return srv.serveTCP(srv.Listener) } return &Error{err: "bad listeners"} } @@ -412,34 +367,66 @@ func (srv *Server) ActivateAndServe() error { // Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and // ActivateAndServe will return. func (srv *Server) Shutdown() error { + return srv.ShutdownContext(context.Background()) +} + +// ShutdownContext shuts down a server. After a call to ShutdownContext, +// ListenAndServe and ActivateAndServe will return. +// +// A context.Context may be passed to limit how long to wait for connections +// to terminate. +func (srv *Server) ShutdownContext(ctx context.Context) error { srv.lock.Lock() if !srv.started { srv.lock.Unlock() return &Error{err: "server not started"} } + srv.started = false - srv.lock.Unlock() if srv.PacketConn != nil { - srv.PacketConn.Close() + srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads } + if srv.Listener != nil { srv.Listener.Close() } - return nil + + for rw := range srv.conns { + rw.SetReadDeadline(aLongTimeAgo) // Unblock reads + } + + srv.lock.Unlock() + + if testShutdownNotify != nil { + testShutdownNotify.Broadcast() + } + + var ctxErr error + select { + case <-srv.shutdown: + case <-ctx.Done(): + ctxErr = ctx.Err() + } + + if srv.PacketConn != nil { + srv.PacketConn.Close() + } + + return ctxErr } +var testShutdownNotify *sync.Cond + // getReadTimeout is a helper func to use system timeout if server did not intend to change it. func (srv *Server) getReadTimeout() time.Duration { - rtimeout := dnsTimeout if srv.ReadTimeout != 0 { - rtimeout = srv.ReadTimeout + return srv.ReadTimeout } - return rtimeout + return dnsTimeout } // serveTCP starts a TCP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveTCP(l net.Listener) error { defer l.Close() @@ -447,203 +434,288 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) - } + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(srv.shutdown) + }() - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } - rtimeout := srv.getReadTimeout() - // deadline is not used here - for { + for srv.isStarted() { rw, err := l.Accept() - srv.lock.RLock() - if !srv.started { - srv.lock.RUnlock() - return nil - } - srv.lock.RUnlock() if err != nil { + if !srv.isStarted() { + return nil + } if neterr, ok := err.(net.Error); ok && neterr.Temporary() { continue } return err } - go func() { - m, err := reader.ReadTCP(rw, rtimeout) - if err != nil { - rw.Close() - return - } - srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) - }() + srv.lock.Lock() + // Track the connection to allow unblocking reads on shutdown. + srv.conns[rw] = struct{}{} + srv.lock.Unlock() + wg.Add(1) + go srv.serveTCPConn(&wg, rw) } + + return nil } // serveUDP starts a UDP listener for the server. -// Each request is handled in a separate goroutine. -func (srv *Server) serveUDP(l *net.UDPConn) error { +func (srv *Server) serveUDP(l net.PacketConn) error { defer l.Close() - if srv.NotifyStartedFunc != nil { - srv.NotifyStartedFunc() - } - - reader := Reader(&defaultReader{srv}) + reader := Reader(defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux + lUDP, isUDP := l.(*net.UDPConn) + readerPC, canPacketConn := reader.(PacketConnReader) + if !isUDP && !canPacketConn { + return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"} + } + + if srv.NotifyStartedFunc != nil { + srv.NotifyStartedFunc() } + + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(srv.shutdown) + }() + rtimeout := srv.getReadTimeout() // deadline is not used here - for { - m, s, err := reader.ReadUDP(l, rtimeout) - srv.lock.RLock() - if !srv.started { - srv.lock.RUnlock() - return nil + for srv.isStarted() { + var ( + m []byte + sPC net.Addr + sUDP *SessionUDP + err error + ) + if isUDP { + m, sUDP, err = reader.ReadUDP(lUDP, rtimeout) + } else { + m, sPC, err = readerPC.ReadPacketConn(l, rtimeout) } - srv.lock.RUnlock() if err != nil { + if !srv.isStarted() { + return nil + } if netErr, ok := err.(net.Error); ok && netErr.Temporary() { continue } return err } if len(m) < headerSize { + if cap(m) == srv.UDPSize { + srv.udpPool.Put(m[:srv.UDPSize]) + } continue } - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + wg.Add(1) + go srv.serveUDPPacket(&wg, m, l, sUDP, sPC) } + + return nil } -// Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} +// Serve a new TCP connection. +func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { + w := &response{tsigSecret: srv.TsigSecret, tcp: rw} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { w.writer = w } - q := 0 // counter for the amount of TCP queries we get - - reader := Reader(&defaultReader{srv}) + reader := Reader(defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } -Redo: - req := new(Msg) - err := req.Unpack(m) - if err != nil { // Send a FormatError back - x := new(Msg) - x.SetRcodeFormatError(req) - w.WriteMsg(x) - goto Exit + + idleTimeout := tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() } - if !srv.Unsafe && req.Response { - goto Exit + + timeout := srv.getReadTimeout() + + limit := srv.MaxTCPQueries + if limit == 0 { + limit = maxTCPQueries } - w.tsigStatus = nil - if w.tsigSecret != nil { - if t := req.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.tsigSecret[secret]; !ok { - w.tsigStatus = ErrKeyAlg - } - w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) - w.tsigTimersOnly = false - w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC + for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ { + m, err := reader.ReadTCP(w.tcp, timeout) + if err != nil { + // TODO(tmthrgd): handle error + break + } + srv.serveDNS(m, w) + if w.closed { + break // Close() was called + } + if w.hijacked { + break // client will call Close() themselves } + // The first read uses the read timeout, the rest use the + // idle timeout. + timeout = idleTimeout } - h.ServeDNS(w, req) // Writes back to the client -Exit: - if w.tcp == nil { - return - } - // TODO(miek): make this number configurable? - if q > maxTCPQueries { // close socket after this many queries + if !w.hijacked { w.Close() - return } - if w.hijacked { - return // client calls Close() + srv.lock.Lock() + delete(srv.conns, w.tcp) + srv.lock.Unlock() + + wg.Done() +} + +// Serve a new UDP request. +func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession} + if srv.DecorateWriter != nil { + w.writer = srv.DecorateWriter(w) + } else { + w.writer = w } - if u != nil { // UDP, "close" and return - w.Close() + + srv.serveDNS(m, w) + wg.Done() +} + +func (srv *Server) serveDNS(m []byte, w *response) { + dh, off, err := unpackMsgHdr(m, 0) + if err != nil { + // Let client hang, they are sending crap; any reply can be used to amplify. return } - idleTimeout := tcpIdleTimeout - if srv.IdleTimeout != nil { - idleTimeout = srv.IdleTimeout() + + req := new(Msg) + req.setHdr(dh) + + switch action := srv.MsgAcceptFunc(dh); action { + case MsgAccept: + if req.unpack(dh, m, off) == nil { + break + } + + fallthrough + case MsgReject, MsgRejectNotImplemented: + opcode := req.Opcode + req.SetRcodeFormatError(req) + req.Zero = false + if action == MsgRejectNotImplemented { + req.Opcode = opcode + req.Rcode = RcodeNotImplemented + } + + // Are we allowed to delete any OPT records here? + req.Ns, req.Answer, req.Extra = nil, nil, nil + + w.WriteMsg(req) + fallthrough + case MsgIgnore: + if w.udp != nil && cap(m) == srv.UDPSize { + srv.udpPool.Put(m[:srv.UDPSize]) + } + + return } - m, err = reader.ReadTCP(w.tcp, idleTimeout) - if err == nil { - q++ - goto Redo + + w.tsigStatus = nil + if w.tsigSecret != nil { + if t := req.IsTsig(); t != nil { + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerify(m, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } + w.tsigTimersOnly = false + w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC + } } - w.Close() - return + + if w.udp != nil && cap(m) == srv.UDPSize { + srv.udpPool.Put(m[:srv.UDPSize]) + } + + srv.Handler.ServeDNS(w, req) // Writes back to the client } func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { - conn.SetReadDeadline(time.Now().Add(timeout)) - l := make([]byte, 2) - n, err := conn.Read(l) - if err != nil || n != 2 { - if err != nil { - return nil, err - } - return nil, ErrShortRead - } - length := binary.BigEndian.Uint16(l) - if length == 0 { - return nil, ErrShortRead + // If we race with ShutdownContext, the read deadline may + // have been set in the distant past to unblock the read + // below. We must not override it, otherwise we may block + // ShutdownContext. + srv.lock.RLock() + if srv.started { + conn.SetReadDeadline(time.Now().Add(timeout)) } - m := make([]byte, int(length)) - n, err = conn.Read(m[:int(length)]) - if err != nil || n == 0 { - if err != nil { - return nil, err - } - return nil, ErrShortRead + srv.lock.RUnlock() + + var length uint16 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + return nil, err } - i := n - for i < int(length) { - j, err := conn.Read(m[i:int(length)]) - if err != nil { - return nil, err - } - i += j + + m := make([]byte, length) + if _, err := io.ReadFull(conn, m); err != nil { + return nil, err } - n = i - m = m[:n] + return m, nil } func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { - conn.SetReadDeadline(time.Now().Add(timeout)) - m := make([]byte, srv.UDPSize) + srv.lock.RLock() + if srv.started { + // See the comment in readTCP above. + conn.SetReadDeadline(time.Now().Add(timeout)) + } + srv.lock.RUnlock() + + m := srv.udpPool.Get().([]byte) n, s, err := ReadFromSessionUDP(conn, m) if err != nil { + srv.udpPool.Put(m) return nil, nil, err } m = m[:n] return m, s, nil } +func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + srv.lock.RLock() + if srv.started { + // See the comment in readTCP above. + conn.SetReadDeadline(time.Now().Add(timeout)) + } + srv.lock.RUnlock() + + m := srv.udpPool.Get().([]byte) + n, addr, err := conn.ReadFrom(m) + if err != nil { + srv.udpPool.Put(m) + return nil, nil, err + } + m = m[:n] + return m, addr, nil +} + // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { + if w.closed { + return &Error{err: "WriteMsg called after Close"} + } + var data []byte if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { @@ -665,38 +737,56 @@ func (w *response) WriteMsg(m *Msg) (err error) { // Write implements the ResponseWriter.Write method. func (w *response) Write(m []byte) (int, error) { + if w.closed { + return 0, &Error{err: "Write called after Close"} + } + switch { case w.udp != nil: - n, err := WriteToSessionUDP(w.udp, m, w.udpSession) - return n, err - case w.tcp != nil: - lm := len(m) - if lm < 2 { - return 0, io.ErrShortBuffer + if u, ok := w.udp.(*net.UDPConn); ok { + return WriteToSessionUDP(u, m, w.udpSession) } - if lm > MaxMsgSize { + return w.udp.WriteTo(m, w.pcSession) + case w.tcp != nil: + if len(m) > MaxMsgSize { return 0, &Error{err: "message too large"} } - l := make([]byte, 2, 2+lm) - binary.BigEndian.PutUint16(l, uint16(lm)) - m = append(l, m...) - n, err := io.Copy(w.tcp, bytes.NewReader(m)) + l := make([]byte, 2) + binary.BigEndian.PutUint16(l, uint16(len(m))) + + n, err := (&net.Buffers{l, m}).WriteTo(w.tcp) return int(n), err + default: + panic("dns: internal error: udp and tcp both nil") } - panic("not reached") } // LocalAddr implements the ResponseWriter.LocalAddr method. func (w *response) LocalAddr() net.Addr { - if w.tcp != nil { + switch { + case w.udp != nil: + return w.udp.LocalAddr() + case w.tcp != nil: return w.tcp.LocalAddr() + default: + panic("dns: internal error: udp and tcp both nil") } - return w.udp.LocalAddr() } // RemoteAddr implements the ResponseWriter.RemoteAddr method. -func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } +func (w *response) RemoteAddr() net.Addr { + switch { + case w.udpSession != nil: + return w.udpSession.RemoteAddr() + case w.pcSession != nil: + return w.pcSession + case w.tcp != nil: + return w.tcp.RemoteAddr() + default: + panic("dns: internal error: udpSession, pcSession and tcp are all nil") + } +} // TsigStatus implements the ResponseWriter.TsigStatus method. func (w *response) TsigStatus() error { return w.tsigStatus } @@ -709,11 +799,30 @@ func (w *response) Hijack() { w.hijacked = true } // Close implements the ResponseWriter.Close method func (w *response) Close() error { - // Can't close the udp conn, as that is actually the listener. - if w.tcp != nil { - e := w.tcp.Close() - w.tcp = nil - return e + if w.closed { + return &Error{err: "connection already closed"} + } + w.closed = true + + switch { + case w.udp != nil: + // Can't close the udp conn, as that is actually the listener. + return nil + case w.tcp != nil: + return w.tcp.Close() + default: + panic("dns: internal error: udp and tcp both nil") + } +} + +// ConnectionState() implements the ConnectionStater.ConnectionState() interface. +func (w *response) ConnectionState() *tls.ConnectionState { + type tlsConnectionStater interface { + ConnectionState() tls.ConnectionState + } + if v, ok := w.tcp.(tlsConnectionStater); ok { + t := v.ConnectionState() + return &t } return nil } |