From bd4491d09b69159c89b53fe930f4b4339a1dc042 Mon Sep 17 00:00:00 2001 From: Michał Górny Date: Wed, 25 Oct 2017 22:24:09 +0200 Subject: compression: Support passing text-mode kwargs to open --- gemato/compression.py | 21 ++++++++--- tests/test_compression.py | 88 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/gemato/compression.py b/gemato/compression.py index 62d2977..023f308 100644 --- a/gemato/compression.py +++ b/gemato/compression.py @@ -71,29 +71,40 @@ class FileStack(object): f.close() -def open_potentially_compressed_path(path, mode): +def open_potentially_compressed_path(path, mode, **kwargs): """ Open the potentially compressed file at specified path @path with mode @mode. If the path ends with one of the known compression suffixes, the file will be decompressed transparently. Otherwise, it will be open directly. + @kwargs can be used to pass additional options for text files. + Only arguments supported by io.TextIOWrapper should be used there. + Returns an object that must be used via the context manager API. """ base, ext = os.path.splitext(path) if ext not in ('.gz', '.bz2', '.lzma', '.xz'): - return io.open(path, mode) + return io.open(path, mode, **kwargs) bmode = mode if 'b' not in bmode: bmode += 'b' f = io.open(path, bmode) + fs = FileStack([f]) try: - cf = open_compressed_file(ext[1:], f, mode) + cf = open_compressed_file(ext[1:], f, bmode if kwargs else mode) + fs.files.append(cf) + + # special args are not supported by compressor backends + # so add a TextIOWrapper on top + if kwargs: + iow = io.TextIOWrapper(cf, **kwargs) + fs.files.append(iow) except: - f.close() + fs.close() raise - return FileStack((f, cf)) + return fs diff --git a/tests/test_compression.py b/tests/test_compression.py index e0deaf3..97b319a 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -12,6 +12,8 @@ import gemato.compression TEST_STRING = b'The quick brown fox jumps over the lazy dog' +# we need to be specific on endianness to avoid unreliably writing BOM +UTF16_TEST_STRING = TEST_STRING.decode('utf8').encode('utf_16_be') class GzipCompressionTest(unittest.TestCase): @@ -72,6 +74,25 @@ L0stUijJSFXISayqVEjJTwcAlGd4GBcAAAA= with gemato.compression.open_compressed_file('gz', rf, 'rb') as gz: self.assertEqual(gz.read(), TEST_STRING) + def test_open_potentially_compressed_path_with_encoding(self): + with tempfile.NamedTemporaryFile(suffix='.gz') as wf: + with gemato.compression.open_compressed_file('gz', wf, 'wb') as gz: + gz.write(UTF16_TEST_STRING) + wf.flush() + + with gemato.compression.open_potentially_compressed_path( + wf.name, 'r', encoding='utf_16_be') as cf: + self.assertEqual(cf.read(), TEST_STRING.decode('utf8')) + + def test_open_potentially_compressed_path_write_with_encoding(self): + with tempfile.NamedTemporaryFile(suffix='.gz') as rf: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'w', encoding='utf_16_be') as cf: + cf.write(TEST_STRING.decode('utf8')) + + with gemato.compression.open_compressed_file('gz', rf, 'rb') as gz: + self.assertEqual(gz.read(), UTF16_TEST_STRING) + class Bzip2CompressionTest(unittest.TestCase): BASE64 = b''' @@ -150,6 +171,31 @@ OxleaA== except gemato.exceptions.UnsupportedCompression: raise unittest.SkipTest('bz2 compression unsupported') + def test_open_potentially_compressed_path_with_encoding(self): + with tempfile.NamedTemporaryFile(suffix='.bz2') as wf: + try: + with gemato.compression.open_compressed_file('bz2', wf, 'wb') as bz2: + bz2.write(UTF16_TEST_STRING) + wf.flush() + + with gemato.compression.open_potentially_compressed_path( + wf.name, 'r', encoding='utf_16_be') as cf: + self.assertEqual(cf.read(), TEST_STRING.decode('utf8')) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('bz2 compression unsupported') + + def test_open_potentially_compressed_path_write_with_encoding(self): + with tempfile.NamedTemporaryFile(suffix='.bz2') as rf: + try: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'w', encoding='utf_16_be') as cf: + cf.write(TEST_STRING.decode('utf8')) + + with gemato.compression.open_compressed_file('bz2', rf, 'rb') as bz2: + self.assertEqual(bz2.read(), UTF16_TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('bz2 compression unsupported') + class LZMALegacyCompressionTest(unittest.TestCase): BASE64 = b''' @@ -239,6 +285,31 @@ ADUdSd6zBOkOpekGFH46zix9wE9VT65OVeV479//7uUAAA== except gemato.exceptions.UnsupportedCompression: raise unittest.SkipTest('lzma compression unsupported') + def test_open_potentially_compressed_path_with_encoding(self): + with tempfile.NamedTemporaryFile(suffix='.lzma') as wf: + try: + with gemato.compression.open_compressed_file('lzma', wf, 'wb') as lzma: + lzma.write(UTF16_TEST_STRING) + wf.flush() + + with gemato.compression.open_potentially_compressed_path( + wf.name, 'r', encoding='utf_16_be') as cf: + self.assertEqual(cf.read(), TEST_STRING.decode('utf8')) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('lzma compression unsupported') + + def test_open_potentially_compressed_path_write_with_encoding(self): + with tempfile.NamedTemporaryFile(suffix='.lzma') as rf: + try: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'w', encoding='utf_16_be') as cf: + cf.write(TEST_STRING.decode('utf8')) + + with gemato.compression.open_compressed_file('lzma', rf, 'rb') as lzma: + self.assertEqual(lzma.read(), UTF16_TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('lzma compression unsupported') + class XZCompressionTest(unittest.TestCase): BASE64 = b''' @@ -351,3 +422,20 @@ class NoCompressionTest(unittest.TestCase): cf.write(TEST_STRING) self.assertEqual(rf.read(), TEST_STRING) + + def test_open_potentially_compressed_path_with_encoding(self): + with tempfile.NamedTemporaryFile() as wf: + wf.write(UTF16_TEST_STRING) + wf.flush() + + with gemato.compression.open_potentially_compressed_path( + wf.name, 'r', encoding='utf_16_be') as cf: + self.assertEqual(cf.read(), TEST_STRING.decode('utf8')) + + def test_open_potentially_compressed_path_write_with_encoding(self): + with tempfile.NamedTemporaryFile() as rf: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'w', encoding='utf_16_be') as cf: + cf.write(TEST_STRING.decode('utf8')) + + self.assertEqual(rf.read(), UTF16_TEST_STRING) -- cgit v1.2.3