summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Turner <jturner.usa@gmail.com>2025-09-09 19:29:48 -0400
committerJohn Turner <jturner.usa@gmail.com>2025-09-09 19:33:07 -0400
commite6db914af092e4ddcfd865e8d23852579755d224 (patch)
tree41b43398b9f31d265bbc81980c9b7bfd605339af
parentfa39a16ee4592514877530f2e4ec58c3dcf2c12c (diff)
downloadpypaste-e6db914af092e4ddcfd865e8d23852579755d224.tar.gz
rewrite to support modular storage backends
-rw-r--r--meson.build9
-rw-r--r--pypaste/__init__.py199
-rw-r--r--pypaste/__main__.py78
-rw-r--r--pypaste/database.py95
-rw-r--r--pypaste/meson.build3
-rw-r--r--pypaste/s3/__init__.py126
-rw-r--r--pypaste/s3/bucket.py (renamed from pypaste/s3.py)46
-rw-r--r--pypaste/s3/meson.build1
8 files changed, 290 insertions, 267 deletions
diff --git a/meson.build b/meson.build
index 9bfcaca..87ad91c 100644
--- a/meson.build
+++ b/meson.build
@@ -4,11 +4,8 @@ python = import('python').find_installation(
modules: ['pygments', 'zstandard', 'aiohttp', 'bozo4'],
)
-sources = files(
- 'pypaste/__init__.py',
- 'pypaste/__main__.py',
- 'pypaste/database.py',
- 'pypaste/s3.py',
-)
+sources = []
+
+subdir('pypaste')
python.install_sources(sources, preserve_path: true)
diff --git a/pypaste/__init__.py b/pypaste/__init__.py
index b5b26ff..0206e9e 100644
--- a/pypaste/__init__.py
+++ b/pypaste/__init__.py
@@ -16,19 +16,16 @@
import sys
import asyncio
import secrets
-import aiohttp
-import zstandard
-from pypaste import s3
-from pypaste.database import Database, PasteRow
-from hashlib import sha256
+import aiosqlite
from aiohttp import web
-from datetime import datetime, UTC
+from datetime import datetime
from dataclasses import dataclass
-from typing import Optional, List
+from typing import Optional, List, Tuple
from pygments import highlight
from pygments.lexers import guess_lexer, get_lexer_by_name
from pygments.formatters import HtmlFormatter
from pygments.styles import get_style_by_name
+from abc import abstractmethod
RESET = "\x1b[0m"
RED = "\x1b[31m"
@@ -83,10 +80,58 @@ def generate_key(words: List[str], length: int) -> str:
@dataclass
+class Paste:
+ key: str
+ dt: datetime
+ syntax: Optional[str]
+ text: str
+
+
+@dataclass
+class Storage:
+ connection: aiosqlite.Connection
+
+ @abstractmethod
+ async def insert(self, paste: Paste) -> None:
+ pass
+
+ @abstractmethod
+ async def retrieve(self, key: str) -> Optional[Paste]:
+ pass
+
+ @abstractmethod
+ async def delete(self, key) -> None:
+ pass
+
+ @abstractmethod
+ async def vacuum(self, size: int) -> None:
+ pass
+
+ async def read_row(self, key: str) -> Optional[Tuple[datetime, int, Optional[str]]]:
+ async with self.connection.execute(
+ "select pastes.datetime,pastes.size,pastes.syntax from pastes where pastes.key=? limit 1",
+ (key,),
+ ) as cursor:
+ match await cursor.fetchone():
+ case [str(dt), int(size), syntax]:
+ return (datetime.fromisoformat(dt), size, syntax)
+ case None:
+ return None
+ case _:
+ raise Exception("unreachable")
+
+ async def exists(self, key: str) -> bool:
+ async with self.connection.execute(
+ "select 1 from pastes where key=?", (key,)
+ ) as cursor:
+ return await cursor.fetchone() is not None
+
+
+@dataclass
class AppConfig:
site: str
- content_length_max_bytes: int
- s3_max_bytes: int
+ content_max_bytes: int
+ storage_max_bytes: int
key_length: int
dictionary: List[str]
default_style: str
@@ -95,15 +140,9 @@ class AppConfig:
class App:
- def __init__(
- self,
- config: AppConfig,
- database: Database,
- bucket: s3.Bucket,
- ):
- self.database = database
+ def __init__(self, config: AppConfig, storage: Storage):
self.config = config
- self.bucket = bucket
+ self.storage = storage
async def download(self, request: web.Request) -> web.Response:
try:
@@ -112,45 +151,15 @@ class App:
return web.HTTPBadRequest(text="provide a key to fetch")
try:
- exists = await self.database.exists(key)
+ paste = await self.storage.retrieve(key)
except Exception as e:
- log_error(f"failed to check if key exists in database: {e}")
+ log_error(f"failed to retrieve paste {key}: {e}")
return web.HTTPInternalServerError()
- if not exists:
+ if paste is None:
+ log_info(f"{key} does not exist, returning 404")
return web.HTTPNotFound()
- req = self.bucket.get(key)
-
- try:
- async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get:
- if get.status == 200:
- data = await get.read()
- else:
- log_error(
- f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}"
- )
-
- return web.HTTPInternalServerError()
- except Exception as e:
- log_error(f"failed to get {key} from s3: {e}")
- return web.HTTPInternalServerError()
-
- def decompress():
- return zstandard.decompress(data)
-
- try:
- decompressed = await asyncio.to_thread(decompress)
- except Exception as e:
- log_error(f"failed to decompress blob {key}: {e}")
- return web.HTTPInternalServerError()
-
- try:
- text = decompressed.decode()
- except Exception as e:
- log_error(f"failed to decode blob: {key}: {e}")
- return web.HTTPInternalServerError()
-
syntax = request.query.get("syntax")
raw = request.query.get("raw")
@@ -160,11 +169,11 @@ class App:
if raw is not None:
log_info(f"sending raw paste {key}")
- return web.HTTPOk(text=text, content_type="text/plain")
+ return web.HTTPOk(text=paste.text, content_type="text/plain")
else:
def render():
- return pygmentize(text, syntax, style, self.config.line_numbers)
+ return pygmentize(paste.text, syntax, style, self.config.line_numbers)
highlighted = await asyncio.to_thread(render)
@@ -177,19 +186,12 @@ class App:
async def upload(self, request: web.Request) -> web.Response:
syntax = request.query.get("syntax")
- try:
- await self.vacuum()
- except Exception as e:
- log_error(f"vacuum failed: {e}")
- return web.HTTPInternalServerError()
-
- match request.content_length:
- case int(i) if i > self.config.content_length_max_bytes:
- return web.HTTPBadRequest(
- text=f"max content length is {self.config.content_length_max_bytes}"
- )
- case _:
- pass
+ if (
+ content_length := request.content_length
+ ) and content_length > self.config.content_max_bytes:
+ return web.HTTPBadRequest(
+ text=f"max content length is {self.config.content_max_bytes}"
+ )
try:
data = await request.read()
@@ -198,80 +200,25 @@ class App:
return web.HTTPInternalServerError(text="failed to read data")
try:
- data.decode()
+ text = data.decode()
except UnicodeError:
return web.HTTPBadRequest(
text="content must be unicode only, no binary data is allowed"
)
- def compress():
- return zstandard.compress(data)
-
- try:
- compressed = await asyncio.to_thread(compress)
- except Exception as e:
- log_error(f"failed to compress data: {e}")
- return web.HTTPInternalServerError()
-
key = generate_key(self.config.dictionary, self.config.key_length)
- req = self.bucket.put(
- key,
- len(compressed),
- "application/octet-stream",
- sha256(compressed).hexdigest(),
- )
-
try:
- async with aiohttp.ClientSession().put(
- req.url, headers=req.headers, data=compressed
- ) as put:
- if put.status != 200:
- log_error(
- f"failed to put {key} to bucket with status: {put.status}"
- )
- return web.HTTPInternalServerError()
+ paste = Paste(key, datetime.now(), syntax, text)
+ await self.storage.insert(paste)
except Exception as e:
- log_error(f"failed to put {key} to bucket: {e}")
- return web.HTTPInternalServerError()
-
- try:
- await self.database.insert(
- PasteRow(key, datetime.now(UTC), len(compressed), syntax)
- )
- except Exception as e:
- log_error(f"failed to insert {key} into database: {e}")
+ log_error(f"failed to insert paste {key} to storage: {e}")
return web.HTTPInternalServerError()
url = f"{self.config.site}/paste/{key}"
log_info(
- f"uploaded paste {key} with syntax {syntax} of size {len(compressed)} bytes: {url}"
+ f"uploaded paste {key} with syntax {syntax} of size {len(data)} bytes: {url}"
)
return web.HTTPOk(text=url)
-
- async def vacuum(self):
- log_info("starting vaccum")
- while (
- use := await self.database.storage_use()
- ) is not None and use > self.config.s3_max_bytes:
- oldest = await self.database.oldest()
-
- # If use is not None, there must be at least 1 paste, so we let the
- # type checker know that this is an unreachable case.
- assert oldest is not None
-
- req = self.bucket.delete(oldest)
-
- async with aiohttp.ClientSession().delete(
- req.url, headers=req.headers
- ) as delete:
- if delete.status == 200:
- log_info(f"successfully deleted {oldest}")
-
- await self.database.delete(oldest)
- else:
- log_warning(
- f"failed to delete {oldest}: got status {delete.status}",
- )
diff --git a/pypaste/__main__.py b/pypaste/__main__.py
index 374e32a..c7704d5 100644
--- a/pypaste/__main__.py
+++ b/pypaste/__main__.py
@@ -16,8 +16,9 @@
import sys
import os
import asyncio
-from pypaste.database import Database
-from pypaste import App, AppConfig, s3, log_error, log_info
+import aiosqlite
+from pypaste import App, AppConfig, log_error, log_info
+from pypaste.s3 import S3
from socket import socket, AF_UNIX, SOCK_STREAM
from argparse import ArgumentParser
from aiohttp import web
@@ -26,6 +27,8 @@ from pathlib import Path
async def main() -> int:
parser = ArgumentParser()
+ subparsers = parser.add_subparsers(dest="command")
+
parser.add_argument("--socket", required=True)
parser.add_argument("--socket-mode", default="0600")
parser.add_argument("--site", required=True)
@@ -33,26 +36,19 @@ async def main() -> int:
parser.add_argument("--key-length", type=int, required=True)
parser.add_argument("--dictionary", type=Path, required=True)
parser.add_argument("--database", type=Path, required=True)
- parser.add_argument("--endpoint", required=True)
- parser.add_argument("--region", required=True)
- parser.add_argument("--bucket", required=True)
- parser.add_argument("--access-key", required=True)
- parser.add_argument("--secret-key", type=Path, required=True)
- parser.add_argument("--s3-max-bytes", type=int, required=True)
+ parser.add_argument("--storage-max-bytes", type=int, required=True)
parser.add_argument("--default-style", default="native")
parser.add_argument("--line-numbers", action="store_true")
parser.add_argument("--line-numbers-inline", action="store_true")
- args = parser.parse_args()
+ s3parser = subparsers.add_parser("s3")
+ s3parser.add_argument("--endpoint", required=True)
+ s3parser.add_argument("--region", required=True)
+ s3parser.add_argument("--bucket", required=True)
+ s3parser.add_argument("--access-key", required=True)
+ s3parser.add_argument("--secret-key", type=Path, required=True)
- try:
- secret_key = args.secret_key.read_text().strip()
- except Exception as e:
- print(
- f"failed to read secret key from: {str(args.secret_key)}: {e}",
- file=sys.stderr,
- )
- return 1
+ args = parser.parse_args()
try:
dictionary = args.dictionary.read_text().split("\n")
@@ -71,30 +67,52 @@ async def main() -> int:
config = AppConfig(
args.site,
args.content_length_max_bytes,
- args.s3_max_bytes,
+ args.storage_max_bytes,
args.key_length,
dictionary,
args.default_style,
line_numbers,
)
- if not args.database.is_file():
- print(f"{args.database} does not exist or is not a file", file=sys.stderr)
+ try:
+ connection = await aiosqlite.connect(args.database)
+ except Exception as e:
+ log_error(f"failed to connect to database {args.database}: {e}")
+ return 1
+
+ try:
+ await connection.execute(
+ (
+ "create table if not exists pastes(key text, datetime text, size int, syntax text)"
+ )
+ )
+ await connection.commit()
+ except Exception as e:
+ log_error(f"failed to initialize database: {e}")
return 1
- database = Database(args.database)
+ match args.command:
+ case "s3":
+ try:
+ secret_key = args.secret_key.read_text().strip()
+ except Exception as e:
+ log_error(f"failed to load secret key from {args.secret_key}: {e}")
+ return 1
- bucket = s3.Bucket(
- args.endpoint,
- args.region,
- args.bucket,
- args.access_key,
- secret_key,
- )
+ storage = S3(
+ connection,
+ args.endpoint,
+ args.region,
+ args.bucket,
+ args.access_key,
+ secret_key,
+ )
- app = web.Application()
+ await storage.setup()
- pypaste = App(config, database, bucket)
+ pypaste = App(config, storage)
+
+ app = web.Application()
app.add_routes(
[web.get("/paste/{key}", pypaste.download), web.post("/paste", pypaste.upload)]
diff --git a/pypaste/database.py b/pypaste/database.py
deleted file mode 100644
index 0ada3fa..0000000
--- a/pypaste/database.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright (C) 2025 John Turner
-
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see <https://www.gnu.org/licenses/>.
-
-import asyncio
-import sqlite3
-from typing import Optional
-from datetime import datetime
-from pathlib import Path
-from dataclasses import dataclass
-
-
-@dataclass
-class PasteRow:
- key: str
- date: datetime
- size: int
- syntax: Optional[str]
-
-
-class Database:
-
- def __init__(self, path: Path):
- self.path = path
-
- async def insert(self, paste: PasteRow):
- def do():
- with sqlite3.connect(self.path) as connection:
- connection.execute(
- "insert into pastes values(?, ?, ?, ?)",
- (
- paste.key,
- paste.date.isoformat(),
- paste.size,
- paste.syntax,
- ),
- )
-
- await asyncio.to_thread(do)
-
- async def delete(self, key: str):
- def do():
- with sqlite3.connect(self.path) as connection:
- connection.execute("delete from pastes where key=?", (key,))
-
- await asyncio.to_thread(do)
-
- async def exists(self, key: str) -> bool:
- def do():
- with sqlite3.connect(self.path) as connection:
- return (
- connection.execute(
- "select 1 from pastes where pastes.key=?", (key,)
- ).fetchone()
- is not None
- )
-
- return await asyncio.to_thread(do)
-
- async def oldest(self) -> Optional[str]:
- def do():
- with sqlite3.connect(self.path) as connection:
- return connection.execute(
- "select pastes.key from pastes order by pastes.datetime limit 1",
- ).fetchone()
-
- match await asyncio.to_thread(do):
- case str(key):
- return key
- case _:
- return None
-
- async def storage_use(self) -> Optional[int]:
- def do():
- with sqlite3.connect(self.path) as connection:
- return connection.execute(
- "select sum(pastes.size) from pastes"
- ).fetchone()
-
- match asyncio.to_thread(do):
- case int(use):
- return use
- case _:
- return None
diff --git a/pypaste/meson.build b/pypaste/meson.build
new file mode 100644
index 0000000..f9ac8bc
--- /dev/null
+++ b/pypaste/meson.build
@@ -0,0 +1,3 @@
+sources += files('__init__.py', '__main__.py')
+
+subdir('s3')
diff --git a/pypaste/s3/__init__.py b/pypaste/s3/__init__.py
new file mode 100644
index 0000000..7d21703
--- /dev/null
+++ b/pypaste/s3/__init__.py
@@ -0,0 +1,126 @@
+# Copyright (C) 2025 John Turner
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+import asyncio
+import zstandard
+import asyncio
+import aiosqlite
+from pypaste import Storage, Paste
+from pypaste.s3.bucket import Bucket
+from dataclasses import dataclass
+from typing import Optional
+
+
+@dataclass
+class S3(Storage):
+ connection: aiosqlite.Connection
+
+ def __init__(
+ self,
+ connection: aiosqlite.Connection,
+ endpoint: str,
+ region: str,
+ bucket: str,
+ access_key: str,
+ secret_key: str,
+ ):
+ self.connection = connection
+ self.bucket = Bucket(endpoint, region, bucket, access_key, secret_key)
+
+ async def setup(self) -> None:
+ await self.connection.execute("create table if not exists s3(key text)")
+ await self.connection.commit()
+
+ async def insert(self, paste: Paste) -> None:
+ def compress():
+ return zstandard.compress(paste.text.encode())
+
+ compressed = await asyncio.to_thread(compress)
+
+ await self.connection.execute(
+ "insert into pastes values(?, ?, ?, ?)",
+ (paste.key, paste.dt.isoformat(), len(compressed), paste.syntax),
+ )
+
+ try:
+ await self.bucket.put(paste.key, compressed)
+ await self.connection.commit()
+ except Exception as e:
+ await self.connection.rollback()
+ raise e
+
+ async def retrieve(self, key: str) -> Optional[Paste]:
+ if not await self.exists(key):
+ return None
+
+ row = await self.read_row(key)
+
+ assert row is not None
+
+ (dt, size, syntax) = row
+
+ data = await self.bucket.get(key)
+
+ assert data is not None
+
+ def decompress() -> str:
+ return zstandard.decompress(data).decode()
+
+ text = await asyncio.to_thread(decompress)
+
+ return Paste(key, dt, syntax, text)
+
+ async def delete(self, key: str) -> None:
+ await self.connection.execute("delete from pastes where key=?", (key,))
+
+ try:
+ await self.bucket.delete(key)
+ await self.connection.commit()
+ except Exception as e:
+ await self.connection.rollback()
+ raise e
+
+ async def vacuum(self, max: int) -> None:
+ while True:
+ async with self.connection.execute(
+ (
+ "select sum(pastes.size) from pastes "
+ "inner join s3 on s3.key "
+ "where s3.key=pastes.key"
+ )
+ ) as cursor:
+ if (row := await cursor.fetchone()) is None:
+ return
+ else:
+ use = row[0]
+
+ async with self.connection.execute(
+ (
+ "select pastes.key from pastes "
+ "inner join s3 on s3.key "
+ "where s3.key=pastes.key "
+ "order by pastes.datetime "
+ "limit 1"
+ )
+ ) as cursor:
+ if (row := await cursor.fetchone()) is None:
+ return
+ else:
+ oldest = row[0]
+
+ if use > max:
+ await self.delete(oldest)
+ else:
+ return
diff --git a/pypaste/s3.py b/pypaste/s3/bucket.py
index 936d4e8..c795bbd 100644
--- a/pypaste/s3.py
+++ b/pypaste/s3/bucket.py
@@ -13,10 +13,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
+import aiohttp
from datetime import datetime, UTC
from typing import Dict
from dataclasses import dataclass
from bozo4 import s3v4_sign_request, s3v4_datetime_string
+from hashlib import sha256
+from typing import Optional
@dataclass
@@ -33,7 +36,7 @@ class Bucket:
access_key: str
secret_key: str
- def get(self, key: str) -> Request:
+ async def get(self, key: str) -> Optional[bytes]:
now = datetime.now(UTC)
headers = {
@@ -58,19 +61,30 @@ class Bucket:
headers["Authorization"] = auth
- return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
+ url = f"https://{self.endpoint}/{self.bucket}/{key}"
+
+ async with aiohttp.ClientSession().get(url, headers=headers) as get:
+ match get.status:
+ case 200:
+ return await get.read()
+ case 404:
+ return None
+ case _:
+ raise Exception(
+ f"failed to get {self.endpoint}/{self.bucket}/{key} with status {get.status}"
+ )
+
+ async def put(self, key: str, data: bytes) -> None:
+ payload_hash = sha256(data).hexdigest()
- def put(
- self, key: str, content_length: int, content_type: str, payload_hash: str
- ) -> Request:
now = datetime.now(UTC)
headers = {
"Host": self.endpoint,
"X-Amz-Date": s3v4_datetime_string(now),
"X-Amz-Content-SHA256": payload_hash,
- "Content-Length": str(content_length),
- "Content-Type": content_type,
+ "Content-Length": str(len(data)),
+ "Content-Type": "application/octect-stream",
}
auth = s3v4_sign_request(
@@ -89,9 +103,15 @@ class Bucket:
headers["Authorization"] = auth
- return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
+ url = f"https://{self.endpoint}/{self.bucket}/{key}"
- def delete(self, key: str) -> Request:
+ async with aiohttp.ClientSession().put(url, headers=headers, data=data) as put:
+ if put.status != 200:
+ raise Exception(
+ f"failed put {self.endpoint}/{self.bucket}/{key} with {put.status}"
+ )
+
+ async def delete(self, key: str) -> None:
now = datetime.now(UTC)
headers = {
@@ -117,4 +137,10 @@ class Bucket:
headers["Authorization"] = auth
- return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
+ url = f"https://{self.endpoint}/{self.bucket}/{key}"
+
+ async with aiohttp.ClientSession().delete(url, headers=headers) as delete:
+ if delete.status != 200:
+ raise Exception(
+ f"failed to delete {self.endpoint}/{self.bucket}/{key} with {delete.status}"
+ )
diff --git a/pypaste/s3/meson.build b/pypaste/s3/meson.build
new file mode 100644
index 0000000..dc7bce5
--- /dev/null
+++ b/pypaste/s3/meson.build
@@ -0,0 +1 @@
+sources += files('__init__.py', 'bucket.py')