summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtests/test_storage.py99
1 files changed, 58 insertions, 41 deletions
diff --git a/tests/test_storage.py b/tests/test_storage.py
index dde2f40..f2b8912 100755
--- a/tests/test_storage.py
+++ b/tests/test_storage.py
@@ -10,7 +10,6 @@ from pypaste.server.sqlite import Sqlite
from pypaste.server.s3 import S3
from datetime import datetime
from pathlib import Path
-from typing import List
def truncate(path: Path) -> None:
@@ -67,37 +66,46 @@ async def test_insert_retrieve(storage: Storage) -> None:
assert paste.text == "hello world"
-async def main() -> int:
- stores: List[Storage] = []
-
- with tempfile.TemporaryDirectory() as tmpdir:
- f = Path(tmpdir) / "database"
- truncate(f)
- async with aiosqlite.connect(f) as connection:
- await connection.execute(
- (
- "create table pastes("
- "key blob,"
- "key_length int,"
- "datetime text,"
- "size int,"
- "syntax text"
- ")"
- )
- )
+async def init_db(connection: aiosqlite.Connection) -> None:
+ await connection.execute(
+ (
+ "create table pastes("
+ "key blob,"
+ "key_length int,"
+ "datetime text,"
+ "size int,"
+ "syntax text"
+ ")"
+ )
+ )
+
+
+async def test_sqlite(tests):
+ for test in tests:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ database = Path(tmpdir) / "pastes.sqlite"
+ truncate(database)
+
+ async with aiosqlite.connect(database) as connection:
+ await init_db(connection)
+
+ sqlite = Sqlite(connection)
+
+ await sqlite.setup()
- sqlite_storage = Sqlite(connection)
- await sqlite_storage.setup()
- stores.append(sqlite_storage)
+ await test(sqlite)
- try:
- os.environ["PYPASTE_TEST_S3"]
- test_s3 = True
- except KeyError:
- test_s3 = False
- if test_s3:
- s3_storage = S3(
+async def test_s3(tests):
+ for test in tests:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ database = Path(tmpdir) / "pastes.sqlite"
+ truncate(database)
+
+ async with aiosqlite.connect(database) as connection:
+ await init_db(connection)
+
+ s3 = S3(
connection,
os.environ["PYPASTE_TEST_ENDPOINT"],
os.environ["PYPASTE_TEST_REGION"],
@@ -105,18 +113,27 @@ async def main() -> int:
os.environ["PYPASTE_TEST_ACCESS_KEY"],
os.environ["PYPASTE_TEST_SECRET_KEY"],
)
- await s3_storage.setup()
- stores.append(s3_storage)
-
- for store in stores:
- await asyncio.gather(
- test_insert_retrieve(store),
- test_insert_retrieve(store),
- test_delete(store),
- test_delete(store),
- test_exists_but_not_in_our_table(store),
- test_exists_but_not_in_our_table(store),
- )
+
+ await s3.setup()
+
+ await test(s3)
+
+
+async def main() -> int:
+ tests = [
+ test_insert_retrieve,
+ test_delete,
+ test_exists,
+ test_exists_but_not_in_our_table,
+ ]
+
+ await test_sqlite(tests)
+
+ try:
+ os.environ["PYPASTE_TEST_S3"]
+ await test_s3(tests)
+ except KeyError:
+ pass
return 0