summaryrefslogtreecommitdiff
path: root/paste
diff options
context:
space:
mode:
Diffstat (limited to 'paste')
-rw-r--r--paste/__main__.py407
-rw-r--r--paste/s3.py104
2 files changed, 511 insertions, 0 deletions
diff --git a/paste/__main__.py b/paste/__main__.py
new file mode 100644
index 0000000..fb1b413
--- /dev/null
+++ b/paste/__main__.py
@@ -0,0 +1,407 @@
+# Copyright (C) 2025 John Turner
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+import sys
+import asyncio
+import sqlite3
+import secrets
+import aiohttp
+import zstandard
+import s3 as s3
+from hashlib import sha256
+from argparse import ArgumentParser
+from aiohttp import web
+from datetime import datetime, UTC
+from pathlib import Path
+from dataclasses import dataclass
+from typing import Optional, List
+from pygments import highlight
+from pygments.lexers import guess_lexer, get_lexer_by_name
+from pygments.formatters import HtmlFormatter
+from pygments.styles import get_style_by_name
+
+
+def pygmentize(
+ content: str, syntax: Optional[str], style: str, line_numbers: str | bool
+) -> str:
+ if syntax is not None:
+ try:
+ lexer = get_lexer_by_name(syntax)
+ except Exception as e:
+ print(e, file=sys.stderr)
+ lexer = guess_lexer(content)
+ else:
+ lexer = guess_lexer(content)
+
+ try:
+ style = get_style_by_name(style)
+ except Exception as e:
+ print(e, file=sys.stderr)
+ style = "default"
+
+ formatter = HtmlFormatter(full=True, style=style, linenos="table")
+
+ return highlight(content, lexer, formatter)
+
+
+def generate_key(words: List[str], length: int) -> str:
+ choices = []
+ for _ in range(length):
+ choices.append(secrets.choice(words))
+
+ return "-".join(word for word in choices).lower()
+
+
+@dataclass
+class PasteRow:
+ key: str
+ date: datetime
+ size: int
+ syntax: Optional[str]
+
+
+class Database:
+
+ def __init__(self, connection: sqlite3.Connection):
+ self.connection = connection
+
+ async def insert(self, paste: PasteRow):
+ def do_insert():
+ self.connection.cursor().execute(
+ "insert into pastes values(?, ?, ?, ?)",
+ (
+ paste.key,
+ paste.date.isoformat(),
+ paste.size,
+ paste.syntax,
+ ),
+ )
+
+ self.connection.commit()
+
+ await asyncio.to_thread(do_insert)
+
+ async def delete(self, key: str):
+ def do_delete():
+ self.connection.cursor().execute("delete pastes where pastes.key=?", (key,))
+
+ self.connection.commit()
+
+ await asyncio.to_thread(do_delete)
+
+ async def exists(self, key: str) -> bool:
+ def do_exists():
+ return (
+ self.connection.cursor()
+ .execute("select * from pastes where pastes.key=?", (key,))
+ .fetchone()
+ is not None
+ )
+
+ return await asyncio.to_thread(do_exists())
+
+ async def oldest(self) -> Optional[str]:
+ def do_oldest():
+ return (
+ self.connection.cursor()
+ .execute(
+ "select pastes.key from pastes order by pastes.datetime limit 1"
+ )
+ .fetchone()
+ )
+
+ result = await asyncio.to_thread(do_oldest)
+
+ if result is not None:
+ return result[0]
+ else:
+ return None
+
+ async def storage_use(self) -> Optional[int]:
+ def do_storage_use():
+ return (
+ self.connection.cursor()
+ .execute("select sum(pastes.size) from pastes")
+ .fetchone()
+ )
+
+ result = await asyncio.to_thread(do_storage_use)
+
+ if result is not None:
+ return result[0]
+ else:
+ return None
+
+
+@dataclass
+class AppConfig:
+ site: str
+ content_length_max_bytes: int
+ s3_max_bytes: int
+ key_length: int
+ dictionary: List[str]
+ default_style: str
+ line_numbers: str | bool
+
+
+class App:
+
+ def __init__(
+ self,
+ config: AppConfig,
+ database: Database,
+ bucket: s3.Bucket,
+ ):
+ self.database = database
+ self.config = config
+ self.bucket = bucket
+
+ async def download(self, request: web.Request) -> web.Response:
+ try:
+ key = request.match_info["key"]
+ except KeyError:
+ return web.HTTPBadRequest(text="provide a key to fetch")
+
+ if not await self.database.exists(key):
+ return web.HTTPNotFound()
+
+ req = self.bucket.get(key)
+
+ try:
+ async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get:
+ if get.status == 200:
+ data = await get.read()
+ else:
+ print(
+ f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}"
+ )
+ return web.HTTPInternalServerError()
+ except Exception as e:
+ print(f"failed to get {key} from s3: {e}", file=sys.stderr)
+ return web.HTTPInternalServerError()
+
+ def decompress():
+ return zstandard.decompress(data)
+
+ try:
+ decompressed = await asyncio.to_thread(decompress)
+ except Exception as e:
+ print(f"failed to decompress blob {key}: {e}", file=sys.stderr)
+ return web.HTTPInternalServerError()
+
+ try:
+ text = decompressed.decode()
+ except Exception as e:
+ print(f"failed to decode blob: {key}: {e}", file=sys.stderr)
+ return web.HTTPInternalServerError()
+
+ syntax = request.query.get("syntax")
+ raw = request.query.get("raw")
+
+ if (style := request.query.get("style")) is None:
+ style = self.config.default_style
+
+ if raw is not None:
+ return web.HTTPOk(text=text, content_type="text/plain")
+ else:
+
+ def render():
+ return pygmentize(text, syntax, style, self.config.line_numbers)
+
+ highlighted = await asyncio.to_thread(render)
+
+ return web.HTTPOk(text=highlighted, content_type="text/html")
+
+ async def upload(self, request: web.Request) -> web.Response:
+ try:
+ await self.vacuum()
+ except Exception as e:
+ print(f"vacuum failed: {e}")
+ return web.HTTPInternalServerError()
+
+ match request.content_length:
+ case int(i) if i > self.config.content_length_max_bytes:
+ return web.HTTPBadRequest(
+ text=f"max content length is {self.config.content_length_max_bytes}"
+ )
+ case _:
+ pass
+
+ try:
+ data = await request.read()
+ except Exception as e:
+ print(f"failed to read data: {e}", file=sys.stderr)
+ return web.HTTPInternalServerError(text="failed to read data")
+
+ try:
+ data.decode()
+ except UnicodeError:
+ return web.HTTPBadRequest(
+ text="content must be unicode only, no binary data is allowed"
+ )
+
+ def compress():
+ return zstandard.compress(data)
+
+ try:
+ compressed = await asyncio.to_thread(compress)
+ except Exception as e:
+ print(f"failed to compress data: {e}", file=sys.stderr)
+ return web.HTTPInternalServerError()
+
+ key = generate_key(self.config.dictionary, self.config.key_length)
+
+ req = self.bucket.put(
+ key,
+ len(compressed),
+ "application/octet-stream",
+ sha256(compressed).hexdigest(),
+ )
+
+ try:
+ async with aiohttp.ClientSession().put(
+ req.url, headers=req.headers, data=compressed
+ ) as put:
+ if put.status != 200:
+ print(f"failed to put {key} to bucket with status: {put.status}")
+ return web.HTTPInternalServerError()
+ except Exception as e:
+ print(f"failed to put {key} to bucket: {e}")
+ return web.HTTPInternalServerError()
+
+ try:
+ await self.database.insert(
+ PasteRow(
+ key, datetime.now(UTC), len(compressed), request.query.get("syntax")
+ )
+ )
+ except Exception as e:
+ print(f"failed to insert {key} into database: {e}", file=sys.stderr)
+ return web.HTTPInternalServerError()
+
+ return web.HTTPOk(text=f"https://{self.config.site}/paste/{key}")
+
+ async def vacuum(self):
+ while (
+ use := self.database.storage_use()
+ ) is not None and use > self.config.s3_max_bytes:
+ oldest = self.database.oldest()
+ # If use is not None, there must be at least 1 paste, so we let the
+ # type checker know that this is an unreachable case.
+ assert oldest is not None
+
+ req = self.bucket.delete(oldest)
+
+ async with aiohttp.ClientSession().delete(
+ req.url, headers=req.headers
+ ) as delete:
+ if delete.status != 200:
+ raise Exception(
+ f"failed to delete {oldest} from bucket with status: {delete.status}"
+ )
+
+ self.database.delete(oldest)
+
+
+def main() -> int:
+ parser = ArgumentParser()
+ parser.add_argument("--host")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--path")
+ parser.add_argument("--site", required=True)
+ parser.add_argument("--content-length-max-bytes", type=int, required=True)
+ parser.add_argument("--key-length", type=int, required=True)
+ parser.add_argument("--dictionary", type=Path, required=True)
+ parser.add_argument("--database", type=Path, required=True)
+ parser.add_argument("--endpoint", required=True)
+ parser.add_argument("--region", required=True)
+ parser.add_argument("--bucket", required=True)
+ parser.add_argument("--access-key", required=True)
+ parser.add_argument("--secret-key", type=Path, required=True)
+ parser.add_argument("--s3-max-bytes", type=int, required=True)
+ parser.add_argument("--default-style", default="native")
+ parser.add_argument("--line-numbers", action="store_true")
+ parser.add_argument("--line-numbers-inline", action="store_true")
+
+ args = parser.parse_args()
+
+ try:
+ secret_key = args.secret_key.read_text().strip()
+ except Exception as e:
+ print(
+ f"failed to read secret key from: {str(args.secret_key)}: {e}",
+ file=sys.stderr,
+ )
+ return 1
+
+ try:
+ dictionary = args.dictionary.read_text().split("\n")
+ except Exception as e:
+ print(f"failed to open dictionary: {str(args.dictionary)}: {e}")
+ return 1
+
+ match [args.line_numbers, args.line_numbers_inline]:
+ case [True, _]:
+ line_numbers: str | bool = "table"
+ case [_, True]:
+ line_numbers = "inline"
+ case [_, _]:
+ line_numbers = False
+
+ config = AppConfig(
+ args.site,
+ args.content_length_max_bytes,
+ args.s3_max_bytes,
+ args.key_length,
+ dictionary,
+ args.default_style,
+ line_numbers,
+ )
+
+ try:
+ connection = sqlite3.connect(args.database)
+ except Exception as e:
+ print(f"failed to connect to database: {args.database}: {e}")
+ return 1
+
+ try:
+ database = Database(connection)
+ except Exception as e:
+ print(f"failed to open database: {e}", file=sys.stderr)
+ return 1
+
+ bucket = s3.Bucket(
+ args.endpoint,
+ args.region,
+ args.bucket,
+ args.access_key,
+ secret_key,
+ )
+
+ webapp = web.Application()
+
+ app = App(config, database, bucket)
+
+ webapp.add_routes(
+ [web.get("/paste/{key}", app.download), web.post("/paste", app.upload)]
+ )
+
+ web.run_app(webapp, host=args.host, port=args.port, path=args.path)
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/paste/s3.py b/paste/s3.py
new file mode 100644
index 0000000..a9b2b21
--- /dev/null
+++ b/paste/s3.py
@@ -0,0 +1,104 @@
+from datetime import datetime, UTC
+from typing import Dict
+from dataclasses import dataclass
+from bozo4 import s3v4_sign_request, s3v4_datetime_string
+
+
+@dataclass
+class Request:
+ url: str
+ headers: Dict[str, str]
+
+
+@dataclass
+class Bucket:
+ endpoint: str
+ region: str
+ bucket: str
+ access_key: str
+ secret_key: str
+
+ def get(self, key: str) -> Request:
+ now = datetime.now(UTC)
+
+ headers = {
+ "Host": self.endpoint,
+ "X-Amz-Date": s3v4_datetime_string(now),
+ "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD",
+ }
+
+ auth = s3v4_sign_request(
+ endpoint=self.endpoint,
+ region=self.region,
+ access_key=self.access_key,
+ secret_key=self.secret_key,
+ request_method="GET",
+ date=now,
+ payload_hash="UNSIGNED-PAYLOAD",
+ uri=f"/{self.bucket}/{key}",
+ parameters={},
+ headers=headers,
+ service="s3",
+ )
+
+ headers["Authorization"] = auth
+
+ return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
+
+ def put(
+ self, key: str, content_length: int, content_type: str, payload_hash: str
+ ) -> Request:
+ now = datetime.now(UTC)
+
+ headers = {
+ "Host": self.endpoint,
+ "X-Amz-Date": s3v4_datetime_string(now),
+ "X-Amz-Content-SHA256": payload_hash,
+ "Content-Length": str(content_length),
+ "Content-Type": content_type,
+ }
+
+ auth = s3v4_sign_request(
+ endpoint=self.endpoint,
+ region=self.region,
+ access_key=self.access_key,
+ secret_key=self.secret_key,
+ request_method="PUT",
+ date=now,
+ payload_hash=payload_hash,
+ uri=f"/{self.bucket}/{key}",
+ parameters={},
+ headers=headers,
+ service="s3",
+ )
+
+ headers["Authorization"] = auth
+
+ return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
+
+ def delete(self, key: str) -> Request:
+ now = datetime.now(UTC)
+
+ headers = {
+ "Host": self.endpoint,
+ "X-Amz-Date": s3v4_datetime_string(now),
+ "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD",
+ }
+
+ auth = s3v4_sign_request(
+ endpoint=self.endpoint,
+ region=self.region,
+ access_key=self.access_key,
+ secret_key=self.secret_key,
+ request_method="DELETE",
+ date=now,
+ payload_hash="UNSIGNED-PAYLOAD",
+ uri=f"/{self.bucket}/{key}",
+ parameters={},
+ headers=headers,
+ service="s3",
+ )
+
+ headers["Authorization"] = auth
+
+ return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)