diff options
-rw-r--r-- | meson.build | 9 | ||||
-rw-r--r-- | pypaste/__init__.py | 199 | ||||
-rw-r--r-- | pypaste/__main__.py | 78 | ||||
-rw-r--r-- | pypaste/database.py | 95 | ||||
-rw-r--r-- | pypaste/meson.build | 3 | ||||
-rw-r--r-- | pypaste/s3/__init__.py | 126 | ||||
-rw-r--r-- | pypaste/s3/bucket.py (renamed from pypaste/s3.py) | 46 | ||||
-rw-r--r-- | pypaste/s3/meson.build | 1 |
8 files changed, 290 insertions, 267 deletions
diff --git a/meson.build b/meson.build index 9bfcaca..87ad91c 100644 --- a/meson.build +++ b/meson.build @@ -4,11 +4,8 @@ python = import('python').find_installation( modules: ['pygments', 'zstandard', 'aiohttp', 'bozo4'], ) -sources = files( - 'pypaste/__init__.py', - 'pypaste/__main__.py', - 'pypaste/database.py', - 'pypaste/s3.py', -) +sources = [] + +subdir('pypaste') python.install_sources(sources, preserve_path: true) diff --git a/pypaste/__init__.py b/pypaste/__init__.py index b5b26ff..0206e9e 100644 --- a/pypaste/__init__.py +++ b/pypaste/__init__.py @@ -16,19 +16,16 @@ import sys import asyncio import secrets -import aiohttp -import zstandard -from pypaste import s3 -from pypaste.database import Database, PasteRow -from hashlib import sha256 +import aiosqlite from aiohttp import web -from datetime import datetime, UTC +from datetime import datetime from dataclasses import dataclass -from typing import Optional, List +from typing import Optional, List, Tuple from pygments import highlight from pygments.lexers import guess_lexer, get_lexer_by_name from pygments.formatters import HtmlFormatter from pygments.styles import get_style_by_name +from abc import abstractmethod RESET = "\x1b[0m" RED = "\x1b[31m" @@ -83,10 +80,58 @@ def generate_key(words: List[str], length: int) -> str: @dataclass +class Paste: + key: str + dt: datetime + syntax: Optional[str] + text: str + + +@dataclass +class Storage: + connection: aiosqlite.Connection + + @abstractmethod + async def insert(self, paste: Paste) -> None: + pass + + @abstractmethod + async def retrieve(self, key: str) -> Optional[Paste]: + pass + + @abstractmethod + async def delete(self, key) -> None: + pass + + @abstractmethod + async def vacuum(self, size: int) -> None: + pass + + async def read_row(self, key: str) -> Optional[Tuple[datetime, int, Optional[str]]]: + async with self.connection.execute( + "select pastes.datetime,pastes.size,pastes.syntax from pastes where pastes.key=? limit 1", + (key,), + ) as cursor: + match await cursor.fetchone(): + case [str(dt), int(size), syntax]: + return (datetime.fromisoformat(dt), size, syntax) + case None: + return None + case _: + raise Exception("unreachable") + + async def exists(self, key: str) -> bool: + async with self.connection.execute( + "select 1 from pastes where key=?", (key,) + ) as cursor: + return await cursor.fetchone() is not None + + +@dataclass class AppConfig: site: str - content_length_max_bytes: int - s3_max_bytes: int + content_max_bytes: int + storage_max_bytes: int key_length: int dictionary: List[str] default_style: str @@ -95,15 +140,9 @@ class AppConfig: class App: - def __init__( - self, - config: AppConfig, - database: Database, - bucket: s3.Bucket, - ): - self.database = database + def __init__(self, config: AppConfig, storage: Storage): self.config = config - self.bucket = bucket + self.storage = storage async def download(self, request: web.Request) -> web.Response: try: @@ -112,45 +151,15 @@ class App: return web.HTTPBadRequest(text="provide a key to fetch") try: - exists = await self.database.exists(key) + paste = await self.storage.retrieve(key) except Exception as e: - log_error(f"failed to check if key exists in database: {e}") + log_error(f"failed to retrieve paste {key}: {e}") return web.HTTPInternalServerError() - if not exists: + if paste is None: + log_info(f"{key} does not exist, returning 404") return web.HTTPNotFound() - req = self.bucket.get(key) - - try: - async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get: - if get.status == 200: - data = await get.read() - else: - log_error( - f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}" - ) - - return web.HTTPInternalServerError() - except Exception as e: - log_error(f"failed to get {key} from s3: {e}") - return web.HTTPInternalServerError() - - def decompress(): - return zstandard.decompress(data) - - try: - decompressed = await asyncio.to_thread(decompress) - except Exception as e: - log_error(f"failed to decompress blob {key}: {e}") - return web.HTTPInternalServerError() - - try: - text = decompressed.decode() - except Exception as e: - log_error(f"failed to decode blob: {key}: {e}") - return web.HTTPInternalServerError() - syntax = request.query.get("syntax") raw = request.query.get("raw") @@ -160,11 +169,11 @@ class App: if raw is not None: log_info(f"sending raw paste {key}") - return web.HTTPOk(text=text, content_type="text/plain") + return web.HTTPOk(text=paste.text, content_type="text/plain") else: def render(): - return pygmentize(text, syntax, style, self.config.line_numbers) + return pygmentize(paste.text, syntax, style, self.config.line_numbers) highlighted = await asyncio.to_thread(render) @@ -177,19 +186,12 @@ class App: async def upload(self, request: web.Request) -> web.Response: syntax = request.query.get("syntax") - try: - await self.vacuum() - except Exception as e: - log_error(f"vacuum failed: {e}") - return web.HTTPInternalServerError() - - match request.content_length: - case int(i) if i > self.config.content_length_max_bytes: - return web.HTTPBadRequest( - text=f"max content length is {self.config.content_length_max_bytes}" - ) - case _: - pass + if ( + content_length := request.content_length + ) and content_length > self.config.content_max_bytes: + return web.HTTPBadRequest( + text=f"max content length is {self.config.content_max_bytes}" + ) try: data = await request.read() @@ -198,80 +200,25 @@ class App: return web.HTTPInternalServerError(text="failed to read data") try: - data.decode() + text = data.decode() except UnicodeError: return web.HTTPBadRequest( text="content must be unicode only, no binary data is allowed" ) - def compress(): - return zstandard.compress(data) - - try: - compressed = await asyncio.to_thread(compress) - except Exception as e: - log_error(f"failed to compress data: {e}") - return web.HTTPInternalServerError() - key = generate_key(self.config.dictionary, self.config.key_length) - req = self.bucket.put( - key, - len(compressed), - "application/octet-stream", - sha256(compressed).hexdigest(), - ) - try: - async with aiohttp.ClientSession().put( - req.url, headers=req.headers, data=compressed - ) as put: - if put.status != 200: - log_error( - f"failed to put {key} to bucket with status: {put.status}" - ) - return web.HTTPInternalServerError() + paste = Paste(key, datetime.now(), syntax, text) + await self.storage.insert(paste) except Exception as e: - log_error(f"failed to put {key} to bucket: {e}") - return web.HTTPInternalServerError() - - try: - await self.database.insert( - PasteRow(key, datetime.now(UTC), len(compressed), syntax) - ) - except Exception as e: - log_error(f"failed to insert {key} into database: {e}") + log_error(f"failed to insert paste {key} to storage: {e}") return web.HTTPInternalServerError() url = f"{self.config.site}/paste/{key}" log_info( - f"uploaded paste {key} with syntax {syntax} of size {len(compressed)} bytes: {url}" + f"uploaded paste {key} with syntax {syntax} of size {len(data)} bytes: {url}" ) return web.HTTPOk(text=url) - - async def vacuum(self): - log_info("starting vaccum") - while ( - use := await self.database.storage_use() - ) is not None and use > self.config.s3_max_bytes: - oldest = await self.database.oldest() - - # If use is not None, there must be at least 1 paste, so we let the - # type checker know that this is an unreachable case. - assert oldest is not None - - req = self.bucket.delete(oldest) - - async with aiohttp.ClientSession().delete( - req.url, headers=req.headers - ) as delete: - if delete.status == 200: - log_info(f"successfully deleted {oldest}") - - await self.database.delete(oldest) - else: - log_warning( - f"failed to delete {oldest}: got status {delete.status}", - ) diff --git a/pypaste/__main__.py b/pypaste/__main__.py index 374e32a..c7704d5 100644 --- a/pypaste/__main__.py +++ b/pypaste/__main__.py @@ -16,8 +16,9 @@ import sys import os import asyncio -from pypaste.database import Database -from pypaste import App, AppConfig, s3, log_error, log_info +import aiosqlite +from pypaste import App, AppConfig, log_error, log_info +from pypaste.s3 import S3 from socket import socket, AF_UNIX, SOCK_STREAM from argparse import ArgumentParser from aiohttp import web @@ -26,6 +27,8 @@ from pathlib import Path async def main() -> int: parser = ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + parser.add_argument("--socket", required=True) parser.add_argument("--socket-mode", default="0600") parser.add_argument("--site", required=True) @@ -33,26 +36,19 @@ async def main() -> int: parser.add_argument("--key-length", type=int, required=True) parser.add_argument("--dictionary", type=Path, required=True) parser.add_argument("--database", type=Path, required=True) - parser.add_argument("--endpoint", required=True) - parser.add_argument("--region", required=True) - parser.add_argument("--bucket", required=True) - parser.add_argument("--access-key", required=True) - parser.add_argument("--secret-key", type=Path, required=True) - parser.add_argument("--s3-max-bytes", type=int, required=True) + parser.add_argument("--storage-max-bytes", type=int, required=True) parser.add_argument("--default-style", default="native") parser.add_argument("--line-numbers", action="store_true") parser.add_argument("--line-numbers-inline", action="store_true") - args = parser.parse_args() + s3parser = subparsers.add_parser("s3") + s3parser.add_argument("--endpoint", required=True) + s3parser.add_argument("--region", required=True) + s3parser.add_argument("--bucket", required=True) + s3parser.add_argument("--access-key", required=True) + s3parser.add_argument("--secret-key", type=Path, required=True) - try: - secret_key = args.secret_key.read_text().strip() - except Exception as e: - print( - f"failed to read secret key from: {str(args.secret_key)}: {e}", - file=sys.stderr, - ) - return 1 + args = parser.parse_args() try: dictionary = args.dictionary.read_text().split("\n") @@ -71,30 +67,52 @@ async def main() -> int: config = AppConfig( args.site, args.content_length_max_bytes, - args.s3_max_bytes, + args.storage_max_bytes, args.key_length, dictionary, args.default_style, line_numbers, ) - if not args.database.is_file(): - print(f"{args.database} does not exist or is not a file", file=sys.stderr) + try: + connection = await aiosqlite.connect(args.database) + except Exception as e: + log_error(f"failed to connect to database {args.database}: {e}") + return 1 + + try: + await connection.execute( + ( + "create table if not exists pastes(key text, datetime text, size int, syntax text)" + ) + ) + await connection.commit() + except Exception as e: + log_error(f"failed to initialize database: {e}") return 1 - database = Database(args.database) + match args.command: + case "s3": + try: + secret_key = args.secret_key.read_text().strip() + except Exception as e: + log_error(f"failed to load secret key from {args.secret_key}: {e}") + return 1 - bucket = s3.Bucket( - args.endpoint, - args.region, - args.bucket, - args.access_key, - secret_key, - ) + storage = S3( + connection, + args.endpoint, + args.region, + args.bucket, + args.access_key, + secret_key, + ) - app = web.Application() + await storage.setup() - pypaste = App(config, database, bucket) + pypaste = App(config, storage) + + app = web.Application() app.add_routes( [web.get("/paste/{key}", pypaste.download), web.post("/paste", pypaste.upload)] diff --git a/pypaste/database.py b/pypaste/database.py deleted file mode 100644 index 0ada3fa..0000000 --- a/pypaste/database.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (C) 2025 John Turner - -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. - -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. - -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <https://www.gnu.org/licenses/>. - -import asyncio -import sqlite3 -from typing import Optional -from datetime import datetime -from pathlib import Path -from dataclasses import dataclass - - -@dataclass -class PasteRow: - key: str - date: datetime - size: int - syntax: Optional[str] - - -class Database: - - def __init__(self, path: Path): - self.path = path - - async def insert(self, paste: PasteRow): - def do(): - with sqlite3.connect(self.path) as connection: - connection.execute( - "insert into pastes values(?, ?, ?, ?)", - ( - paste.key, - paste.date.isoformat(), - paste.size, - paste.syntax, - ), - ) - - await asyncio.to_thread(do) - - async def delete(self, key: str): - def do(): - with sqlite3.connect(self.path) as connection: - connection.execute("delete from pastes where key=?", (key,)) - - await asyncio.to_thread(do) - - async def exists(self, key: str) -> bool: - def do(): - with sqlite3.connect(self.path) as connection: - return ( - connection.execute( - "select 1 from pastes where pastes.key=?", (key,) - ).fetchone() - is not None - ) - - return await asyncio.to_thread(do) - - async def oldest(self) -> Optional[str]: - def do(): - with sqlite3.connect(self.path) as connection: - return connection.execute( - "select pastes.key from pastes order by pastes.datetime limit 1", - ).fetchone() - - match await asyncio.to_thread(do): - case str(key): - return key - case _: - return None - - async def storage_use(self) -> Optional[int]: - def do(): - with sqlite3.connect(self.path) as connection: - return connection.execute( - "select sum(pastes.size) from pastes" - ).fetchone() - - match asyncio.to_thread(do): - case int(use): - return use - case _: - return None diff --git a/pypaste/meson.build b/pypaste/meson.build new file mode 100644 index 0000000..f9ac8bc --- /dev/null +++ b/pypaste/meson.build @@ -0,0 +1,3 @@ +sources += files('__init__.py', '__main__.py') + +subdir('s3') diff --git a/pypaste/s3/__init__.py b/pypaste/s3/__init__.py new file mode 100644 index 0000000..7d21703 --- /dev/null +++ b/pypaste/s3/__init__.py @@ -0,0 +1,126 @@ +# Copyright (C) 2025 John Turner + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <https://www.gnu.org/licenses/>. + +import asyncio +import zstandard +import asyncio +import aiosqlite +from pypaste import Storage, Paste +from pypaste.s3.bucket import Bucket +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class S3(Storage): + connection: aiosqlite.Connection + + def __init__( + self, + connection: aiosqlite.Connection, + endpoint: str, + region: str, + bucket: str, + access_key: str, + secret_key: str, + ): + self.connection = connection + self.bucket = Bucket(endpoint, region, bucket, access_key, secret_key) + + async def setup(self) -> None: + await self.connection.execute("create table if not exists s3(key text)") + await self.connection.commit() + + async def insert(self, paste: Paste) -> None: + def compress(): + return zstandard.compress(paste.text.encode()) + + compressed = await asyncio.to_thread(compress) + + await self.connection.execute( + "insert into pastes values(?, ?, ?, ?)", + (paste.key, paste.dt.isoformat(), len(compressed), paste.syntax), + ) + + try: + await self.bucket.put(paste.key, compressed) + await self.connection.commit() + except Exception as e: + await self.connection.rollback() + raise e + + async def retrieve(self, key: str) -> Optional[Paste]: + if not await self.exists(key): + return None + + row = await self.read_row(key) + + assert row is not None + + (dt, size, syntax) = row + + data = await self.bucket.get(key) + + assert data is not None + + def decompress() -> str: + return zstandard.decompress(data).decode() + + text = await asyncio.to_thread(decompress) + + return Paste(key, dt, syntax, text) + + async def delete(self, key: str) -> None: + await self.connection.execute("delete from pastes where key=?", (key,)) + + try: + await self.bucket.delete(key) + await self.connection.commit() + except Exception as e: + await self.connection.rollback() + raise e + + async def vacuum(self, max: int) -> None: + while True: + async with self.connection.execute( + ( + "select sum(pastes.size) from pastes " + "inner join s3 on s3.key " + "where s3.key=pastes.key" + ) + ) as cursor: + if (row := await cursor.fetchone()) is None: + return + else: + use = row[0] + + async with self.connection.execute( + ( + "select pastes.key from pastes " + "inner join s3 on s3.key " + "where s3.key=pastes.key " + "order by pastes.datetime " + "limit 1" + ) + ) as cursor: + if (row := await cursor.fetchone()) is None: + return + else: + oldest = row[0] + + if use > max: + await self.delete(oldest) + else: + return diff --git a/pypaste/s3.py b/pypaste/s3/bucket.py index 936d4e8..c795bbd 100644 --- a/pypaste/s3.py +++ b/pypaste/s3/bucket.py @@ -13,10 +13,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. +import aiohttp from datetime import datetime, UTC from typing import Dict from dataclasses import dataclass from bozo4 import s3v4_sign_request, s3v4_datetime_string +from hashlib import sha256 +from typing import Optional @dataclass @@ -33,7 +36,7 @@ class Bucket: access_key: str secret_key: str - def get(self, key: str) -> Request: + async def get(self, key: str) -> Optional[bytes]: now = datetime.now(UTC) headers = { @@ -58,19 +61,30 @@ class Bucket: headers["Authorization"] = auth - return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) + url = f"https://{self.endpoint}/{self.bucket}/{key}" + + async with aiohttp.ClientSession().get(url, headers=headers) as get: + match get.status: + case 200: + return await get.read() + case 404: + return None + case _: + raise Exception( + f"failed to get {self.endpoint}/{self.bucket}/{key} with status {get.status}" + ) + + async def put(self, key: str, data: bytes) -> None: + payload_hash = sha256(data).hexdigest() - def put( - self, key: str, content_length: int, content_type: str, payload_hash: str - ) -> Request: now = datetime.now(UTC) headers = { "Host": self.endpoint, "X-Amz-Date": s3v4_datetime_string(now), "X-Amz-Content-SHA256": payload_hash, - "Content-Length": str(content_length), - "Content-Type": content_type, + "Content-Length": str(len(data)), + "Content-Type": "application/octect-stream", } auth = s3v4_sign_request( @@ -89,9 +103,15 @@ class Bucket: headers["Authorization"] = auth - return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) + url = f"https://{self.endpoint}/{self.bucket}/{key}" - def delete(self, key: str) -> Request: + async with aiohttp.ClientSession().put(url, headers=headers, data=data) as put: + if put.status != 200: + raise Exception( + f"failed put {self.endpoint}/{self.bucket}/{key} with {put.status}" + ) + + async def delete(self, key: str) -> None: now = datetime.now(UTC) headers = { @@ -117,4 +137,10 @@ class Bucket: headers["Authorization"] = auth - return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) + url = f"https://{self.endpoint}/{self.bucket}/{key}" + + async with aiohttp.ClientSession().delete(url, headers=headers) as delete: + if delete.status != 200: + raise Exception( + f"failed to delete {self.endpoint}/{self.bucket}/{key} with {delete.status}" + ) diff --git a/pypaste/s3/meson.build b/pypaste/s3/meson.build new file mode 100644 index 0000000..dc7bce5 --- /dev/null +++ b/pypaste/s3/meson.build @@ -0,0 +1 @@ +sources += files('__init__.py', 'bucket.py') |