diff --git a/src/compress/lzw/reader.go b/src/compress/lzw/reader.go index f08021190c..952870a56a 100644 --- a/src/compress/lzw/reader.go +++ b/src/compress/lzw/reader.go @@ -42,15 +42,15 @@ const ( flushBuffer = 1 << maxWidth ) -// decoder is the state from which the readXxx method converts a byte -// stream into a code stream. -type decoder struct { +// Reader is an io.Reader which can be used to read compressed data in the +// LZW format. +type Reader struct { r io.ByteReader bits uint32 nBits uint width uint - read func(*decoder) (uint16, error) // readLSB or readMSB - litWidth int // width in bits of literal codes + read func(*Reader) (uint16, error) // readLSB or readMSB + litWidth int // width in bits of literal codes err error // The first 1<>= d.width - d.nBits -= d.width + code := uint16(r.bits & (1<>= r.width + r.nBits -= r.width return code, nil } // readMSB returns the next code for "Most Significant Bits first" data. -func (d *decoder) readMSB() (uint16, error) { - for d.nBits < d.width { - x, err := d.r.ReadByte() +func (r *Reader) readMSB() (uint16, error) { + for r.nBits < r.width { + x, err := r.r.ReadByte() if err != nil { return 0, err } - d.bits |= uint32(x) << (24 - d.nBits) - d.nBits += 8 + r.bits |= uint32(x) << (24 - r.nBits) + r.nBits += 8 } - code := uint16(d.bits >> (32 - d.width)) - d.bits <<= d.width - d.nBits -= d.width + code := uint16(r.bits >> (32 - r.width)) + r.bits <<= r.width + r.nBits -= r.width return code, nil } -func (d *decoder) Read(b []byte) (int, error) { +// Read implements io.Reader, reading uncompressed bytes from its underlying Reader. +func (r *Reader) Read(b []byte) (int, error) { for { - if len(d.toRead) > 0 { - n := copy(b, d.toRead) - d.toRead = d.toRead[n:] + if len(r.toRead) > 0 { + n := copy(b, r.toRead) + r.toRead = r.toRead[n:] return n, nil } - if d.err != nil { - return 0, d.err + if r.err != nil { + return 0, r.err } - d.decode() + r.decode() } } // decode decompresses bytes from r and leaves them in d.toRead. // read specifies how to decode bytes into codes. // litWidth is the width in bits of literal codes. -func (d *decoder) decode() { +func (r *Reader) decode() { // Loop over the code stream, converting codes into decompressed bytes. loop: for { - code, err := d.read(d) + code, err := r.read(r) if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } - d.err = err + r.err = err break } switch { - case code < d.clear: + case code < r.clear: // We have a literal code. - d.output[d.o] = uint8(code) - d.o++ - if d.last != decoderInvalidCode { + r.output[r.o] = uint8(code) + r.o++ + if r.last != decoderInvalidCode { // Save what the hi code expands to. - d.suffix[d.hi] = uint8(code) - d.prefix[d.hi] = d.last + r.suffix[r.hi] = uint8(code) + r.prefix[r.hi] = r.last } - case code == d.clear: - d.width = 1 + uint(d.litWidth) - d.hi = d.eof - d.overflow = 1 << d.width - d.last = decoderInvalidCode + case code == r.clear: + r.width = 1 + uint(r.litWidth) + r.hi = r.eof + r.overflow = 1 << r.width + r.last = decoderInvalidCode continue - case code == d.eof: - d.err = io.EOF + case code == r.eof: + r.err = io.EOF break loop - case code <= d.hi: - c, i := code, len(d.output)-1 - if code == d.hi && d.last != decoderInvalidCode { + case code <= r.hi: + c, i := code, len(r.output)-1 + if code == r.hi && r.last != decoderInvalidCode { // code == hi is a special case which expands to the last expansion // followed by the head of the last expansion. To find the head, we walk // the prefix chain until we find a literal code. - c = d.last - for c >= d.clear { - c = d.prefix[c] + c = r.last + for c >= r.clear { + c = r.prefix[c] } - d.output[i] = uint8(c) + r.output[i] = uint8(c) i-- - c = d.last + c = r.last } // Copy the suffix chain into output and then write that to w. - for c >= d.clear { - d.output[i] = d.suffix[c] + for c >= r.clear { + r.output[i] = r.suffix[c] i-- - c = d.prefix[c] + c = r.prefix[c] } - d.output[i] = uint8(c) - d.o += copy(d.output[d.o:], d.output[i:]) - if d.last != decoderInvalidCode { + r.output[i] = uint8(c) + r.o += copy(r.output[r.o:], r.output[i:]) + if r.last != decoderInvalidCode { // Save what the hi code expands to. - d.suffix[d.hi] = uint8(c) - d.prefix[d.hi] = d.last + r.suffix[r.hi] = uint8(c) + r.prefix[r.hi] = r.last } default: - d.err = errors.New("lzw: invalid code") + r.err = errors.New("lzw: invalid code") break loop } - d.last, d.hi = code, d.hi+1 - if d.hi >= d.overflow { - if d.hi > d.overflow { + r.last, r.hi = code, r.hi+1 + if r.hi >= r.overflow { + if r.hi > r.overflow { panic("unreachable") } - if d.width == maxWidth { - d.last = decoderInvalidCode + if r.width == maxWidth { + r.last = decoderInvalidCode // Undo the d.hi++ a few lines above, so that (1) we maintain // the invariant that d.hi < d.overflow, and (2) d.hi does not // eventually overflow a uint16. - d.hi-- + r.hi-- } else { - d.width++ - d.overflow = 1 << d.width + r.width++ + r.overflow = 1 << r.width } } - if d.o >= flushBuffer { + if r.o >= flushBuffer { break } } // Flush pending output. - d.toRead = d.output[:d.o] - d.o = 0 + r.toRead = r.output[:r.o] + r.o = 0 } var errClosed = errors.New("lzw: reader/writer is closed") -func (d *decoder) Close() error { - d.err = errClosed // in case any Reads come along +// Close closes the Reader and returns an error for any future read operation. +// It does not close the underlying io.Reader. +func (r *Reader) Close() error { + r.err = errClosed // in case any Reads come along return nil } +// Reset clears the Reader's state and allows it to be reused again +// as a new Reader. +func (r *Reader) Reset(src io.Reader, order Order, litWidth int) { + *r = Reader{} + r.init(src, order, litWidth) +} + // NewReader creates a new io.ReadCloser. // Reads from the returned io.ReadCloser read and decompress data from r. // If r does not also implement io.ByteReader, @@ -238,32 +248,43 @@ func (d *decoder) Close() error { // The number of bits to use for literal codes, litWidth, must be in the // range [2,8] and is typically 8. It must equal the litWidth // used during compression. +// +// It is guaranteed that the underlying type of the returned io.ReadCloser +// is a *Reader. func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser { - d := new(decoder) + return newReader(r, order, litWidth) +} + +func newReader(src io.Reader, order Order, litWidth int) *Reader { + r := new(Reader) + r.init(src, order, litWidth) + return r +} + +func (r *Reader) init(src io.Reader, order Order, litWidth int) { switch order { case LSB: - d.read = (*decoder).readLSB + r.read = (*Reader).readLSB case MSB: - d.read = (*decoder).readMSB + r.read = (*Reader).readMSB default: - d.err = errors.New("lzw: unknown order") - return d + r.err = errors.New("lzw: unknown order") + return } if litWidth < 2 || 8 < litWidth { - d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth) - return d + r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth) + return } - if br, ok := r.(io.ByteReader); ok { - d.r = br - } else { - d.r = bufio.NewReader(r) - } - d.litWidth = litWidth - d.width = 1 + uint(litWidth) - d.clear = uint16(1) << uint(litWidth) - d.eof, d.hi = d.clear+1, d.clear+1 - d.overflow = uint16(1) << d.width - d.last = decoderInvalidCode - return d + br, ok := src.(io.ByteReader) + if !ok && src != nil { + br = bufio.NewReader(src) + } + r.r = br + r.litWidth = litWidth + r.width = 1 + uint(litWidth) + r.clear = uint16(1) << uint(litWidth) + r.eof, r.hi = r.clear+1, r.clear+1 + r.overflow = uint16(1) << r.width + r.last = decoderInvalidCode } diff --git a/src/compress/lzw/reader_test.go b/src/compress/lzw/reader_test.go index d1eb76d042..9a2a477302 100644 --- a/src/compress/lzw/reader_test.go +++ b/src/compress/lzw/reader_test.go @@ -120,6 +120,53 @@ func TestReader(t *testing.T) { } } +func TestReaderReset(t *testing.T) { + var b bytes.Buffer + for _, tt := range lzwTests { + d := strings.Split(tt.desc, ";") + var order Order + switch d[1] { + case "LSB": + order = LSB + case "MSB": + order = MSB + default: + t.Errorf("%s: bad order %q", tt.desc, d[1]) + } + litWidth, _ := strconv.Atoi(d[2]) + rc := NewReader(strings.NewReader(tt.compressed), order, litWidth) + defer rc.Close() + b.Reset() + n, err := io.Copy(&b, rc) + b1 := b.Bytes() + if err != nil { + if err != tt.err { + t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err) + } + if err == io.ErrUnexpectedEOF { + // Even if the input is truncated, we should still return the + // partial decoded result. + if n == 0 || !strings.HasPrefix(tt.raw, b.String()) { + t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, b.String(), tt.raw) + } + } + continue + } + + b.Reset() + rc.(*Reader).Reset(strings.NewReader(tt.compressed), order, litWidth) + n, err = io.Copy(&b, rc) + b2 := b.Bytes() + if err != nil { + t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, nil) + continue + } + if !bytes.Equal(b1, b2) { + t.Errorf("bytes read were not the same") + } + } +} + type devZero struct{} func (devZero) Read(p []byte) (int, error) { @@ -131,7 +178,7 @@ func (devZero) Read(p []byte) (int, error) { func TestHiCodeDoesNotOverflow(t *testing.T) { r := NewReader(devZero{}, LSB, 8) - d := r.(*decoder) + d := r.(*Reader) buf := make([]byte, 1024) oldHi := uint16(0) for i := 0; i < 100; i++ { @@ -226,28 +273,43 @@ func BenchmarkDecoder(b *testing.B) { b.Fatalf("test file has no data") } + getInputBuf := func(buf []byte, n int) []byte { + compressed := new(bytes.Buffer) + w := NewWriter(compressed, LSB, 8) + for i := 0; i < n; i += len(buf) { + if len(buf) > n-i { + buf = buf[:n-i] + } + w.Write(buf) + } + w.Close() + return compressed.Bytes() + } + for e := 4; e <= 6; e++ { n := int(math.Pow10(e)) b.Run(fmt.Sprint("1e", e), func(b *testing.B) { b.StopTimer() b.SetBytes(int64(n)) - buf0 := buf - compressed := new(bytes.Buffer) - w := NewWriter(compressed, LSB, 8) - for i := 0; i < n; i += len(buf0) { - if len(buf0) > n-i { - buf0 = buf0[:n-i] - } - w.Write(buf0) - } - w.Close() - buf1 := compressed.Bytes() - buf0, compressed, w = nil, nil, nil + buf1 := getInputBuf(buf, n) runtime.GC() b.StartTimer() for i := 0; i < b.N; i++ { io.Copy(io.Discard, NewReader(bytes.NewReader(buf1), LSB, 8)) } }) + b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) { + b.StopTimer() + b.SetBytes(int64(n)) + buf1 := getInputBuf(buf, n) + runtime.GC() + b.StartTimer() + r := NewReader(bytes.NewReader(buf1), LSB, 8) + for i := 0; i < b.N; i++ { + io.Copy(io.Discard, r) + r.Close() + r.(*Reader).Reset(bytes.NewReader(buf1), LSB, 8) + } + }) } } diff --git a/src/compress/lzw/writer.go b/src/compress/lzw/writer.go index 6ddb335f31..552bdc2ce1 100644 --- a/src/compress/lzw/writer.go +++ b/src/compress/lzw/writer.go @@ -17,19 +17,6 @@ type writer interface { Flush() error } -// An errWriteCloser is an io.WriteCloser that always returns a given error. -type errWriteCloser struct { - err error -} - -func (e *errWriteCloser) Write([]byte) (int, error) { - return 0, e.err -} - -func (e *errWriteCloser) Close() error { - return e.err -} - const ( // A code is a 12 bit value, stored as a uint32 when encoding to avoid // type conversions when shifting bits. @@ -44,14 +31,15 @@ const ( invalidEntry = 0 ) -// encoder is LZW compressor. -type encoder struct { +// Writer is an LZW compressor. It writes the compressed form of the data +// to an underlying writer (see NewWriter). +type Writer struct { // w is the writer that compressed bytes are written to. w writer // order, write, bits, nBits and width are the state for // converting a code stream into a byte stream. order Order - write func(*encoder, uint32) error + write func(*Writer, uint32) error bits uint32 nBits uint width uint @@ -63,7 +51,7 @@ type encoder struct { // savedCode is the accumulated code at the end of the most recent Write // call. It is equal to invalidCode if there was no such call. savedCode uint32 - // err is the first error encountered during writing. Closing the encoder + // err is the first error encountered during writing. Closing the writer // will make any future Write calls return errClosed err error // table is the hash table from 20-bit keys to 12-bit values. Each table @@ -74,80 +62,80 @@ type encoder struct { } // writeLSB writes the code c for "Least Significant Bits first" data. -func (e *encoder) writeLSB(c uint32) error { - e.bits |= c << e.nBits - e.nBits += e.width - for e.nBits >= 8 { - if err := e.w.WriteByte(uint8(e.bits)); err != nil { +func (w *Writer) writeLSB(c uint32) error { + w.bits |= c << w.nBits + w.nBits += w.width + for w.nBits >= 8 { + if err := w.w.WriteByte(uint8(w.bits)); err != nil { return err } - e.bits >>= 8 - e.nBits -= 8 + w.bits >>= 8 + w.nBits -= 8 } return nil } // writeMSB writes the code c for "Most Significant Bits first" data. -func (e *encoder) writeMSB(c uint32) error { - e.bits |= c << (32 - e.width - e.nBits) - e.nBits += e.width - for e.nBits >= 8 { - if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil { +func (w *Writer) writeMSB(c uint32) error { + w.bits |= c << (32 - w.width - w.nBits) + w.nBits += w.width + for w.nBits >= 8 { + if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil { return err } - e.bits <<= 8 - e.nBits -= 8 + w.bits <<= 8 + w.nBits -= 8 } return nil } -// errOutOfCodes is an internal error that means that the encoder has run out +// errOutOfCodes is an internal error that means that the writer has run out // of unused codes and a clear code needs to be sent next. var errOutOfCodes = errors.New("lzw: out of codes") // incHi increments e.hi and checks for both overflow and running out of // unused codes. In the latter case, incHi sends a clear code, resets the -// encoder state and returns errOutOfCodes. -func (e *encoder) incHi() error { - e.hi++ - if e.hi == e.overflow { - e.width++ - e.overflow <<= 1 +// writer state and returns errOutOfCodes. +func (w *Writer) incHi() error { + w.hi++ + if w.hi == w.overflow { + w.width++ + w.overflow <<= 1 } - if e.hi == maxCode { - clear := uint32(1) << e.litWidth - if err := e.write(e, clear); err != nil { + if w.hi == maxCode { + clear := uint32(1) << w.litWidth + if err := w.write(w, clear); err != nil { return err } - e.width = e.litWidth + 1 - e.hi = clear + 1 - e.overflow = clear << 1 - for i := range e.table { - e.table[i] = invalidEntry + w.width = w.litWidth + 1 + w.hi = clear + 1 + w.overflow = clear << 1 + for i := range w.table { + w.table[i] = invalidEntry } return errOutOfCodes } return nil } -// Write writes a compressed representation of p to e's underlying writer. -func (e *encoder) Write(p []byte) (n int, err error) { - if e.err != nil { - return 0, e.err +// Write writes a compressed representation of p to w's underlying writer. +func (w *Writer) Write(p []byte) (n int, err error) { + if w.err != nil { + return 0, w.err } if len(p) == 0 { return 0, nil } - if maxLit := uint8(1< maxLit { - e.err = errors.New("lzw: input byte too large for the litWidth") - return 0, e.err + w.err = errors.New("lzw: input byte too large for the litWidth") + return 0, w.err } } } n = len(p) - code := e.savedCode + code := w.savedCode if code == invalidCode { // The first code sent is always a literal code. code, p = uint32(p[0]), p[1:] @@ -159,77 +147,84 @@ loop: // If there is a hash table hit for this key then we continue the loop // and do not emit a code yet. hash := (key>>12 ^ key) & tableMask - for h, t := hash, e.table[hash]; t != invalidEntry; { + for h, t := hash, w.table[hash]; t != invalidEntry; { if key == t>>12 { code = t & maxCode continue loop } h = (h + 1) & tableMask - t = e.table[h] + t = w.table[h] } // Otherwise, write the current code, and literal becomes the start of // the next emitted code. - if e.err = e.write(e, code); e.err != nil { - return 0, e.err + if w.err = w.write(w, code); w.err != nil { + return 0, w.err } code = literal // Increment e.hi, the next implied code. If we run out of codes, reset - // the encoder state (including clearing the hash table) and continue. - if err1 := e.incHi(); err1 != nil { + // the writer state (including clearing the hash table) and continue. + if err1 := w.incHi(); err1 != nil { if err1 == errOutOfCodes { continue } - e.err = err1 - return 0, e.err + w.err = err1 + return 0, w.err } // Otherwise, insert key -> e.hi into the map that e.table represents. for { - if e.table[hash] == invalidEntry { - e.table[hash] = (key << 12) | e.hi + if w.table[hash] == invalidEntry { + w.table[hash] = (key << 12) | w.hi break } hash = (hash + 1) & tableMask } } - e.savedCode = code + w.savedCode = code return n, nil } -// Close closes the encoder, flushing any pending output. It does not close or -// flush e's underlying writer. -func (e *encoder) Close() error { - if e.err != nil { - if e.err == errClosed { +// Close closes the Writer, flushing any pending output. It does not close +// w's underlying writer. +func (w *Writer) Close() error { + if w.err != nil { + if w.err == errClosed { return nil } - return e.err + return w.err } // Make any future calls to Write return errClosed. - e.err = errClosed + w.err = errClosed // Write the savedCode if valid. - if e.savedCode != invalidCode { - if err := e.write(e, e.savedCode); err != nil { + if w.savedCode != invalidCode { + if err := w.write(w, w.savedCode); err != nil { return err } - if err := e.incHi(); err != nil && err != errOutOfCodes { + if err := w.incHi(); err != nil && err != errOutOfCodes { return err } } // Write the eof code. - eof := uint32(1)< 0 { - if e.order == MSB { - e.bits >>= 24 + if w.nBits > 0 { + if w.order == MSB { + w.bits >>= 24 } - if err := e.w.WriteByte(uint8(e.bits)); err != nil { + if err := w.w.WriteByte(uint8(w.bits)); err != nil { return err } } - return e.w.Flush() + return w.w.Flush() +} + +// Reset clears the Writer's state and allows it to be reused again +// as a new Writer. +func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) { + *w = Writer{} + w.init(dst, order, litWidth) } // NewWriter creates a new io.WriteCloser. @@ -238,32 +233,43 @@ func (e *encoder) Close() error { // finished writing. // The number of bits to use for literal codes, litWidth, must be in the // range [2,8] and is typically 8. Input bytes must be less than 1<