From 87b1aaa37cefec8deacdf9c3c30d26015bdfb00b Mon Sep 17 00:00:00 2001 From: Xuyang Kang Date: Sun, 17 Jul 2016 00:23:56 -0700 Subject: [PATCH] encoding/base64: This change modifies Go to take strict option when decoding base64 If strict option is enabled, when decoding, instead of skip the padding bits, it will do strict check to enforce they are set to zero. Fixes #15656 Change-Id: I869fb725a39cc9dde44dbc4ff0046446e7abc642 Reviewed-on: https://go-review.googlesource.com/24964 Reviewed-by: Russ Cox Run-TryBot: Russ Cox TryBot-Result: Gobot Gobot --- src/encoding/base64/base64.go | 24 +++++++++++++++++++++--- src/encoding/base64/base64_test.go | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/src/encoding/base64/base64.go b/src/encoding/base64/base64.go index c2116d8a343..d2efad4518b 100644 --- a/src/encoding/base64/base64.go +++ b/src/encoding/base64/base64.go @@ -23,6 +23,7 @@ type Encoding struct { encode [64]byte decodeMap [256]byte padChar rune + strict bool } const ( @@ -62,6 +63,14 @@ func (enc Encoding) WithPadding(padding rune) *Encoding { return &enc } +// Strict creates a new encoding identical to enc except with +// strict decoding enabled. In this mode, the decoder requires that +// trailing padding bits are zero, as described in RFC 4648 section 3.5. +func (enc Encoding) Strict() *Encoding { + enc.strict = true + return &enc +} + // StdEncoding is the standard base64 encoding, as defined in // RFC 4648. var StdEncoding = NewEncoding(encodeStd) @@ -311,15 +320,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { // Convert 4x 6bit source bytes into 3 bytes val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3]) + dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16) switch dlen { case 4: - dst[2] = byte(val >> 0) + dst[2] = dbuf[2] + dbuf[2] = 0 fallthrough case 3: - dst[1] = byte(val >> 8) + dst[1] = dbuf[1] + if enc.strict && dbuf[2] != 0 { + return n, end, CorruptInputError(si - 1) + } + dbuf[1] = 0 fallthrough case 2: - dst[0] = byte(val >> 16) + dst[0] = dbuf[0] + if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) { + return n, end, CorruptInputError(si - 2) + } } dst = dst[dinc:] n += dlen - 1 diff --git a/src/encoding/base64/base64_test.go b/src/encoding/base64/base64_test.go index 19ddb92f644..e2e1d59f3c0 100644 --- a/src/encoding/base64/base64_test.go +++ b/src/encoding/base64/base64_test.go @@ -85,6 +85,11 @@ var encodingTests = []encodingTest{ {RawStdEncoding, rawRef}, {RawURLEncoding, rawUrlRef}, {funnyEncoding, funnyRef}, + {StdEncoding.Strict(), stdRef}, + {URLEncoding.Strict(), urlRef}, + {RawStdEncoding.Strict(), rawRef}, + {RawURLEncoding.Strict(), rawUrlRef}, + {funnyEncoding.Strict(), funnyRef}, } var bigtest = testpair{ @@ -436,6 +441,22 @@ func TestDecoderIssue7733(t *testing.T) { } } +func TestDecoderIssue15656(t *testing.T) { + _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==") + want := CorruptInputError(22) + if !reflect.DeepEqual(want, err) { + t.Errorf("Error = %v; want CorruptInputError(22)", err) + } + _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==") + if err != nil { + t.Errorf("Error = %v; want nil", err) + } + _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==") + if err != nil { + t.Errorf("Error = %v; want nil", err) + } +} + func BenchmarkEncodeToString(b *testing.B) { data := make([]byte, 8192) b.SetBytes(int64(len(data)))