# Copyright (C) 2025 John Turner # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see . import sys import asyncio import secrets import aiohttp import aiosqlite import zstandard from . import s3 from hashlib import sha256 from argparse import ArgumentParser from aiohttp import web from datetime import datetime, UTC from pathlib import Path from dataclasses import dataclass from typing import Optional, List from pygments import highlight from pygments.lexers import guess_lexer, get_lexer_by_name from pygments.formatters import HtmlFormatter from pygments.styles import get_style_by_name def pygmentize( content: str, syntax: Optional[str], style: str, line_numbers: str | bool ) -> str: if syntax is not None: try: lexer = get_lexer_by_name(syntax) except Exception as e: print(e, file=sys.stderr) lexer = guess_lexer(content) else: lexer = guess_lexer(content) try: style = get_style_by_name(style) except Exception as e: print(e, file=sys.stderr) style = "default" formatter = HtmlFormatter(full=True, style=style, linenos="table") return highlight(content, lexer, formatter) def generate_key(words: List[str], length: int) -> str: choices = [] for _ in range(length): choices.append(secrets.choice(words)) return "-".join(word for word in choices).lower() @dataclass class PasteRow: key: str date: datetime size: int syntax: Optional[str] class Database: def __init__(self, path: Path): self.path = path async def insert(self, paste: PasteRow): async with aiosqlite.connect(self.path) as db: await db.execute( "insert into pastes values(?, ?, ?, ?)", ( paste.key, paste.date.isoformat(), paste.size, paste.syntax, ), ) await db.commit() async def delete(self, key: str): async with aiosqlite.connect(self.path) as db: await db.execute("delete from pastes where key=?", (key,)) await db.commit() async def exists(self, key: str) -> bool: async with aiosqlite.connect(self.path) as db: result = await db.execute("select 1 from pastes where pastes.key=?", (key,)) return await result.fetchone() is not None async def oldest(self) -> Optional[str]: async with aiosqlite.connect(self.path) as db: result = await db.execute( "select pastes.key from pastes order by pastes.datetime limit 1", ) row = await result.fetchone() return row[0] async def storage_use(self) -> Optional[int]: async with aiosqlite.connect(self.path) as db: result = await db.execute("select sum(pastes.size) from pastes") row = await result.fetchone() return row[0] @dataclass class AppConfig: site: str content_length_max_bytes: int s3_max_bytes: int key_length: int dictionary: List[str] default_style: str line_numbers: str | bool class App: def __init__( self, config: AppConfig, database: Database, bucket: s3.Bucket, ): self.database = database self.config = config self.bucket = bucket async def download(self, request: web.Request) -> web.Response: try: key = request.match_info["key"] except KeyError: return web.HTTPBadRequest(text="provide a key to fetch") if not await self.database.exists(key): return web.HTTPNotFound() req = self.bucket.get(key) try: async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get: if get.status == 200: data = await get.read() else: print( f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}" ) return web.HTTPInternalServerError() except Exception as e: print(f"failed to get {key} from s3: {e}", file=sys.stderr) return web.HTTPInternalServerError() def decompress(): return zstandard.decompress(data) try: decompressed = await asyncio.to_thread(decompress) except Exception as e: print(f"failed to decompress blob {key}: {e}", file=sys.stderr) return web.HTTPInternalServerError() try: text = decompressed.decode() except Exception as e: print(f"failed to decode blob: {key}: {e}", file=sys.stderr) return web.HTTPInternalServerError() syntax = request.query.get("syntax") raw = request.query.get("raw") if (style := request.query.get("style")) is None: style = self.config.default_style if raw is not None: return web.HTTPOk(text=text, content_type="text/plain") else: def render(): return pygmentize(text, syntax, style, self.config.line_numbers) highlighted = await asyncio.to_thread(render) return web.HTTPOk(text=highlighted, content_type="text/html") async def upload(self, request: web.Request) -> web.Response: try: await self.vacuum() except Exception as e: print(f"vacuum failed: {e}") return web.HTTPInternalServerError() match request.content_length: case int(i) if i > self.config.content_length_max_bytes: return web.HTTPBadRequest( text=f"max content length is {self.config.content_length_max_bytes}" ) case _: pass try: data = await request.read() except Exception as e: print(f"failed to read data: {e}", file=sys.stderr) return web.HTTPInternalServerError(text="failed to read data") try: data.decode() except UnicodeError: return web.HTTPBadRequest( text="content must be unicode only, no binary data is allowed" ) def compress(): return zstandard.compress(data) try: compressed = await asyncio.to_thread(compress) except Exception as e: print(f"failed to compress data: {e}", file=sys.stderr) return web.HTTPInternalServerError() key = generate_key(self.config.dictionary, self.config.key_length) req = self.bucket.put( key, len(compressed), "application/octet-stream", sha256(compressed).hexdigest(), ) try: async with aiohttp.ClientSession().put( req.url, headers=req.headers, data=compressed ) as put: if put.status != 200: print(f"failed to put {key} to bucket with status: {put.status}") return web.HTTPInternalServerError() except Exception as e: print(f"failed to put {key} to bucket: {e}") return web.HTTPInternalServerError() try: await self.database.insert( PasteRow( key, datetime.now(UTC), len(compressed), request.query.get("syntax") ) ) except Exception as e: print(f"failed to insert {key} into database: {e}", file=sys.stderr) return web.HTTPInternalServerError() return web.HTTPOk(text=f"https://{self.config.site}/paste/{key}") async def vacuum(self): while ( use := await self.database.storage_use() ) is not None and use > self.config.s3_max_bytes: oldest = await self.database.oldest() # If use is not None, there must be at least 1 paste, so we let the # type checker know that this is an unreachable case. assert oldest is not None req = self.bucket.delete(oldest) async with aiohttp.ClientSession().delete( req.url, headers=req.headers ) as delete: if delete.status != 200: raise Exception( f"failed to delete {oldest} from bucket with status: {delete.status}" ) await self.database.delete(oldest) def main() -> int: parser = ArgumentParser() parser.add_argument("--host") parser.add_argument("--port", type=int) parser.add_argument("--path") parser.add_argument("--site", required=True) parser.add_argument("--content-length-max-bytes", type=int, required=True) parser.add_argument("--key-length", type=int, required=True) parser.add_argument("--dictionary", type=Path, required=True) parser.add_argument("--database", type=Path, required=True) parser.add_argument("--endpoint", required=True) parser.add_argument("--region", required=True) parser.add_argument("--bucket", required=True) parser.add_argument("--access-key", required=True) parser.add_argument("--secret-key", type=Path, required=True) parser.add_argument("--s3-max-bytes", type=int, required=True) parser.add_argument("--default-style", default="native") parser.add_argument("--line-numbers", action="store_true") parser.add_argument("--line-numbers-inline", action="store_true") args = parser.parse_args() try: secret_key = args.secret_key.read_text().strip() except Exception as e: print( f"failed to read secret key from: {str(args.secret_key)}: {e}", file=sys.stderr, ) return 1 try: dictionary = args.dictionary.read_text().split("\n") except Exception as e: print(f"failed to open dictionary: {str(args.dictionary)}: {e}") return 1 match [args.line_numbers, args.line_numbers_inline]: case [True, _]: line_numbers: str | bool = "table" case [_, True]: line_numbers = "inline" case [_, _]: line_numbers = False config = AppConfig( args.site, args.content_length_max_bytes, args.s3_max_bytes, args.key_length, dictionary, args.default_style, line_numbers, ) if not args.database.is_file(): print(f"{args.database} does not exist or is not a file", file=sys.stderr) return 1 try: database = Database(args.database) except Exception as e: print(f"failed to open database: {e}", file=sys.stderr) return 1 bucket = s3.Bucket( args.endpoint, args.region, args.bucket, args.access_key, secret_key, ) webapp = web.Application() app = App(config, database, bucket) webapp.add_routes( [web.get("/paste/{key}", app.download), web.post("/paste", app.upload)] ) web.run_app(webapp, host=args.host, port=args.port, path=args.path) return 0 if __name__ == "__main__": sys.exit(main())