summaryrefslogtreecommitdiff
path: root/pypaste/server/s3/__init__.py
blob: d2cf7b2c7934e1c80a98c7313f2f261137f03acb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (C) 2025 John Turner

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import asyncio
import zstandard
import aiosqlite
from pypaste.server import Storage, Paste
from pypaste.server.s3.bucket import Bucket
from dataclasses import dataclass
from typing import Optional


@dataclass
class S3(Storage):
    connection: aiosqlite.Connection

    def __init__(
        self,
        connection: aiosqlite.Connection,
        endpoint: str,
        region: str,
        bucket: str,
        access_key: str,
        secret_key: str,
    ):
        self.connection = connection
        self.bucket = Bucket(endpoint, region, bucket, access_key, secret_key)

    async def setup(self) -> None:
        await self.connection.execute("create table if not exists s3(key text)")
        await self.connection.commit()

    async def insert(self, paste: Paste) -> None:
        def compress():
            return zstandard.compress(paste.text.encode())

        compressed = await asyncio.to_thread(compress)

        await self.connection.execute(
            "insert into pastes values(?, ?, ?, ?)",
            (paste.key, paste.dt.isoformat(), len(compressed), paste.syntax),
        )

        try:
            await self.bucket.put(paste.key, compressed)
            await self.connection.commit()
        except Exception as e:
            await self.connection.rollback()
            raise e

    async def retrieve(self, key: str) -> Optional[Paste]:
        if not await self.exists(key):
            return None

        row = await self.read_row(key)

        assert row is not None

        (dt, size, syntax) = row

        data = await self.bucket.get(key)

        assert data is not None

        def decompress() -> str:
            return zstandard.decompress(data).decode()

        text = await asyncio.to_thread(decompress)

        return Paste(key, dt, syntax, text)

    async def delete(self, key: str) -> None:
        await self.connection.execute("delete from pastes where key=?", (key,))

        try:
            await self.bucket.delete(key)
            await self.connection.commit()
        except Exception as e:
            await self.connection.rollback()
            raise e

    async def vacuum(self, max: int) -> None:
        while True:
            async with self.connection.execute(
                (
                    "select sum(pastes.size) from pastes "
                    "inner join s3 on s3.key "
                    "where s3.key=pastes.key"
                )
            ) as cursor:
                if (row := await cursor.fetchone()) is None:
                    return
                else:
                    use = row[0]

            async with self.connection.execute(
                (
                    "select pastes.key from pastes "
                    "inner join s3 on s3.key "
                    "where s3.key=pastes.key "
                    "order by pastes.datetime "
                    "limit 1"
                )
            ) as cursor:
                if (row := await cursor.fetchone()) is None:
                    return
                else:
                    oldest = row[0]

            if use > max:
                await self.delete(oldest)
            else:
                return