diff options
-rw-r--r-- | gemato/compression.py | 51 | ||||
-rw-r--r-- | tests/test_compression.py | 91 |
2 files changed, 142 insertions, 0 deletions
diff --git a/gemato/compression.py b/gemato/compression.py index c15de0c..62d2977 100644 --- a/gemato/compression.py +++ b/gemato/compression.py @@ -4,6 +4,8 @@ # Licensed under the terms of 2-clause BSD license import gzip +import io +import os.path import sys if sys.version_info >= (3, 3): @@ -46,3 +48,52 @@ def open_compressed_file(suffix, f, mode='r'): return lzma.LZMAFile(f, format=lzma.FORMAT_XZ, mode=mode) raise gemato.exceptions.UnsupportedCompression(suffix) + + +class FileStack(object): + """ + A context manager for stacked files. Maintains handles for all files + on stack, returns the topmost (last) layer on enter and closes them + all on exit. + """ + + def __init__(self, files=[]): + self.files = files + + def __enter__(self): + return self.files[-1] + + def __exit__(self, exc_type, exc_value, exc_cb): + self.close() + + def close(self): + for f in reversed(self.files): + f.close() + + +def open_potentially_compressed_path(path, mode): + """ + Open the potentially compressed file at specified path @path + with mode @mode. If the path ends with one of the known compression + suffixes, the file will be decompressed transparently. Otherwise, + it will be open directly. + + Returns an object that must be used via the context manager API. + """ + + base, ext = os.path.splitext(path) + if ext not in ('.gz', '.bz2', '.lzma', '.xz'): + return io.open(path, mode) + + bmode = mode + if 'b' not in bmode: + bmode += 'b' + + f = io.open(path, bmode) + try: + cf = open_compressed_file(ext[1:], f, mode) + except: + f.close() + raise + + return FileStack((f, cf)) diff --git a/tests/test_compression.py b/tests/test_compression.py index a81e80f..0870a1d 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -5,6 +5,7 @@ import base64 import io +import tempfile import unittest import gemato.compression @@ -53,6 +54,24 @@ L0stUijJSFXISayqVEjJTwcAlGd4GBcAAAA= with gemato.compression.open_compressed_file('gz', f, 'rb') as gz: self.assertEqual(gz.read(), TEST_STRING) + def test_open_potentially_compressed_path(self): + with tempfile.NamedTemporaryFile(suffix='.gz') as wf: + wf.write(base64.b64decode(self.BASE64)) + wf.flush() + + with gemato.compression.open_potentially_compressed_path( + wf.name, 'rb') as cf: + self.assertEqual(cf.read(), TEST_STRING) + + def test_open_potentially_compressed_path_write(self): + with tempfile.NamedTemporaryFile(suffix='.gz') as rf: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'wb') as cf: + cf.write(TEST_STRING) + + with gemato.compression.open_compressed_file('gz', rf, 'rb') as gz: + self.assertEqual(gz.read(), TEST_STRING) + class Bzip2CompressionTest(unittest.TestCase): BASE64 = b''' @@ -107,6 +126,30 @@ OxleaA== except gemato.exceptions.UnsupportedCompression: raise unittest.SkipTest('bz2 compression unsupported') + def test_open_potentially_compressed_path(self): + with tempfile.NamedTemporaryFile(suffix='.bz2') as wf: + wf.write(base64.b64decode(self.BASE64)) + wf.flush() + + try: + with gemato.compression.open_potentially_compressed_path( + wf.name, 'rb') as cf: + self.assertEqual(cf.read(), TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('bz2 compression unsupported') + + def test_open_potentially_compressed_path_write(self): + with tempfile.NamedTemporaryFile(suffix='.bz2') as rf: + try: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'wb') as cf: + cf.write(TEST_STRING) + + with gemato.compression.open_compressed_file('bz2', rf, 'rb') as bz2: + self.assertEqual(bz2.read(), TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('bz2 compression unsupported') + class LZMALegacyCompressionTest(unittest.TestCase): BASE64 = b''' @@ -172,6 +215,30 @@ ADUdSd6zBOkOpekGFH46zix9wE9VT65OVeV479//7uUAAA== with gemato.compression.open_compressed_file('xz', f, "rb") as xz: xz.read() + def test_open_potentially_compressed_path(self): + with tempfile.NamedTemporaryFile(suffix='.lzma') as wf: + wf.write(base64.b64decode(self.BASE64)) + wf.flush() + + try: + with gemato.compression.open_potentially_compressed_path( + wf.name, 'rb') as cf: + self.assertEqual(cf.read(), TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('lzma compression unsupported') + + def test_open_potentially_compressed_path_write(self): + with tempfile.NamedTemporaryFile(suffix='.lzma') as rf: + try: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'wb') as cf: + cf.write(TEST_STRING) + + with gemato.compression.open_compressed_file('lzma', rf, 'rb') as lzma: + self.assertEqual(lzma.read(), TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('lzma compression unsupported') + class XZCompressionTest(unittest.TestCase): BASE64 = b''' @@ -237,3 +304,27 @@ dGhlIGxhenkgZG9nAADjZCTmHjHqggABLxeBCEmxH7bzfQEAAAAABFla with self.assertRaises(gemato.compression.lzma.LZMAError): with gemato.compression.open_compressed_file('lzma', f, "rb") as lzma: lzma.read() + + def test_open_potentially_compressed_path(self): + with tempfile.NamedTemporaryFile(suffix='.xz') as wf: + wf.write(base64.b64decode(self.BASE64)) + wf.flush() + + try: + with gemato.compression.open_potentially_compressed_path( + wf.name, 'rb') as cf: + self.assertEqual(cf.read(), TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('xz compression unsupported') + + def test_open_potentially_compressed_path_write(self): + with tempfile.NamedTemporaryFile(suffix='.xz') as rf: + try: + with gemato.compression.open_potentially_compressed_path( + rf.name, 'wb') as cf: + cf.write(TEST_STRING) + + with gemato.compression.open_compressed_file('xz', rf, 'rb') as xz: + self.assertEqual(xz.read(), TEST_STRING) + except gemato.exceptions.UnsupportedCompression: + raise unittest.SkipTest('xz compression unsupported') |