summaryrefslogtreecommitdiff
path: root/tests/testutil.py
blob: d85cf3f8691536ec7cb72b7facb24cdaf54cefe7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# gemato: Test utility functions
# (c) 2017-2023 Michał Górny
# SPDX-License-Identifier: GPL-2.0-or-later

import errno
import functools
import os
import os.path
import random
import stat
import threading

import pytest

from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs


def disallow_writes(path):
    """Mark path non-writable, recursively"""
    for dirpath, dirs, files in os.walk(path, topdown=False):
        for f in files + dirs:
            st = os.lstat(os.path.join(dirpath, f))
            if not stat.S_ISLNK(st.st_mode):
                os.chmod(os.path.join(dirpath, f),
                         st.st_mode & ~0o222)
    os.chmod(path, 0o555)


class HKPServerRequestHandler(BaseHTTPRequestHandler):
    def __init__(self, keys, *args, **kwargs):
        self.keys = keys
        BaseHTTPRequestHandler.__init__(self, *args, **kwargs)

    def log_message(self, *args, **kwargs):
        pass

    def do_GET(self):
        try:
            parsed = urlparse(self.path)
            assert parsed.path == '/pks/lookup'

            qs = parse_qs(parsed.query)
            assert qs.get('op') == ['get']
            assert len(qs.get('search', [])) == 1

            key = qs['search'][0]
            assert key.startswith('0x')
            key = key[2:]
        except AssertionError:
            self.send_error(400, "Bad request")
            return

        if key not in self.keys:
            self.send_error(404, "Not found")
            return

        self.send_response(200, "OK")
        self.send_header("Content-type", "application/pgp-keys")
        self.end_headers()
        # note: technically we should be using ASCII armor here
        # but GnuPG seems happy with the binary form too
        self.wfile.write(self.keys[key])
        self.wfile.flush()


class HKPServer:
    def __init__(self):
        self.keys = {}
        self.addr = None

    def start(self):
        # try 10 randomly selected ports before giving up
        for port in random.sample(range(1024, 32768), 10):
            try:
                self.server = HTTPServer(
                    ('127.0.0.1', port),
                    functools.partial(HKPServerRequestHandler, self.keys))
            except OSError as e:
                if e.errno != errno.EADDRINUSE:
                    pytest.skip(
                        f'Unable to bind the HKP server: {e}')
            else:
                break
        else:
            pytest.skip('Unable to find a free port for HKP server')

        self.addr = f'hkp://127.0.0.1:{port}'
        self.thread = threading.Thread(target=self.server.serve_forever)
        self.thread.start()

    def stop(self):
        assert self.addr is not None
        self.server.shutdown()
        self.server.server_close()
        self.thread.join()
        self.addr = None