summaryrefslogtreecommitdiff
path: root/paste
diff options
context:
space:
mode:
Diffstat (limited to 'paste')
-rw-r--r--paste/__main__.py79
1 files changed, 27 insertions, 52 deletions
diff --git a/paste/__main__.py b/paste/__main__.py
index 039210c..5ff869a 100644
--- a/paste/__main__.py
+++ b/paste/__main__.py
@@ -15,9 +15,9 @@
import sys
import asyncio
-import sqlite3
import secrets
import aiohttp
+import aiosqlite
import zstandard
from . import s3
from hashlib import sha256
@@ -74,12 +74,12 @@ class PasteRow:
class Database:
- def __init__(self, connection: sqlite3.Connection):
- self.connection = connection
+ def __init__(self, path: Path):
+ self.path = path
async def insert(self, paste: PasteRow):
- def do_insert():
- self.connection.cursor().execute(
+ async with aiosqlite.connect(self.path) as db:
+ await db.execute(
"insert into pastes values(?, ?, ?, ?)",
(
paste.key,
@@ -89,60 +89,37 @@ class Database:
),
)
- self.connection.commit()
-
- await asyncio.to_thread(do_insert)
+ await db.commit()
async def delete(self, key: str):
- def do_delete():
- self.connection.cursor().execute("delete pastes where pastes.key=?", (key,))
-
- self.connection.commit()
+ async with aiosqlite.connect(self.path) as db:
+ await db.execute("delete pastes where key=?", (key,))
- await asyncio.to_thread(do_delete)
+ await db.commit()
async def exists(self, key: str) -> bool:
- def do_exists():
- return (
- self.connection.cursor()
- .execute("select * from pastes where pastes.key=?", (key,))
- .fetchone()
- is not None
- )
+ async with aiosqlite.connect(self.path) as db:
+ result = await db.execute("select 1 from pastes where pastes.key=?", (key,))
- return await asyncio.to_thread(do_exists())
+ return await result.fetchone() is not None
async def oldest(self) -> Optional[str]:
- def do_oldest():
- return (
- self.connection.cursor()
- .execute(
- "select pastes.key from pastes order by pastes.datetime limit 1"
- )
- .fetchone()
+ async with aiosqlite.connect(self.path) as db:
+ result = await db.execute(
+ "select pastes.key from pastes order by pastes.datetime limit 1",
)
- result = await asyncio.to_thread(do_oldest)
+ row = await result.fetchone()
- if result is not None:
- return result[0]
- else:
- return None
+ return row[0]
async def storage_use(self) -> Optional[int]:
- def do_storage_use():
- return (
- self.connection.cursor()
- .execute("select sum(pastes.size) from pastes")
- .fetchone()
- )
+ async with aiosqlite.connect(self.path) as db:
+ result = await db.execute("select sum(pastes.size) from pastes")
- result = await asyncio.to_thread(do_storage_use)
+ row = await result.fetchone()
- if result is not None:
- return result[0]
- else:
- return None
+ return row[0]
@dataclass
@@ -295,9 +272,9 @@ class App:
async def vacuum(self):
while (
- use := self.database.storage_use()
+ use := await self.database.storage_use()
) is not None and use > self.config.s3_max_bytes:
- oldest = self.database.oldest()
+ oldest = await self.database.oldest()
# If use is not None, there must be at least 1 paste, so we let the
# type checker know that this is an unreachable case.
assert oldest is not None
@@ -312,7 +289,7 @@ class App:
f"failed to delete {oldest} from bucket with status: {delete.status}"
)
- self.database.delete(oldest)
+ await self.database.delete(oldest)
def main() -> int:
@@ -370,14 +347,12 @@ def main() -> int:
line_numbers,
)
- try:
- connection = sqlite3.connect(args.database)
- except Exception as e:
- print(f"failed to connect to database: {args.database}: {e}")
+ if not args.database.is_file():
+ print(f"{args.database} does not exist or is not a file", file=sys.stderr)
return 1
try:
- database = Database(connection)
+ database = Database(args.database)
except Exception as e:
print(f"failed to open database: {e}", file=sys.stderr)
return 1