diff options
| -rwxr-xr-x | tests/test_storage.py | 99 |
1 files changed, 58 insertions, 41 deletions
diff --git a/tests/test_storage.py b/tests/test_storage.py index dde2f40..f2b8912 100755 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -10,7 +10,6 @@ from pypaste.server.sqlite import Sqlite from pypaste.server.s3 import S3 from datetime import datetime from pathlib import Path -from typing import List def truncate(path: Path) -> None: @@ -67,37 +66,46 @@ async def test_insert_retrieve(storage: Storage) -> None: assert paste.text == "hello world" -async def main() -> int: - stores: List[Storage] = [] - - with tempfile.TemporaryDirectory() as tmpdir: - f = Path(tmpdir) / "database" - truncate(f) - async with aiosqlite.connect(f) as connection: - await connection.execute( - ( - "create table pastes(" - "key blob," - "key_length int," - "datetime text," - "size int," - "syntax text" - ")" - ) - ) +async def init_db(connection: aiosqlite.Connection) -> None: + await connection.execute( + ( + "create table pastes(" + "key blob," + "key_length int," + "datetime text," + "size int," + "syntax text" + ")" + ) + ) + + +async def test_sqlite(tests): + for test in tests: + with tempfile.TemporaryDirectory() as tmpdir: + database = Path(tmpdir) / "pastes.sqlite" + truncate(database) + + async with aiosqlite.connect(database) as connection: + await init_db(connection) + + sqlite = Sqlite(connection) + + await sqlite.setup() - sqlite_storage = Sqlite(connection) - await sqlite_storage.setup() - stores.append(sqlite_storage) + await test(sqlite) - try: - os.environ["PYPASTE_TEST_S3"] - test_s3 = True - except KeyError: - test_s3 = False - if test_s3: - s3_storage = S3( +async def test_s3(tests): + for test in tests: + with tempfile.TemporaryDirectory() as tmpdir: + database = Path(tmpdir) / "pastes.sqlite" + truncate(database) + + async with aiosqlite.connect(database) as connection: + await init_db(connection) + + s3 = S3( connection, os.environ["PYPASTE_TEST_ENDPOINT"], os.environ["PYPASTE_TEST_REGION"], @@ -105,18 +113,27 @@ async def main() -> int: os.environ["PYPASTE_TEST_ACCESS_KEY"], os.environ["PYPASTE_TEST_SECRET_KEY"], ) - await s3_storage.setup() - stores.append(s3_storage) - - for store in stores: - await asyncio.gather( - test_insert_retrieve(store), - test_insert_retrieve(store), - test_delete(store), - test_delete(store), - test_exists_but_not_in_our_table(store), - test_exists_but_not_in_our_table(store), - ) + + await s3.setup() + + await test(s3) + + +async def main() -> int: + tests = [ + test_insert_retrieve, + test_delete, + test_exists, + test_exists_but_not_in_our_table, + ] + + await test_sqlite(tests) + + try: + os.environ["PYPASTE_TEST_S3"] + await test_s3(tests) + except KeyError: + pass return 0 |
