encoding/xml: add (*Encoder).Close

Flush can not check for unclosed elements, as more data might be encoded
after Flush is called. Close implicitly calls Flush and also checks that
all opened elements are closed as well.

Fixes #53346

Change-Id: I889b9f5ae54e5dfabb9e6948d96c5ed7bc1110f9
Reviewed-on: https://go-review.googlesource.com/c/go/+/424777
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
Axel Wagner 2022-08-18 16:18:34 +02:00 committed by Gopher Robot
parent be9e2440a7
commit 7f632f76db
3 changed files with 133 additions and 7 deletions

1
api/next/53346.txt Normal file
View file

@ -0,0 +1 @@
pkg encoding/xml, method (*Encoder) Close() error #53346

View file

@ -8,6 +8,7 @@ import (
"bufio"
"bytes"
"encoding"
"errors"
"fmt"
"io"
"reflect"
@ -78,7 +79,11 @@ const (
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(v any) ([]byte, error) {
var b bytes.Buffer
if err := NewEncoder(&b).Encode(v); err != nil {
enc := NewEncoder(&b)
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
@ -129,6 +134,9 @@ func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}
@ -139,7 +147,7 @@ type Encoder struct {
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
e := &Encoder{printer{Writer: bufio.NewWriter(w)}}
e := &Encoder{printer{w: bufio.NewWriter(w)}}
e.p.encoder = e
return e
}
@ -163,7 +171,7 @@ func (enc *Encoder) Encode(v any) error {
if err != nil {
return err
}
return enc.p.Flush()
return enc.p.w.Flush()
}
// EncodeElement writes the XML encoding of v to the stream,
@ -178,7 +186,7 @@ func (enc *Encoder) EncodeElement(v any, start StartElement) error {
if err != nil {
return err
}
return enc.p.Flush()
return enc.p.w.Flush()
}
var (
@ -224,7 +232,7 @@ func (enc *Encoder) EncodeToken(t Token) error {
case ProcInst:
// First token to be encoded which is also a ProcInst with target of xml
// is the xml declaration. The only ProcInst where target of xml is allowed.
if t.Target == "xml" && p.Buffered() != 0 {
if t.Target == "xml" && p.w.Buffered() != 0 {
return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded")
}
if !isNameString(t.Target) {
@ -297,11 +305,18 @@ func isValidDirective(dir Directive) bool {
// Flush flushes any buffered XML to the underlying writer.
// See the EncodeToken documentation for details about when it is necessary.
func (enc *Encoder) Flush() error {
return enc.p.Flush()
return enc.p.w.Flush()
}
// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (enc *Encoder) Close() error {
return enc.p.Close()
}
type printer struct {
*bufio.Writer
w *bufio.Writer
encoder *Encoder
seq int
indent string
@ -313,6 +328,8 @@ type printer struct {
attrPrefix map[string]string // map name space -> prefix
prefixes []string
tags []Name
closed bool
err error
}
// createAttrPrefix finds the name space prefix attribute to use for the given name space,
@ -961,6 +978,56 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
return p.cachedWriteError()
}
// Write implements io.Writer
func (p *printer) Write(b []byte) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.Write(b)
}
return n, p.err
}
// WriteString implements io.StringWriter
func (p *printer) WriteString(s string) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.WriteString(s)
}
return n, p.err
}
// WriteByte implements io.ByteWriter
func (p *printer) WriteByte(c byte) error {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
p.err = p.w.WriteByte(c)
}
return p.err
}
// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (p *printer) Close() error {
if p.closed {
return nil
}
p.closed = true
if err := p.w.Flush(); err != nil {
return err
}
if len(p.tags) > 0 {
return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local)
}
return nil
}
// return the bufio Writer's cached write error
func (p *printer) cachedWriteError() error {
_, err := p.Write(nil)

View file

@ -2531,3 +2531,61 @@ func TestMarshalZeroValue(t *testing.T) {
t.Fatalf("unexpected unmarshal result, want %q but got %q", proofXml, anotherXML)
}
}
var closeTests = []struct {
desc string
toks []Token
want string
err string
}{{
desc: "unclosed start element",
toks: []Token{
StartElement{Name{"", "foo"}, nil},
},
want: `<foo>`,
err: "unclosed tag <foo>",
}, {
desc: "closed element",
toks: []Token{
StartElement{Name{"", "foo"}, nil},
EndElement{Name{"", "foo"}},
},
want: `<foo></foo>`,
}, {
desc: "directive",
toks: []Token{
Directive("foo"),
},
want: `<!foo>`,
}}
func TestClose(t *testing.T) {
for _, tt := range closeTests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
var out strings.Builder
enc := NewEncoder(&out)
for j, tok := range tt.toks {
if err := enc.EncodeToken(tok); err != nil {
t.Fatalf("token #%d: %v", j, err)
}
}
err := enc.Close()
switch {
case tt.err != "" && err == nil:
t.Error(" expected error; got none")
case tt.err == "" && err != nil:
t.Errorf(" got error: %v", err)
case tt.err != "" && err != nil && tt.err != err.Error():
t.Errorf(" error mismatch; got %v, want %v", err, tt.err)
}
if got := out.String(); got != tt.want {
t.Errorf("\ngot %v\nwant %v", got, tt.want)
}
t.Log(enc.p.closed)
if err := enc.EncodeToken(Directive("foo")); err == nil {
t.Errorf("unexpected success when encoding after Close")
}
})
}
}