summaryrefslogtreecommitdiff
path: root/src/paste.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/paste.py')
-rwxr-xr-xsrc/paste.py371
1 files changed, 371 insertions, 0 deletions
diff --git a/src/paste.py b/src/paste.py
new file mode 100755
index 0000000..baef4fc
--- /dev/null
+++ b/src/paste.py
@@ -0,0 +1,371 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+import io
+import secrets
+import sqlite3
+import zstandard
+from datetime import datetime
+from urllib.parse import parse_qs
+from pathlib import Path
+from dataclasses import dataclass
+from typing import Optional, Dict, List, Any
+from boto3.session import Session
+from botocore.client import Config
+from botocore.handlers import set_list_objects_encoding_type_url
+from pygments import highlight
+from pygments.lexers import guess_lexer, get_lexer_by_name
+from pygments.formatters import HtmlFormatter
+
+
+@dataclass
+class Environment:
+ request_method: str
+ content_length: int
+ path: Optional[str]
+ parameters: Dict[str, List[str]]
+ content_length_max_bytes: int
+ access_key: str
+ secret_key: str
+ endpoint: str
+ region: str
+ bucket: str
+ signature_version: str
+ key_length: int
+ site: str
+ dict: Path
+ db_path: Path
+
+ def __init__(self, environment: Dict[str, str]):
+ self.request_method = environment["REQUEST_METHOD"]
+ self.content_length = int(environment["CONTENT_LENGTH"])
+
+ try:
+ self.path = environment["PATH_INFO"]
+ except KeyError:
+ self.path = None
+
+ try:
+ query_string = environment["QUERY_STRING"]
+ self.parameters = parse_qs(query_string)
+ except KeyError:
+ self.parameters = {}
+
+ self.content_length_max_bytes = int(
+ environment["PASTE_CONTENT_LENGTH_MAX_BYTES"]
+ )
+ self.access_key = environment["PASTE_ACCESS_KEY"]
+ self.secret_key = environment["PASTE_SECRET_KEY"]
+ self.endpoint = environment["PASTE_ENDPOINT"]
+ self.region = environment["PASTE_REGION"]
+ self.bucket = environment["PASTE_BUCKET_NAME"]
+ self.signature_version = environment["PASTE_SIGNATURE_VERSION"]
+ self.db_path = Path(environment["PASTE_DB"])
+ self.key_length = int(environment["PASTE_KEY_LENGTH"])
+ self.site = environment["PASTE_SITE"]
+ self.dict = Path(environment["PASTE_DICT"])
+ self.storage_max_bytes = int(environment["PASTE_STORAGE_MAX_BYTES"])
+
+
+def connect_to_s3(
+ endpoint: str,
+ region: str,
+ signature_version: str,
+ access_key: str,
+ secret_key: str,
+):
+ session = Session(
+ aws_access_key_id=access_key,
+ aws_secret_access_key=secret_key,
+ region_name=region,
+ )
+
+ session.events.unregister(
+ "before-parameter-build.s3.ListObjects", set_list_objects_encoding_type_url
+ )
+
+ return session.resource(
+ "s3",
+ endpoint_url=endpoint,
+ config=Config(signature_version=signature_version),
+ )
+
+
+def key_in_db(cursor: sqlite3.Cursor, key: str) -> bool:
+ results = cursor.execute("select * from pastes where pastes.key=?", (key,))
+ return results.fetchone() is not None
+
+
+def get_syntax_from_db(cursor: sqlite3.Cursor, key: str) -> Optional[str]:
+ results = cursor.execute(
+ "select pastes.syntax from pastes where pastes.key=?", (key,)
+ )
+ return results.fetchone()
+
+
+def get_storage_use(cursor: sqlite3.Cursor) -> Optional[int]:
+ results = cursor.execute("select sum(pastes.size) from pastes")
+
+ if (row := results.fetchone()) is not None:
+ return row[0]
+ else:
+ return None
+
+
+def delete_oldest_paste(cursor: sqlite3.Cursor) -> Optional[str]:
+ results = cursor.execute(
+ "select pastes.key from pastes order by pastes.datetime limit 1"
+ )
+
+ if (row := results.fetchone()) is not None:
+ key = row[0]
+
+ cursor.execute("delete from pastes where pastes.key=?", (key,))
+
+ return key
+ else:
+ return None
+
+
+def generate_key(words: List[str], length: int) -> str:
+ choices = []
+ for _ in range(length):
+ choices.append(secrets.choice(words))
+
+ return "-".join(word for word in choices).lower()
+
+
+def pygmentize(content: str, syntax: Optional[str]) -> str:
+ if syntax is not None:
+ try:
+ lexer = get_lexer_by_name(syntax)
+ except Exception as e:
+ print(e, file=sys.stderr)
+ lexer = guess_lexer(content)
+ else:
+ lexer = guess_lexer(content)
+
+ formatter = HtmlFormatter(full=True, cssclass="source")
+
+ return highlight(content, lexer, formatter)
+
+
+def main():
+ try:
+ environment = Environment({key: val for key, val in os.environ.items()})
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to load environment: {e}", file=sys.stderr)
+ return
+
+ if environment.request_method not in ["GET", "POST"]:
+ print("Status: 400\r\n")
+ print(f"unsupported request method: {environment.request_method}")
+ return
+
+ try:
+ words = environment.dict.read_text().split("\n")
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to load dictonary: {e}", file=sys.stderr)
+ return
+
+ try:
+ s3 = connect_to_s3(
+ environment.endpoint,
+ environment.region,
+ environment.signature_version,
+ environment.access_key,
+ environment.secret_key,
+ )
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to connect to s3: {e}", file=sys.stderr)
+ return
+
+ try:
+ bucket = s3.Bucket(environment.bucket)
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to connect to bucket: {e}", file=sys.stderr)
+ return
+
+ try:
+ connection = sqlite3.connect(environment.db_path)
+ cursor = connection.cursor()
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to connect to database: {e}")
+ return
+
+ match environment.request_method:
+ case "POST" if (
+ environment.content_length > environment.content_length_max_bytes
+ ):
+ print("Status: 400\r\n")
+ print(
+ f"content length exceeded maximum length of {environment.content_length_max_bytes}"
+ )
+ return
+
+ case "POST":
+ # delete pastes to make size for new one
+ while (
+ (storage_use := get_storage_use(cursor)) is not None
+ and storage_use + environment.content_length
+ > environment.storage_max_bytes
+ ):
+ match delete_oldest_paste(cursor):
+ case str(key):
+ try:
+ s3.Object(environment.bucket, key).delete()
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(
+ f"failed to delete object: {key}: {e}", file=sys.stderr
+ )
+ return
+
+ try:
+ connection.commit()
+ print(
+ f"rotated {key} out to make room for new paste",
+ file=sys.stderr,
+ )
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to commit to database: {e}", file=sys.stderr)
+ return
+ case None:
+ break
+
+ try:
+ syntax = environment.parameters["syntax"][0]
+ except KeyError:
+ syntax = None
+ except IndexError:
+ syntax = None
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f'failed to load query parameter "syntax": {e}', file=sys.stderr)
+ return
+
+ try:
+ key = generate_key(words, environment.key_length)
+ while key_in_db(cursor, key):
+ key = generate_key(words, environment.key_length)
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to query database: {e}", file=sys.stderr)
+ return
+
+ content = sys.stdin.buffer.read(environment.content_length)
+
+ try:
+ content.decode()
+ except Exception as e:
+ print("Status: 400\r\n")
+ print(f"failed to decode content: {e}")
+ return
+
+ compressed = zstandard.compress(content)
+ buffer = io.BytesIO(compressed)
+
+ try:
+ cursor.execute(
+ "insert into pastes values (?, ?, ?, ?)",
+ (
+ key,
+ len(compressed),
+ datetime.now().isoformat(),
+ syntax,
+ ),
+ )
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to insert into database: {e}", file=sys.stderr)
+ return
+
+ try:
+ bucket.upload_fileobj(buffer, key)
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to upload paste: {e}", file=sys.stderr)
+ return
+
+ try:
+ connection.commit()
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to commit to db: {e}")
+ return
+
+ print("Status: 200\r\n")
+ print(f"{environment.site}/{key}")
+ return
+
+ case "GET" if environment.path is None:
+ print("Status: 400\r\n")
+ print("please provide a key to fetch")
+ return
+
+ case "GET":
+ # strip leading / from path
+ key = environment.path[1:]
+
+ try:
+ exists = key_in_db(cursor, key)
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to query db: {e}")
+ return
+
+ if not exists:
+ print("Status: 404\r\n")
+ return
+
+ buffer = io.BytesIO()
+
+ try:
+ bucket.download_fileobj(key, buffer)
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to download from bucket: {e}")
+ return
+
+ buffer.seek(0)
+
+ decompressed = zstandard.decompress(buffer.read())
+ content = decompressed.decode()
+
+ try:
+ syntax = environment.parameters["syntax"][0]
+ except KeyError:
+ try:
+ syntax = get_syntax_from_db(cursor, key)
+ except Exception as e:
+ print("Status: 500\r\n")
+ print(f"failed to query database: {e}")
+ return
+
+ try:
+ if environment.parameters["raw"]:
+ raw = True
+ else:
+ raw = False
+ except KeyError:
+ raw = False
+
+ if raw:
+ print("Status: 200\r\n")
+ print(content)
+ return
+ else:
+ highlighted = pygmentize(content, syntax)
+ print("Status: 200\r\n")
+ print(highlighted)
+ return
+
+
+main()