aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/msg.go
diff options
context:
space:
mode:
authorTyler Davis <tydavis@gmail.com>2021-02-15 20:47:30 +0000
committerTyler Davis <tydavis@gmail.com>2021-02-15 20:47:30 +0000
commita687ebabb6589ebb36a9c385f583a19ac462b831 (patch)
tree4112f2272dfe6df7f106819c1381ab59d7ea5d2f /vendor/github.com/miekg/dns/msg.go
parentf22b6da3c7964a23d93269b6c5de9f322c3837a8 (diff)
downloaddnstracker-a687ebabb6589ebb36a9c385f583a19ac462b831.tar.gz
dnstracker-a687ebabb6589ebb36a9c385f583a19ac462b831.zip
Update go modules for 1.15
Diffstat (limited to 'vendor/github.com/miekg/dns/msg.go')
-rw-r--r--vendor/github.com/miekg/dns/msg.go835
1 files changed, 439 insertions, 396 deletions
diff --git a/vendor/github.com/miekg/dns/msg.go b/vendor/github.com/miekg/dns/msg.go
index 975dde7..1728a98 100644
--- a/vendor/github.com/miekg/dns/msg.go
+++ b/vendor/github.com/miekg/dns/msg.go
@@ -9,21 +9,41 @@
package dns
//go:generate go run msg_generate.go
-//go:generate go run compress_generate.go
import (
- crand "crypto/rand"
+ "crypto/rand"
"encoding/binary"
"fmt"
"math/big"
- "math/rand"
"strconv"
- "sync"
+ "strings"
)
const (
maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer
maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4
+
+ // This is the maximum number of compression pointers that should occur in a
+ // semantically valid message. Each label in a domain name must be at least one
+ // octet and is separated by a period. The root label won't be represented by a
+ // compression pointer to a compression pointer, hence the -2 to exclude the
+ // smallest valid root label.
+ //
+ // It is possible to construct a valid message that has more compression pointers
+ // than this, and still doesn't loop, by pointing to a previous pointer. This is
+ // not something a well written implementation should ever do, so we leave them
+ // to trip the maximum compression pointer check.
+ maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
+
+ // This is the maximum length of a domain name in presentation format. The
+ // maximum wire length of a domain name is 255 octets (see above), with the
+ // maximum label length being 63. The wire format requires one extra byte over
+ // the presentation format, reducing the number of octets by 1. Each label in
+ // the name will be separated by a single period, with each octet in the label
+ // expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
+ // length, then the final label can only be 61 octets long to not exceed the
+ // maximum allowed wire length.
+ maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
)
// Errors defined in this package.
@@ -46,59 +66,28 @@ var (
ErrRRset error = &Error{err: "bad rrset"}
ErrSecret error = &Error{err: "no secrets defined"}
ErrShortRead error = &Error{err: "short read"}
- ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
- ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
- ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication.
- ErrTruncated error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired.
+ ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
+ ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
+ ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication.
)
-// Id by default, returns a 16 bits random number to be used as a
-// message id. The random provided should be good enough. This being a
-// variable the function can be reassigned to a custom function.
-// For instance, to make it return a static value:
+// Id by default returns a 16-bit random number to be used as a message id. The
+// number is drawn from a cryptographically secure random number generator.
+// This being a variable the function can be reassigned to a custom function.
+// For instance, to make it return a static value for testing:
//
// dns.Id = func() uint16 { return 3 }
var Id = id
-var (
- idLock sync.Mutex
- idRand *rand.Rand
-)
-
// id returns a 16 bits random number to be used as a
// message id. The random provided should be good enough.
func id() uint16 {
- idLock.Lock()
-
- if idRand == nil {
- // This (partially) works around
- // https://github.com/golang/go/issues/11833 by only
- // seeding idRand upon the first call to id.
-
- var seed int64
- var buf [8]byte
-
- if _, err := crand.Read(buf[:]); err == nil {
- seed = int64(binary.LittleEndian.Uint64(buf[:]))
- } else {
- seed = rand.Int63()
- }
-
- idRand = rand.New(rand.NewSource(seed))
+ var output uint16
+ err := binary.Read(rand.Reader, binary.BigEndian, &output)
+ if err != nil {
+ panic("dns: reading random id failed: " + err.Error())
}
-
- // The call to idRand.Uint32 must be within the
- // mutex lock because *rand.Rand is not safe for
- // concurrent use.
- //
- // There is no added performance overhead to calling
- // idRand.Uint32 inside a mutex lock over just
- // calling rand.Uint32 as the global math/rand rng
- // is internally protected by a sync.Mutex.
- id := uint16(idRand.Uint32())
-
- idLock.Unlock()
- return id
+ return output
}
// MsgHdr is a a manually-unpacked version of (id, bits).
@@ -151,7 +140,7 @@ var RcodeToString = map[int]string{
RcodeFormatError: "FORMERR",
RcodeServerFailure: "SERVFAIL",
RcodeNameError: "NXDOMAIN",
- RcodeNotImplemented: "NOTIMPL",
+ RcodeNotImplemented: "NOTIMP",
RcodeRefused: "REFUSED",
RcodeYXDomain: "YXDOMAIN", // See RFC 2136
RcodeYXRrset: "YXRRSET",
@@ -169,6 +158,39 @@ var RcodeToString = map[int]string{
RcodeBadCookie: "BADCOOKIE",
}
+// compressionMap is used to allow a more efficient compression map
+// to be used for internal packDomainName calls without changing the
+// signature or functionality of public API.
+//
+// In particular, map[string]uint16 uses 25% less per-entry memory
+// than does map[string]int.
+type compressionMap struct {
+ ext map[string]int // external callers
+ int map[string]uint16 // internal callers
+}
+
+func (m compressionMap) valid() bool {
+ return m.int != nil || m.ext != nil
+}
+
+func (m compressionMap) insert(s string, pos int) {
+ if m.ext != nil {
+ m.ext[s] = pos
+ } else {
+ m.int[s] = uint16(pos)
+ }
+}
+
+func (m compressionMap) find(s string) (int, bool) {
+ if m.ext != nil {
+ pos, ok := m.ext[s]
+ return pos, ok
+ }
+
+ pos, ok := m.int[s]
+ return int(pos), ok
+}
+
// Domain names are a sequence of counted strings
// split at the dots. They end with a zero-length string.
@@ -177,143 +199,156 @@ var RcodeToString = map[int]string{
// map needs to hold a mapping between domain names and offsets
// pointing into msg.
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
- off1, _, err = packDomainName(s, msg, off, compression, compress)
- return
+ return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
}
-func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) {
- // special case if msg == nil
- lenmsg := 256
- if msg != nil {
- lenmsg = len(msg)
- }
+func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
+ // XXX: A logical copy of this function exists in IsDomainName and
+ // should be kept in sync with this function.
+
ls := len(s)
if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
- return off, 0, nil
- }
- // If not fully qualified, error out, but only if msg == nil #ugly
- switch {
- case msg == nil:
- if s[ls-1] != '.' {
- s += "."
- ls++
- }
- case msg != nil:
- if s[ls-1] != '.' {
- return lenmsg, 0, ErrFqdn
- }
+ return off, nil
}
+
+ // If not fully qualified, error out.
+ if !IsFqdn(s) {
+ return len(msg), ErrFqdn
+ }
+
// Each dot ends a segment of the name.
// We trade each dot byte for a length byte.
// Except for escaped dots (\.), which are normal dots.
// There is also a trailing zero.
// Compression
- nameoffset := -1
pointer := -1
+
// Emit sequence of counted strings, chopping at dots.
- begin := 0
- bs := []byte(s)
- roBs, bsFresh, escapedDot := s, true, false
+ var (
+ begin int
+ compBegin int
+ compOff int
+ bs []byte
+ wasDot bool
+ )
+loop:
for i := 0; i < ls; i++ {
- if bs[i] == '\\' {
- for j := i; j < ls-1; j++ {
- bs[j] = bs[j+1]
+ var c byte
+ if bs == nil {
+ c = s[i]
+ } else {
+ c = bs[i]
+ }
+
+ switch c {
+ case '\\':
+ if off+1 > len(msg) {
+ return len(msg), ErrBuf
}
- ls--
- if off+1 > lenmsg {
- return lenmsg, labels, ErrBuf
+
+ if bs == nil {
+ bs = []byte(s)
}
+
// check for \DDD
- if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
- bs[i] = dddToByte(bs[i:])
- for j := i + 1; j < ls-2; j++ {
- bs[j] = bs[j+2]
- }
- ls -= 2
+ if i+3 < ls && isDigit(bs[i+1]) && isDigit(bs[i+2]) && isDigit(bs[i+3]) {
+ bs[i] = dddToByte(bs[i+1:])
+ copy(bs[i+1:ls-3], bs[i+4:])
+ ls -= 3
+ compOff += 3
+ } else {
+ copy(bs[i:ls-1], bs[i+1:])
+ ls--
+ compOff++
}
- escapedDot = bs[i] == '.'
- bsFresh = false
- continue
- }
- if bs[i] == '.' {
- if i > 0 && bs[i-1] == '.' && !escapedDot {
+ wasDot = false
+ case '.':
+ if wasDot {
// two dots back to back is not legal
- return lenmsg, labels, ErrRdata
+ return len(msg), ErrRdata
}
- if i-begin >= 1<<6 { // top two bits of length must be clear
- return lenmsg, labels, ErrRdata
+ wasDot = true
+
+ labelLen := i - begin
+ if labelLen >= 1<<6 { // top two bits of length must be clear
+ return len(msg), ErrRdata
}
+
// off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified
- if off+1 > lenmsg {
- return lenmsg, labels, ErrBuf
- }
- if msg != nil {
- msg[off] = byte(i - begin)
- }
- offset := off
- off++
- for j := begin; j < i; j++ {
- if off+1 > lenmsg {
- return lenmsg, labels, ErrBuf
- }
- if msg != nil {
- msg[off] = bs[j]
- }
- off++
- }
- if compress && !bsFresh {
- roBs = string(bs)
- bsFresh = true
+ if off+1+labelLen > len(msg) {
+ return len(msg), ErrBuf
}
+
// Don't try to compress '.'
- // We should only compress when compress it true, but we should also still pick
+ // We should only compress when compress is true, but we should also still pick
// up names that can be used for *future* compression(s).
- if compression != nil && roBs[begin:] != "." {
- if p, ok := compression[roBs[begin:]]; !ok {
- // Only offsets smaller than this can be used.
- if offset < maxCompressionOffset {
- compression[roBs[begin:]] = offset
- }
- } else {
+ if compression.valid() && !isRootLabel(s, bs, begin, ls) {
+ if p, ok := compression.find(s[compBegin:]); ok {
// The first hit is the longest matching dname
// keep the pointer offset we get back and store
// the offset of the current name, because that's
// where we need to insert the pointer later
// If compress is true, we're allowed to compress this dname
- if pointer == -1 && compress {
- pointer = p // Where to point to
- nameoffset = offset // Where to point from
- break
+ if compress {
+ pointer = p // Where to point to
+ break loop
}
+ } else if off < maxCompressionOffset {
+ // Only offsets smaller than maxCompressionOffset can be used.
+ compression.insert(s[compBegin:], off)
}
}
- labels++
+
+ // The following is covered by the length check above.
+ msg[off] = byte(labelLen)
+
+ if bs == nil {
+ copy(msg[off+1:], s[begin:i])
+ } else {
+ copy(msg[off+1:], bs[begin:i])
+ }
+ off += 1 + labelLen
+
begin = i + 1
+ compBegin = begin + compOff
+ default:
+ wasDot = false
}
- escapedDot = false
}
+
// Root label is special
- if len(bs) == 1 && bs[0] == '.' {
- return off, labels, nil
+ if isRootLabel(s, bs, 0, ls) {
+ return off, nil
}
+
// If we did compression and we find something add the pointer here
if pointer != -1 {
// We have two bytes (14 bits) to put the pointer in
- // if msg == nil, we will never do compression
- binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000))
- off = nameoffset + 1
- goto End
+ binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
+ return off + 2, nil
}
- if msg != nil && off < len(msg) {
+
+ if off < len(msg) {
msg[off] = 0
}
-End:
- off++
- return off, labels, nil
+
+ return off + 1, nil
+}
+
+// isRootLabel returns whether s or bs, from off to end, is the root
+// label ".".
+//
+// If bs is nil, s will be checked, otherwise bs will be checked.
+func isRootLabel(s string, bs []byte, off, end int) bool {
+ if bs == nil {
+ return s[off:end] == "."
+ }
+
+ return end-off == 1 && bs[off] == '.'
}
// Unpack a domain name.
@@ -330,12 +365,16 @@ End:
// In theory, the pointers are only allowed to jump backward.
// We let them jump anywhere and stop jumping after a while.
-// UnpackDomainName unpacks a domain name into a string.
+// UnpackDomainName unpacks a domain name into a string. It returns
+// the name, the new offset into msg and any error that occurred.
+//
+// When an error is encountered, the unpacked name will be discarded
+// and len(msg) will be returned as the offset.
func UnpackDomainName(msg []byte, off int) (string, int, error) {
- s := make([]byte, 0, 64)
+ s := make([]byte, 0, maxDomainNamePresentationLength)
off1 := 0
lenmsg := len(msg)
- maxLen := maxDomainNameWireOctets
+ budget := maxDomainNameWireOctets
ptr := 0 // number of pointers followed
Loop:
for {
@@ -354,30 +393,17 @@ Loop:
if off+c > lenmsg {
return "", lenmsg, ErrBuf
}
- for j := off; j < off+c; j++ {
- switch b := msg[j]; b {
- case '.', '(', ')', ';', ' ', '@':
- fallthrough
- case '"', '\\':
+ budget -= c + 1 // +1 for the label separator
+ if budget <= 0 {
+ return "", lenmsg, ErrLongDomain
+ }
+ for _, b := range msg[off : off+c] {
+ if isDomainNameLabelSpecial(b) {
s = append(s, '\\', b)
- // presentation-format \X escapes add an extra byte
- maxLen++
- default:
- if b < 32 || b >= 127 { // unprintable, use \DDD
- var buf [3]byte
- bufs := strconv.AppendInt(buf[:0], int64(b), 10)
- s = append(s, '\\')
- for i := 0; i < 3-len(bufs); i++ {
- s = append(s, '0')
- }
- for _, r := range bufs {
- s = append(s, r)
- }
- // presentation-format \DDD escapes add 3 extra bytes
- maxLen += 3
- } else {
- s = append(s, b)
- }
+ } else if b < ' ' || b > '~' {
+ s = append(s, escapeByte(b)...)
+ } else {
+ s = append(s, b)
}
}
s = append(s, '.')
@@ -396,7 +422,7 @@ Loop:
if ptr == 0 {
off1 = off
}
- if ptr++; ptr > 10 {
+ if ptr++; ptr > maxCompressionPointers {
return "", lenmsg, &Error{err: "too many compression pointers"}
}
// pointer should guarantee that it advances and points forwards at least
@@ -412,10 +438,7 @@ Loop:
off1 = off
}
if len(s) == 0 {
- s = []byte(".")
- } else if len(s) >= maxLen {
- // error if the name is too long, but don't throw it away
- return string(s), lenmsg, ErrLongDomain
+ return ".", off1, nil
}
return string(s), off1, nil
}
@@ -429,11 +452,11 @@ func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
return offset, nil
}
var err error
- for i := range txt {
- if len(txt[i]) > len(tmp) {
+ for _, s := range txt {
+ if len(s) > len(tmp) {
return offset, ErrBuf
}
- offset, err = packTxtString(txt[i], msg, offset, tmp)
+ offset, err = packTxtString(s, msg, offset, tmp)
if err != nil {
return offset, err
}
@@ -512,7 +535,7 @@ func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
off = off0
var s string
for off < len(msg) && err == nil {
- s, off, err = unpackTxtString(msg, off)
+ s, off, err = unpackString(msg, off)
if err == nil {
ss = append(ss, s)
}
@@ -520,43 +543,16 @@ func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
return
}
-func unpackTxtString(msg []byte, offset int) (string, int, error) {
- if offset+1 > len(msg) {
- return "", offset, &Error{err: "overflow unpacking txt"}
- }
- l := int(msg[offset])
- if offset+l+1 > len(msg) {
- return "", offset, &Error{err: "overflow unpacking txt"}
- }
- s := make([]byte, 0, l)
- for _, b := range msg[offset+1 : offset+1+l] {
- switch b {
- case '"', '\\':
- s = append(s, '\\', b)
- default:
- if b < 32 || b > 127 { // unprintable
- var buf [3]byte
- bufs := strconv.AppendInt(buf[:0], int64(b), 10)
- s = append(s, '\\')
- for i := 0; i < 3-len(bufs); i++ {
- s = append(s, '0')
- }
- for _, r := range bufs {
- s = append(s, r)
- }
- } else {
- s = append(s, b)
- }
- }
- }
- offset += 1 + l
- return string(s), offset, nil
-}
-
// Helpers for dealing with escaped bytes
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
func dddToByte(s []byte) byte {
+ _ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
+ return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
+}
+
+func dddStringToByte(s string) byte {
+ _ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
}
@@ -574,19 +570,38 @@ func intToBytes(i *big.Int, length int) []byte {
// PackRR packs a resource record rr into msg[off:].
// See PackDomainName for documentation about the compression.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
+ headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
+ if err == nil {
+ // packRR no longer sets the Rdlength field on the rr, but
+ // callers might be expecting it so we set it here.
+ rr.Header().Rdlength = uint16(off1 - headerEnd)
+ }
+ return off1, err
+}
+
+func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
if rr == nil {
- return len(msg), &Error{err: "nil rr"}
+ return len(msg), len(msg), &Error{err: "nil rr"}
+ }
+
+ headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
+ if err != nil {
+ return headerEnd, len(msg), err
}
- off1, err = rr.pack(msg, off, compression, compress)
+ off1, err = rr.pack(msg, headerEnd, compression, compress)
if err != nil {
- return len(msg), err
+ return headerEnd, len(msg), err
}
- // TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
- if rawSetRdlength(msg, off, off1) {
- return off1, nil
+
+ rdlength := off1 - headerEnd
+ if int(uint16(rdlength)) != rdlength { // overflow
+ return headerEnd, len(msg), ErrRdata
}
- return off, ErrRdata
+
+ // The RDLENGTH field is the last field in the header and we set it here.
+ binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
+ return headerEnd, off1, nil
}
// UnpackRR unpacks msg[off:] into an RR.
@@ -595,17 +610,42 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
if err != nil {
return nil, len(msg), err
}
- end := off + int(h.Rdlength)
- if fn, known := typeToUnpack[h.Rrtype]; !known {
- rr, off, err = unpackRFC3597(h, msg, off)
+ return UnpackRRWithHeader(h, msg, off)
+}
+
+// UnpackRRWithHeader unpacks the record type specific payload given an existing
+// RR_Header.
+func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
+ if newFn, ok := TypeToRR[h.Rrtype]; ok {
+ rr = newFn()
+ *rr.Header() = h
} else {
- rr, off, err = fn(h, msg, off)
+ rr = &RFC3597{Hdr: h}
+ }
+
+ if off < 0 || off > len(msg) {
+ return &h, off, &Error{err: "bad off"}
+ }
+
+ end := off + int(h.Rdlength)
+ if end < off || end > len(msg) {
+ return &h, end, &Error{err: "bad rdlength"}
+ }
+
+ if noRdata(h) {
+ return rr, off, nil
+ }
+
+ off, err = rr.unpack(msg, off)
+ if err != nil {
+ return nil, end, err
}
if off != end {
return &h, end, &Error{err: "bad rdlength"}
}
- return rr, off, err
+
+ return rr, off, nil
}
// unpackRRslice unpacks msg[off:] into an []RR.
@@ -623,7 +663,6 @@ func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error)
}
// If offset does not increase anymore, l is a lie
if off1 == off {
- l = i
break
}
dst = append(dst, r)
@@ -684,35 +723,37 @@ func (dns *Msg) Pack() (msg []byte, err error) {
return dns.PackBuffer(nil)
}
-// PackBuffer packs a Msg, using the given buffer buf. If buf is too small
-// a new buffer is allocated.
+// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
- // We use a similar function in tsig.go's stripTsig.
- var (
- dh Header
- compression map[string]int
- )
-
- if dns.Compress {
- compression = make(map[string]int) // Compression pointer mappings
+ // If this message can't be compressed, avoid filling the
+ // compression map and creating garbage.
+ if dns.Compress && dns.isCompressible() {
+ compression := make(map[string]uint16) // Compression pointer mappings.
+ return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
}
+ return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
+}
+
+// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
+func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
if dns.Rcode < 0 || dns.Rcode > 0xFFF {
return nil, ErrRcode
}
- if dns.Rcode > 0xF {
- // Regular RCODE field is 4 bits
- opt := dns.IsEdns0()
- if opt == nil {
- return nil, ErrExtendedRcode
- }
- opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
- dns.Rcode &= 0xF
+
+ // Set extended rcode unconditionally if we have an opt, this will allow
+ // reseting the extended rcode bits if they need to.
+ if opt := dns.IsEdns0(); opt != nil {
+ opt.SetExtendedRcode(uint16(dns.Rcode))
+ } else if dns.Rcode > 0xF {
+ // If Rcode is an extended one and opt is nil, error out.
+ return nil, ErrExtendedRcode
}
// Convert convenient Msg into wire-like Header.
+ var dh Header
dh.Id = dns.Id
- dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
+ dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
if dns.Response {
dh.Bits |= _QR
}
@@ -738,50 +779,44 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
dh.Bits |= _CD
}
- // Prepare variable sized arrays.
- question := dns.Question
- answer := dns.Answer
- ns := dns.Ns
- extra := dns.Extra
-
- dh.Qdcount = uint16(len(question))
- dh.Ancount = uint16(len(answer))
- dh.Nscount = uint16(len(ns))
- dh.Arcount = uint16(len(extra))
+ dh.Qdcount = uint16(len(dns.Question))
+ dh.Ancount = uint16(len(dns.Answer))
+ dh.Nscount = uint16(len(dns.Ns))
+ dh.Arcount = uint16(len(dns.Extra))
// We need the uncompressed length here, because we first pack it and then compress it.
msg = buf
- uncompressedLen := compressedLen(dns, false)
+ uncompressedLen := msgLenWithCompressionMap(dns, nil)
if packLen := uncompressedLen + 1; len(msg) < packLen {
msg = make([]byte, packLen)
}
// Pack it in: header and then the pieces.
off := 0
- off, err = dh.pack(msg, off, compression, dns.Compress)
+ off, err = dh.pack(msg, off, compression, compress)
if err != nil {
return nil, err
}
- for i := 0; i < len(question); i++ {
- off, err = question[i].pack(msg, off, compression, dns.Compress)
+ for _, r := range dns.Question {
+ off, err = r.pack(msg, off, compression, compress)
if err != nil {
return nil, err
}
}
- for i := 0; i < len(answer); i++ {
- off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
+ for _, r := range dns.Answer {
+ _, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
}
- for i := 0; i < len(ns); i++ {
- off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
+ for _, r := range dns.Ns {
+ _, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
}
- for i := 0; i < len(extra); i++ {
- off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
+ for _, r := range dns.Extra {
+ _, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
@@ -789,28 +824,7 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
return msg[:off], nil
}
-// Unpack unpacks a binary message to a Msg structure.
-func (dns *Msg) Unpack(msg []byte) (err error) {
- var (
- dh Header
- off int
- )
- if dh, off, err = unpackMsgHdr(msg, off); err != nil {
- return err
- }
-
- dns.Id = dh.Id
- dns.Response = (dh.Bits & _QR) != 0
- dns.Opcode = int(dh.Bits>>11) & 0xF
- dns.Authoritative = (dh.Bits & _AA) != 0
- dns.Truncated = (dh.Bits & _TC) != 0
- dns.RecursionDesired = (dh.Bits & _RD) != 0
- dns.RecursionAvailable = (dh.Bits & _RA) != 0
- dns.Zero = (dh.Bits & _Z) != 0
- dns.AuthenticatedData = (dh.Bits & _AD) != 0
- dns.CheckingDisabled = (dh.Bits & _CD) != 0
- dns.Rcode = int(dh.Bits & 0xF)
-
+func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
// If we are at the end of the message we should return *just* the
// header. This can still be useful to the caller. 9.9.9.9 sends these
// when responding with REFUSED for instance.
@@ -829,8 +843,6 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
var q Question
q, off, err = unpackQuestion(msg, off)
if err != nil {
- // Even if Truncated is set, we only will set ErrTruncated if we
- // actually got the questions
return err
}
if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
@@ -854,16 +866,29 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
// The header counts might have been wrong so we need to update it
dh.Arcount = uint16(len(dns.Extra))
+ // Set extended Rcode
+ if opt := dns.IsEdns0(); opt != nil {
+ dns.Rcode |= opt.ExtendedRcode()
+ }
+
if off != len(msg) {
// TODO(miek) make this an error?
// use PackOpt to let people tell how detailed the error reporting should be?
// println("dns: extra bytes in dns packet", off, "<", len(msg))
- } else if dns.Truncated {
- // Whether we ran into a an error or not, we want to return that it
- // was truncated
- err = ErrTruncated
}
return err
+
+}
+
+// Unpack unpacks a binary message to a Msg structure.
+func (dns *Msg) Unpack(msg []byte) (err error) {
+ dh, off, err := unpackMsgHdr(msg, 0)
+ if err != nil {
+ return err
+ }
+
+ dns.setHdr(dh)
+ return dns.unpack(dh, msg, off)
}
// Convert a complete message to a string with dig-like output.
@@ -878,138 +903,148 @@ func (dns *Msg) String() string {
s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
if len(dns.Question) > 0 {
s += "\n;; QUESTION SECTION:\n"
- for i := 0; i < len(dns.Question); i++ {
- s += dns.Question[i].String() + "\n"
+ for _, r := range dns.Question {
+ s += r.String() + "\n"
}
}
if len(dns.Answer) > 0 {
s += "\n;; ANSWER SECTION:\n"
- for i := 0; i < len(dns.Answer); i++ {
- if dns.Answer[i] != nil {
- s += dns.Answer[i].String() + "\n"
+ for _, r := range dns.Answer {
+ if r != nil {
+ s += r.String() + "\n"
}
}
}
if len(dns.Ns) > 0 {
s += "\n;; AUTHORITY SECTION:\n"
- for i := 0; i < len(dns.Ns); i++ {
- if dns.Ns[i] != nil {
- s += dns.Ns[i].String() + "\n"
+ for _, r := range dns.Ns {
+ if r != nil {
+ s += r.String() + "\n"
}
}
}
if len(dns.Extra) > 0 {
s += "\n;; ADDITIONAL SECTION:\n"
- for i := 0; i < len(dns.Extra); i++ {
- if dns.Extra[i] != nil {
- s += dns.Extra[i].String() + "\n"
+ for _, r := range dns.Extra {
+ if r != nil {
+ s += r.String() + "\n"
}
}
}
return s
}
+// isCompressible returns whether the msg may be compressible.
+func (dns *Msg) isCompressible() bool {
+ // If we only have one question, there is nothing we can ever compress.
+ return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
+ len(dns.Ns) > 0 || len(dns.Extra) > 0
+}
+
// Len returns the message length when in (un)compressed wire format.
// If dns.Compress is true compression it is taken into account. Len()
// is provided to be a faster way to get the size of the resulting packet,
// than packing it, measuring the size and discarding the buffer.
-func (dns *Msg) Len() int { return compressedLen(dns, dns.Compress) }
-
-// compressedLen returns the message length when in compressed wire format
-// when compress is true, otherwise the uncompressed length is returned.
-func compressedLen(dns *Msg, compress bool) int {
- // We always return one more than needed.
- l := 12 // Message header is always 12 bytes
- if compress {
- compression := map[string]int{}
- for _, r := range dns.Question {
- l += r.len()
- compressionLenHelper(compression, r.Name)
- }
- l += compressionLenSlice(compression, dns.Answer)
- l += compressionLenSlice(compression, dns.Ns)
- l += compressionLenSlice(compression, dns.Extra)
- } else {
- for _, r := range dns.Question {
- l += r.len()
- }
- for _, r := range dns.Answer {
- if r != nil {
- l += r.len()
- }
- }
- for _, r := range dns.Ns {
- if r != nil {
- l += r.len()
- }
- }
- for _, r := range dns.Extra {
- if r != nil {
- l += r.len()
- }
- }
+func (dns *Msg) Len() int {
+ // If this message can't be compressed, avoid filling the
+ // compression map and creating garbage.
+ if dns.Compress && dns.isCompressible() {
+ compression := make(map[string]struct{})
+ return msgLenWithCompressionMap(dns, compression)
}
- return l
+
+ return msgLenWithCompressionMap(dns, nil)
}
-func compressionLenSlice(c map[string]int, rs []RR) int {
- var l int
- for _, r := range rs {
- if r == nil {
- continue
+func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
+ l := headerSize
+
+ for _, r := range dns.Question {
+ l += r.len(l, compression)
+ }
+ for _, r := range dns.Answer {
+ if r != nil {
+ l += r.len(l, compression)
}
- l += r.len()
- k, ok := compressionLenSearch(c, r.Header().Name)
- if ok {
- l += 1 - k
+ }
+ for _, r := range dns.Ns {
+ if r != nil {
+ l += r.len(l, compression)
}
- compressionLenHelper(c, r.Header().Name)
- k, ok = compressionLenSearchType(c, r)
- if ok {
- l += 1 - k
+ }
+ for _, r := range dns.Extra {
+ if r != nil {
+ l += r.len(l, compression)
}
- compressionLenHelperType(c, r)
}
+
return l
}
-// Put the parts of the name in the compression map.
-func compressionLenHelper(c map[string]int, s string) {
- pref := ""
- lbs := Split(s)
- for j := len(lbs) - 1; j >= 0; j-- {
- pref = s[lbs[j]:]
- if _, ok := c[pref]; !ok {
- c[pref] = len(pref)
+func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
+ if s == "" || s == "." {
+ return 1
+ }
+
+ escaped := strings.Contains(s, "\\")
+
+ if compression != nil && (compress || off < maxCompressionOffset) {
+ // compressionLenSearch will insert the entry into the compression
+ // map if it doesn't contain it.
+ if l, ok := compressionLenSearch(compression, s, off); ok && compress {
+ if escaped {
+ return escapedNameLen(s[:l]) + 2
+ }
+
+ return l + 2
}
}
+
+ if escaped {
+ return escapedNameLen(s) + 1
+ }
+
+ return len(s) + 1
}
-// Look for each part in the compression map and returns its length,
-// keep on searching so we get the longest match.
-func compressionLenSearch(c map[string]int, s string) (int, bool) {
- off := 0
- end := false
- if s == "" { // don't bork on bogus data
- return 0, false
+func escapedNameLen(s string) int {
+ nameLen := len(s)
+ for i := 0; i < len(s); i++ {
+ if s[i] != '\\' {
+ continue
+ }
+
+ if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
+ nameLen -= 3
+ i += 3
+ } else {
+ nameLen--
+ i++
+ }
}
- for {
+
+ return nameLen
+}
+
+func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
+ for off, end := 0, false; !end; off, end = NextLabel(s, off) {
if _, ok := c[s[off:]]; ok {
- return len(s[off:]), true
+ return off, true
}
- if end {
- break
+
+ if msgOff+off < maxCompressionOffset {
+ c[s[off:]] = struct{}{}
}
- off, end = NextLabel(s, off)
}
+
return 0, false
}
// Copy returns a new RR which is a deep-copy of r.
-func Copy(r RR) RR { r1 := r.copy(); return r1 }
+func Copy(r RR) RR { return r.copy() }
// Len returns the length (in octets) of the uncompressed RR in wire format.
-func Len(r RR) int { return r.len() }
+func Len(r RR) int { return r.len(0, nil) }
// Copy returns a new *Msg which is a deep-copy of dns.
func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
@@ -1025,40 +1060,27 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg {
}
rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
- var rri int
+ r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
+ r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
+ r1.Extra = rrArr[:0:len(dns.Extra)]
- if len(dns.Answer) > 0 {
- rrbegin := rri
- for i := 0; i < len(dns.Answer); i++ {
- rrArr[rri] = dns.Answer[i].copy()
- rri++
- }
- r1.Answer = rrArr[rrbegin:rri:rri]
+ for _, r := range dns.Answer {
+ r1.Answer = append(r1.Answer, r.copy())
}
- if len(dns.Ns) > 0 {
- rrbegin := rri
- for i := 0; i < len(dns.Ns); i++ {
- rrArr[rri] = dns.Ns[i].copy()
- rri++
- }
- r1.Ns = rrArr[rrbegin:rri:rri]
+ for _, r := range dns.Ns {
+ r1.Ns = append(r1.Ns, r.copy())
}
- if len(dns.Extra) > 0 {
- rrbegin := rri
- for i := 0; i < len(dns.Extra); i++ {
- rrArr[rri] = dns.Extra[i].copy()
- rri++
- }
- r1.Extra = rrArr[rrbegin:rri:rri]
+ for _, r := range dns.Extra {
+ r1.Extra = append(r1.Extra, r.copy())
}
return r1
}
-func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
- off, err := PackDomainName(q.Name, msg, off, compression, compress)
+func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
+ off, err := packDomainName(q.Name, msg, off, compression, compress)
if err != nil {
return off, err
}
@@ -1099,7 +1121,7 @@ func unpackQuestion(msg []byte, off int) (Question, int, error) {
return q, off, err
}
-func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
+func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, err := packUint16(dh.Id, msg, off)
if err != nil {
return off, err
@@ -1121,7 +1143,10 @@ func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress
return off, err
}
off, err = packUint16(dh.Arcount, msg, off)
- return off, err
+ if err != nil {
+ return off, err
+ }
+ return off, nil
}
func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
@@ -1150,5 +1175,23 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
return dh, off, err
}
dh.Arcount, off, err = unpackUint16(msg, off)
- return dh, off, err
+ if err != nil {
+ return dh, off, err
+ }
+ return dh, off, nil
+}
+
+// setHdr set the header in the dns using the binary data in dh.
+func (dns *Msg) setHdr(dh Header) {
+ dns.Id = dh.Id
+ dns.Response = dh.Bits&_QR != 0
+ dns.Opcode = int(dh.Bits>>11) & 0xF
+ dns.Authoritative = dh.Bits&_AA != 0
+ dns.Truncated = dh.Bits&_TC != 0
+ dns.RecursionDesired = dh.Bits&_RD != 0
+ dns.RecursionAvailable = dh.Bits&_RA != 0
+ dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
+ dns.AuthenticatedData = dh.Bits&_AD != 0
+ dns.CheckingDisabled = dh.Bits&_CD != 0
+ dns.Rcode = int(dh.Bits & 0xF)
}