encoding/base64: This change modifies Go to take strict option when decoding base64

If strict option is enabled, when decoding, instead of skip the padding
bits, it will do strict check to enforce they are set to zero.

Fixes #15656

Change-Id: I869fb725a39cc9dde44dbc4ff0046446e7abc642
Reviewed-on: https://go-review.googlesource.com/24964
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Xuyang Kang 2016-07-17 00:23:56 -07:00 committed by Russ Cox
parent 56b5546b91
commit 87b1aaa37c
2 changed files with 42 additions and 3 deletions

View file

@ -23,6 +23,7 @@ type Encoding struct {
encode [64]byte
decodeMap [256]byte
padChar rune
strict bool
}
const (
@ -62,6 +63,14 @@ func (enc Encoding) WithPadding(padding rune) *Encoding {
return &enc
}
// Strict creates a new encoding identical to enc except with
// strict decoding enabled. In this mode, the decoder requires that
// trailing padding bits are zero, as described in RFC 4648 section 3.5.
func (enc Encoding) Strict() *Encoding {
enc.strict = true
return &enc
}
// StdEncoding is the standard base64 encoding, as defined in
// RFC 4648.
var StdEncoding = NewEncoding(encodeStd)
@ -311,15 +320,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// Convert 4x 6bit source bytes into 3 bytes
val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
switch dlen {
case 4:
dst[2] = byte(val >> 0)
dst[2] = dbuf[2]
dbuf[2] = 0
fallthrough
case 3:
dst[1] = byte(val >> 8)
dst[1] = dbuf[1]
if enc.strict && dbuf[2] != 0 {
return n, end, CorruptInputError(si - 1)
}
dbuf[1] = 0
fallthrough
case 2:
dst[0] = byte(val >> 16)
dst[0] = dbuf[0]
if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
return n, end, CorruptInputError(si - 2)
}
}
dst = dst[dinc:]
n += dlen - 1

View file

@ -85,6 +85,11 @@ var encodingTests = []encodingTest{
{RawStdEncoding, rawRef},
{RawURLEncoding, rawUrlRef},
{funnyEncoding, funnyRef},
{StdEncoding.Strict(), stdRef},
{URLEncoding.Strict(), urlRef},
{RawStdEncoding.Strict(), rawRef},
{RawURLEncoding.Strict(), rawUrlRef},
{funnyEncoding.Strict(), funnyRef},
}
var bigtest = testpair{
@ -436,6 +441,22 @@ func TestDecoderIssue7733(t *testing.T) {
}
}
func TestDecoderIssue15656(t *testing.T) {
_, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
want := CorruptInputError(22)
if !reflect.DeepEqual(want, err) {
t.Errorf("Error = %v; want CorruptInputError(22)", err)
}
_, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==")
if err != nil {
t.Errorf("Error = %v; want nil", err)
}
_, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
if err != nil {
t.Errorf("Error = %v; want nil", err)
}
}
func BenchmarkEncodeToString(b *testing.B) {
data := make([]byte, 8192)
b.SetBytes(int64(len(data)))