# Copyright (C) 2025 John Turner # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see . import sys import asyncio import sqlite3 import secrets import aiohttp import zstandard from . import s3 from hashlib import sha256 from argparse import ArgumentParser from aiohttp import web from datetime import datetime, UTC from pathlib import Path from dataclasses import dataclass from typing import Optional, List from pygments import highlight from pygments.lexers import guess_lexer, get_lexer_by_name from pygments.formatters import HtmlFormatter from pygments.styles import get_style_by_name def pygmentize( content: str, syntax: Optional[str], style: str, line_numbers: str | bool ) -> str: if syntax is not None: try: lexer = get_lexer_by_name(syntax) except Exception as e: print(e, file=sys.stderr) lexer = guess_lexer(content) else: lexer = guess_lexer(content) try: style = get_style_by_name(style) except Exception as e: print(e, file=sys.stderr) style = "default" formatter = HtmlFormatter(full=True, style=style, linenos="table") return highlight(content, lexer, formatter) def generate_key(words: List[str], length: int) -> str: choices = [] for _ in range(length): choices.append(secrets.choice(words)) return "-".join(word for word in choices).lower() @dataclass class PasteRow: key: str date: datetime size: int syntax: Optional[str] class Database: def __init__(self, connection: sqlite3.Connection): self.connection = connection async def insert(self, paste: PasteRow): def do_insert(): self.connection.cursor().execute( "insert into pastes values(?, ?, ?, ?)", ( paste.key, paste.date.isoformat(), paste.size, paste.syntax, ), ) self.connection.commit() await asyncio.to_thread(do_insert) async def delete(self, key: str): def do_delete(): self.connection.cursor().execute("delete pastes where pastes.key=?", (key,)) self.connection.commit() await asyncio.to_thread(do_delete) async def exists(self, key: str) -> bool: def do_exists(): return ( self.connection.cursor() .execute("select * from pastes where pastes.key=?", (key,)) .fetchone() is not None ) return await asyncio.to_thread(do_exists()) async def oldest(self) -> Optional[str]: def do_oldest(): return ( self.connection.cursor() .execute( "select pastes.key from pastes order by pastes.datetime limit 1" ) .fetchone() ) result = await asyncio.to_thread(do_oldest) if result is not None: return result[0] else: return None async def storage_use(self) -> Optional[int]: def do_storage_use(): return ( self.connection.cursor() .execute("select sum(pastes.size) from pastes") .fetchone() ) result = await asyncio.to_thread(do_storage_use) if result is not None: return result[0] else: return None @dataclass class AppConfig: site: str content_length_max_bytes: int s3_max_bytes: int key_length: int dictionary: List[str] default_style: str line_numbers: str | bool class App: def __init__( self, config: AppConfig, database: Database, bucket: s3.Bucket, ): self.database = database self.config = config self.bucket = bucket async def download(self, request: web.Request) -> web.Response: try: key = request.match_info["key"] except KeyError: return web.HTTPBadRequest(text="provide a key to fetch") if not await self.database.exists(key): return web.HTTPNotFound() req = self.bucket.get(key) try: async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get: if get.status == 200: data = await get.read() else: print( f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}" ) return web.HTTPInternalServerError() except Exception as e: print(f"failed to get {key} from s3: {e}", file=sys.stderr) return web.HTTPInternalServerError() def decompress(): return zstandard.decompress(data) try: decompressed = await asyncio.to_thread(decompress) except Exception as e: print(f"failed to decompress blob {key}: {e}", file=sys.stderr) return web.HTTPInternalServerError() try: text = decompressed.decode() except Exception as e: print(f"failed to decode blob: {key}: {e}", file=sys.stderr) return web.HTTPInternalServerError() syntax = request.query.get("syntax") raw = request.query.get("raw") if (style := request.query.get("style")) is None: style = self.config.default_style if raw is not None: return web.HTTPOk(text=text, content_type="text/plain") else: def render(): return pygmentize(text, syntax, style, self.config.line_numbers) highlighted = await asyncio.to_thread(render) return web.HTTPOk(text=highlighted, content_type="text/html") async def upload(self, request: web.Request) -> web.Response: try: await self.vacuum() except Exception as e: print(f"vacuum failed: {e}") return web.HTTPInternalServerError() match request.content_length: case int(i) if i > self.config.content_length_max_bytes: return web.HTTPBadRequest( text=f"max content length is {self.config.content_length_max_bytes}" ) case _: pass try: data = await request.read() except Exception as e: print(f"failed to read data: {e}", file=sys.stderr) return web.HTTPInternalServerError(text="failed to read data") try: data.decode() except UnicodeError: return web.HTTPBadRequest( text="content must be unicode only, no binary data is allowed" ) def compress(): return zstandard.compress(data) try: compressed = await asyncio.to_thread(compress) except Exception as e: print(f"failed to compress data: {e}", file=sys.stderr) return web.HTTPInternalServerError() key = generate_key(self.config.dictionary, self.config.key_length) req = self.bucket.put( key, len(compressed), "application/octet-stream", sha256(compressed).hexdigest(), ) try: async with aiohttp.ClientSession().put( req.url, headers=req.headers, data=compressed ) as put: if put.status != 200: print(f"failed to put {key} to bucket with status: {put.status}") return web.HTTPInternalServerError() except Exception as e: print(f"failed to put {key} to bucket: {e}") return web.HTTPInternalServerError() try: await self.database.insert( PasteRow( key, datetime.now(UTC), len(compressed), request.query.get("syntax") ) ) except Exception as e: print(f"failed to insert {key} into database: {e}", file=sys.stderr) return web.HTTPInternalServerError() return web.HTTPOk(text=f"https://{self.config.site}/paste/{key}") async def vacuum(self): while ( use := self.database.storage_use() ) is not None and use > self.config.s3_max_bytes: oldest = self.database.oldest() # If use is not None, there must be at least 1 paste, so we let the # type checker know that this is an unreachable case. assert oldest is not None req = self.bucket.delete(oldest) async with aiohttp.ClientSession().delete( req.url, headers=req.headers ) as delete: if delete.status != 200: raise Exception( f"failed to delete {oldest} from bucket with status: {delete.status}" ) self.database.delete(oldest) def main() -> int: parser = ArgumentParser() parser.add_argument("--host") parser.add_argument("--port", type=int) parser.add_argument("--path") parser.add_argument("--site", required=True) parser.add_argument("--content-length-max-bytes", type=int, required=True) parser.add_argument("--key-length", type=int, required=True) parser.add_argument("--dictionary", type=Path, required=True) parser.add_argument("--database", type=Path, required=True) parser.add_argument("--endpoint", required=True) parser.add_argument("--region", required=True) parser.add_argument("--bucket", required=True) parser.add_argument("--access-key", required=True) parser.add_argument("--secret-key", type=Path, required=True) parser.add_argument("--s3-max-bytes", type=int, required=True) parser.add_argument("--default-style", default="native") parser.add_argument("--line-numbers", action="store_true") parser.add_argument("--line-numbers-inline", action="store_true") args = parser.parse_args() try: secret_key = args.secret_key.read_text().strip() except Exception as e: print( f"failed to read secret key from: {str(args.secret_key)}: {e}", file=sys.stderr, ) return 1 try: dictionary = args.dictionary.read_text().split("\n") except Exception as e: print(f"failed to open dictionary: {str(args.dictionary)}: {e}") return 1 match [args.line_numbers, args.line_numbers_inline]: case [True, _]: line_numbers: str | bool = "table" case [_, True]: line_numbers = "inline" case [_, _]: line_numbers = False config = AppConfig( args.site, args.content_length_max_bytes, args.s3_max_bytes, args.key_length, dictionary, args.default_style, line_numbers, ) try: connection = sqlite3.connect(args.database) except Exception as e: print(f"failed to connect to database: {args.database}: {e}") return 1 try: database = Database(connection) except Exception as e: print(f"failed to open database: {e}", file=sys.stderr) return 1 bucket = s3.Bucket( args.endpoint, args.region, args.bucket, args.access_key, secret_key, ) webapp = web.Application() app = App(config, database, bucket) webapp.add_routes( [web.get("/paste/{key}", app.download), web.post("/paste", app.upload)] ) web.run_app(webapp, host=args.host, port=args.port, path=args.path) return 0 if __name__ == "__main__": sys.exit(main())