net: keep waiting for valid DNS response until timeout

Prevents denial of service attacks from bogus UDP packets.

Fixes #13281.

Change-Id: Ifb51b17a1b0807bfd27b144d6037431701184e7b
Reviewed-on: https://go-review.googlesource.com/22126
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Matthew Dempsky 2016-04-15 19:19:58 -07:00
parent 9f1ccd647f
commit 3411d63219
4 changed files with 263 additions and 63 deletions

View file

@ -38,46 +38,67 @@ type dnsConn interface {
SetDeadline(time.Time) error
// readDNSResponse reads a DNS response message from the DNS
// transport endpoint and returns the received DNS response
// message.
readDNSResponse() (*dnsMsg, error)
// writeDNSQuery writes a DNS query message to the DNS
// connection endpoint.
writeDNSQuery(*dnsMsg) error
// dnsRoundTrip executes a single DNS transaction, returning a
// DNS response message for the provided DNS query message.
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
}
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
b := make([]byte, 512) // see RFC 1035
n, err := c.Read(b)
if err != nil {
return nil, err
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
return msg, nil
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
return dnsRoundTripUDP(c, query)
}
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
// such as a *UDPConn.
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return errors.New("cannot marshal DNS message")
return nil, errors.New("cannot marshal DNS message")
}
if _, err := c.Write(b); err != nil {
return err
return nil, err
}
b = make([]byte, 512) // see RFC 1035
for {
n, err := c.Read(b)
if err != nil {
return nil, err
}
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
// timeout. See golang.org/issue/13281.
continue
}
return resp, nil
}
return nil
}
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
return dnsRoundTripTCP(c, out)
}
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
// such as a *TCPConn.
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
l := len(b)
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
return nil, err
}
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err
}
l := int(b[0])<<8 | int(b[1])
l = int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
if err != nil {
return nil, err
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
return msg, nil
}
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
if !ok {
return errors.New("cannot marshal DNS message")
if !resp.IsResponseTo(query) {
return nil, errors.New("invalid DNS response")
}
l := uint16(len(b))
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
return err
}
return nil
return resp, nil
}
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
c.SetDeadline(d)
}
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
if err := c.writeDNSQuery(&out); err != nil {
return nil, mapErr(err)
}
in, err := c.readDNSResponse()
in, err := c.dnsRoundTrip(&out)
if err != nil {
return nil, mapErr(err)
}
if in.id != out.id {
return nil, errors.New("DNS message ID mismatch")
}
if in.truncated { // see RFC 5966
continue
}

View file

@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
}
type fakeDNSConn struct {
// last query
qmu sync.Mutex // guards q
q *dnsMsg
// reply handler
rh func(*dnsMsg) (*dnsMsg, error)
}
@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error {
return nil
}
func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error {
f.qmu.Lock()
defer f.qmu.Unlock()
f.q = q
return nil
}
func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) {
f.qmu.Lock()
q := f.q
f.qmu.Unlock()
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
return f.rh(q)
}
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) {
const TestAddr uint32 = 0x80420001
c, s := Pipe()
go func() {
b := make([]byte, 512)
n, err := s.Read(b)
if err != nil {
t.Fatal(err)
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
t.Fatal("invalid DNS query")
}
s.Write([]byte("garbage DNS response packet"))
msg.response = true
msg.id++ // make invalid ID
b, ok := msg.Pack()
if !ok {
t.Fatal("failed to pack DNS response")
}
s.Write(b)
msg.id-- // restore original ID
msg.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: "www.example.com.",
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
},
A: TestAddr,
},
}
b, ok = msg.Pack()
if !ok {
t.Fatal("failed to pack DNS response")
}
s.Write(b)
}()
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: 42,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
}
resp, err := dnsRoundTripUDP(c, msg)
if err != nil {
t.Fatalf("dnsRoundTripUDP failed: %v", err)
}
if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
t.Error("got address %v, want %v", got, TestAddr)
}
}

View file

@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string {
}
return s
}
// IsResponseTo reports whether m is an acceptable response to query.
func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
if !m.response {
return false
}
if m.id != query.id {
return false
}
if len(m.question) != len(query.question) {
return false
}
for i, q := range m.question {
q2 := query.question[i]
if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
return false
}
}
return true
}

View file

@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
}
}
func TestIsResponseTo(t *testing.T) {
// Sample DNS query.
query := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: 42,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
}
resp := query
resp.response = true
if !resp.IsResponseTo(&query) {
t.Error("got false, want true")
}
badResponses := []dnsMsg{
// Different ID.
{
dnsMsgHdr: dnsMsgHdr{
id: 43,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
},
// Different query name.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.google.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
},
// Different query type.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
},
},
},
// Different query class.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassCSNET,
},
},
},
// No questions.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
},
// Extra questions.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
{
Name: "www.golang.org.",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
},
},
},
}
for i := range badResponses {
if badResponses[i].IsResponseTo(&query) {
t.Error("%v: got true, want false", i)
}
}
}
// Valid DNS SRV reply
const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +