Make TextIOWrapper's seek/tell work properly with stateful decoders;

document and rename things to make seek/tell workings a little clearer.

Add a weird decoder for testing TextIOWrapper's seek/tell methods.

Document the getstate/setstate protocol conventions for IncrementalDecoders.
This commit is contained in:
Ka-Ping Yee 2008-03-18 04:51:32 +00:00
parent b5dc90b5fa
commit f44c7e8996
3 changed files with 367 additions and 109 deletions

View file

@ -237,7 +237,7 @@ class IncrementalDecoder(object):
""" """
def __init__(self, errors='strict'): def __init__(self, errors='strict'):
""" """
Creates a IncrementalDecoder instance. Create a IncrementalDecoder instance.
The IncrementalDecoder may use different error handling schemes by The IncrementalDecoder may use different error handling schemes by
providing the errors keyword argument. See the module docstring providing the errors keyword argument. See the module docstring
@ -247,28 +247,35 @@ def __init__(self, errors='strict'):
def decode(self, input, final=False): def decode(self, input, final=False):
""" """
Decodes input and returns the resulting object. Decode input and returns the resulting object.
""" """
raise NotImplementedError raise NotImplementedError
def reset(self): def reset(self):
""" """
Resets the decoder to the initial state. Reset the decoder to the initial state.
""" """
def getstate(self): def getstate(self):
""" """
Return the current state of the decoder. This must be a Return the current state of the decoder.
(buffered_input, additional_state_info) tuple. By convention,
additional_state_info should represent the state of the decoder This must be a (buffered_input, additional_state_info) tuple.
WITHOUT yet having processed the contents of buffered_input. buffered_input must be a bytes object containing bytes that
were passed to decode() that have not yet been converted.
additional_state_info must be a non-negative integer
representing the state of the decoder WITHOUT yet having
processed the contents of buffered_input. In the initial state
and after reset(), getstate() must return (b"", 0).
""" """
return (b"", 0) return (b"", 0)
def setstate(self, state): def setstate(self, state):
""" """
Set the current state of the decoder. state must have been Set the current state of the decoder.
returned by getstate().
state must have been returned by getstate(). The effect of
setstate((b"", 0)) must be equivalent to reset().
""" """
class BufferedIncrementalDecoder(IncrementalDecoder): class BufferedIncrementalDecoder(IncrementalDecoder):

278
Lib/io.py
View file

@ -802,11 +802,10 @@ def peek(self, n=0):
return self._read_buf return self._read_buf
def read1(self, n): def read1(self, n):
"""Reads up to n bytes. """Reads up to n bytes, with at most one read() system call.
Returns up to n bytes. If at least one byte is buffered, Returns up to n bytes. If at least one byte is buffered, we
we only return buffered bytes. Otherwise, we do one only return buffered bytes. Otherwise, we do one raw read.
raw read.
""" """
if n <= 0: if n <= 0:
return b"" return b""
@ -1180,10 +1179,24 @@ def __init__(self, buffer, encoding=None, errors=None, newline=None,
self._writenl = newline or os.linesep self._writenl = newline or os.linesep
self._encoder = None self._encoder = None
self._decoder = None self._decoder = None
self._pending = "" self._decoded_text = "" # buffer for text produced by decoder
self._snapshot = None self._snapshot = None # info for reconstructing decoder state
self._seekable = self._telling = self.buffer.seekable() self._seekable = self._telling = self.buffer.seekable()
# A word about _snapshot. This attribute is either None, or a tuple
# (decoder_state, input_chunk, decoded_chars) where decoder_state is
# the second (integer) item of the decoder state, input_chunk is the
# chunk of bytes that was read, and decoded_chars is the number of
# characters rendered by the decoder after feeding it those bytes.
# We use this to reconstruct intermediate decoder states in tell().
# Naming convention:
# - integer variables ending in "_bytes" count input bytes
# - integer variables ending in "_chars" count decoded characters
def __repr__(self):
return '<TIOW %x>' % id(self)
@property @property
def encoding(self): def encoding(self):
return self._encoding return self._encoding
@ -1196,13 +1209,6 @@ def errors(self):
def line_buffering(self): def line_buffering(self):
return self._line_buffering return self._line_buffering
# A word about _snapshot. This attribute is either None, or a
# tuple (decoder_state, readahead, pending) where decoder_state is
# the second (integer) item of the decoder state, readahead is the
# chunk of bytes that was read, and pending is the characters that
# were rendered by the decoder after feeding it those bytes. We
# use this to reconstruct intermediate decoder states in tell().
def seekable(self): def seekable(self):
return self._seekable return self._seekable
@ -1262,126 +1268,199 @@ def _get_decoder(self):
return decoder return decoder
def _read_chunk(self): def _read_chunk(self):
"""
Read and decode the next chunk of data from the BufferedReader.
Return a tuple of two elements: all the bytes that were read, and
the decoded string produced by the decoder. (The entire input
chunk is sent to the decoder, but some of it may remain buffered
in the decoder, yet to be converted.)
"""
if self._decoder is None: if self._decoder is None:
raise ValueError("no decoder") raise ValueError("no decoder")
if not self._telling: if not self._telling:
readahead = self.buffer.read1(self._CHUNK_SIZE) # No one should call tell(), so don't bother taking a snapshot.
pending = self._decoder.decode(readahead, not readahead) input_chunk = self.buffer.read1(self._CHUNK_SIZE)
return readahead, pending eof = not input_chunk
decoder_buffer, decoder_state = self._decoder.getstate() decoded = self._decoder.decode(input_chunk, eof)
readahead = self.buffer.read1(self._CHUNK_SIZE) return (input_chunk, decoded)
pending = self._decoder.decode(readahead, not readahead)
self._snapshot = (decoder_state, decoder_buffer + readahead, pending)
return readahead, pending
def _encode_decoder_state(self, ds, pos): # The cookie returned by tell() cannot include the contents of
x = 0 # the decoder's buffer, so we need to snapshot a point in the
for i in bytes(ds): # input where the decoder has nothing in its input buffer.
x = x<<8 | i
return (x<<64) | pos
def _decode_decoder_state(self, pos): dec_buffer, dec_flags = self._decoder.getstate()
x, pos = divmod(pos, 1<<64) # The state tuple returned by getstate() contains the decoder's
if not x: # input buffer and an integer representing any other state. Thus,
return None, pos # there is a valid snapshot point len(decoder_buffer) bytes ago in
b = b"" # the input, with the state tuple (b'', decoder_state).
while x:
b.append(x&0xff) input_chunk = self.buffer.read1(self._CHUNK_SIZE)
x >>= 8 eof = not input_chunk
return str(b[::-1]), pos decoded = self._decoder.decode(input_chunk, eof)
# At the snapshot point len(dec_buffer) bytes ago, the next input
# to be passed to the decoder is dec_buffer + input_chunk. Save
# len(decoded) so that later, tell() can figure out how much
# decoded data has been used up by TextIOWrapper.read().
self._snapshot = (dec_flags, dec_buffer + input_chunk, len(decoded))
return (input_chunk, decoded)
def _encode_tell_cookie(self, position, dec_flags=0,
feed_bytes=0, need_eof=0, skip_chars=0):
# The meaning of a tell() cookie is: seek to position, set the
# decoder flags to dec_flags, read feed_bytes bytes, feed them
# into the decoder with need_eof as the EOF flag, then skip
# skip_chars characters of the decoded result. For most simple
# decoders, this should often just be the position.
return (position | (dec_flags<<64) | (feed_bytes<<128) |
(skip_chars<<192) | bool(need_eof)<<256)
def _decode_tell_cookie(self, bigint):
rest, position = divmod(bigint, 1<<64)
rest, dec_flags = divmod(rest, 1<<64)
rest, feed_bytes = divmod(rest, 1<<64)
need_eof, skip_chars = divmod(rest, 1<<64)
return position, dec_flags, feed_bytes, need_eof, skip_chars
def tell(self): def tell(self):
if not self._seekable: if not self._seekable:
raise IOError("Underlying stream is not seekable") raise IOError("underlying stream is not seekable")
if not self._telling: if not self._telling:
raise IOError("Telling position disabled by next() call") raise IOError("telling position disabled by next() call")
self.flush() self.flush()
position = self.buffer.tell() position = self.buffer.tell()
decoder = self._decoder decoder = self._decoder
if decoder is None or self._snapshot is None: if decoder is None or self._snapshot is None:
if self._pending: if self._decoded_text:
raise ValueError("pending data") # This should never happen.
raise AssertionError("pending decoded text")
return position return position
decoder_state, readahead, pending = self._snapshot
position -= len(readahead) # Skip backward to the snapshot point (see _read_chunk).
needed = len(pending) - len(self._pending) dec_flags, next_input, decoded_chars = self._snapshot
if not needed: position -= len(next_input)
return self._encode_decoder_state(decoder_state, position)
# How many decoded characters have been consumed since the snapshot?
skip_chars = decoded_chars - len(self._decoded_text)
if skip_chars == 0:
# We haven't moved from the snapshot point.
return self._encode_tell_cookie(position, dec_flags)
# Walk the decoder forward, one byte at a time, to find the minimum
# input necessary to give us the decoded characters we need to skip.
# As we go, look for the "safe point" nearest to the current location
# (i.e. a point where the decoder has nothing buffered, so we can
# safely start from there when trying to return to this location).
saved_state = decoder.getstate() saved_state = decoder.getstate()
try: try:
decoder.setstate((b"", decoder_state)) decoder.setstate((b"", dec_flags))
n = 0 fed_bytes = 0
bb = bytearray(1) decoded_chars = 0
for i, bb[0] in enumerate(readahead): need_eof = 0
n += len(decoder.decode(bb)) last_safe_point = (dec_flags, 0, 0)
if n >= needed:
decoder_buffer, decoder_state = decoder.getstate() next_byte = bytearray(1)
return self._encode_decoder_state( for next_byte[0] in next_input:
decoder_state, decoded = decoder.decode(next_byte)
position + (i+1) - len(decoder_buffer) - (n - needed)) fed_bytes += 1
raise IOError("Can't reconstruct logical file position") decoded_chars += len(decoded)
dec_buffer, dec_flags = decoder.getstate()
if not dec_buffer and decoded_chars <= skip_chars:
# Decoder buffer is empty, so it's safe to start from here.
last_safe_point = (dec_flags, fed_bytes, decoded_chars)
if decoded_chars >= skip_chars:
break
else:
# We didn't get enough decoded data; send EOF to get more.
decoded = decoder.decode(b"", True)
decoded_chars += len(decoded)
need_eof = 1
if decoded_chars < skip_chars:
raise IOError("can't reconstruct logical file position")
# Advance the starting position to the last safe point.
dec_flags, safe_fed_bytes, safe_decoded_chars = last_safe_point
position += safe_fed_bytes
fed_bytes -= safe_fed_bytes
skip_chars -= safe_decoded_chars
return self._encode_tell_cookie(
position, dec_flags, fed_bytes, need_eof, skip_chars)
finally: finally:
decoder.setstate(saved_state) decoder.setstate(saved_state)
def seek(self, pos, whence=0): def seek(self, cookie, whence=0):
if not self._seekable: if not self._seekable:
raise IOError("Underlying stream is not seekable") raise IOError("underlying stream is not seekable")
if whence == 1: if whence == 1: # seek relative to current position
if pos != 0: if cookie != 0:
raise IOError("Can't do nonzero cur-relative seeks") raise IOError("can't do nonzero cur-relative seeks")
pos = self.tell() # Seeking to the current position should attempt to
# sync the underlying buffer with the current position.
whence = 0 whence = 0
if whence == 2: cookie = self.tell()
if pos != 0: if whence == 2: # seek relative to end of file
raise IOError("Can't do nonzero end-relative seeks") if cookie != 0:
raise IOError("can't do nonzero end-relative seeks")
self.flush() self.flush()
pos = self.buffer.seek(0, 2) position = self.buffer.seek(0, 2)
self._decoded_text = ""
self._snapshot = None self._snapshot = None
self._pending = ""
if self._decoder: if self._decoder:
self._decoder.reset() self._decoder.reset()
return pos return position
if whence != 0: if whence != 0:
raise ValueError("Invalid whence (%r, should be 0, 1 or 2)" % raise ValueError("invalid whence (%r, should be 0, 1 or 2)" %
(whence,)) (whence,))
if pos < 0: if cookie < 0:
raise ValueError("Negative seek position %r" % (pos,)) raise ValueError("negative seek position %r" % (cookie,))
self.flush() self.flush()
orig_pos = pos
ds, pos = self._decode_decoder_state(pos) # Seek back to the snapshot point.
if not ds: position, dec_flags, feed_bytes, need_eof, skip_chars = \
self.buffer.seek(pos) self._decode_tell_cookie(cookie)
self.buffer.seek(position)
self._decoded_text = ""
self._snapshot = None self._snapshot = None
self._pending = ""
if self._decoder: if self._decoder or dec_flags or feed_bytes or need_eof:
self._decoder.reset() # Restore the decoder flags to their values from the snapshot.
return pos self._decoder = self._decoder or self._get_decoder()
decoder = self._decoder or self._get_decoder() self._decoder.setstate((b"", dec_flags))
decoder.set_state(("", ds))
self.buffer.seek(pos) if feed_bytes or need_eof:
self._snapshot = (ds, b"", "") # Feed feed_bytes bytes to the decoder.
self._pending = "" input_chunk = self.buffer.read(feed_bytes)
self._decoder = decoder decoded = self._decoder.decode(input_chunk, need_eof)
return orig_pos if len(decoded) < skip_chars:
raise IOError("can't restore logical file position")
# Skip skip_chars of the decoded characters.
self._decoded_text = decoded[skip_chars:]
# Restore the snapshot.
self._snapshot = (dec_flags, input_chunk, len(decoded))
return cookie
def read(self, n=None): def read(self, n=None):
if n is None: if n is None:
n = -1 n = -1
decoder = self._decoder or self._get_decoder() decoder = self._decoder or self._get_decoder()
res = self._pending result = self._decoded_text
if n < 0: if n < 0:
res += decoder.decode(self.buffer.read(), True) result += decoder.decode(self.buffer.read(), True)
self._pending = "" self._decoded_text = ""
self._snapshot = None self._snapshot = None
return res return result
else: else:
while len(res) < n: while len(result) < n:
readahead, pending = self._read_chunk() input_chunk, decoded = self._read_chunk()
res += pending result += decoded
if not readahead: if not input_chunk:
break break
self._pending = res[n:] self._decoded_text = result[n:]
return res[:n] return result[:n]
def __next__(self): def __next__(self):
self._telling = False self._telling = False
@ -1400,10 +1479,11 @@ def readline(self, limit=None):
line = self.readline() line = self.readline()
if len(line) <= limit: if len(line) <= limit:
return line return line
line, self._pending = line[:limit], line[limit:] + self._pending line, self._decoded_text = \
line[:limit], line[limit:] + self._decoded_text
return line return line
line = self._pending line = self._decoded_text
start = 0 start = 0
decoder = self._decoder or self._get_decoder() decoder = self._decoder or self._get_decoder()
@ -1467,11 +1547,11 @@ def readline(self, limit=None):
line += more_line line += more_line
else: else:
# end of file # end of file
self._pending = '' self._decoded_text = ''
self._snapshot = None self._snapshot = None
return line return line
self._pending = line[endpos:] self._decoded_text = line[endpos:]
return line[:endpos] return line[:endpos]
@property @property

View file

@ -8,6 +8,7 @@
from itertools import chain from itertools import chain
from test import test_support from test import test_support
import codecs
import io # The module under test import io # The module under test
@ -486,6 +487,122 @@ def testSeekAndTell(self):
self.assertEquals(b"fl", rw.read(11)) self.assertEquals(b"fl", rw.read(11))
self.assertRaises(TypeError, rw.seek, 0.0) self.assertRaises(TypeError, rw.seek, 0.0)
# To fully exercise seek/tell, the StatefulIncrementalDecoder has these
# properties:
# - A single output character can correspond to many bytes of input.
# - The number of input bytes to complete the character can be
# undetermined until the last input byte is received.
# - The number of input bytes can vary depending on previous input.
# - A single input byte can correspond to many characters of output.
# - The number of output characters can be undetermined until the
# last input byte is received.
# - The number of output characters can vary depending on previous input.
class StatefulIncrementalDecoder(codecs.IncrementalDecoder):
"""
For testing seek/tell behavior with a stateful, buffering decoder.
Input is a sequence of words. Words may be fixed-length (length set
by input) or variable-length (period-terminated). In variable-length
mode, extra periods are ignored. Possible words are:
- 'i' followed by a number sets the input length, I (maximum 99).
When I is set to 0, words are space-terminated.
- 'o' followed by a number sets the output length, O (maximum 99).
- Any other word is converted into a word followed by a period on
the output. The output word consists of the input word truncated
or padded out with hyphens to make its length equal to O. If O
is 0, the word is output verbatim without truncating or padding.
I and O are initially set to 1. When I changes, any buffered input is
re-scanned according to the new I. EOF also terminates the last word.
"""
def __init__(self, errors='strict'):
codecs.IncrementalEncoder.__init__(self, errors)
self.reset()
def __repr__(self):
return '<SID %x>' % id(self)
def reset(self):
self.i = 1
self.o = 1
self.buffer = bytearray()
def getstate(self):
i, o = self.i ^ 1, self.o ^ 1 # so that flags = 0 after reset()
return bytes(self.buffer), i*100 + o
def setstate(self, state):
buffer, io = state
self.buffer = bytearray(buffer)
i, o = divmod(io, 100)
self.i, self.o = i ^ 1, o ^ 1
def decode(self, input, final=False):
output = ''
for b in input:
if self.i == 0: # variable-length, terminated with period
if b == ord('.'):
if self.buffer:
output += self.process_word()
else:
self.buffer.append(b)
else: # fixed-length, terminate after self.i bytes
self.buffer.append(b)
if len(self.buffer) == self.i:
output += self.process_word()
if final and self.buffer: # EOF terminates the last word
output += self.process_word()
return output
def process_word(self):
output = ''
if self.buffer[0] == ord('i'):
self.i = min(99, int(self.buffer[1:] or 0)) # set input length
elif self.buffer[0] == ord('o'):
self.o = min(99, int(self.buffer[1:] or 0)) # set output length
else:
output = self.buffer.decode('ascii')
if len(output) < self.o:
output += '-'*self.o # pad out with hyphens
if self.o:
output = output[:self.o] # truncate to output length
output += '.'
self.buffer = bytearray()
return output
class StatefulIncrementalDecoderTest(unittest.TestCase):
"""
Make sure the StatefulIncrementalDecoder actually works.
"""
test_cases = [
# I=1 fixed-length mode
(b'abcd', False, 'a.b.c.d.'),
# I=0, O=0, variable-length mode
(b'oiabcd', True, 'abcd.'),
# I=0, O=0, variable-length mode, should ignore extra periods
(b'oi...abcd...', True, 'abcd.'),
# I=0, O=6
(b'i.o6.xyz.', False, 'xyz---.'),
# I=2, O=6
(b'i.i2.o6xyz', True, 'xy----.z-----.'),
# I=0, O=3
(b'i.o3.x.xyz.toolong.', False, 'x--.xyz.too.'),
# I=6, O=3
(b'i.o3.i6.abcdefghijklmnop', True, 'abc.ghi.mno.')
]
def testDecoder(self):
# Try a few one-shot test cases.
for input, eof, output in self.test_cases:
d = StatefulIncrementalDecoder()
self.assertEquals(d.decode(input, eof), output)
# Also test an unfinished decode, followed by forcing EOF.
d = StatefulIncrementalDecoder()
self.assertEquals(d.decode(b'oiabcd'), '')
self.assertEquals(d.decode(b'', 1), 'abcd.')
class TextIOWrapperTest(unittest.TestCase): class TextIOWrapperTest(unittest.TestCase):
@ -765,6 +882,60 @@ def testSeekingToo(self):
f.readline() f.readline()
f.tell() f.tell()
def testSeekAndTell(self):
"""Test seek/tell using the StatefulIncrementalDecoder."""
def lookupTestDecoder(name):
if self.codecEnabled and name == 'test_decoder':
return codecs.CodecInfo(
name='test_decoder', encode=None, decode=None,
incrementalencoder=None,
streamreader=None, streamwriter=None,
incrementaldecoder=StatefulIncrementalDecoder)
def testSeekAndTellWithData(data, min_pos=0):
"""Tell/seek to various points within a data stream and ensure
that the decoded data returned by read() is consistent."""
f = io.open(test_support.TESTFN, 'wb')
f.write(data)
f.close()
f = io.open(test_support.TESTFN, encoding='test_decoder')
decoded = f.read()
f.close()
for i in range(min_pos, len(decoded) + 1): # seek positions
for j in [1, 5, len(decoded) - i]: # read lengths
f = io.open(test_support.TESTFN, encoding='test_decoder')
self.assertEquals(f.read(i), decoded[:i])
cookie = f.tell()
self.assertEquals(f.read(j), decoded[i:i + j])
f.seek(cookie)
self.assertEquals(f.read(), decoded[i:])
f.close()
# Register a special incremental decoder for testing.
codecs.register(lookupTestDecoder)
self.codecEnabled = 1
# Run the tests.
try:
# Try each test case.
for input, _, _ in StatefulIncrementalDecoderTest.test_cases:
testSeekAndTellWithData(input)
# Position each test case so that it crosses a chunk boundary.
CHUNK_SIZE = io.TextIOWrapper._CHUNK_SIZE
for input, _, _ in StatefulIncrementalDecoderTest.test_cases:
offset = CHUNK_SIZE - len(input)//2
prefix = b'.'*offset
# Don't bother seeking into the prefix (takes too long).
min_pos = offset*2
testSeekAndTellWithData(prefix + input, min_pos)
# Ensure our test decoder won't interfere with subsequent tests.
finally:
self.codecEnabled = 0
def testEncodedWrites(self): def testEncodedWrites(self):
data = "1234567890" data = "1234567890"
tests = ("utf-16", tests = ("utf-16",