aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/server.go')
-rw-r--r--vendor/github.com/miekg/dns/server.go779
1 files changed, 444 insertions, 335 deletions
diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go
index 685753f..30dfd41 100644
--- a/vendor/github.com/miekg/dns/server.go
+++ b/vendor/github.com/miekg/dns/server.go
@@ -3,23 +3,40 @@
package dns
import (
- "bytes"
+ "context"
"crypto/tls"
"encoding/binary"
+ "errors"
"io"
"net"
+ "strings"
"sync"
"time"
)
-// Maximum number of TCP queries before we close the socket.
+// Default maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
+// aLongTimeAgo is a non-zero time, far in the past, used for
+// immediate cancelation of network operations.
+var aLongTimeAgo = time.Unix(1, 0)
+
// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
}
+// The HandlerFunc type is an adapter to allow the use of
+// ordinary functions as DNS handlers. If f is a function
+// with the appropriate signature, HandlerFunc(f) is a
+// Handler object that calls f.
+type HandlerFunc func(ResponseWriter, *Msg)
+
+// ServeDNS calls f(w, r).
+func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
+ f(w, r)
+}
+
// A ResponseWriter interface is used by an DNS handler to
// construct an DNS response.
type ResponseWriter interface {
@@ -42,49 +59,35 @@ type ResponseWriter interface {
Hijack()
}
+// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
+// when available.
+type ConnectionStater interface {
+ ConnectionState() *tls.ConnectionState
+}
+
type response struct {
+ closed bool // connection has been closed
hijacked bool // connection has been hijacked by handler
- tsigStatus error
tsigTimersOnly bool
+ tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
- udp *net.UDPConn // i/o connection if UDP was used
+ udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
- remoteAddr net.Addr // address of the client
+ pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}
-// ServeMux is an DNS request multiplexer. It matches the
-// zone name of each incoming request against a list of
-// registered patterns add calls the handler for the pattern
-// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
-// that queries for the DS record are redirected to the parent zone (if that
-// is also registered), otherwise the child gets the query.
-// ServeMux is also safe for concurrent access from multiple goroutines.
-type ServeMux struct {
- z map[string]Handler
- m *sync.RWMutex
-}
-
-// NewServeMux allocates and returns a new ServeMux.
-func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
-
-// DefaultServeMux is the default ServeMux used by Serve.
-var DefaultServeMux = NewServeMux()
-
-// The HandlerFunc type is an adapter to allow the use of
-// ordinary functions as DNS handlers. If f is a function
-// with the appropriate signature, HandlerFunc(f) is a
-// Handler object that calls f.
-type HandlerFunc func(ResponseWriter, *Msg)
-
-// ServeDNS calls f(w, r).
-func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
- f(w, r)
+// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
+func handleRefused(w ResponseWriter, r *Msg) {
+ m := new(Msg)
+ m.SetRcode(r, RcodeRefused)
+ w.WriteMsg(m)
}
// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
+// Deprecated: This function is going away.
func HandleFailed(w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeServerFailure)
@@ -92,8 +95,6 @@ func HandleFailed(w ResponseWriter, r *Msg) {
w.WriteMsg(m)
}
-func failedHandler() Handler { return HandlerFunc(HandleFailed) }
-
// ListenAndServe Starts a server on address and network specified Invoke handler
// for incoming queries.
func ListenAndServe(addr string, network string, handler Handler) error {
@@ -132,99 +133,6 @@ func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
return server.ActivateAndServe()
}
-func (mux *ServeMux) match(q string, t uint16) Handler {
- mux.m.RLock()
- defer mux.m.RUnlock()
- var handler Handler
- b := make([]byte, len(q)) // worst case, one label of length q
- off := 0
- end := false
- for {
- l := len(q[off:])
- for i := 0; i < l; i++ {
- b[i] = q[off+i]
- if b[i] >= 'A' && b[i] <= 'Z' {
- b[i] |= ('a' - 'A')
- }
- }
- if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key
- if t != TypeDS {
- return h
- }
- // Continue for DS to see if we have a parent too, if so delegeate to the parent
- handler = h
- }
- off, end = NextLabel(q, off)
- if end {
- break
- }
- }
- // Wildcard match, if we have found nothing try the root zone as a last resort.
- if h, ok := mux.z["."]; ok {
- return h
- }
- return handler
-}
-
-// Handle adds a handler to the ServeMux for pattern.
-func (mux *ServeMux) Handle(pattern string, handler Handler) {
- if pattern == "" {
- panic("dns: invalid pattern " + pattern)
- }
- mux.m.Lock()
- mux.z[Fqdn(pattern)] = handler
- mux.m.Unlock()
-}
-
-// HandleFunc adds a handler function to the ServeMux for pattern.
-func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
- mux.Handle(pattern, HandlerFunc(handler))
-}
-
-// HandleRemove deregistrars the handler specific for pattern from the ServeMux.
-func (mux *ServeMux) HandleRemove(pattern string) {
- if pattern == "" {
- panic("dns: invalid pattern " + pattern)
- }
- mux.m.Lock()
- delete(mux.z, Fqdn(pattern))
- mux.m.Unlock()
-}
-
-// ServeDNS dispatches the request to the handler whose
-// pattern most closely matches the request message. If DefaultServeMux
-// is used the correct thing for DS queries is done: a possible parent
-// is sought.
-// If no handler is found a standard SERVFAIL message is returned
-// If the request message does not have exactly one question in the
-// question section a SERVFAIL is returned, unlesss Unsafe is true.
-func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
- var h Handler
- if len(request.Question) < 1 { // allow more than one question
- h = failedHandler()
- } else {
- if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
- h = failedHandler()
- }
- }
- h.ServeDNS(w, request)
-}
-
-// Handle registers the handler with the given pattern
-// in the DefaultServeMux. The documentation for
-// ServeMux explains how patterns are matched.
-func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
-
-// HandleRemove deregisters the handle with the given pattern
-// in the DefaultServeMux.
-func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
-
-// HandleFunc registers the handler function with the given pattern
-// in the DefaultServeMux.
-func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
- DefaultServeMux.HandleFunc(pattern, handler)
-}
-
// Writer writes raw DNS messages; each call to Write should send an entire message.
type Writer interface {
io.Writer
@@ -240,22 +148,40 @@ type Reader interface {
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
}
-// defaultReader is an adapter for the Server struct that implements the Reader interface
-// using the readTCP and readUDP func of the embedded Server.
+// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
+type PacketConnReader interface {
+ Reader
+
+ // ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
+ // alter connection properties, for example the read-deadline.
+ ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
+}
+
+// defaultReader is an adapter for the Server struct that implements the Reader and
+// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
+// of the embedded Server.
type defaultReader struct {
*Server
}
-func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
+var _ PacketConnReader = defaultReader{}
+
+func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout)
}
-func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
+func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
return dr.readUDP(conn, timeout)
}
+func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+ return dr.readPacketConn(conn, timeout)
+}
+
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
// Implementations should never return a nil Reader.
+// Readers should also implement the optional PacketConnReader interface.
+// PacketConnReader is required to use a generic net.PacketConn.
type DecorateReader func(Reader) Reader
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
@@ -287,87 +213,120 @@ type Server struct {
IdleTimeout func() time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
TsigSecret map[string]string
- // Unsafe instructs the server to disregard any sanity checks and directly hand the message to
- // the handler. It will specifically not check if the query has the QR bit not set.
- Unsafe bool
// If NotifyStartedFunc is set it is called once the server has started listening.
NotifyStartedFunc func()
// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
DecorateReader DecorateReader
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter
+ // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
+ MaxTCPQueries int
+ // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
+ // It is only supported on go1.11+ and when using ListenAndServe.
+ ReusePort bool
+ // AcceptMsgFunc will check the incoming message and will reject it early in the process.
+ // By default DefaultMsgAcceptFunc will be used.
+ MsgAcceptFunc MsgAcceptFunc
// Shutdown handling
- lock sync.RWMutex
- started bool
+ lock sync.RWMutex
+ started bool
+ shutdown chan struct{}
+ conns map[net.Conn]struct{}
+
+ // A pool for UDP message buffers.
+ udpPool sync.Pool
+}
+
+func (srv *Server) isStarted() bool {
+ srv.lock.RLock()
+ started := srv.started
+ srv.lock.RUnlock()
+ return started
+}
+
+func makeUDPBuffer(size int) func() interface{} {
+ return func() interface{} {
+ return make([]byte, size)
+ }
+}
+
+func (srv *Server) init() {
+ srv.shutdown = make(chan struct{})
+ srv.conns = make(map[net.Conn]struct{})
+
+ if srv.UDPSize == 0 {
+ srv.UDPSize = MinMsgSize
+ }
+ if srv.MsgAcceptFunc == nil {
+ srv.MsgAcceptFunc = DefaultMsgAcceptFunc
+ }
+ if srv.Handler == nil {
+ srv.Handler = DefaultServeMux
+ }
+
+ srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
+}
+
+func unlockOnce(l sync.Locker) func() {
+ var once sync.Once
+ return func() { once.Do(l.Unlock) }
}
// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
+ unlock := unlockOnce(&srv.lock)
srv.lock.Lock()
- defer srv.lock.Unlock()
+ defer unlock()
+
if srv.started {
return &Error{err: "server already started"}
}
+
addr := srv.Addr
if addr == "" {
addr = ":domain"
}
- if srv.UDPSize == 0 {
- srv.UDPSize = MinMsgSize
- }
+
+ srv.init()
+
switch srv.Net {
case "tcp", "tcp4", "tcp6":
- a, err := net.ResolveTCPAddr(srv.Net, addr)
- if err != nil {
- return err
- }
- l, err := net.ListenTCP(srv.Net, a)
+ l, err := listenTCP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
srv.Listener = l
srv.started = true
- srv.lock.Unlock()
- err = srv.serveTCP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return err
+ unlock()
+ return srv.serveTCP(l)
case "tcp-tls", "tcp4-tls", "tcp6-tls":
- network := "tcp"
- if srv.Net == "tcp4-tls" {
- network = "tcp4"
- } else if srv.Net == "tcp6-tls" {
- network = "tcp6"
+ if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
+ return errors.New("dns: neither Certificates nor GetCertificate set in Config")
}
-
- l, err := tls.Listen(network, addr, srv.TLSConfig)
+ network := strings.TrimSuffix(srv.Net, "-tls")
+ l, err := listenTCP(network, addr, srv.ReusePort)
if err != nil {
return err
}
+ l = tls.NewListener(l, srv.TLSConfig)
srv.Listener = l
srv.started = true
- srv.lock.Unlock()
- err = srv.serveTCP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return err
+ unlock()
+ return srv.serveTCP(l)
case "udp", "udp4", "udp6":
- a, err := net.ResolveUDPAddr(srv.Net, addr)
+ l, err := listenUDP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
- l, err := net.ListenUDP(srv.Net, a)
- if err != nil {
- return err
- }
- if e := setUDPSocketOptions(l); e != nil {
+ u := l.(*net.UDPConn)
+ if e := setUDPSocketOptions(u); e != nil {
return e
}
srv.PacketConn = l
srv.started = true
- srv.lock.Unlock()
- err = srv.serveUDP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return err
+ unlock()
+ return srv.serveUDP(u)
}
return &Error{err: "bad network"}
}
@@ -375,36 +334,32 @@ func (srv *Server) ListenAndServe() error {
// ActivateAndServe starts a nameserver with the PacketConn or Listener
// configured in *Server. Its main use is to start a server from systemd.
func (srv *Server) ActivateAndServe() error {
+ unlock := unlockOnce(&srv.lock)
srv.lock.Lock()
- defer srv.lock.Unlock()
+ defer unlock()
+
if srv.started {
return &Error{err: "server already started"}
}
- pConn := srv.PacketConn
- l := srv.Listener
- if pConn != nil {
- if srv.UDPSize == 0 {
- srv.UDPSize = MinMsgSize
- }
+
+ srv.init()
+
+ if srv.PacketConn != nil {
// Check PacketConn interface's type is valid and value
// is not nil
- if t, ok := pConn.(*net.UDPConn); ok && t != nil {
+ if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
if e := setUDPSocketOptions(t); e != nil {
return e
}
- srv.started = true
- srv.lock.Unlock()
- e := srv.serveUDP(t)
- srv.lock.Lock() // to satisfy the defer at the top
- return e
}
+ srv.started = true
+ unlock()
+ return srv.serveUDP(srv.PacketConn)
}
- if l != nil {
+ if srv.Listener != nil {
srv.started = true
- srv.lock.Unlock()
- e := srv.serveTCP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return e
+ unlock()
+ return srv.serveTCP(srv.Listener)
}
return &Error{err: "bad listeners"}
}
@@ -412,34 +367,66 @@ func (srv *Server) ActivateAndServe() error {
// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
// ActivateAndServe will return.
func (srv *Server) Shutdown() error {
+ return srv.ShutdownContext(context.Background())
+}
+
+// ShutdownContext shuts down a server. After a call to ShutdownContext,
+// ListenAndServe and ActivateAndServe will return.
+//
+// A context.Context may be passed to limit how long to wait for connections
+// to terminate.
+func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.lock.Lock()
if !srv.started {
srv.lock.Unlock()
return &Error{err: "server not started"}
}
+
srv.started = false
- srv.lock.Unlock()
if srv.PacketConn != nil {
- srv.PacketConn.Close()
+ srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
+
if srv.Listener != nil {
srv.Listener.Close()
}
- return nil
+
+ for rw := range srv.conns {
+ rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
+ }
+
+ srv.lock.Unlock()
+
+ if testShutdownNotify != nil {
+ testShutdownNotify.Broadcast()
+ }
+
+ var ctxErr error
+ select {
+ case <-srv.shutdown:
+ case <-ctx.Done():
+ ctxErr = ctx.Err()
+ }
+
+ if srv.PacketConn != nil {
+ srv.PacketConn.Close()
+ }
+
+ return ctxErr
}
+var testShutdownNotify *sync.Cond
+
// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration {
- rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
- rtimeout = srv.ReadTimeout
+ return srv.ReadTimeout
}
- return rtimeout
+ return dnsTimeout
}
// serveTCP starts a TCP listener for the server.
-// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close()
@@ -447,203 +434,288 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc()
}
- reader := Reader(&defaultReader{srv})
- if srv.DecorateReader != nil {
- reader = srv.DecorateReader(reader)
- }
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(srv.shutdown)
+ }()
- handler := srv.Handler
- if handler == nil {
- handler = DefaultServeMux
- }
- rtimeout := srv.getReadTimeout()
- // deadline is not used here
- for {
+ for srv.isStarted() {
rw, err := l.Accept()
- srv.lock.RLock()
- if !srv.started {
- srv.lock.RUnlock()
- return nil
- }
- srv.lock.RUnlock()
if err != nil {
+ if !srv.isStarted() {
+ return nil
+ }
if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
continue
}
return err
}
- go func() {
- m, err := reader.ReadTCP(rw, rtimeout)
- if err != nil {
- rw.Close()
- return
- }
- srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
- }()
+ srv.lock.Lock()
+ // Track the connection to allow unblocking reads on shutdown.
+ srv.conns[rw] = struct{}{}
+ srv.lock.Unlock()
+ wg.Add(1)
+ go srv.serveTCPConn(&wg, rw)
}
+
+ return nil
}
// serveUDP starts a UDP listener for the server.
-// Each request is handled in a separate goroutine.
-func (srv *Server) serveUDP(l *net.UDPConn) error {
+func (srv *Server) serveUDP(l net.PacketConn) error {
defer l.Close()
- if srv.NotifyStartedFunc != nil {
- srv.NotifyStartedFunc()
- }
-
- reader := Reader(&defaultReader{srv})
+ reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
- handler := srv.Handler
- if handler == nil {
- handler = DefaultServeMux
+ lUDP, isUDP := l.(*net.UDPConn)
+ readerPC, canPacketConn := reader.(PacketConnReader)
+ if !isUDP && !canPacketConn {
+ return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
+ }
+
+ if srv.NotifyStartedFunc != nil {
+ srv.NotifyStartedFunc()
}
+
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(srv.shutdown)
+ }()
+
rtimeout := srv.getReadTimeout()
// deadline is not used here
- for {
- m, s, err := reader.ReadUDP(l, rtimeout)
- srv.lock.RLock()
- if !srv.started {
- srv.lock.RUnlock()
- return nil
+ for srv.isStarted() {
+ var (
+ m []byte
+ sPC net.Addr
+ sUDP *SessionUDP
+ err error
+ )
+ if isUDP {
+ m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
+ } else {
+ m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
}
- srv.lock.RUnlock()
if err != nil {
+ if !srv.isStarted() {
+ return nil
+ }
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
continue
}
return err
}
if len(m) < headerSize {
+ if cap(m) == srv.UDPSize {
+ srv.udpPool.Put(m[:srv.UDPSize])
+ }
continue
}
- go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
+ wg.Add(1)
+ go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
}
+
+ return nil
}
-// Serve a new connection.
-func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
- w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
+// Serve a new TCP connection.
+func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
+ w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
- q := 0 // counter for the amount of TCP queries we get
-
- reader := Reader(&defaultReader{srv})
+ reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
-Redo:
- req := new(Msg)
- err := req.Unpack(m)
- if err != nil { // Send a FormatError back
- x := new(Msg)
- x.SetRcodeFormatError(req)
- w.WriteMsg(x)
- goto Exit
+
+ idleTimeout := tcpIdleTimeout
+ if srv.IdleTimeout != nil {
+ idleTimeout = srv.IdleTimeout()
}
- if !srv.Unsafe && req.Response {
- goto Exit
+
+ timeout := srv.getReadTimeout()
+
+ limit := srv.MaxTCPQueries
+ if limit == 0 {
+ limit = maxTCPQueries
}
- w.tsigStatus = nil
- if w.tsigSecret != nil {
- if t := req.IsTsig(); t != nil {
- secret := t.Hdr.Name
- if _, ok := w.tsigSecret[secret]; !ok {
- w.tsigStatus = ErrKeyAlg
- }
- w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
- w.tsigTimersOnly = false
- w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
+ for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
+ m, err := reader.ReadTCP(w.tcp, timeout)
+ if err != nil {
+ // TODO(tmthrgd): handle error
+ break
+ }
+ srv.serveDNS(m, w)
+ if w.closed {
+ break // Close() was called
+ }
+ if w.hijacked {
+ break // client will call Close() themselves
}
+ // The first read uses the read timeout, the rest use the
+ // idle timeout.
+ timeout = idleTimeout
}
- h.ServeDNS(w, req) // Writes back to the client
-Exit:
- if w.tcp == nil {
- return
- }
- // TODO(miek): make this number configurable?
- if q > maxTCPQueries { // close socket after this many queries
+ if !w.hijacked {
w.Close()
- return
}
- if w.hijacked {
- return // client calls Close()
+ srv.lock.Lock()
+ delete(srv.conns, w.tcp)
+ srv.lock.Unlock()
+
+ wg.Done()
+}
+
+// Serve a new UDP request.
+func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
+ w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
+ if srv.DecorateWriter != nil {
+ w.writer = srv.DecorateWriter(w)
+ } else {
+ w.writer = w
}
- if u != nil { // UDP, "close" and return
- w.Close()
+
+ srv.serveDNS(m, w)
+ wg.Done()
+}
+
+func (srv *Server) serveDNS(m []byte, w *response) {
+ dh, off, err := unpackMsgHdr(m, 0)
+ if err != nil {
+ // Let client hang, they are sending crap; any reply can be used to amplify.
return
}
- idleTimeout := tcpIdleTimeout
- if srv.IdleTimeout != nil {
- idleTimeout = srv.IdleTimeout()
+
+ req := new(Msg)
+ req.setHdr(dh)
+
+ switch action := srv.MsgAcceptFunc(dh); action {
+ case MsgAccept:
+ if req.unpack(dh, m, off) == nil {
+ break
+ }
+
+ fallthrough
+ case MsgReject, MsgRejectNotImplemented:
+ opcode := req.Opcode
+ req.SetRcodeFormatError(req)
+ req.Zero = false
+ if action == MsgRejectNotImplemented {
+ req.Opcode = opcode
+ req.Rcode = RcodeNotImplemented
+ }
+
+ // Are we allowed to delete any OPT records here?
+ req.Ns, req.Answer, req.Extra = nil, nil, nil
+
+ w.WriteMsg(req)
+ fallthrough
+ case MsgIgnore:
+ if w.udp != nil && cap(m) == srv.UDPSize {
+ srv.udpPool.Put(m[:srv.UDPSize])
+ }
+
+ return
}
- m, err = reader.ReadTCP(w.tcp, idleTimeout)
- if err == nil {
- q++
- goto Redo
+
+ w.tsigStatus = nil
+ if w.tsigSecret != nil {
+ if t := req.IsTsig(); t != nil {
+ if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
+ w.tsigStatus = TsigVerify(m, secret, "", false)
+ } else {
+ w.tsigStatus = ErrSecret
+ }
+ w.tsigTimersOnly = false
+ w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
+ }
}
- w.Close()
- return
+
+ if w.udp != nil && cap(m) == srv.UDPSize {
+ srv.udpPool.Put(m[:srv.UDPSize])
+ }
+
+ srv.Handler.ServeDNS(w, req) // Writes back to the client
}
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
- conn.SetReadDeadline(time.Now().Add(timeout))
- l := make([]byte, 2)
- n, err := conn.Read(l)
- if err != nil || n != 2 {
- if err != nil {
- return nil, err
- }
- return nil, ErrShortRead
- }
- length := binary.BigEndian.Uint16(l)
- if length == 0 {
- return nil, ErrShortRead
+ // If we race with ShutdownContext, the read deadline may
+ // have been set in the distant past to unblock the read
+ // below. We must not override it, otherwise we may block
+ // ShutdownContext.
+ srv.lock.RLock()
+ if srv.started {
+ conn.SetReadDeadline(time.Now().Add(timeout))
}
- m := make([]byte, int(length))
- n, err = conn.Read(m[:int(length)])
- if err != nil || n == 0 {
- if err != nil {
- return nil, err
- }
- return nil, ErrShortRead
+ srv.lock.RUnlock()
+
+ var length uint16
+ if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
+ return nil, err
}
- i := n
- for i < int(length) {
- j, err := conn.Read(m[i:int(length)])
- if err != nil {
- return nil, err
- }
- i += j
+
+ m := make([]byte, length)
+ if _, err := io.ReadFull(conn, m); err != nil {
+ return nil, err
}
- n = i
- m = m[:n]
+
return m, nil
}
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
- conn.SetReadDeadline(time.Now().Add(timeout))
- m := make([]byte, srv.UDPSize)
+ srv.lock.RLock()
+ if srv.started {
+ // See the comment in readTCP above.
+ conn.SetReadDeadline(time.Now().Add(timeout))
+ }
+ srv.lock.RUnlock()
+
+ m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m)
if err != nil {
+ srv.udpPool.Put(m)
return nil, nil, err
}
m = m[:n]
return m, s, nil
}
+func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
+ srv.lock.RLock()
+ if srv.started {
+ // See the comment in readTCP above.
+ conn.SetReadDeadline(time.Now().Add(timeout))
+ }
+ srv.lock.RUnlock()
+
+ m := srv.udpPool.Get().([]byte)
+ n, addr, err := conn.ReadFrom(m)
+ if err != nil {
+ srv.udpPool.Put(m)
+ return nil, nil, err
+ }
+ m = m[:n]
+ return m, addr, nil
+}
+
// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
+ if w.closed {
+ return &Error{err: "WriteMsg called after Close"}
+ }
+
var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
@@ -665,38 +737,56 @@ func (w *response) WriteMsg(m *Msg) (err error) {
// Write implements the ResponseWriter.Write method.
func (w *response) Write(m []byte) (int, error) {
+ if w.closed {
+ return 0, &Error{err: "Write called after Close"}
+ }
+
switch {
case w.udp != nil:
- n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
- return n, err
- case w.tcp != nil:
- lm := len(m)
- if lm < 2 {
- return 0, io.ErrShortBuffer
+ if u, ok := w.udp.(*net.UDPConn); ok {
+ return WriteToSessionUDP(u, m, w.udpSession)
}
- if lm > MaxMsgSize {
+ return w.udp.WriteTo(m, w.pcSession)
+ case w.tcp != nil:
+ if len(m) > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
- l := make([]byte, 2, 2+lm)
- binary.BigEndian.PutUint16(l, uint16(lm))
- m = append(l, m...)
- n, err := io.Copy(w.tcp, bytes.NewReader(m))
+ l := make([]byte, 2)
+ binary.BigEndian.PutUint16(l, uint16(len(m)))
+
+ n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
return int(n), err
+ default:
+ panic("dns: internal error: udp and tcp both nil")
}
- panic("not reached")
}
// LocalAddr implements the ResponseWriter.LocalAddr method.
func (w *response) LocalAddr() net.Addr {
- if w.tcp != nil {
+ switch {
+ case w.udp != nil:
+ return w.udp.LocalAddr()
+ case w.tcp != nil:
return w.tcp.LocalAddr()
+ default:
+ panic("dns: internal error: udp and tcp both nil")
}
- return w.udp.LocalAddr()
}
// RemoteAddr implements the ResponseWriter.RemoteAddr method.
-func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
+func (w *response) RemoteAddr() net.Addr {
+ switch {
+ case w.udpSession != nil:
+ return w.udpSession.RemoteAddr()
+ case w.pcSession != nil:
+ return w.pcSession
+ case w.tcp != nil:
+ return w.tcp.RemoteAddr()
+ default:
+ panic("dns: internal error: udpSession, pcSession and tcp are all nil")
+ }
+}
// TsigStatus implements the ResponseWriter.TsigStatus method.
func (w *response) TsigStatus() error { return w.tsigStatus }
@@ -709,11 +799,30 @@ func (w *response) Hijack() { w.hijacked = true }
// Close implements the ResponseWriter.Close method
func (w *response) Close() error {
- // Can't close the udp conn, as that is actually the listener.
- if w.tcp != nil {
- e := w.tcp.Close()
- w.tcp = nil
- return e
+ if w.closed {
+ return &Error{err: "connection already closed"}
+ }
+ w.closed = true
+
+ switch {
+ case w.udp != nil:
+ // Can't close the udp conn, as that is actually the listener.
+ return nil
+ case w.tcp != nil:
+ return w.tcp.Close()
+ default:
+ panic("dns: internal error: udp and tcp both nil")
+ }
+}
+
+// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
+func (w *response) ConnectionState() *tls.ConnectionState {
+ type tlsConnectionStater interface {
+ ConnectionState() tls.ConnectionState
+ }
+ if v, ok := w.tcp.(tlsConnectionStater); ok {
+ t := v.ConnectionState()
+ return &t
}
return nil
}