diff options
Diffstat (limited to 'src')
-rwxr-xr-x | src/paste.py | 371 |
1 files changed, 371 insertions, 0 deletions
diff --git a/src/paste.py b/src/paste.py new file mode 100755 index 0000000..baef4fc --- /dev/null +++ b/src/paste.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 + +import os +import sys +import io +import secrets +import sqlite3 +import zstandard +from datetime import datetime +from urllib.parse import parse_qs +from pathlib import Path +from dataclasses import dataclass +from typing import Optional, Dict, List, Any +from boto3.session import Session +from botocore.client import Config +from botocore.handlers import set_list_objects_encoding_type_url +from pygments import highlight +from pygments.lexers import guess_lexer, get_lexer_by_name +from pygments.formatters import HtmlFormatter + + +@dataclass +class Environment: + request_method: str + content_length: int + path: Optional[str] + parameters: Dict[str, List[str]] + content_length_max_bytes: int + access_key: str + secret_key: str + endpoint: str + region: str + bucket: str + signature_version: str + key_length: int + site: str + dict: Path + db_path: Path + + def __init__(self, environment: Dict[str, str]): + self.request_method = environment["REQUEST_METHOD"] + self.content_length = int(environment["CONTENT_LENGTH"]) + + try: + self.path = environment["PATH_INFO"] + except KeyError: + self.path = None + + try: + query_string = environment["QUERY_STRING"] + self.parameters = parse_qs(query_string) + except KeyError: + self.parameters = {} + + self.content_length_max_bytes = int( + environment["PASTE_CONTENT_LENGTH_MAX_BYTES"] + ) + self.access_key = environment["PASTE_ACCESS_KEY"] + self.secret_key = environment["PASTE_SECRET_KEY"] + self.endpoint = environment["PASTE_ENDPOINT"] + self.region = environment["PASTE_REGION"] + self.bucket = environment["PASTE_BUCKET_NAME"] + self.signature_version = environment["PASTE_SIGNATURE_VERSION"] + self.db_path = Path(environment["PASTE_DB"]) + self.key_length = int(environment["PASTE_KEY_LENGTH"]) + self.site = environment["PASTE_SITE"] + self.dict = Path(environment["PASTE_DICT"]) + self.storage_max_bytes = int(environment["PASTE_STORAGE_MAX_BYTES"]) + + +def connect_to_s3( + endpoint: str, + region: str, + signature_version: str, + access_key: str, + secret_key: str, +): + session = Session( + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + ) + + session.events.unregister( + "before-parameter-build.s3.ListObjects", set_list_objects_encoding_type_url + ) + + return session.resource( + "s3", + endpoint_url=endpoint, + config=Config(signature_version=signature_version), + ) + + +def key_in_db(cursor: sqlite3.Cursor, key: str) -> bool: + results = cursor.execute("select * from pastes where pastes.key=?", (key,)) + return results.fetchone() is not None + + +def get_syntax_from_db(cursor: sqlite3.Cursor, key: str) -> Optional[str]: + results = cursor.execute( + "select pastes.syntax from pastes where pastes.key=?", (key,) + ) + return results.fetchone() + + +def get_storage_use(cursor: sqlite3.Cursor) -> Optional[int]: + results = cursor.execute("select sum(pastes.size) from pastes") + + if (row := results.fetchone()) is not None: + return row[0] + else: + return None + + +def delete_oldest_paste(cursor: sqlite3.Cursor) -> Optional[str]: + results = cursor.execute( + "select pastes.key from pastes order by pastes.datetime limit 1" + ) + + if (row := results.fetchone()) is not None: + key = row[0] + + cursor.execute("delete from pastes where pastes.key=?", (key,)) + + return key + else: + return None + + +def generate_key(words: List[str], length: int) -> str: + choices = [] + for _ in range(length): + choices.append(secrets.choice(words)) + + return "-".join(word for word in choices).lower() + + +def pygmentize(content: str, syntax: Optional[str]) -> str: + if syntax is not None: + try: + lexer = get_lexer_by_name(syntax) + except Exception as e: + print(e, file=sys.stderr) + lexer = guess_lexer(content) + else: + lexer = guess_lexer(content) + + formatter = HtmlFormatter(full=True, cssclass="source") + + return highlight(content, lexer, formatter) + + +def main(): + try: + environment = Environment({key: val for key, val in os.environ.items()}) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to load environment: {e}", file=sys.stderr) + return + + if environment.request_method not in ["GET", "POST"]: + print("Status: 400\r\n") + print(f"unsupported request method: {environment.request_method}") + return + + try: + words = environment.dict.read_text().split("\n") + except Exception as e: + print("Status: 500\r\n") + print(f"failed to load dictonary: {e}", file=sys.stderr) + return + + try: + s3 = connect_to_s3( + environment.endpoint, + environment.region, + environment.signature_version, + environment.access_key, + environment.secret_key, + ) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to connect to s3: {e}", file=sys.stderr) + return + + try: + bucket = s3.Bucket(environment.bucket) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to connect to bucket: {e}", file=sys.stderr) + return + + try: + connection = sqlite3.connect(environment.db_path) + cursor = connection.cursor() + except Exception as e: + print("Status: 500\r\n") + print(f"failed to connect to database: {e}") + return + + match environment.request_method: + case "POST" if ( + environment.content_length > environment.content_length_max_bytes + ): + print("Status: 400\r\n") + print( + f"content length exceeded maximum length of {environment.content_length_max_bytes}" + ) + return + + case "POST": + # delete pastes to make size for new one + while ( + (storage_use := get_storage_use(cursor)) is not None + and storage_use + environment.content_length + > environment.storage_max_bytes + ): + match delete_oldest_paste(cursor): + case str(key): + try: + s3.Object(environment.bucket, key).delete() + except Exception as e: + print("Status: 500\r\n") + print( + f"failed to delete object: {key}: {e}", file=sys.stderr + ) + return + + try: + connection.commit() + print( + f"rotated {key} out to make room for new paste", + file=sys.stderr, + ) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to commit to database: {e}", file=sys.stderr) + return + case None: + break + + try: + syntax = environment.parameters["syntax"][0] + except KeyError: + syntax = None + except IndexError: + syntax = None + except Exception as e: + print("Status: 500\r\n") + print(f'failed to load query parameter "syntax": {e}', file=sys.stderr) + return + + try: + key = generate_key(words, environment.key_length) + while key_in_db(cursor, key): + key = generate_key(words, environment.key_length) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to query database: {e}", file=sys.stderr) + return + + content = sys.stdin.buffer.read(environment.content_length) + + try: + content.decode() + except Exception as e: + print("Status: 400\r\n") + print(f"failed to decode content: {e}") + return + + compressed = zstandard.compress(content) + buffer = io.BytesIO(compressed) + + try: + cursor.execute( + "insert into pastes values (?, ?, ?, ?)", + ( + key, + len(compressed), + datetime.now().isoformat(), + syntax, + ), + ) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to insert into database: {e}", file=sys.stderr) + return + + try: + bucket.upload_fileobj(buffer, key) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to upload paste: {e}", file=sys.stderr) + return + + try: + connection.commit() + except Exception as e: + print("Status: 500\r\n") + print(f"failed to commit to db: {e}") + return + + print("Status: 200\r\n") + print(f"{environment.site}/{key}") + return + + case "GET" if environment.path is None: + print("Status: 400\r\n") + print("please provide a key to fetch") + return + + case "GET": + # strip leading / from path + key = environment.path[1:] + + try: + exists = key_in_db(cursor, key) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to query db: {e}") + return + + if not exists: + print("Status: 404\r\n") + return + + buffer = io.BytesIO() + + try: + bucket.download_fileobj(key, buffer) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to download from bucket: {e}") + return + + buffer.seek(0) + + decompressed = zstandard.decompress(buffer.read()) + content = decompressed.decode() + + try: + syntax = environment.parameters["syntax"][0] + except KeyError: + try: + syntax = get_syntax_from_db(cursor, key) + except Exception as e: + print("Status: 500\r\n") + print(f"failed to query database: {e}") + return + + try: + if environment.parameters["raw"]: + raw = True + else: + raw = False + except KeyError: + raw = False + + if raw: + print("Status: 200\r\n") + print(content) + return + else: + highlighted = pygmentize(content, syntax) + print("Status: 200\r\n") + print(highlighted) + return + + +main() |