summaryrefslogtreecommitdiff
path: root/paste
diff options
context:
space:
mode:
Diffstat (limited to 'paste')
-rw-r--r--paste/__main__.py385
-rw-r--r--paste/s3.py105
2 files changed, 0 insertions, 490 deletions
diff --git a/paste/__main__.py b/paste/__main__.py
deleted file mode 100644
index 21ac508..0000000
--- a/paste/__main__.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# Copyright (C) 2025 John Turner
-
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see <https://www.gnu.org/licenses/>.
-
-import sys
-import asyncio
-import secrets
-import aiohttp
-import aiosqlite
-import zstandard
-from . import s3
-from hashlib import sha256
-from argparse import ArgumentParser
-from aiohttp import web
-from datetime import datetime, UTC
-from pathlib import Path
-from dataclasses import dataclass
-from typing import Optional, List
-from pygments import highlight
-from pygments.lexers import guess_lexer, get_lexer_by_name
-from pygments.formatters import HtmlFormatter
-from pygments.styles import get_style_by_name
-
-
-def pygmentize(
- content: str, syntax: Optional[str], style: str, line_numbers: str | bool
-) -> str:
- if syntax is not None:
- try:
- lexer = get_lexer_by_name(syntax)
- except Exception as e:
- print(e, file=sys.stderr)
- lexer = guess_lexer(content)
- else:
- lexer = guess_lexer(content)
-
- try:
- style = get_style_by_name(style)
- except Exception as e:
- print(e, file=sys.stderr)
- style = "default"
-
- formatter = HtmlFormatter(full=True, style=style, linenos="table")
-
- return highlight(content, lexer, formatter)
-
-
-def generate_key(words: List[str], length: int) -> str:
- choices = []
- for _ in range(length):
- choices.append(secrets.choice(words))
-
- return "-".join(word for word in choices).lower()
-
-
-@dataclass
-class PasteRow:
- key: str
- date: datetime
- size: int
- syntax: Optional[str]
-
-
-class Database:
-
- def __init__(self, path: Path):
- self.path = path
-
- async def insert(self, paste: PasteRow):
- async with aiosqlite.connect(self.path) as db:
- await db.execute(
- "insert into pastes values(?, ?, ?, ?)",
- (
- paste.key,
- paste.date.isoformat(),
- paste.size,
- paste.syntax,
- ),
- )
-
- await db.commit()
-
- async def delete(self, key: str):
- async with aiosqlite.connect(self.path) as db:
- await db.execute("delete from pastes where key=?", (key,))
-
- await db.commit()
-
- async def exists(self, key: str) -> bool:
- async with aiosqlite.connect(self.path) as db:
- result = await db.execute("select 1 from pastes where pastes.key=?", (key,))
-
- return await result.fetchone() is not None
-
- async def oldest(self) -> Optional[str]:
- async with aiosqlite.connect(self.path) as db:
- result = await db.execute(
- "select pastes.key from pastes order by pastes.datetime limit 1",
- )
-
- row = await result.fetchone()
-
- return row[0]
-
- async def storage_use(self) -> Optional[int]:
- async with aiosqlite.connect(self.path) as db:
- result = await db.execute("select sum(pastes.size) from pastes")
-
- row = await result.fetchone()
-
- return row[0]
-
-
-@dataclass
-class AppConfig:
- site: str
- content_length_max_bytes: int
- s3_max_bytes: int
- key_length: int
- dictionary: List[str]
- default_style: str
- line_numbers: str | bool
-
-
-class App:
-
- def __init__(
- self,
- config: AppConfig,
- database: Database,
- bucket: s3.Bucket,
- ):
- self.database = database
- self.config = config
- self.bucket = bucket
-
- async def download(self, request: web.Request) -> web.Response:
- try:
- key = request.match_info["key"]
- except KeyError:
- return web.HTTPBadRequest(text="provide a key to fetch")
-
- if not await self.database.exists(key):
- return web.HTTPNotFound()
-
- req = self.bucket.get(key)
-
- try:
- async with aiohttp.ClientSession().get(req.url, headers=req.headers) as get:
- if get.status == 200:
- data = await get.read()
- else:
- print(
- f"{self.bucket.endpoint} returned status ({get.status}) while fetching {key}"
- )
- return web.HTTPInternalServerError()
- except Exception as e:
- print(f"failed to get {key} from s3: {e}", file=sys.stderr)
- return web.HTTPInternalServerError()
-
- def decompress():
- return zstandard.decompress(data)
-
- try:
- decompressed = await asyncio.to_thread(decompress)
- except Exception as e:
- print(f"failed to decompress blob {key}: {e}", file=sys.stderr)
- return web.HTTPInternalServerError()
-
- try:
- text = decompressed.decode()
- except Exception as e:
- print(f"failed to decode blob: {key}: {e}", file=sys.stderr)
- return web.HTTPInternalServerError()
-
- syntax = request.query.get("syntax")
- raw = request.query.get("raw")
-
- if (style := request.query.get("style")) is None:
- style = self.config.default_style
-
- if raw is not None:
- return web.HTTPOk(text=text, content_type="text/plain")
- else:
-
- def render():
- return pygmentize(text, syntax, style, self.config.line_numbers)
-
- highlighted = await asyncio.to_thread(render)
-
- return web.HTTPOk(text=highlighted, content_type="text/html")
-
- async def upload(self, request: web.Request) -> web.Response:
- try:
- await self.vacuum()
- except Exception as e:
- print(f"vacuum failed: {e}")
- return web.HTTPInternalServerError()
-
- match request.content_length:
- case int(i) if i > self.config.content_length_max_bytes:
- return web.HTTPBadRequest(
- text=f"max content length is {self.config.content_length_max_bytes}"
- )
- case _:
- pass
-
- try:
- data = await request.read()
- except Exception as e:
- print(f"failed to read data: {e}", file=sys.stderr)
- return web.HTTPInternalServerError(text="failed to read data")
-
- try:
- data.decode()
- except UnicodeError:
- return web.HTTPBadRequest(
- text="content must be unicode only, no binary data is allowed"
- )
-
- def compress():
- return zstandard.compress(data)
-
- try:
- compressed = await asyncio.to_thread(compress)
- except Exception as e:
- print(f"failed to compress data: {e}", file=sys.stderr)
- return web.HTTPInternalServerError()
-
- key = generate_key(self.config.dictionary, self.config.key_length)
-
- req = self.bucket.put(
- key,
- len(compressed),
- "application/octet-stream",
- sha256(compressed).hexdigest(),
- )
-
- try:
- async with aiohttp.ClientSession().put(
- req.url, headers=req.headers, data=compressed
- ) as put:
- if put.status != 200:
- print(f"failed to put {key} to bucket with status: {put.status}")
- return web.HTTPInternalServerError()
- except Exception as e:
- print(f"failed to put {key} to bucket: {e}")
- return web.HTTPInternalServerError()
-
- try:
- await self.database.insert(
- PasteRow(
- key, datetime.now(UTC), len(compressed), request.query.get("syntax")
- )
- )
- except Exception as e:
- print(f"failed to insert {key} into database: {e}", file=sys.stderr)
- return web.HTTPInternalServerError()
-
- return web.HTTPOk(text=f"https://{self.config.site}/paste/{key}")
-
- async def vacuum(self):
- while (
- use := await self.database.storage_use()
- ) is not None and use > self.config.s3_max_bytes:
- oldest = await self.database.oldest()
- # If use is not None, there must be at least 1 paste, so we let the
- # type checker know that this is an unreachable case.
- assert oldest is not None
-
- req = self.bucket.delete(oldest)
-
- async with aiohttp.ClientSession().delete(
- req.url, headers=req.headers
- ) as delete:
- if delete.status != 200:
- print(
- f"failed to delete {oldest}: got status {delete.status}",
- file=sys.stderr,
- )
- else:
- print(f"successfully deleted {oldest}", file=sys.stderr)
-
- await self.database.delete(oldest)
-
-
-def main() -> int:
- parser = ArgumentParser()
- parser.add_argument("--host")
- parser.add_argument("--port", type=int)
- parser.add_argument("--path")
- parser.add_argument("--site", required=True)
- parser.add_argument("--content-length-max-bytes", type=int, required=True)
- parser.add_argument("--key-length", type=int, required=True)
- parser.add_argument("--dictionary", type=Path, required=True)
- parser.add_argument("--database", type=Path, required=True)
- parser.add_argument("--endpoint", required=True)
- parser.add_argument("--region", required=True)
- parser.add_argument("--bucket", required=True)
- parser.add_argument("--access-key", required=True)
- parser.add_argument("--secret-key", type=Path, required=True)
- parser.add_argument("--s3-max-bytes", type=int, required=True)
- parser.add_argument("--default-style", default="native")
- parser.add_argument("--line-numbers", action="store_true")
- parser.add_argument("--line-numbers-inline", action="store_true")
-
- args = parser.parse_args()
-
- try:
- secret_key = args.secret_key.read_text().strip()
- except Exception as e:
- print(
- f"failed to read secret key from: {str(args.secret_key)}: {e}",
- file=sys.stderr,
- )
- return 1
-
- try:
- dictionary = args.dictionary.read_text().split("\n")
- except Exception as e:
- print(f"failed to open dictionary: {str(args.dictionary)}: {e}")
- return 1
-
- match [args.line_numbers, args.line_numbers_inline]:
- case [True, _]:
- line_numbers: str | bool = "table"
- case [_, True]:
- line_numbers = "inline"
- case [_, _]:
- line_numbers = False
-
- config = AppConfig(
- args.site,
- args.content_length_max_bytes,
- args.s3_max_bytes,
- args.key_length,
- dictionary,
- args.default_style,
- line_numbers,
- )
-
- if not args.database.is_file():
- print(f"{args.database} does not exist or is not a file", file=sys.stderr)
- return 1
-
- try:
- database = Database(args.database)
- except Exception as e:
- print(f"failed to open database: {e}", file=sys.stderr)
- return 1
-
- bucket = s3.Bucket(
- args.endpoint,
- args.region,
- args.bucket,
- args.access_key,
- secret_key,
- )
-
- webapp = web.Application()
-
- app = App(config, database, bucket)
-
- webapp.add_routes(
- [web.get("/paste/{key}", app.download), web.post("/paste", app.upload)]
- )
-
- web.run_app(webapp, host=args.host, port=args.port, path=args.path)
-
- return 0
-
-
-if __name__ == "__main__":
- sys.exit(main())
diff --git a/paste/s3.py b/paste/s3.py
deleted file mode 100644
index 673e428..0000000
--- a/paste/s3.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from datetime import datetime, UTC
-from typing import Dict
-from dataclasses import dataclass
-from bozo4 import s3v4_sign_request, s3v4_datetime_string
-
-
-@dataclass
-class Request:
- url: str
- headers: Dict[str, str]
-
-
-@dataclass
-class Bucket:
- endpoint: str
- region: str
- bucket: str
- access_key: str
- secret_key: str
-
- def get(self, key: str) -> Request:
- now = datetime.now(UTC)
-
- headers = {
- "Host": self.endpoint,
- "X-Amz-Date": s3v4_datetime_string(now),
- "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD",
- }
-
- auth = s3v4_sign_request(
- endpoint=self.endpoint,
- region=self.region,
- access_key=self.access_key,
- secret_key=self.secret_key,
- request_method="GET",
- date=now,
- payload_hash="UNSIGNED-PAYLOAD",
- uri=f"/{self.bucket}/{key}",
- parameters={},
- headers=headers,
- service="s3",
- )
-
- headers["Authorization"] = auth
-
- return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
-
- def put(
- self, key: str, content_length: int, content_type: str, payload_hash: str
- ) -> Request:
- now = datetime.now(UTC)
-
- headers = {
- "Host": self.endpoint,
- "X-Amz-Date": s3v4_datetime_string(now),
- "X-Amz-Content-SHA256": payload_hash,
- "Content-Length": str(content_length),
- "Content-Type": content_type,
- }
-
- auth = s3v4_sign_request(
- endpoint=self.endpoint,
- region=self.region,
- access_key=self.access_key,
- secret_key=self.secret_key,
- request_method="PUT",
- date=now,
- payload_hash=payload_hash,
- uri=f"/{self.bucket}/{key}",
- parameters={},
- headers=headers,
- service="s3",
- )
-
- headers["Authorization"] = auth
-
- return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)
-
- def delete(self, key: str) -> Request:
- now = datetime.now(UTC)
-
- headers = {
- "Host": self.endpoint,
- "X-Amz-Date": s3v4_datetime_string(now),
- "X-Amz-Content-SHA256": "UNSIGNED-PAYLOAD",
- "Content-Length": "0",
- }
-
- auth = s3v4_sign_request(
- endpoint=self.endpoint,
- region=self.region,
- access_key=self.access_key,
- secret_key=self.secret_key,
- request_method="DELETE",
- date=now,
- payload_hash="UNSIGNED-PAYLOAD",
- uri=f"/{self.bucket}/{key}",
- parameters={},
- headers=headers,
- service="s3",
- )
-
- headers["Authorization"] = auth
-
- return Request(f"https://{self.endpoint}/{self.bucket}/{key}", headers)