summaryrefslogtreecommitdiff
path: root/src/request.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/request.rs')
-rw-r--r--src/request.rs183
1 files changed, 183 insertions, 0 deletions
diff --git a/src/request.rs b/src/request.rs
new file mode 100644
index 0000000..99b7482
--- /dev/null
+++ b/src/request.rs
@@ -0,0 +1,183 @@
+use mlua::UserData;
+use mon::{
+ Parser, ParserIter, any, ascii_alphanumeric, ascii_whitespace, input::InputIter, one_of, tag,
+ whitespace,
+};
+
+use strum::Display;
+use tokio::io::{AsyncBufRead, AsyncBufReadExt};
+
+use std::{collections::HashMap, io, num::ParseIntError, str};
+
+use get::Get;
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("io error: {0}")]
+ Io(#[from] io::Error),
+
+ #[error("parser error")]
+ Parse(usize),
+
+ #[error("parse int error: {0}")]
+ ParseInt(ParseIntError),
+
+ #[error("unicode error: {0}")]
+ Unicode(#[from] str::Utf8Error),
+}
+
+#[derive(Debug, Clone, Copy, Display)]
+pub enum Method {
+ #[strum(to_string = "GET")]
+ Get,
+
+ #[strum(to_string = "HEAD")]
+ Head,
+}
+
+#[derive(Debug, Clone, Get)]
+pub struct Path(#[get(method = "inner")] Vec<u8>);
+
+#[derive(Debug, Clone, Get)]
+pub struct Request {
+ method: Method,
+ path: Path,
+ headers: HashMap<String, Vec<u8>>,
+}
+
+impl Request {
+ pub async fn parse<T: AsyncBufRead + Unpin>(
+ mut reader: T,
+ buf: &mut Vec<u8>,
+ line: &mut Vec<u8>,
+ ) -> Result<Option<Self>, Error> {
+ buf.clear();
+
+ loop {
+ line.clear();
+
+ if reader.read_until(b'\n', line).await? == 0 {
+ return Ok(None);
+ }
+
+ if line == b"\r\n" {
+ break;
+ }
+
+ buf.extend_from_slice(line);
+ }
+
+ let (method, path, headers) = match parse().parse_finished(InputIter::new(buf)) {
+ Ok(((method, path), headers)) => (method, Path(path), headers),
+ Err(mon::ParserFinishedError::Err(e) | mon::ParserFinishedError::Unfinished(e)) => {
+ return Err(Error::Parse(e.position()));
+ }
+ };
+
+ let headers = headers
+ .into_iter()
+ .map(|(key, value)| {
+ Ok::<(String, Vec<u8>), str::Utf8Error>((
+ str::from_utf8(&key)?.to_lowercase(),
+ value,
+ ))
+ })
+ .collect::<Result<HashMap<String, Vec<u8>>, _>>()?;
+
+ Ok(Some(Self {
+ method,
+ path,
+ headers,
+ }))
+ }
+}
+
+impl UserData for Request {
+ fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
+ fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
+
+ fields.add_field_method_get("path", |_, this| {
+ Ok(mlua::BString::new(this.path().inner().clone()))
+ });
+
+ fields.add_field_method_get("headers", |lua, this| {
+ let table = lua.create_table()?;
+
+ for (key, value) in this.headers() {
+ table.set(key.clone(), value.clone())?;
+ }
+
+ Ok(table)
+ })
+ }
+}
+
+fn method<'a>() -> impl Parser<&'a [u8], Output = Method> {
+ tag(b"GET".as_slice())
+ .map(|_| Method::Get)
+ .or(tag(b"HEAD".as_slice()).map(|_| Method::Head))
+}
+
+fn path<'a>() -> impl Parser<&'a [u8], Output = Vec<u8>> {
+ any().and_not(whitespace()).repeated().at_least(1)
+}
+
+fn header<'a>() -> impl Parser<&'a [u8], Output = (Vec<u8>, Vec<u8>)> {
+ let key = ascii_alphanumeric()
+ .followed_by(
+ ascii_alphanumeric()
+ .or(one_of(b"-".iter().copied()))
+ .repeated()
+ .many(),
+ )
+ .recognize()
+ .map(|output: &[u8]| output.to_vec());
+
+ let value = any()
+ .and_not(tag(b"\r\n".as_slice()))
+ .repeated()
+ .at_least(1);
+
+ key.and(value.preceded_by(tag(b": ".as_slice())))
+}
+
+#[allow(clippy::type_complexity)]
+fn parse<'a>() -> impl Parser<&'a [u8], Output = ((Method, Vec<u8>), Vec<(Vec<u8>, Vec<u8>)>)> {
+ method()
+ .followed_by(ascii_whitespace())
+ .and(path())
+ .followed_by(ascii_whitespace())
+ .followed_by(tag(b"HTTP/1.1\r\n".as_slice()))
+ .and(
+ header()
+ .separated_by_with_trailing(tag(b"\r\n".as_slice()))
+ .many(),
+ )
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_parse_header() {
+ let input = b"Content-Length: 100";
+
+ header().parse_finished(InputIter::new(input)).unwrap();
+ }
+
+ #[tokio::test]
+ async fn test_parse_get() {
+ let mut buf = Vec::new();
+ let mut line = Vec::new();
+ let input = b"GET /path HTTP/1.1\r\nContent-Length: 100\r\n";
+
+ match Request::parse(input.as_slice(), &mut buf, &mut line).await {
+ Ok(_) => (),
+ Err(Error::Parse(position)) => {
+ panic!("{}", &str::from_utf8(input).unwrap()[position..])
+ }
+ Err(e) => panic!("{e}"),
+ }
+ }
+}