aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/xfr.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/xfr.go')
-rw-r--r--vendor/github.com/miekg/dns/xfr.go64
1 files changed, 35 insertions, 29 deletions
diff --git a/vendor/github.com/miekg/dns/xfr.go b/vendor/github.com/miekg/dns/xfr.go
index 5d0ff5c..43970e6 100644
--- a/vendor/github.com/miekg/dns/xfr.go
+++ b/vendor/github.com/miekg/dns/xfr.go
@@ -35,30 +35,36 @@ type Transfer struct {
// channel, err := transfer.In(message, master)
//
func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
+ switch q.Question[0].Qtype {
+ case TypeAXFR, TypeIXFR:
+ default:
+ return nil, &Error{"unsupported question type"}
+ }
+
timeout := dnsTimeout
if t.DialTimeout != 0 {
timeout = t.DialTimeout
}
+
if t.Conn == nil {
t.Conn, err = DialTimeout("tcp", a, timeout)
if err != nil {
return nil, err
}
}
+
if err := t.WriteMsg(q); err != nil {
return nil, err
}
+
env = make(chan *Envelope)
- go func() {
- if q.Question[0].Qtype == TypeAXFR {
- go t.inAxfr(q, env)
- return
- }
- if q.Question[0].Qtype == TypeIXFR {
- go t.inIxfr(q, env)
- return
- }
- }()
+ switch q.Question[0].Qtype {
+ case TypeAXFR:
+ go t.inAxfr(q, env)
+ case TypeIXFR:
+ go t.inIxfr(q, env)
+ }
+
return env, nil
}
@@ -111,7 +117,7 @@ func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) {
}
func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
- serial := uint32(0) // The first serial seen is the current server serial
+ var serial uint32 // The first serial seen is the current server serial
axfr := true
n := 0
qser := q.Ns[0].(*SOA).Serial
@@ -176,14 +182,17 @@ func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
//
// ch := make(chan *dns.Envelope)
// tr := new(dns.Transfer)
-// go tr.Out(w, r, ch)
+// var wg sync.WaitGroup
+// go func() {
+// tr.Out(w, r, ch)
+// wg.Done()
+// }()
// ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}}
// close(ch)
-// w.Hijack()
-// // w.Close() // Client closes connection
+// wg.Wait() // wait until everything is written out
+// w.Close() // close connection
//
-// The server is responsible for sending the correct sequence of RRs through the
-// channel ch.
+// The server is responsible for sending the correct sequence of RRs through the channel ch.
func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
for x := range ch {
r := new(Msg)
@@ -192,11 +201,14 @@ func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
r.Authoritative = true
// assume it fits TODO(miek): fix
r.Answer = append(r.Answer, x.RR...)
+ if tsig := q.IsTsig(); tsig != nil && w.TsigStatus() == nil {
+ r.SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix())
+ }
if err := w.WriteMsg(r); err != nil {
return err
}
+ w.TsigTimersOnly(true)
}
- w.TsigTimersOnly(true)
return nil
}
@@ -237,24 +249,18 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) {
if err != nil {
return err
}
- if _, err = t.Write(out); err != nil {
- return err
- }
- return nil
+ _, err = t.Write(out)
+ return err
}
func isSOAFirst(in *Msg) bool {
- if len(in.Answer) > 0 {
- return in.Answer[0].Header().Rrtype == TypeSOA
- }
- return false
+ return len(in.Answer) > 0 &&
+ in.Answer[0].Header().Rrtype == TypeSOA
}
func isSOALast(in *Msg) bool {
- if len(in.Answer) > 0 {
- return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
- }
- return false
+ return len(in.Answer) > 0 &&
+ in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
}
const errXFR = "bad xfr rcode: %d"