diff options
Diffstat (limited to 'vendor/github.com/miekg/dns/xfr.go')
-rw-r--r-- | vendor/github.com/miekg/dns/xfr.go | 64 |
1 files changed, 35 insertions, 29 deletions
diff --git a/vendor/github.com/miekg/dns/xfr.go b/vendor/github.com/miekg/dns/xfr.go index 5d0ff5c..43970e6 100644 --- a/vendor/github.com/miekg/dns/xfr.go +++ b/vendor/github.com/miekg/dns/xfr.go @@ -35,30 +35,36 @@ type Transfer struct { // channel, err := transfer.In(message, master) // func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { + switch q.Question[0].Qtype { + case TypeAXFR, TypeIXFR: + default: + return nil, &Error{"unsupported question type"} + } + timeout := dnsTimeout if t.DialTimeout != 0 { timeout = t.DialTimeout } + if t.Conn == nil { t.Conn, err = DialTimeout("tcp", a, timeout) if err != nil { return nil, err } } + if err := t.WriteMsg(q); err != nil { return nil, err } + env = make(chan *Envelope) - go func() { - if q.Question[0].Qtype == TypeAXFR { - go t.inAxfr(q, env) - return - } - if q.Question[0].Qtype == TypeIXFR { - go t.inIxfr(q, env) - return - } - }() + switch q.Question[0].Qtype { + case TypeAXFR: + go t.inAxfr(q, env) + case TypeIXFR: + go t.inIxfr(q, env) + } + return env, nil } @@ -111,7 +117,7 @@ func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) { } func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) { - serial := uint32(0) // The first serial seen is the current server serial + var serial uint32 // The first serial seen is the current server serial axfr := true n := 0 qser := q.Ns[0].(*SOA).Serial @@ -176,14 +182,17 @@ func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) { // // ch := make(chan *dns.Envelope) // tr := new(dns.Transfer) -// go tr.Out(w, r, ch) +// var wg sync.WaitGroup +// go func() { +// tr.Out(w, r, ch) +// wg.Done() +// }() // ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}} // close(ch) -// w.Hijack() -// // w.Close() // Client closes connection +// wg.Wait() // wait until everything is written out +// w.Close() // close connection // -// The server is responsible for sending the correct sequence of RRs through the -// channel ch. +// The server is responsible for sending the correct sequence of RRs through the channel ch. func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { for x := range ch { r := new(Msg) @@ -192,11 +201,14 @@ func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { r.Authoritative = true // assume it fits TODO(miek): fix r.Answer = append(r.Answer, x.RR...) + if tsig := q.IsTsig(); tsig != nil && w.TsigStatus() == nil { + r.SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix()) + } if err := w.WriteMsg(r); err != nil { return err } + w.TsigTimersOnly(true) } - w.TsigTimersOnly(true) return nil } @@ -237,24 +249,18 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) { if err != nil { return err } - if _, err = t.Write(out); err != nil { - return err - } - return nil + _, err = t.Write(out) + return err } func isSOAFirst(in *Msg) bool { - if len(in.Answer) > 0 { - return in.Answer[0].Header().Rrtype == TypeSOA - } - return false + return len(in.Answer) > 0 && + in.Answer[0].Header().Rrtype == TypeSOA } func isSOALast(in *Msg) bool { - if len(in.Answer) > 0 { - return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA - } - return false + return len(in.Answer) > 0 && + in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA } const errXFR = "bad xfr rcode: %d" |