From a518a3ae1d5d7c3b1b8d7d13684b228ee3e58434 Mon Sep 17 00:00:00 2001 From: John Turner Date: Tue, 2 Sep 2025 01:08:21 -0400 Subject: rewrite using aiohttp --- paste/__main__.py | 407 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ paste/s3.py | 104 ++++++++++++++ src/paste.py | 371 ------------------------------------------------- 3 files changed, 511 insertions(+), 371 deletions(-) create mode 100644 paste/__main__.py create mode 100644 paste/s3.py delete mode 100755 src/paste.py diff --git a/paste/__main__.py b/paste/__main__.py new file mode 100644 index 0000000..fb1b413 --- /dev/null +++ b/paste/__main__.py @@ -0,0 +1,407 @@ +# Copyright (C) 2025 John Turner + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import sys +import asyncio +import sqlite3 +import secrets +import aiohttp +import zstandard +import s3 as s3 +from hashlib import sha256 +from argparse import ArgumentParser +from aiohttp import web +from datetime import datetime, UTC +from pathlib import Path +from dataclasses import dataclass +from typing import Optional, List +from pygments import highlight +from pygments.lexers import guess_lexer, get_lexer_by_name +from pygments.formatters import HtmlFormatter +from pygments.styles import get_style_by_name + + +def pygmentize( + content: str, syntax: Optional[str], style: str, line_numbers: str | bool +) -> str: + if syntax is not None: + try: + lexer = get_lexer_by_name(syntax) + except Exception as e: + print(e, file=sys.stderr) + lexer = guess_lexer(content) + else: + lexer = guess_lexer(content) + + try: + style = get_style_by_name(style) + except Exception as e: + print(e, file=sys.stderr) + style = "default" + + formatter = HtmlFormatter(full=True, style=style, linenos="table") + + return highlight(content, lexer, formatter) + + +def generate_key(words: List[str], length: int) -> str: + choices = [] + for _ in range(length): + choices.append(secrets.choice(words)) + + return "-".join(word for word in choices).lower() + + +@dataclass +class PasteRow: + key: str + date: datetime + size: int + syntax: Optional[str] + + +class Database: + + def __init__(self, connection: sqlite3.Connection): + self.connection = connection + + async def insert(self, paste: PasteRow): + def do_insert(): + self.connection.cursor().execute( + "insert into pastes values(?, ?, ?, ?)", + ( + paste.key, + paste.date.isoformat(), + paste.size, + paste.syntax, + ), + ) + + self.connection.commit() + + await asyncio.to_thread(do_insert) + + async def delete(self, key: str): + def do_delete(): + self.connection.cursor().execute("delete pastes where pastes.key=?", (key,)) + + self.connection.commit() + + await asyncio.to_thread(do_delete) + + async def exists(self, key: str) -> bool: + def do_exists(): + return ( + self.connection.cursor() + .execute("select * from pastes where pastes.key=?", (key,)) + .fetchone() + is not None + ) + + return await asyncio.to_thread(do_exists()) + + async def oldest(self) -> Optional[str]: + def do_oldest(): + return ( + self.connection.cursor() + .execute( + "select pastes.key from pastes order by pastes.datetime limit 1" + ) + .fetchone() + ) + + result = await asyncio.to_thread(do_oldest) + + if result is not None: + return result[0] + else: + return None + + async def storage_use(self) -> Optional[int]: + def do_storage_use(): + return ( + self.connection.cursor() + .execute("select sum(pastes.size) from pastes") + .fetchone() + ) + + result = await asyncio.to_thread(do_storage_use) + + if result is not None: + return result[0] + else: + return None + + +@dataclass +class AppConfig: + site: str + content_length_max_bytes: int + s3_max_bytes: int + key_length: int + dictionary: List[str] + default_style: str + line_numbers: str | bool + + +class App: + + def __init__( + self, + config: AppConfig, + database: Database, + bucket: s3.Bucket, + ): + self.database = database + self.config = config + self.bucket = bucket + + async def download(self, request: web.Request) -> web.Response: + try: + key = request.match_info["key"] + except KeyError: + return web.HTTPBadRequest(text="provide a key to fetch") + + if not await self.database.exists(key): + return web.HTTPNotFound() + + req = self.bucket.get(key) + + try: + async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get: + if get.status == 200: + data = await get.read() + else: + print( + f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}" + ) + return web.HTTPInternalServerError() + except Exception as e: + print(f"failed to get {key} from s3: {e}", file=sys.stderr) + return web.HTTPInternalServerError() + + def decompress(): + return zstandard.decompress(data) + + try: + decompressed = await asyncio.to_thread(decompress) + except Exception as e: + print(f"failed to decompress blob {key}: {e}", file=sys.stderr) + return web.HTTPInternalServerError() + + try: + text = decompressed.decode() + except Exception as e: + print(f"failed to decode blob: {key}: {e}", file=sys.stderr) + return web.HTTPInternalServerError() + + syntax = request.query.get("syntax") + raw = request.query.get("raw") + + if (style := request.query.get("style")) is None: + style = self.config.default_style + + if raw is not None: + return web.HTTPOk(text=text, content_type="text/plain") + else: + + def render(): + return pygmentize(text, syntax, style, self.config.line_numbers) + + highlighted = await asyncio.to_thread(render) + + return web.HTTPOk(text=highlighted, content_type="text/html") + + async def upload(self, request: web.Request) -> web.Response: + try: + await self.vacuum() + except Exception as e: + print(f"vacuum failed: {e}") + return web.HTTPInternalServerError() + + match request.content_length: + case int(i) if i > self.config.content_length_max_bytes: + return web.HTTPBadRequest( + text=f"max content length is {self.config.content_length_max_bytes}" + ) + case _: + pass + + try: + data = await request.read() + except Exception as e: + print(f"failed to read data: {e}", file=sys.stderr) + return web.HTTPInternalServerError(text="failed to read data") + + try: + data.decode() + except UnicodeError: + return web.HTTPBadRequest( + text="content must be unicode only, no binary data is allowed" + ) + + def compress(): + return zstandard.compress(data) + + try: + compressed = await asyncio.to_thread(compress) + except Exception as e: + print(f"failed to compress data: {e}", file=sys.stderr) + return web.HTTPInternalServerError() + + key = generate_key(self.config.dictionary, self.config.key_length) + + req = self.bucket.put( + key, + len(compressed), + "application/octet-stream", + sha256(compressed).hexdigest(), + ) + + try: + async with aiohttp.ClientSession().put( + req.url, headers=req.headers, data=compressed + ) as put: + if put.status != 200: + print(f"failed to put {key} to bucket with status: {put.status}") + return web.HTTPInternalServerError() + except Exception as e: + print(f"failed to put {key} to bucket: {e}") + return web.HTTPInternalServerError() + + try: + await self.database.insert( + PasteRow( + key, datetime.now(UTC), len(compressed), request.query.get("syntax") + ) + ) + except Exception as e: + print(f"failed to insert {key} into database: {e}", file=sys.stderr) + return web.HTTPInternalServerError() + + return web.HTTPOk(text=f"https://{self.config.site}/paste/{key}") + + async def vacuum(self): + while ( + use := self.database.storage_use() + ) is not None and use > self.config.s3_max_bytes: + oldest = self.database.oldest() + # If use is not None, there must be at least 1 paste, so we let the + # type checker know that this is an unreachable case. + assert oldest is not None + + req = self.bucket.delete(oldest) + + async with aiohttp.ClientSession().delete( + req.url, headers=req.headers + ) as delete: + if delete.status != 200: + raise Exception( + f"failed to delete {oldest} from bucket with status: {delete.status}" + ) + + self.database.delete(oldest) + + +def main() -> int: + parser = ArgumentParser() + parser.add_argument("--host") + parser.add_argument("--port", type=int) + parser.add_argument("--path") + parser.add_argument("--site", required=True) + parser.add_argument("--content-length-max-bytes", type=int, required=True) + parser.add_argument("--key-length", type=int, required=True) + parser.add_argument("--dictionary", type=Path, required=True) + parser.add_argument("--database", type=Path, required=True) + parser.add_argument("--endpoint", required=True) + parser.add_argument("--region", required=True) + parser.add_argument("--bucket", required=True) + parser.add_argument("--access-key", required=True) + parser.add_argument("--secret-key", type=Path, required=True) + parser.add_argument("--s3-max-bytes", type=int, required=True) + parser.add_argument("--default-style", default="native") + parser.add_argument("--line-numbers", action="store_true") + parser.add_argument("--line-numbers-inline", action="store_true") + + args = parser.parse_args() + + try: + secret_key = args.secret_key.read_text().strip() + except Exception as e: + print( + f"failed to read secret key from: {str(args.secret_key)}: {e}", + file=sys.stderr, + ) + return 1 + + try: + dictionary = args.dictionary.read_text().split("\n") + except Exception as e: + print(f"failed to open dictionary: {str(args.dictionary)}: {e}") + return 1 + + match [args.line_numbers, args.line_numbers_inline]: + case [True, _]: + line_numbers: str | bool = "table" + case [_, True]: + line_numbers = "inline" + case [_, _]: + line_numbers = False + + config = AppConfig( + args.site, + args.content_length_max_bytes, + args.s3_max_bytes, + args.key_length, + dictionary, + args.default_style, + line_numbers, + ) + + try: + connection = sqlite3.connect(args.database) + except Exception as e: + print(f"failed to connect to database: {args.database}: {e}") + return 1 + + try: + database = Database(connection) + except Exception as e: + print(f"failed to open database: {e}", file=sys.stderr) + return 1 + + bucket = s3.Bucket( + args.endpoint, + args.region, + args.bucket, + args.access_key, + secret_key, + ) + + webapp = web.Application() + + app = App(config, database, bucket) + + webapp.add_routes( + [web.get("/paste/{key}", app.download), web.post("/paste", app.upload)] + ) + + web.run_app(webapp, host=args.host, port=args.port, path=args.path) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/paste/s3.py b/paste/s3.py new file mode 100644 index 0000000..a9b2b21 --- /dev/null +++ b/paste/s3.py @@ -0,0 +1,104 @@ +from datetime import datetime, UTC +from typing import Dict +from dataclasses import dataclass +from bozo4 import s3v4_sign_request, s3v4_datetime_string + + +@dataclass +class Request: + url: str + headers: Dict[str, str] + + +@dataclass +class Bucket: + endpoint: str + region: str + bucket: str + access_key: str + secret_key: str + + def get(self, key: str) -> Request: + now = datetime.now(UTC) + + headers = { + "Host": self.endpoint, + "X-Amz-Date": s3v4_datetime_string(now), + "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD", + } + + auth = s3v4_sign_request( + endpoint=self.endpoint, + region=self.region, + access_key=self.access_key, + secret_key=self.secret_key, + request_method="GET", + date=now, + payload_hash="UNSIGNED-PAYLOAD", + uri=f"/{self.bucket}/{key}", + parameters={}, + headers=headers, + service="s3", + ) + + headers["Authorization"] = auth + + return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) + + def put( + self, key: str, content_length: int, content_type: str, payload_hash: str + ) -> Request: + now = datetime.now(UTC) + + headers = { + "Host": self.endpoint, + "X-Amz-Date": s3v4_datetime_string(now), + "X-Amz-Content-SHA256": payload_hash, + "Content-Length": str(content_length), + "Content-Type": content_type, + } + + auth = s3v4_sign_request( + endpoint=self.endpoint, + region=self.region, + access_key=self.access_key, + secret_key=self.secret_key, + request_method="PUT", + date=now, + payload_hash=payload_hash, + uri=f"/{self.bucket}/{key}", + parameters={}, + headers=headers, + service="s3", + ) + + headers["Authorization"] = auth + + return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) + + def delete(self, key: str) -> Request: + now = datetime.now(UTC) + + headers = { + "Host": self.endpoint, + "X-Amz-Date": s3v4_datetime_string(now), + "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD", + } + + auth = s3v4_sign_request( + endpoint=self.endpoint, + region=self.region, + access_key=self.access_key, + secret_key=self.secret_key, + request_method="DELETE", + date=now, + payload_hash="UNSIGNED-PAYLOAD", + uri=f"/{self.bucket}/{key}", + parameters={}, + headers=headers, + service="s3", + ) + + headers["Authorization"] = auth + + return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) diff --git a/src/paste.py b/src/paste.py deleted file mode 100755 index baef4fc..0000000 --- a/src/paste.py +++ /dev/null @@ -1,371 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -import io -import secrets -import sqlite3 -import zstandard -from datetime import datetime -from urllib.parse import parse_qs -from pathlib import Path -from dataclasses import dataclass -from typing import Optional, Dict, List, Any -from boto3.session import Session -from botocore.client import Config -from botocore.handlers import set_list_objects_encoding_type_url -from pygments import highlight -from pygments.lexers import guess_lexer, get_lexer_by_name -from pygments.formatters import HtmlFormatter - - -@dataclass -class Environment: - request_method: str - content_length: int - path: Optional[str] - parameters: Dict[str, List[str]] - content_length_max_bytes: int - access_key: str - secret_key: str - endpoint: str - region: str - bucket: str - signature_version: str - key_length: int - site: str - dict: Path - db_path: Path - - def __init__(self, environment: Dict[str, str]): - self.request_method = environment["REQUEST_METHOD"] - self.content_length = int(environment["CONTENT_LENGTH"]) - - try: - self.path = environment["PATH_INFO"] - except KeyError: - self.path = None - - try: - query_string = environment["QUERY_STRING"] - self.parameters = parse_qs(query_string) - except KeyError: - self.parameters = {} - - self.content_length_max_bytes = int( - environment["PASTE_CONTENT_LENGTH_MAX_BYTES"] - ) - self.access_key = environment["PASTE_ACCESS_KEY"] - self.secret_key = environment["PASTE_SECRET_KEY"] - self.endpoint = environment["PASTE_ENDPOINT"] - self.region = environment["PASTE_REGION"] - self.bucket = environment["PASTE_BUCKET_NAME"] - self.signature_version = environment["PASTE_SIGNATURE_VERSION"] - self.db_path = Path(environment["PASTE_DB"]) - self.key_length = int(environment["PASTE_KEY_LENGTH"]) - self.site = environment["PASTE_SITE"] - self.dict = Path(environment["PASTE_DICT"]) - self.storage_max_bytes = int(environment["PASTE_STORAGE_MAX_BYTES"]) - - -def connect_to_s3( - endpoint: str, - region: str, - signature_version: str, - access_key: str, - secret_key: str, -): - session = Session( - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - ) - - session.events.unregister( - "before-parameter-build.s3.ListObjects", set_list_objects_encoding_type_url - ) - - return session.resource( - "s3", - endpoint_url=endpoint, - config=Config(signature_version=signature_version), - ) - - -def key_in_db(cursor: sqlite3.Cursor, key: str) -> bool: - results = cursor.execute("select * from pastes where pastes.key=?", (key,)) - return results.fetchone() is not None - - -def get_syntax_from_db(cursor: sqlite3.Cursor, key: str) -> Optional[str]: - results = cursor.execute( - "select pastes.syntax from pastes where pastes.key=?", (key,) - ) - return results.fetchone() - - -def get_storage_use(cursor: sqlite3.Cursor) -> Optional[int]: - results = cursor.execute("select sum(pastes.size) from pastes") - - if (row := results.fetchone()) is not None: - return row[0] - else: - return None - - -def delete_oldest_paste(cursor: sqlite3.Cursor) -> Optional[str]: - results = cursor.execute( - "select pastes.key from pastes order by pastes.datetime limit 1" - ) - - if (row := results.fetchone()) is not None: - key = row[0] - - cursor.execute("delete from pastes where pastes.key=?", (key,)) - - return key - else: - return None - - -def generate_key(words: List[str], length: int) -> str: - choices = [] - for _ in range(length): - choices.append(secrets.choice(words)) - - return "-".join(word for word in choices).lower() - - -def pygmentize(content: str, syntax: Optional[str]) -> str: - if syntax is not None: - try: - lexer = get_lexer_by_name(syntax) - except Exception as e: - print(e, file=sys.stderr) - lexer = guess_lexer(content) - else: - lexer = guess_lexer(content) - - formatter = HtmlFormatter(full=True, cssclass="source") - - return highlight(content, lexer, formatter) - - -def main(): - try: - environment = Environment({key: val for key, val in os.environ.items()}) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to load environment: {e}", file=sys.stderr) - return - - if environment.request_method not in ["GET", "POST"]: - print("Status: 400\r\n") - print(f"unsupported request method: {environment.request_method}") - return - - try: - words = environment.dict.read_text().split("\n") - except Exception as e: - print("Status: 500\r\n") - print(f"failed to load dictonary: {e}", file=sys.stderr) - return - - try: - s3 = connect_to_s3( - environment.endpoint, - environment.region, - environment.signature_version, - environment.access_key, - environment.secret_key, - ) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to connect to s3: {e}", file=sys.stderr) - return - - try: - bucket = s3.Bucket(environment.bucket) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to connect to bucket: {e}", file=sys.stderr) - return - - try: - connection = sqlite3.connect(environment.db_path) - cursor = connection.cursor() - except Exception as e: - print("Status: 500\r\n") - print(f"failed to connect to database: {e}") - return - - match environment.request_method: - case "POST" if ( - environment.content_length > environment.content_length_max_bytes - ): - print("Status: 400\r\n") - print( - f"content length exceeded maximum length of {environment.content_length_max_bytes}" - ) - return - - case "POST": - # delete pastes to make size for new one - while ( - (storage_use := get_storage_use(cursor)) is not None - and storage_use + environment.content_length - > environment.storage_max_bytes - ): - match delete_oldest_paste(cursor): - case str(key): - try: - s3.Object(environment.bucket, key).delete() - except Exception as e: - print("Status: 500\r\n") - print( - f"failed to delete object: {key}: {e}", file=sys.stderr - ) - return - - try: - connection.commit() - print( - f"rotated {key} out to make room for new paste", - file=sys.stderr, - ) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to commit to database: {e}", file=sys.stderr) - return - case None: - break - - try: - syntax = environment.parameters["syntax"][0] - except KeyError: - syntax = None - except IndexError: - syntax = None - except Exception as e: - print("Status: 500\r\n") - print(f'failed to load query parameter "syntax": {e}', file=sys.stderr) - return - - try: - key = generate_key(words, environment.key_length) - while key_in_db(cursor, key): - key = generate_key(words, environment.key_length) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to query database: {e}", file=sys.stderr) - return - - content = sys.stdin.buffer.read(environment.content_length) - - try: - content.decode() - except Exception as e: - print("Status: 400\r\n") - print(f"failed to decode content: {e}") - return - - compressed = zstandard.compress(content) - buffer = io.BytesIO(compressed) - - try: - cursor.execute( - "insert into pastes values (?, ?, ?, ?)", - ( - key, - len(compressed), - datetime.now().isoformat(), - syntax, - ), - ) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to insert into database: {e}", file=sys.stderr) - return - - try: - bucket.upload_fileobj(buffer, key) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to upload paste: {e}", file=sys.stderr) - return - - try: - connection.commit() - except Exception as e: - print("Status: 500\r\n") - print(f"failed to commit to db: {e}") - return - - print("Status: 200\r\n") - print(f"{environment.site}/{key}") - return - - case "GET" if environment.path is None: - print("Status: 400\r\n") - print("please provide a key to fetch") - return - - case "GET": - # strip leading / from path - key = environment.path[1:] - - try: - exists = key_in_db(cursor, key) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to query db: {e}") - return - - if not exists: - print("Status: 404\r\n") - return - - buffer = io.BytesIO() - - try: - bucket.download_fileobj(key, buffer) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to download from bucket: {e}") - return - - buffer.seek(0) - - decompressed = zstandard.decompress(buffer.read()) - content = decompressed.decode() - - try: - syntax = environment.parameters["syntax"][0] - except KeyError: - try: - syntax = get_syntax_from_db(cursor, key) - except Exception as e: - print("Status: 500\r\n") - print(f"failed to query database: {e}") - return - - try: - if environment.parameters["raw"]: - raw = True - else: - raw = False - except KeyError: - raw = False - - if raw: - print("Status: 200\r\n") - print(content) - return - else: - highlighted = pygmentize(content, syntax) - print("Status: 200\r\n") - print(highlighted) - return - - -main() -- cgit v1.2.3