summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gemato/compression.py21
-rw-r--r--tests/test_compression.py88
2 files changed, 104 insertions, 5 deletions
diff --git a/gemato/compression.py b/gemato/compression.py
index 62d2977..023f308 100644
--- a/gemato/compression.py
+++ b/gemato/compression.py
@@ -71,29 +71,40 @@ class FileStack(object):
f.close()
-def open_potentially_compressed_path(path, mode):
+def open_potentially_compressed_path(path, mode, **kwargs):
"""
Open the potentially compressed file at specified path @path
with mode @mode. If the path ends with one of the known compression
suffixes, the file will be decompressed transparently. Otherwise,
it will be open directly.
+ @kwargs can be used to pass additional options for text files.
+ Only arguments supported by io.TextIOWrapper should be used there.
+
Returns an object that must be used via the context manager API.
"""
base, ext = os.path.splitext(path)
if ext not in ('.gz', '.bz2', '.lzma', '.xz'):
- return io.open(path, mode)
+ return io.open(path, mode, **kwargs)
bmode = mode
if 'b' not in bmode:
bmode += 'b'
f = io.open(path, bmode)
+ fs = FileStack([f])
try:
- cf = open_compressed_file(ext[1:], f, mode)
+ cf = open_compressed_file(ext[1:], f, bmode if kwargs else mode)
+ fs.files.append(cf)
+
+ # special args are not supported by compressor backends
+ # so add a TextIOWrapper on top
+ if kwargs:
+ iow = io.TextIOWrapper(cf, **kwargs)
+ fs.files.append(iow)
except:
- f.close()
+ fs.close()
raise
- return FileStack((f, cf))
+ return fs
diff --git a/tests/test_compression.py b/tests/test_compression.py
index e0deaf3..97b319a 100644
--- a/tests/test_compression.py
+++ b/tests/test_compression.py
@@ -12,6 +12,8 @@ import gemato.compression
TEST_STRING = b'The quick brown fox jumps over the lazy dog'
+# we need to be specific on endianness to avoid unreliably writing BOM
+UTF16_TEST_STRING = TEST_STRING.decode('utf8').encode('utf_16_be')
class GzipCompressionTest(unittest.TestCase):
@@ -72,6 +74,25 @@ L0stUijJSFXISayqVEjJTwcAlGd4GBcAAAA=
with gemato.compression.open_compressed_file('gz', rf, 'rb') as gz:
self.assertEqual(gz.read(), TEST_STRING)
+ def test_open_potentially_compressed_path_with_encoding(self):
+ with tempfile.NamedTemporaryFile(suffix='.gz') as wf:
+ with gemato.compression.open_compressed_file('gz', wf, 'wb') as gz:
+ gz.write(UTF16_TEST_STRING)
+ wf.flush()
+
+ with gemato.compression.open_potentially_compressed_path(
+ wf.name, 'r', encoding='utf_16_be') as cf:
+ self.assertEqual(cf.read(), TEST_STRING.decode('utf8'))
+
+ def test_open_potentially_compressed_path_write_with_encoding(self):
+ with tempfile.NamedTemporaryFile(suffix='.gz') as rf:
+ with gemato.compression.open_potentially_compressed_path(
+ rf.name, 'w', encoding='utf_16_be') as cf:
+ cf.write(TEST_STRING.decode('utf8'))
+
+ with gemato.compression.open_compressed_file('gz', rf, 'rb') as gz:
+ self.assertEqual(gz.read(), UTF16_TEST_STRING)
+
class Bzip2CompressionTest(unittest.TestCase):
BASE64 = b'''
@@ -150,6 +171,31 @@ OxleaA==
except gemato.exceptions.UnsupportedCompression:
raise unittest.SkipTest('bz2 compression unsupported')
+ def test_open_potentially_compressed_path_with_encoding(self):
+ with tempfile.NamedTemporaryFile(suffix='.bz2') as wf:
+ try:
+ with gemato.compression.open_compressed_file('bz2', wf, 'wb') as bz2:
+ bz2.write(UTF16_TEST_STRING)
+ wf.flush()
+
+ with gemato.compression.open_potentially_compressed_path(
+ wf.name, 'r', encoding='utf_16_be') as cf:
+ self.assertEqual(cf.read(), TEST_STRING.decode('utf8'))
+ except gemato.exceptions.UnsupportedCompression:
+ raise unittest.SkipTest('bz2 compression unsupported')
+
+ def test_open_potentially_compressed_path_write_with_encoding(self):
+ with tempfile.NamedTemporaryFile(suffix='.bz2') as rf:
+ try:
+ with gemato.compression.open_potentially_compressed_path(
+ rf.name, 'w', encoding='utf_16_be') as cf:
+ cf.write(TEST_STRING.decode('utf8'))
+
+ with gemato.compression.open_compressed_file('bz2', rf, 'rb') as bz2:
+ self.assertEqual(bz2.read(), UTF16_TEST_STRING)
+ except gemato.exceptions.UnsupportedCompression:
+ raise unittest.SkipTest('bz2 compression unsupported')
+
class LZMALegacyCompressionTest(unittest.TestCase):
BASE64 = b'''
@@ -239,6 +285,31 @@ ADUdSd6zBOkOpekGFH46zix9wE9VT65OVeV479//7uUAAA==
except gemato.exceptions.UnsupportedCompression:
raise unittest.SkipTest('lzma compression unsupported')
+ def test_open_potentially_compressed_path_with_encoding(self):
+ with tempfile.NamedTemporaryFile(suffix='.lzma') as wf:
+ try:
+ with gemato.compression.open_compressed_file('lzma', wf, 'wb') as lzma:
+ lzma.write(UTF16_TEST_STRING)
+ wf.flush()
+
+ with gemato.compression.open_potentially_compressed_path(
+ wf.name, 'r', encoding='utf_16_be') as cf:
+ self.assertEqual(cf.read(), TEST_STRING.decode('utf8'))
+ except gemato.exceptions.UnsupportedCompression:
+ raise unittest.SkipTest('lzma compression unsupported')
+
+ def test_open_potentially_compressed_path_write_with_encoding(self):
+ with tempfile.NamedTemporaryFile(suffix='.lzma') as rf:
+ try:
+ with gemato.compression.open_potentially_compressed_path(
+ rf.name, 'w', encoding='utf_16_be') as cf:
+ cf.write(TEST_STRING.decode('utf8'))
+
+ with gemato.compression.open_compressed_file('lzma', rf, 'rb') as lzma:
+ self.assertEqual(lzma.read(), UTF16_TEST_STRING)
+ except gemato.exceptions.UnsupportedCompression:
+ raise unittest.SkipTest('lzma compression unsupported')
+
class XZCompressionTest(unittest.TestCase):
BASE64 = b'''
@@ -351,3 +422,20 @@ class NoCompressionTest(unittest.TestCase):
cf.write(TEST_STRING)
self.assertEqual(rf.read(), TEST_STRING)
+
+ def test_open_potentially_compressed_path_with_encoding(self):
+ with tempfile.NamedTemporaryFile() as wf:
+ wf.write(UTF16_TEST_STRING)
+ wf.flush()
+
+ with gemato.compression.open_potentially_compressed_path(
+ wf.name, 'r', encoding='utf_16_be') as cf:
+ self.assertEqual(cf.read(), TEST_STRING.decode('utf8'))
+
+ def test_open_potentially_compressed_path_write_with_encoding(self):
+ with tempfile.NamedTemporaryFile() as rf:
+ with gemato.compression.open_potentially_compressed_path(
+ rf.name, 'w', encoding='utf_16_be') as cf:
+ cf.write(TEST_STRING.decode('utf8'))
+
+ self.assertEqual(rf.read(), UTF16_TEST_STRING)