Make TextIOWrapper's seek/tell work properly with stateful decoders;

document and rename things to make seek/tell workings a little clearer.

Add a weird decoder for testing TextIOWrapper's seek/tell methods.

Document the getstate/setstate protocol conventions for IncrementalDecoders.
This commit is contained in:
Ka-Ping Yee 2008-03-18 04:51:32 +00:00
parent b5dc90b5fa
commit f44c7e8996
3 changed files with 367 additions and 109 deletions

View file

@ -237,7 +237,7 @@ class IncrementalDecoder(object):
"""
def __init__(self, errors='strict'):
"""
Creates a IncrementalDecoder instance.
Create a IncrementalDecoder instance.
The IncrementalDecoder may use different error handling schemes by
providing the errors keyword argument. See the module docstring
@ -247,28 +247,35 @@ def __init__(self, errors='strict'):
def decode(self, input, final=False):
"""
Decodes input and returns the resulting object.
Decode input and returns the resulting object.
"""
raise NotImplementedError
def reset(self):
"""
Resets the decoder to the initial state.
Reset the decoder to the initial state.
"""
def getstate(self):
"""
Return the current state of the decoder. This must be a
(buffered_input, additional_state_info) tuple. By convention,
additional_state_info should represent the state of the decoder
WITHOUT yet having processed the contents of buffered_input.
Return the current state of the decoder.
This must be a (buffered_input, additional_state_info) tuple.
buffered_input must be a bytes object containing bytes that
were passed to decode() that have not yet been converted.
additional_state_info must be a non-negative integer
representing the state of the decoder WITHOUT yet having
processed the contents of buffered_input. In the initial state
and after reset(), getstate() must return (b"", 0).
"""
return (b"", 0)
def setstate(self, state):
"""
Set the current state of the decoder. state must have been
returned by getstate().
Set the current state of the decoder.
state must have been returned by getstate(). The effect of
setstate((b"", 0)) must be equivalent to reset().
"""
class BufferedIncrementalDecoder(IncrementalDecoder):

280
Lib/io.py
View file

@ -802,11 +802,10 @@ def peek(self, n=0):
return self._read_buf
def read1(self, n):
"""Reads up to n bytes.
"""Reads up to n bytes, with at most one read() system call.
Returns up to n bytes. If at least one byte is buffered,
we only return buffered bytes. Otherwise, we do one
raw read.
Returns up to n bytes. If at least one byte is buffered, we
only return buffered bytes. Otherwise, we do one raw read.
"""
if n <= 0:
return b""
@ -1180,10 +1179,24 @@ def __init__(self, buffer, encoding=None, errors=None, newline=None,
self._writenl = newline or os.linesep
self._encoder = None
self._decoder = None
self._pending = ""
self._snapshot = None
self._decoded_text = "" # buffer for text produced by decoder
self._snapshot = None # info for reconstructing decoder state
self._seekable = self._telling = self.buffer.seekable()
# A word about _snapshot. This attribute is either None, or a tuple
# (decoder_state, input_chunk, decoded_chars) where decoder_state is
# the second (integer) item of the decoder state, input_chunk is the
# chunk of bytes that was read, and decoded_chars is the number of
# characters rendered by the decoder after feeding it those bytes.
# We use this to reconstruct intermediate decoder states in tell().
# Naming convention:
# - integer variables ending in "_bytes" count input bytes
# - integer variables ending in "_chars" count decoded characters
def __repr__(self):
return '<TIOW %x>' % id(self)
@property
def encoding(self):
return self._encoding
@ -1196,13 +1209,6 @@ def errors(self):
def line_buffering(self):
return self._line_buffering
# A word about _snapshot. This attribute is either None, or a
# tuple (decoder_state, readahead, pending) where decoder_state is
# the second (integer) item of the decoder state, readahead is the
# chunk of bytes that was read, and pending is the characters that
# were rendered by the decoder after feeding it those bytes. We
# use this to reconstruct intermediate decoder states in tell().
def seekable(self):
return self._seekable
@ -1262,126 +1268,199 @@ def _get_decoder(self):
return decoder
def _read_chunk(self):
"""
Read and decode the next chunk of data from the BufferedReader.
Return a tuple of two elements: all the bytes that were read, and
the decoded string produced by the decoder. (The entire input
chunk is sent to the decoder, but some of it may remain buffered
in the decoder, yet to be converted.)
"""
if self._decoder is None:
raise ValueError("no decoder")
if not self._telling:
readahead = self.buffer.read1(self._CHUNK_SIZE)
pending = self._decoder.decode(readahead, not readahead)
return readahead, pending
decoder_buffer, decoder_state = self._decoder.getstate()
readahead = self.buffer.read1(self._CHUNK_SIZE)
pending = self._decoder.decode(readahead, not readahead)
self._snapshot = (decoder_state, decoder_buffer + readahead, pending)
return readahead, pending
# No one should call tell(), so don't bother taking a snapshot.
input_chunk = self.buffer.read1(self._CHUNK_SIZE)
eof = not input_chunk
decoded = self._decoder.decode(input_chunk, eof)
return (input_chunk, decoded)
def _encode_decoder_state(self, ds, pos):
x = 0
for i in bytes(ds):
x = x<<8 | i
return (x<<64) | pos
# The cookie returned by tell() cannot include the contents of
# the decoder's buffer, so we need to snapshot a point in the
# input where the decoder has nothing in its input buffer.
def _decode_decoder_state(self, pos):
x, pos = divmod(pos, 1<<64)
if not x:
return None, pos
b = b""
while x:
b.append(x&0xff)
x >>= 8
return str(b[::-1]), pos
dec_buffer, dec_flags = self._decoder.getstate()
# The state tuple returned by getstate() contains the decoder's
# input buffer and an integer representing any other state. Thus,
# there is a valid snapshot point len(decoder_buffer) bytes ago in
# the input, with the state tuple (b'', decoder_state).
input_chunk = self.buffer.read1(self._CHUNK_SIZE)
eof = not input_chunk
decoded = self._decoder.decode(input_chunk, eof)
# At the snapshot point len(dec_buffer) bytes ago, the next input
# to be passed to the decoder is dec_buffer + input_chunk. Save
# len(decoded) so that later, tell() can figure out how much
# decoded data has been used up by TextIOWrapper.read().
self._snapshot = (dec_flags, dec_buffer + input_chunk, len(decoded))
return (input_chunk, decoded)
def _encode_tell_cookie(self, position, dec_flags=0,
feed_bytes=0, need_eof=0, skip_chars=0):
# The meaning of a tell() cookie is: seek to position, set the
# decoder flags to dec_flags, read feed_bytes bytes, feed them
# into the decoder with need_eof as the EOF flag, then skip
# skip_chars characters of the decoded result. For most simple
# decoders, this should often just be the position.
return (position | (dec_flags<<64) | (feed_bytes<<128) |
(skip_chars<<192) | bool(need_eof)<<256)
def _decode_tell_cookie(self, bigint):
rest, position = divmod(bigint, 1<<64)
rest, dec_flags = divmod(rest, 1<<64)
rest, feed_bytes = divmod(rest, 1<<64)
need_eof, skip_chars = divmod(rest, 1<<64)
return position, dec_flags, feed_bytes, need_eof, skip_chars
def tell(self):
if not self._seekable:
raise IOError("Underlying stream is not seekable")
raise IOError("underlying stream is not seekable")
if not self._telling:
raise IOError("Telling position disabled by next() call")
raise IOError("telling position disabled by next() call")
self.flush()
position = self.buffer.tell()
decoder = self._decoder
if decoder is None or self._snapshot is None:
if self._pending:
raise ValueError("pending data")
if self._decoded_text:
# This should never happen.
raise AssertionError("pending decoded text")
return position
decoder_state, readahead, pending = self._snapshot
position -= len(readahead)
needed = len(pending) - len(self._pending)
if not needed:
return self._encode_decoder_state(decoder_state, position)
# Skip backward to the snapshot point (see _read_chunk).
dec_flags, next_input, decoded_chars = self._snapshot
position -= len(next_input)
# How many decoded characters have been consumed since the snapshot?
skip_chars = decoded_chars - len(self._decoded_text)
if skip_chars == 0:
# We haven't moved from the snapshot point.
return self._encode_tell_cookie(position, dec_flags)
# Walk the decoder forward, one byte at a time, to find the minimum
# input necessary to give us the decoded characters we need to skip.
# As we go, look for the "safe point" nearest to the current location
# (i.e. a point where the decoder has nothing buffered, so we can
# safely start from there when trying to return to this location).
saved_state = decoder.getstate()
try:
decoder.setstate((b"", decoder_state))
n = 0
bb = bytearray(1)
for i, bb[0] in enumerate(readahead):
n += len(decoder.decode(bb))
if n >= needed:
decoder_buffer, decoder_state = decoder.getstate()
return self._encode_decoder_state(
decoder_state,
position + (i+1) - len(decoder_buffer) - (n - needed))
raise IOError("Can't reconstruct logical file position")
decoder.setstate((b"", dec_flags))
fed_bytes = 0
decoded_chars = 0
need_eof = 0
last_safe_point = (dec_flags, 0, 0)
next_byte = bytearray(1)
for next_byte[0] in next_input:
decoded = decoder.decode(next_byte)
fed_bytes += 1
decoded_chars += len(decoded)
dec_buffer, dec_flags = decoder.getstate()
if not dec_buffer and decoded_chars <= skip_chars:
# Decoder buffer is empty, so it's safe to start from here.
last_safe_point = (dec_flags, fed_bytes, decoded_chars)
if decoded_chars >= skip_chars:
break
else:
# We didn't get enough decoded data; send EOF to get more.
decoded = decoder.decode(b"", True)
decoded_chars += len(decoded)
need_eof = 1
if decoded_chars < skip_chars:
raise IOError("can't reconstruct logical file position")
# Advance the starting position to the last safe point.
dec_flags, safe_fed_bytes, safe_decoded_chars = last_safe_point
position += safe_fed_bytes
fed_bytes -= safe_fed_bytes
skip_chars -= safe_decoded_chars
return self._encode_tell_cookie(
position, dec_flags, fed_bytes, need_eof, skip_chars)
finally:
decoder.setstate(saved_state)
def seek(self, pos, whence=0):
def seek(self, cookie, whence=0):
if not self._seekable:
raise IOError("Underlying stream is not seekable")
if whence == 1:
if pos != 0:
raise IOError("Can't do nonzero cur-relative seeks")
pos = self.tell()
raise IOError("underlying stream is not seekable")
if whence == 1: # seek relative to current position
if cookie != 0:
raise IOError("can't do nonzero cur-relative seeks")
# Seeking to the current position should attempt to
# sync the underlying buffer with the current position.
whence = 0
if whence == 2:
if pos != 0:
raise IOError("Can't do nonzero end-relative seeks")
cookie = self.tell()
if whence == 2: # seek relative to end of file
if cookie != 0:
raise IOError("can't do nonzero end-relative seeks")
self.flush()
pos = self.buffer.seek(0, 2)
position = self.buffer.seek(0, 2)
self._decoded_text = ""
self._snapshot = None
self._pending = ""
if self._decoder:
self._decoder.reset()
return pos
return position
if whence != 0:
raise ValueError("Invalid whence (%r, should be 0, 1 or 2)" %
raise ValueError("invalid whence (%r, should be 0, 1 or 2)" %
(whence,))
if pos < 0:
raise ValueError("Negative seek position %r" % (pos,))
if cookie < 0:
raise ValueError("negative seek position %r" % (cookie,))
self.flush()
orig_pos = pos
ds, pos = self._decode_decoder_state(pos)
if not ds:
self.buffer.seek(pos)
self._snapshot = None
self._pending = ""
if self._decoder:
self._decoder.reset()
return pos
decoder = self._decoder or self._get_decoder()
decoder.set_state(("", ds))
self.buffer.seek(pos)
self._snapshot = (ds, b"", "")
self._pending = ""
self._decoder = decoder
return orig_pos
# Seek back to the snapshot point.
position, dec_flags, feed_bytes, need_eof, skip_chars = \
self._decode_tell_cookie(cookie)
self.buffer.seek(position)
self._decoded_text = ""
self._snapshot = None
if self._decoder or dec_flags or feed_bytes or need_eof:
# Restore the decoder flags to their values from the snapshot.
self._decoder = self._decoder or self._get_decoder()
self._decoder.setstate((b"", dec_flags))
if feed_bytes or need_eof:
# Feed feed_bytes bytes to the decoder.
input_chunk = self.buffer.read(feed_bytes)
decoded = self._decoder.decode(input_chunk, need_eof)
if len(decoded) < skip_chars:
raise IOError("can't restore logical file position")
# Skip skip_chars of the decoded characters.
self._decoded_text = decoded[skip_chars:]
# Restore the snapshot.
self._snapshot = (dec_flags, input_chunk, len(decoded))
return cookie
def read(self, n=None):
if n is None:
n = -1
decoder = self._decoder or self._get_decoder()
res = self._pending
result = self._decoded_text
if n < 0:
res += decoder.decode(self.buffer.read(), True)
self._pending = ""
result += decoder.decode(self.buffer.read(), True)
self._decoded_text = ""
self._snapshot = None
return res
return result
else:
while len(res) < n:
readahead, pending = self._read_chunk()
res += pending
if not readahead:
while len(result) < n:
input_chunk, decoded = self._read_chunk()
result += decoded
if not input_chunk:
break
self._pending = res[n:]
return res[:n]
self._decoded_text = result[n:]
return result[:n]
def __next__(self):
self._telling = False
@ -1400,10 +1479,11 @@ def readline(self, limit=None):
line = self.readline()
if len(line) <= limit:
return line
line, self._pending = line[:limit], line[limit:] + self._pending
line, self._decoded_text = \
line[:limit], line[limit:] + self._decoded_text
return line
line = self._pending
line = self._decoded_text
start = 0
decoder = self._decoder or self._get_decoder()
@ -1467,11 +1547,11 @@ def readline(self, limit=None):
line += more_line
else:
# end of file
self._pending = ''
self._decoded_text = ''
self._snapshot = None
return line
self._pending = line[endpos:]
self._decoded_text = line[endpos:]
return line[:endpos]
@property

View file

@ -8,6 +8,7 @@
from itertools import chain
from test import test_support
import codecs
import io # The module under test
@ -486,6 +487,122 @@ def testSeekAndTell(self):
self.assertEquals(b"fl", rw.read(11))
self.assertRaises(TypeError, rw.seek, 0.0)
# To fully exercise seek/tell, the StatefulIncrementalDecoder has these
# properties:
# - A single output character can correspond to many bytes of input.
# - The number of input bytes to complete the character can be
# undetermined until the last input byte is received.
# - The number of input bytes can vary depending on previous input.
# - A single input byte can correspond to many characters of output.
# - The number of output characters can be undetermined until the
# last input byte is received.
# - The number of output characters can vary depending on previous input.
class StatefulIncrementalDecoder(codecs.IncrementalDecoder):
"""
For testing seek/tell behavior with a stateful, buffering decoder.
Input is a sequence of words. Words may be fixed-length (length set
by input) or variable-length (period-terminated). In variable-length
mode, extra periods are ignored. Possible words are:
- 'i' followed by a number sets the input length, I (maximum 99).
When I is set to 0, words are space-terminated.
- 'o' followed by a number sets the output length, O (maximum 99).
- Any other word is converted into a word followed by a period on
the output. The output word consists of the input word truncated
or padded out with hyphens to make its length equal to O. If O
is 0, the word is output verbatim without truncating or padding.
I and O are initially set to 1. When I changes, any buffered input is
re-scanned according to the new I. EOF also terminates the last word.
"""
def __init__(self, errors='strict'):
codecs.IncrementalEncoder.__init__(self, errors)
self.reset()
def __repr__(self):
return '<SID %x>' % id(self)
def reset(self):
self.i = 1
self.o = 1
self.buffer = bytearray()
def getstate(self):
i, o = self.i ^ 1, self.o ^ 1 # so that flags = 0 after reset()
return bytes(self.buffer), i*100 + o
def setstate(self, state):
buffer, io = state
self.buffer = bytearray(buffer)
i, o = divmod(io, 100)
self.i, self.o = i ^ 1, o ^ 1
def decode(self, input, final=False):
output = ''
for b in input:
if self.i == 0: # variable-length, terminated with period
if b == ord('.'):
if self.buffer:
output += self.process_word()
else:
self.buffer.append(b)
else: # fixed-length, terminate after self.i bytes
self.buffer.append(b)
if len(self.buffer) == self.i:
output += self.process_word()
if final and self.buffer: # EOF terminates the last word
output += self.process_word()
return output
def process_word(self):
output = ''
if self.buffer[0] == ord('i'):
self.i = min(99, int(self.buffer[1:] or 0)) # set input length
elif self.buffer[0] == ord('o'):
self.o = min(99, int(self.buffer[1:] or 0)) # set output length
else:
output = self.buffer.decode('ascii')
if len(output) < self.o:
output += '-'*self.o # pad out with hyphens
if self.o:
output = output[:self.o] # truncate to output length
output += '.'
self.buffer = bytearray()
return output
class StatefulIncrementalDecoderTest(unittest.TestCase):
"""
Make sure the StatefulIncrementalDecoder actually works.
"""
test_cases = [
# I=1 fixed-length mode
(b'abcd', False, 'a.b.c.d.'),
# I=0, O=0, variable-length mode
(b'oiabcd', True, 'abcd.'),
# I=0, O=0, variable-length mode, should ignore extra periods
(b'oi...abcd...', True, 'abcd.'),
# I=0, O=6
(b'i.o6.xyz.', False, 'xyz---.'),
# I=2, O=6
(b'i.i2.o6xyz', True, 'xy----.z-----.'),
# I=0, O=3
(b'i.o3.x.xyz.toolong.', False, 'x--.xyz.too.'),
# I=6, O=3
(b'i.o3.i6.abcdefghijklmnop', True, 'abc.ghi.mno.')
]
def testDecoder(self):
# Try a few one-shot test cases.
for input, eof, output in self.test_cases:
d = StatefulIncrementalDecoder()
self.assertEquals(d.decode(input, eof), output)
# Also test an unfinished decode, followed by forcing EOF.
d = StatefulIncrementalDecoder()
self.assertEquals(d.decode(b'oiabcd'), '')
self.assertEquals(d.decode(b'', 1), 'abcd.')
class TextIOWrapperTest(unittest.TestCase):
@ -765,6 +882,60 @@ def testSeekingToo(self):
f.readline()
f.tell()
def testSeekAndTell(self):
"""Test seek/tell using the StatefulIncrementalDecoder."""
def lookupTestDecoder(name):
if self.codecEnabled and name == 'test_decoder':
return codecs.CodecInfo(
name='test_decoder', encode=None, decode=None,
incrementalencoder=None,
streamreader=None, streamwriter=None,
incrementaldecoder=StatefulIncrementalDecoder)
def testSeekAndTellWithData(data, min_pos=0):
"""Tell/seek to various points within a data stream and ensure
that the decoded data returned by read() is consistent."""
f = io.open(test_support.TESTFN, 'wb')
f.write(data)
f.close()
f = io.open(test_support.TESTFN, encoding='test_decoder')
decoded = f.read()
f.close()
for i in range(min_pos, len(decoded) + 1): # seek positions
for j in [1, 5, len(decoded) - i]: # read lengths
f = io.open(test_support.TESTFN, encoding='test_decoder')
self.assertEquals(f.read(i), decoded[:i])
cookie = f.tell()
self.assertEquals(f.read(j), decoded[i:i + j])
f.seek(cookie)
self.assertEquals(f.read(), decoded[i:])
f.close()
# Register a special incremental decoder for testing.
codecs.register(lookupTestDecoder)
self.codecEnabled = 1
# Run the tests.
try:
# Try each test case.
for input, _, _ in StatefulIncrementalDecoderTest.test_cases:
testSeekAndTellWithData(input)
# Position each test case so that it crosses a chunk boundary.
CHUNK_SIZE = io.TextIOWrapper._CHUNK_SIZE
for input, _, _ in StatefulIncrementalDecoderTest.test_cases:
offset = CHUNK_SIZE - len(input)//2
prefix = b'.'*offset
# Don't bother seeking into the prefix (takes too long).
min_pos = offset*2
testSeekAndTellWithData(prefix + input, min_pos)
# Ensure our test decoder won't interfere with subsequent tests.
finally:
self.codecEnabled = 0
def testEncodedWrites(self):
data = "1234567890"
tests = ("utf-16",