summaryrefslogtreecommitdiff
path: root/src/client.rs
blob: c5c8e1833ac48aefd05218227e4356d94436f097 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use httparse::EMPTY_HEADER;

use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};

use crate::{request::Request, response::Response};

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("unsupported version")]
    Version,

    #[error("io error: {0}")]
    Io(#[from] io::Error),

    #[error("invalid method: {0}")]
    Method(#[from] http::method::InvalidMethod),

    #[error("http error: {0}")]
    Http(#[from] http::Error),

    #[error("http parse error: {0}")]
    Parse(#[from] httparse::Error),
}

pub struct Client<R, W> {
    reader: R,
    writer: W,
}

impl<R, W> Client<R, W>
where
    R: AsyncBufRead + Unpin,
    W: AsyncWrite + Unpin,
{
    pub fn new(reader: R, writer: W) -> Self {
        Self { reader, writer }
    }

    pub async fn send_response(&mut self, response: Response) -> io::Result<()> {
        response.to_wire(&mut self.writer).await?;

        self.writer.flush().await?;

        Ok(())
    }

    pub async fn read_request(&mut self) -> Result<Option<Request<Vec<u8>>>, Error> {
        let mut buf = Vec::new();
        let mut line = Vec::new();

        loop {
            line.clear();

            if self.reader.read_until(b'\n', &mut line).await? == 0 {
                return Ok(None);
            }

            if line == b"\r\n" || line.is_empty() {
                break;
            }

            buf.extend_from_slice(&line);
            buf.extend_from_slice(b"\r\n");
        }

        let mut headers = [EMPTY_HEADER; 64];
        let mut parsed = httparse::Request::new(&mut headers);

        parsed.parse(&buf)?;

        let mut builder = http::Request::builder();

        builder = builder.method(http::Method::from_bytes(parsed.method.unwrap().as_bytes())?);
        builder = builder.uri(parsed.path.unwrap());
        builder = builder.version(match parsed.version.unwrap() {
            1 => http::Version::HTTP_11,
            _ => return Err(Error::Version),
        });

        for header in parsed.headers {
            builder = builder.header(header.name, header.value);
        }

        let body: Vec<u8> = Vec::new();

        Ok(Some(Request::new(builder.body(body)?)))
    }
}