diff --git a/src/encoding/base64/base64.go b/src/encoding/base64/base64.go index c2116d8a343..d2efad4518b 100644 --- a/src/encoding/base64/base64.go +++ b/src/encoding/base64/base64.go @@ -23,6 +23,7 @@ type Encoding struct { encode [64]byte decodeMap [256]byte padChar rune + strict bool } const ( @@ -62,6 +63,14 @@ func (enc Encoding) WithPadding(padding rune) *Encoding { return &enc } +// Strict creates a new encoding identical to enc except with +// strict decoding enabled. In this mode, the decoder requires that +// trailing padding bits are zero, as described in RFC 4648 section 3.5. +func (enc Encoding) Strict() *Encoding { + enc.strict = true + return &enc +} + // StdEncoding is the standard base64 encoding, as defined in // RFC 4648. var StdEncoding = NewEncoding(encodeStd) @@ -311,15 +320,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { // Convert 4x 6bit source bytes into 3 bytes val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3]) + dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16) switch dlen { case 4: - dst[2] = byte(val >> 0) + dst[2] = dbuf[2] + dbuf[2] = 0 fallthrough case 3: - dst[1] = byte(val >> 8) + dst[1] = dbuf[1] + if enc.strict && dbuf[2] != 0 { + return n, end, CorruptInputError(si - 1) + } + dbuf[1] = 0 fallthrough case 2: - dst[0] = byte(val >> 16) + dst[0] = dbuf[0] + if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) { + return n, end, CorruptInputError(si - 2) + } } dst = dst[dinc:] n += dlen - 1 diff --git a/src/encoding/base64/base64_test.go b/src/encoding/base64/base64_test.go index 19ddb92f644..e2e1d59f3c0 100644 --- a/src/encoding/base64/base64_test.go +++ b/src/encoding/base64/base64_test.go @@ -85,6 +85,11 @@ var encodingTests = []encodingTest{ {RawStdEncoding, rawRef}, {RawURLEncoding, rawUrlRef}, {funnyEncoding, funnyRef}, + {StdEncoding.Strict(), stdRef}, + {URLEncoding.Strict(), urlRef}, + {RawStdEncoding.Strict(), rawRef}, + {RawURLEncoding.Strict(), rawUrlRef}, + {funnyEncoding.Strict(), funnyRef}, } var bigtest = testpair{ @@ -436,6 +441,22 @@ func TestDecoderIssue7733(t *testing.T) { } } +func TestDecoderIssue15656(t *testing.T) { + _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==") + want := CorruptInputError(22) + if !reflect.DeepEqual(want, err) { + t.Errorf("Error = %v; want CorruptInputError(22)", err) + } + _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==") + if err != nil { + t.Errorf("Error = %v; want nil", err) + } + _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==") + if err != nil { + t.Errorf("Error = %v; want nil", err) + } +} + func BenchmarkEncodeToString(b *testing.B) { data := make([]byte, 8192) b.SetBytes(int64(len(data)))