diff options
Diffstat (limited to 'paste')
-rw-r--r-- | paste/__main__.py | 385 | ||||
-rw-r--r-- | paste/s3.py | 105 |
2 files changed, 0 insertions, 490 deletions
diff --git a/paste/__main__.py b/paste/__main__.py deleted file mode 100644 index 21ac508..0000000 --- a/paste/__main__.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright (C) 2025 John Turner - -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. - -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. - -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <https://www.gnu.org/licenses/>. - -import sys -import asyncio -import secrets -import aiohttp -import aiosqlite -import zstandard -from . import s3 -from hashlib import sha256 -from argparse import ArgumentParser -from aiohttp import web -from datetime import datetime, UTC -from pathlib import Path -from dataclasses import dataclass -from typing import Optional, List -from pygments import highlight -from pygments.lexers import guess_lexer, get_lexer_by_name -from pygments.formatters import HtmlFormatter -from pygments.styles import get_style_by_name - - -def pygmentize( - content: str, syntax: Optional[str], style: str, line_numbers: str | bool -) -> str: - if syntax is not None: - try: - lexer = get_lexer_by_name(syntax) - except Exception as e: - print(e, file=sys.stderr) - lexer = guess_lexer(content) - else: - lexer = guess_lexer(content) - - try: - style = get_style_by_name(style) - except Exception as e: - print(e, file=sys.stderr) - style = "default" - - formatter = HtmlFormatter(full=True, style=style, linenos="table") - - return highlight(content, lexer, formatter) - - -def generate_key(words: List[str], length: int) -> str: - choices = [] - for _ in range(length): - choices.append(secrets.choice(words)) - - return "-".join(word for word in choices).lower() - - -@dataclass -class PasteRow: - key: str - date: datetime - size: int - syntax: Optional[str] - - -class Database: - - def __init__(self, path: Path): - self.path = path - - async def insert(self, paste: PasteRow): - async with aiosqlite.connect(self.path) as db: - await db.execute( - "insert into pastes values(?, ?, ?, ?)", - ( - paste.key, - paste.date.isoformat(), - paste.size, - paste.syntax, - ), - ) - - await db.commit() - - async def delete(self, key: str): - async with aiosqlite.connect(self.path) as db: - await db.execute("delete from pastes where key=?", (key,)) - - await db.commit() - - async def exists(self, key: str) -> bool: - async with aiosqlite.connect(self.path) as db: - result = await db.execute("select 1 from pastes where pastes.key=?", (key,)) - - return await result.fetchone() is not None - - async def oldest(self) -> Optional[str]: - async with aiosqlite.connect(self.path) as db: - result = await db.execute( - "select pastes.key from pastes order by pastes.datetime limit 1", - ) - - row = await result.fetchone() - - return row[0] - - async def storage_use(self) -> Optional[int]: - async with aiosqlite.connect(self.path) as db: - result = await db.execute("select sum(pastes.size) from pastes") - - row = await result.fetchone() - - return row[0] - - -@dataclass -class AppConfig: - site: str - content_length_max_bytes: int - s3_max_bytes: int - key_length: int - dictionary: List[str] - default_style: str - line_numbers: str | bool - - -class App: - - def __init__( - self, - config: AppConfig, - database: Database, - bucket: s3.Bucket, - ): - self.database = database - self.config = config - self.bucket = bucket - - async def download(self, request: web.Request) -> web.Response: - try: - key = request.match_info["key"] - except KeyError: - return web.HTTPBadRequest(text="provide a key to fetch") - - if not await self.database.exists(key): - return web.HTTPNotFound() - - req = self.bucket.get(key) - - try: - async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get: - if get.status == 200: - data = await get.read() - else: - print( - f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}" - ) - return web.HTTPInternalServerError() - except Exception as e: - print(f"failed to get {key} from s3: {e}", file=sys.stderr) - return web.HTTPInternalServerError() - - def decompress(): - return zstandard.decompress(data) - - try: - decompressed = await asyncio.to_thread(decompress) - except Exception as e: - print(f"failed to decompress blob {key}: {e}", file=sys.stderr) - return web.HTTPInternalServerError() - - try: - text = decompressed.decode() - except Exception as e: - print(f"failed to decode blob: {key}: {e}", file=sys.stderr) - return web.HTTPInternalServerError() - - syntax = request.query.get("syntax") - raw = request.query.get("raw") - - if (style := request.query.get("style")) is None: - style = self.config.default_style - - if raw is not None: - return web.HTTPOk(text=text, content_type="text/plain") - else: - - def render(): - return pygmentize(text, syntax, style, self.config.line_numbers) - - highlighted = await asyncio.to_thread(render) - - return web.HTTPOk(text=highlighted, content_type="text/html") - - async def upload(self, request: web.Request) -> web.Response: - try: - await self.vacuum() - except Exception as e: - print(f"vacuum failed: {e}") - return web.HTTPInternalServerError() - - match request.content_length: - case int(i) if i > self.config.content_length_max_bytes: - return web.HTTPBadRequest( - text=f"max content length is {self.config.content_length_max_bytes}" - ) - case _: - pass - - try: - data = await request.read() - except Exception as e: - print(f"failed to read data: {e}", file=sys.stderr) - return web.HTTPInternalServerError(text="failed to read data") - - try: - data.decode() - except UnicodeError: - return web.HTTPBadRequest( - text="content must be unicode only, no binary data is allowed" - ) - - def compress(): - return zstandard.compress(data) - - try: - compressed = await asyncio.to_thread(compress) - except Exception as e: - print(f"failed to compress data: {e}", file=sys.stderr) - return web.HTTPInternalServerError() - - key = generate_key(self.config.dictionary, self.config.key_length) - - req = self.bucket.put( - key, - len(compressed), - "application/octet-stream", - sha256(compressed).hexdigest(), - ) - - try: - async with aiohttp.ClientSession().put( - req.url, headers=req.headers, data=compressed - ) as put: - if put.status != 200: - print(f"failed to put {key} to bucket with status: {put.status}") - return web.HTTPInternalServerError() - except Exception as e: - print(f"failed to put {key} to bucket: {e}") - return web.HTTPInternalServerError() - - try: - await self.database.insert( - PasteRow( - key, datetime.now(UTC), len(compressed), request.query.get("syntax") - ) - ) - except Exception as e: - print(f"failed to insert {key} into database: {e}", file=sys.stderr) - return web.HTTPInternalServerError() - - return web.HTTPOk(text=f"https://{self.config.site}/paste/{key}") - - async def vacuum(self): - while ( - use := await self.database.storage_use() - ) is not None and use > self.config.s3_max_bytes: - oldest = await self.database.oldest() - # If use is not None, there must be at least 1 paste, so we let the - # type checker know that this is an unreachable case. - assert oldest is not None - - req = self.bucket.delete(oldest) - - async with aiohttp.ClientSession().delete( - req.url, headers=req.headers - ) as delete: - if delete.status != 200: - print( - f"failed to delete {oldest}: got status {delete.status}", - file=sys.stderr, - ) - else: - print(f"successfully deleted {oldest}", file=sys.stderr) - - await self.database.delete(oldest) - - -def main() -> int: - parser = ArgumentParser() - parser.add_argument("--host") - parser.add_argument("--port", type=int) - parser.add_argument("--path") - parser.add_argument("--site", required=True) - parser.add_argument("--content-length-max-bytes", type=int, required=True) - parser.add_argument("--key-length", type=int, required=True) - parser.add_argument("--dictionary", type=Path, required=True) - parser.add_argument("--database", type=Path, required=True) - parser.add_argument("--endpoint", required=True) - parser.add_argument("--region", required=True) - parser.add_argument("--bucket", required=True) - parser.add_argument("--access-key", required=True) - parser.add_argument("--secret-key", type=Path, required=True) - parser.add_argument("--s3-max-bytes", type=int, required=True) - parser.add_argument("--default-style", default="native") - parser.add_argument("--line-numbers", action="store_true") - parser.add_argument("--line-numbers-inline", action="store_true") - - args = parser.parse_args() - - try: - secret_key = args.secret_key.read_text().strip() - except Exception as e: - print( - f"failed to read secret key from: {str(args.secret_key)}: {e}", - file=sys.stderr, - ) - return 1 - - try: - dictionary = args.dictionary.read_text().split("\n") - except Exception as e: - print(f"failed to open dictionary: {str(args.dictionary)}: {e}") - return 1 - - match [args.line_numbers, args.line_numbers_inline]: - case [True, _]: - line_numbers: str | bool = "table" - case [_, True]: - line_numbers = "inline" - case [_, _]: - line_numbers = False - - config = AppConfig( - args.site, - args.content_length_max_bytes, - args.s3_max_bytes, - args.key_length, - dictionary, - args.default_style, - line_numbers, - ) - - if not args.database.is_file(): - print(f"{args.database} does not exist or is not a file", file=sys.stderr) - return 1 - - try: - database = Database(args.database) - except Exception as e: - print(f"failed to open database: {e}", file=sys.stderr) - return 1 - - bucket = s3.Bucket( - args.endpoint, - args.region, - args.bucket, - args.access_key, - secret_key, - ) - - webapp = web.Application() - - app = App(config, database, bucket) - - webapp.add_routes( - [web.get("/paste/{key}", app.download), web.post("/paste", app.upload)] - ) - - web.run_app(webapp, host=args.host, port=args.port, path=args.path) - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/paste/s3.py b/paste/s3.py deleted file mode 100644 index 673e428..0000000 --- a/paste/s3.py +++ /dev/null @@ -1,105 +0,0 @@ -from datetime import datetime, UTC -from typing import Dict -from dataclasses import dataclass -from bozo4 import s3v4_sign_request, s3v4_datetime_string - - -@dataclass -class Request: - url: str - headers: Dict[str, str] - - -@dataclass -class Bucket: - endpoint: str - region: str - bucket: str - access_key: str - secret_key: str - - def get(self, key: str) -> Request: - now = datetime.now(UTC) - - headers = { - "Host": self.endpoint, - "X-Amz-Date": s3v4_datetime_string(now), - "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD", - } - - auth = s3v4_sign_request( - endpoint=self.endpoint, - region=self.region, - access_key=self.access_key, - secret_key=self.secret_key, - request_method="GET", - date=now, - payload_hash="UNSIGNED-PAYLOAD", - uri=f"/{self.bucket}/{key}", - parameters={}, - headers=headers, - service="s3", - ) - - headers["Authorization"] = auth - - return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) - - def put( - self, key: str, content_length: int, content_type: str, payload_hash: str - ) -> Request: - now = datetime.now(UTC) - - headers = { - "Host": self.endpoint, - "X-Amz-Date": s3v4_datetime_string(now), - "X-Amz-Content-SHA256": payload_hash, - "Content-Length": str(content_length), - "Content-Type": content_type, - } - - auth = s3v4_sign_request( - endpoint=self.endpoint, - region=self.region, - access_key=self.access_key, - secret_key=self.secret_key, - request_method="PUT", - date=now, - payload_hash=payload_hash, - uri=f"/{self.bucket}/{key}", - parameters={}, - headers=headers, - service="s3", - ) - - headers["Authorization"] = auth - - return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) - - def delete(self, key: str) -> Request: - now = datetime.now(UTC) - - headers = { - "Host": self.endpoint, - "X-Amz-Date": s3v4_datetime_string(now), - "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD", - "Content-Length": "0", - } - - auth = s3v4_sign_request( - endpoint=self.endpoint, - region=self.region, - access_key=self.access_key, - secret_key=self.secret_key, - request_method="DELETE", - date=now, - payload_hash="UNSIGNED-PAYLOAD", - uri=f"/{self.bucket}/{key}", - parameters={}, - headers=headers, - service="s3", - ) - - headers["Authorization"] = auth - - return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers) |