compress/lzw: add Reset method to Reader and Writer

We add a Reset method which clears any internal state of an encoder
or a decoder to let it be reused again as a new Writer or Reader respectively.

We also export the encoder and decoder structs, renaming them
to be Reader and Writer, and we guarantee that the underlying types
from the constructors will always be Reader and Writer respectively.

Benchmark results by reusing the encoder:
on cpu: Intel(R) Core(TM) i5-8265U CPU @ 1.60GHz

name                 time/op
Decoder/1e4-8          93.6µs ± 1%
Decoder/1e-Reuse4-8    87.7µs ± 1%
Decoder/1e5-8           877µs ± 1%
Decoder/1e-Reuse5-8     860µs ± 3%
Decoder/1e6-8          8.79ms ± 1%
Decoder/1e-Reuse6-8    8.82ms ± 4%
Encoder/1e4-8           168µs ± 2%
Encoder/1e-Reuse4-8     160µs ± 1%
Encoder/1e5-8          1.64ms ± 1%
Encoder/1e-Reuse5-8    1.61ms ± 2%
Encoder/1e6-8          16.2ms ± 6%
Encoder/1e-Reuse6-8    15.8ms ± 2%

name                 speed
Decoder/1e4-8         107MB/s ± 1%
Decoder/1e-Reuse4-8   114MB/s ± 1%
Decoder/1e5-8         114MB/s ± 1%
Decoder/1e-Reuse5-8   116MB/s ± 3%
Decoder/1e6-8         114MB/s ± 1%
Decoder/1e-Reuse6-8   113MB/s ± 5%
Encoder/1e4-8        59.7MB/s ± 2%
Encoder/1e-Reuse4-8  62.4MB/s ± 1%
Encoder/1e5-8        61.1MB/s ± 1%
Encoder/1e-Reuse5-8  62.0MB/s ± 2%
Encoder/1e6-8        61.7MB/s ± 5%
Encoder/1e-Reuse6-8  63.4MB/s ± 2%

name                 alloc/op
Decoder/1e4-8          21.8kB ± 0%
Decoder/1e-Reuse4-8     50.0B ± 0%
Decoder/1e5-8          21.8kB ± 0%
Decoder/1e-Reuse5-8     70.4B ± 2%
Decoder/1e6-8          21.9kB ± 0%
Decoder/1e-Reuse6-8      271B ± 3%
Encoder/1e4-8          77.9kB ± 0%
Encoder/1e-Reuse4-8    4.17kB ± 0%
Encoder/1e5-8          77.9kB ± 0%
Encoder/1e-Reuse5-8    4.27kB ± 0%
Encoder/1e6-8          77.9kB ± 0%
Encoder/1e-Reuse6-8    5.22kB ± 0%

name                 allocs/op
Decoder/1e4-8            2.00 ± 0%
Decoder/1e-Reuse4-8      1.00 ± 0%
Decoder/1e5-8            2.00 ± 0%
Decoder/1e-Reuse5-8      1.00 ± 0%
Decoder/1e6-8            2.00 ± 0%
Decoder/1e-Reuse6-8      1.00 ± 0%
Encoder/1e4-8            3.00 ± 0%
Encoder/1e-Reuse4-8      2.00 ± 0%
Encoder/1e5-8            3.00 ± 0%
Encoder/1e-Reuse5-8      2.00 ± 0%
Encoder/1e6-8            3.00 ± 0%
Encoder/1e-Reuse6-8      2.00 ± 0%

Fixes #26535

Change-Id: Icde613fea6234a5bdce95f1e49910f5687e30b22
Reviewed-on: https://go-review.googlesource.com/c/go/+/273667
Trust: Agniva De Sarker <agniva.quicksilver@gmail.com>
Trust: Joe Tsai <thebrokentoaster@gmail.com>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Agniva De Sarker 2020-11-27 11:25:00 +05:30 committed by Agniva De Sarker
parent 119d76d98e
commit 72b501cb03
4 changed files with 349 additions and 206 deletions

View file

@ -42,15 +42,15 @@ const (
flushBuffer = 1 << maxWidth
)
// decoder is the state from which the readXxx method converts a byte
// stream into a code stream.
type decoder struct {
// Reader is an io.Reader which can be used to read compressed data in the
// LZW format.
type Reader struct {
r io.ByteReader
bits uint32
nBits uint
width uint
read func(*decoder) (uint16, error) // readLSB or readMSB
litWidth int // width in bits of literal codes
read func(*Reader) (uint16, error) // readLSB or readMSB
litWidth int // width in bits of literal codes
err error
// The first 1<<litWidth codes are literal codes.
@ -87,148 +87,158 @@ type decoder struct {
}
// readLSB returns the next code for "Least Significant Bits first" data.
func (d *decoder) readLSB() (uint16, error) {
for d.nBits < d.width {
x, err := d.r.ReadByte()
func (r *Reader) readLSB() (uint16, error) {
for r.nBits < r.width {
x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
d.bits |= uint32(x) << d.nBits
d.nBits += 8
r.bits |= uint32(x) << r.nBits
r.nBits += 8
}
code := uint16(d.bits & (1<<d.width - 1))
d.bits >>= d.width
d.nBits -= d.width
code := uint16(r.bits & (1<<r.width - 1))
r.bits >>= r.width
r.nBits -= r.width
return code, nil
}
// readMSB returns the next code for "Most Significant Bits first" data.
func (d *decoder) readMSB() (uint16, error) {
for d.nBits < d.width {
x, err := d.r.ReadByte()
func (r *Reader) readMSB() (uint16, error) {
for r.nBits < r.width {
x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
d.bits |= uint32(x) << (24 - d.nBits)
d.nBits += 8
r.bits |= uint32(x) << (24 - r.nBits)
r.nBits += 8
}
code := uint16(d.bits >> (32 - d.width))
d.bits <<= d.width
d.nBits -= d.width
code := uint16(r.bits >> (32 - r.width))
r.bits <<= r.width
r.nBits -= r.width
return code, nil
}
func (d *decoder) Read(b []byte) (int, error) {
// Read implements io.Reader, reading uncompressed bytes from its underlying Reader.
func (r *Reader) Read(b []byte) (int, error) {
for {
if len(d.toRead) > 0 {
n := copy(b, d.toRead)
d.toRead = d.toRead[n:]
if len(r.toRead) > 0 {
n := copy(b, r.toRead)
r.toRead = r.toRead[n:]
return n, nil
}
if d.err != nil {
return 0, d.err
if r.err != nil {
return 0, r.err
}
d.decode()
r.decode()
}
}
// decode decompresses bytes from r and leaves them in d.toRead.
// read specifies how to decode bytes into codes.
// litWidth is the width in bits of literal codes.
func (d *decoder) decode() {
func (r *Reader) decode() {
// Loop over the code stream, converting codes into decompressed bytes.
loop:
for {
code, err := d.read(d)
code, err := r.read(r)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
d.err = err
r.err = err
break
}
switch {
case code < d.clear:
case code < r.clear:
// We have a literal code.
d.output[d.o] = uint8(code)
d.o++
if d.last != decoderInvalidCode {
r.output[r.o] = uint8(code)
r.o++
if r.last != decoderInvalidCode {
// Save what the hi code expands to.
d.suffix[d.hi] = uint8(code)
d.prefix[d.hi] = d.last
r.suffix[r.hi] = uint8(code)
r.prefix[r.hi] = r.last
}
case code == d.clear:
d.width = 1 + uint(d.litWidth)
d.hi = d.eof
d.overflow = 1 << d.width
d.last = decoderInvalidCode
case code == r.clear:
r.width = 1 + uint(r.litWidth)
r.hi = r.eof
r.overflow = 1 << r.width
r.last = decoderInvalidCode
continue
case code == d.eof:
d.err = io.EOF
case code == r.eof:
r.err = io.EOF
break loop
case code <= d.hi:
c, i := code, len(d.output)-1
if code == d.hi && d.last != decoderInvalidCode {
case code <= r.hi:
c, i := code, len(r.output)-1
if code == r.hi && r.last != decoderInvalidCode {
// code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code.
c = d.last
for c >= d.clear {
c = d.prefix[c]
c = r.last
for c >= r.clear {
c = r.prefix[c]
}
d.output[i] = uint8(c)
r.output[i] = uint8(c)
i--
c = d.last
c = r.last
}
// Copy the suffix chain into output and then write that to w.
for c >= d.clear {
d.output[i] = d.suffix[c]
for c >= r.clear {
r.output[i] = r.suffix[c]
i--
c = d.prefix[c]
c = r.prefix[c]
}
d.output[i] = uint8(c)
d.o += copy(d.output[d.o:], d.output[i:])
if d.last != decoderInvalidCode {
r.output[i] = uint8(c)
r.o += copy(r.output[r.o:], r.output[i:])
if r.last != decoderInvalidCode {
// Save what the hi code expands to.
d.suffix[d.hi] = uint8(c)
d.prefix[d.hi] = d.last
r.suffix[r.hi] = uint8(c)
r.prefix[r.hi] = r.last
}
default:
d.err = errors.New("lzw: invalid code")
r.err = errors.New("lzw: invalid code")
break loop
}
d.last, d.hi = code, d.hi+1
if d.hi >= d.overflow {
if d.hi > d.overflow {
r.last, r.hi = code, r.hi+1
if r.hi >= r.overflow {
if r.hi > r.overflow {
panic("unreachable")
}
if d.width == maxWidth {
d.last = decoderInvalidCode
if r.width == maxWidth {
r.last = decoderInvalidCode
// Undo the d.hi++ a few lines above, so that (1) we maintain
// the invariant that d.hi < d.overflow, and (2) d.hi does not
// eventually overflow a uint16.
d.hi--
r.hi--
} else {
d.width++
d.overflow = 1 << d.width
r.width++
r.overflow = 1 << r.width
}
}
if d.o >= flushBuffer {
if r.o >= flushBuffer {
break
}
}
// Flush pending output.
d.toRead = d.output[:d.o]
d.o = 0
r.toRead = r.output[:r.o]
r.o = 0
}
var errClosed = errors.New("lzw: reader/writer is closed")
func (d *decoder) Close() error {
d.err = errClosed // in case any Reads come along
// Close closes the Reader and returns an error for any future read operation.
// It does not close the underlying io.Reader.
func (r *Reader) Close() error {
r.err = errClosed // in case any Reads come along
return nil
}
// Reset clears the Reader's state and allows it to be reused again
// as a new Reader.
func (r *Reader) Reset(src io.Reader, order Order, litWidth int) {
*r = Reader{}
r.init(src, order, litWidth)
}
// NewReader creates a new io.ReadCloser.
// Reads from the returned io.ReadCloser read and decompress data from r.
// If r does not also implement io.ByteReader,
@ -238,32 +248,43 @@ func (d *decoder) Close() error {
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. It must equal the litWidth
// used during compression.
//
// It is guaranteed that the underlying type of the returned io.ReadCloser
// is a *Reader.
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
d := new(decoder)
return newReader(r, order, litWidth)
}
func newReader(src io.Reader, order Order, litWidth int) *Reader {
r := new(Reader)
r.init(src, order, litWidth)
return r
}
func (r *Reader) init(src io.Reader, order Order, litWidth int) {
switch order {
case LSB:
d.read = (*decoder).readLSB
r.read = (*Reader).readLSB
case MSB:
d.read = (*decoder).readMSB
r.read = (*Reader).readMSB
default:
d.err = errors.New("lzw: unknown order")
return d
r.err = errors.New("lzw: unknown order")
return
}
if litWidth < 2 || 8 < litWidth {
d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return d
r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return
}
if br, ok := r.(io.ByteReader); ok {
d.r = br
} else {
d.r = bufio.NewReader(r)
}
d.litWidth = litWidth
d.width = 1 + uint(litWidth)
d.clear = uint16(1) << uint(litWidth)
d.eof, d.hi = d.clear+1, d.clear+1
d.overflow = uint16(1) << d.width
d.last = decoderInvalidCode
return d
br, ok := src.(io.ByteReader)
if !ok && src != nil {
br = bufio.NewReader(src)
}
r.r = br
r.litWidth = litWidth
r.width = 1 + uint(litWidth)
r.clear = uint16(1) << uint(litWidth)
r.eof, r.hi = r.clear+1, r.clear+1
r.overflow = uint16(1) << r.width
r.last = decoderInvalidCode
}

View file

@ -120,6 +120,53 @@ func TestReader(t *testing.T) {
}
}
func TestReaderReset(t *testing.T) {
var b bytes.Buffer
for _, tt := range lzwTests {
d := strings.Split(tt.desc, ";")
var order Order
switch d[1] {
case "LSB":
order = LSB
case "MSB":
order = MSB
default:
t.Errorf("%s: bad order %q", tt.desc, d[1])
}
litWidth, _ := strconv.Atoi(d[2])
rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
defer rc.Close()
b.Reset()
n, err := io.Copy(&b, rc)
b1 := b.Bytes()
if err != nil {
if err != tt.err {
t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
}
if err == io.ErrUnexpectedEOF {
// Even if the input is truncated, we should still return the
// partial decoded result.
if n == 0 || !strings.HasPrefix(tt.raw, b.String()) {
t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, b.String(), tt.raw)
}
}
continue
}
b.Reset()
rc.(*Reader).Reset(strings.NewReader(tt.compressed), order, litWidth)
n, err = io.Copy(&b, rc)
b2 := b.Bytes()
if err != nil {
t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, nil)
continue
}
if !bytes.Equal(b1, b2) {
t.Errorf("bytes read were not the same")
}
}
}
type devZero struct{}
func (devZero) Read(p []byte) (int, error) {
@ -131,7 +178,7 @@ func (devZero) Read(p []byte) (int, error) {
func TestHiCodeDoesNotOverflow(t *testing.T) {
r := NewReader(devZero{}, LSB, 8)
d := r.(*decoder)
d := r.(*Reader)
buf := make([]byte, 1024)
oldHi := uint16(0)
for i := 0; i < 100; i++ {
@ -226,28 +273,43 @@ func BenchmarkDecoder(b *testing.B) {
b.Fatalf("test file has no data")
}
getInputBuf := func(buf []byte, n int) []byte {
compressed := new(bytes.Buffer)
w := NewWriter(compressed, LSB, 8)
for i := 0; i < n; i += len(buf) {
if len(buf) > n-i {
buf = buf[:n-i]
}
w.Write(buf)
}
w.Close()
return compressed.Bytes()
}
for e := 4; e <= 6; e++ {
n := int(math.Pow10(e))
b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
b.StopTimer()
b.SetBytes(int64(n))
buf0 := buf
compressed := new(bytes.Buffer)
w := NewWriter(compressed, LSB, 8)
for i := 0; i < n; i += len(buf0) {
if len(buf0) > n-i {
buf0 = buf0[:n-i]
}
w.Write(buf0)
}
w.Close()
buf1 := compressed.Bytes()
buf0, compressed, w = nil, nil, nil
buf1 := getInputBuf(buf, n)
runtime.GC()
b.StartTimer()
for i := 0; i < b.N; i++ {
io.Copy(io.Discard, NewReader(bytes.NewReader(buf1), LSB, 8))
}
})
b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
b.StopTimer()
b.SetBytes(int64(n))
buf1 := getInputBuf(buf, n)
runtime.GC()
b.StartTimer()
r := NewReader(bytes.NewReader(buf1), LSB, 8)
for i := 0; i < b.N; i++ {
io.Copy(io.Discard, r)
r.Close()
r.(*Reader).Reset(bytes.NewReader(buf1), LSB, 8)
}
})
}
}

View file

@ -17,19 +17,6 @@ type writer interface {
Flush() error
}
// An errWriteCloser is an io.WriteCloser that always returns a given error.
type errWriteCloser struct {
err error
}
func (e *errWriteCloser) Write([]byte) (int, error) {
return 0, e.err
}
func (e *errWriteCloser) Close() error {
return e.err
}
const (
// A code is a 12 bit value, stored as a uint32 when encoding to avoid
// type conversions when shifting bits.
@ -44,14 +31,15 @@ const (
invalidEntry = 0
)
// encoder is LZW compressor.
type encoder struct {
// Writer is an LZW compressor. It writes the compressed form of the data
// to an underlying writer (see NewWriter).
type Writer struct {
// w is the writer that compressed bytes are written to.
w writer
// order, write, bits, nBits and width are the state for
// converting a code stream into a byte stream.
order Order
write func(*encoder, uint32) error
write func(*Writer, uint32) error
bits uint32
nBits uint
width uint
@ -63,7 +51,7 @@ type encoder struct {
// savedCode is the accumulated code at the end of the most recent Write
// call. It is equal to invalidCode if there was no such call.
savedCode uint32
// err is the first error encountered during writing. Closing the encoder
// err is the first error encountered during writing. Closing the writer
// will make any future Write calls return errClosed
err error
// table is the hash table from 20-bit keys to 12-bit values. Each table
@ -74,80 +62,80 @@ type encoder struct {
}
// writeLSB writes the code c for "Least Significant Bits first" data.
func (e *encoder) writeLSB(c uint32) error {
e.bits |= c << e.nBits
e.nBits += e.width
for e.nBits >= 8 {
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
func (w *Writer) writeLSB(c uint32) error {
w.bits |= c << w.nBits
w.nBits += w.width
for w.nBits >= 8 {
if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
e.bits >>= 8
e.nBits -= 8
w.bits >>= 8
w.nBits -= 8
}
return nil
}
// writeMSB writes the code c for "Most Significant Bits first" data.
func (e *encoder) writeMSB(c uint32) error {
e.bits |= c << (32 - e.width - e.nBits)
e.nBits += e.width
for e.nBits >= 8 {
if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil {
func (w *Writer) writeMSB(c uint32) error {
w.bits |= c << (32 - w.width - w.nBits)
w.nBits += w.width
for w.nBits >= 8 {
if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
return err
}
e.bits <<= 8
e.nBits -= 8
w.bits <<= 8
w.nBits -= 8
}
return nil
}
// errOutOfCodes is an internal error that means that the encoder has run out
// errOutOfCodes is an internal error that means that the writer has run out
// of unused codes and a clear code needs to be sent next.
var errOutOfCodes = errors.New("lzw: out of codes")
// incHi increments e.hi and checks for both overflow and running out of
// unused codes. In the latter case, incHi sends a clear code, resets the
// encoder state and returns errOutOfCodes.
func (e *encoder) incHi() error {
e.hi++
if e.hi == e.overflow {
e.width++
e.overflow <<= 1
// writer state and returns errOutOfCodes.
func (w *Writer) incHi() error {
w.hi++
if w.hi == w.overflow {
w.width++
w.overflow <<= 1
}
if e.hi == maxCode {
clear := uint32(1) << e.litWidth
if err := e.write(e, clear); err != nil {
if w.hi == maxCode {
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return err
}
e.width = e.litWidth + 1
e.hi = clear + 1
e.overflow = clear << 1
for i := range e.table {
e.table[i] = invalidEntry
w.width = w.litWidth + 1
w.hi = clear + 1
w.overflow = clear << 1
for i := range w.table {
w.table[i] = invalidEntry
}
return errOutOfCodes
}
return nil
}
// Write writes a compressed representation of p to e's underlying writer.
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
// Write writes a compressed representation of p to w's underlying writer.
func (w *Writer) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
if len(p) == 0 {
return 0, nil
}
if maxLit := uint8(1<<e.litWidth - 1); maxLit != 0xff {
if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
for _, x := range p {
if x > maxLit {
e.err = errors.New("lzw: input byte too large for the litWidth")
return 0, e.err
w.err = errors.New("lzw: input byte too large for the litWidth")
return 0, w.err
}
}
}
n = len(p)
code := e.savedCode
code := w.savedCode
if code == invalidCode {
// The first code sent is always a literal code.
code, p = uint32(p[0]), p[1:]
@ -159,77 +147,84 @@ loop:
// If there is a hash table hit for this key then we continue the loop
// and do not emit a code yet.
hash := (key>>12 ^ key) & tableMask
for h, t := hash, e.table[hash]; t != invalidEntry; {
for h, t := hash, w.table[hash]; t != invalidEntry; {
if key == t>>12 {
code = t & maxCode
continue loop
}
h = (h + 1) & tableMask
t = e.table[h]
t = w.table[h]
}
// Otherwise, write the current code, and literal becomes the start of
// the next emitted code.
if e.err = e.write(e, code); e.err != nil {
return 0, e.err
if w.err = w.write(w, code); w.err != nil {
return 0, w.err
}
code = literal
// Increment e.hi, the next implied code. If we run out of codes, reset
// the encoder state (including clearing the hash table) and continue.
if err1 := e.incHi(); err1 != nil {
// the writer state (including clearing the hash table) and continue.
if err1 := w.incHi(); err1 != nil {
if err1 == errOutOfCodes {
continue
}
e.err = err1
return 0, e.err
w.err = err1
return 0, w.err
}
// Otherwise, insert key -> e.hi into the map that e.table represents.
for {
if e.table[hash] == invalidEntry {
e.table[hash] = (key << 12) | e.hi
if w.table[hash] == invalidEntry {
w.table[hash] = (key << 12) | w.hi
break
}
hash = (hash + 1) & tableMask
}
}
e.savedCode = code
w.savedCode = code
return n, nil
}
// Close closes the encoder, flushing any pending output. It does not close or
// flush e's underlying writer.
func (e *encoder) Close() error {
if e.err != nil {
if e.err == errClosed {
// Close closes the Writer, flushing any pending output. It does not close
// w's underlying writer.
func (w *Writer) Close() error {
if w.err != nil {
if w.err == errClosed {
return nil
}
return e.err
return w.err
}
// Make any future calls to Write return errClosed.
e.err = errClosed
w.err = errClosed
// Write the savedCode if valid.
if e.savedCode != invalidCode {
if err := e.write(e, e.savedCode); err != nil {
if w.savedCode != invalidCode {
if err := w.write(w, w.savedCode); err != nil {
return err
}
if err := e.incHi(); err != nil && err != errOutOfCodes {
if err := w.incHi(); err != nil && err != errOutOfCodes {
return err
}
}
// Write the eof code.
eof := uint32(1)<<e.litWidth + 1
if err := e.write(e, eof); err != nil {
eof := uint32(1)<<w.litWidth + 1
if err := w.write(w, eof); err != nil {
return err
}
// Write the final bits.
if e.nBits > 0 {
if e.order == MSB {
e.bits >>= 24
if w.nBits > 0 {
if w.order == MSB {
w.bits >>= 24
}
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
}
return e.w.Flush()
return w.w.Flush()
}
// Reset clears the Writer's state and allows it to be reused again
// as a new Writer.
func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
*w = Writer{}
w.init(dst, order, litWidth)
}
// NewWriter creates a new io.WriteCloser.
@ -238,32 +233,43 @@ func (e *encoder) Close() error {
// finished writing.
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
//
// It is guaranteed that the underlying type of the returned io.WriteCloser
// is a *Writer.
func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
var write func(*encoder, uint32) error
return newWriter(w, order, litWidth)
}
func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
w := new(Writer)
w.init(dst, order, litWidth)
return w
}
func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
switch order {
case LSB:
write = (*encoder).writeLSB
w.write = (*Writer).writeLSB
case MSB:
write = (*encoder).writeMSB
w.write = (*Writer).writeMSB
default:
return &errWriteCloser{errors.New("lzw: unknown order")}
w.err = errors.New("lzw: unknown order")
return
}
if litWidth < 2 || 8 < litWidth {
return &errWriteCloser{fmt.Errorf("lzw: litWidth %d out of range", litWidth)}
w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return
}
bw, ok := w.(writer)
if !ok {
bw = bufio.NewWriter(w)
bw, ok := dst.(writer)
if !ok && dst != nil {
bw = bufio.NewWriter(dst)
}
w.w = bw
lw := uint(litWidth)
return &encoder{
w: bw,
order: order,
write: write,
width: 1 + lw,
litWidth: lw,
hi: 1<<lw + 1,
overflow: 1 << (lw + 1),
savedCode: invalidCode,
}
w.order = order
w.width = 1 + lw
w.litWidth = lw
w.hi = 1<<lw + 1
w.overflow = 1 << (lw + 1)
w.savedCode = invalidCode
}

View file

@ -5,6 +5,7 @@
package lzw
import (
"bytes"
"fmt"
"internal/testenv"
"io"
@ -105,6 +106,50 @@ func TestWriter(t *testing.T) {
}
}
func TestWriterReset(t *testing.T) {
for _, order := range [...]Order{LSB, MSB} {
t.Run(fmt.Sprintf("Order %d", order), func(t *testing.T) {
for litWidth := 6; litWidth <= 8; litWidth++ {
t.Run(fmt.Sprintf("LitWidth %d", litWidth), func(t *testing.T) {
var data []byte
if litWidth == 6 {
data = []byte{1, 2, 3}
} else {
data = []byte(`lorem ipsum dolor sit amet`)
}
var buf bytes.Buffer
w := NewWriter(&buf, order, litWidth)
if _, err := w.Write(data); err != nil {
t.Errorf("write: %v: %v", string(data), err)
}
if err := w.Close(); err != nil {
t.Errorf("close: %v", err)
}
b1 := buf.Bytes()
buf.Reset()
w.(*Writer).Reset(&buf, order, litWidth)
if _, err := w.Write(data); err != nil {
t.Errorf("write: %v: %v", string(data), err)
}
if err := w.Close(); err != nil {
t.Errorf("close: %v", err)
}
b2 := buf.Bytes()
if !bytes.Equal(b1, b2) {
t.Errorf("bytes written were not same")
}
})
}
})
}
}
func TestWriterReturnValues(t *testing.T) {
w := NewWriter(io.Discard, LSB, 8)
n, err := w.Write([]byte("asdf"))
@ -152,5 +197,14 @@ func BenchmarkEncoder(b *testing.B) {
w.Close()
}
})
b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
b.SetBytes(int64(n))
w := NewWriter(io.Discard, LSB, 8)
for i := 0; i < b.N; i++ {
w.Write(buf1)
w.Close()
w.(*Writer).Reset(io.Discard, LSB, 8)
}
})
}
}