diff options
Diffstat (limited to 'src/request.rs')
| -rw-r--r-- | src/request.rs | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/src/request.rs b/src/request.rs new file mode 100644 index 0000000..99b7482 --- /dev/null +++ b/src/request.rs @@ -0,0 +1,183 @@ +use mlua::UserData; +use mon::{ + Parser, ParserIter, any, ascii_alphanumeric, ascii_whitespace, input::InputIter, one_of, tag, + whitespace, +}; + +use strum::Display; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +use std::{collections::HashMap, io, num::ParseIntError, str}; + +use get::Get; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("io error: {0}")] + Io(#[from] io::Error), + + #[error("parser error")] + Parse(usize), + + #[error("parse int error: {0}")] + ParseInt(ParseIntError), + + #[error("unicode error: {0}")] + Unicode(#[from] str::Utf8Error), +} + +#[derive(Debug, Clone, Copy, Display)] +pub enum Method { + #[strum(to_string = "GET")] + Get, + + #[strum(to_string = "HEAD")] + Head, +} + +#[derive(Debug, Clone, Get)] +pub struct Path(#[get(method = "inner")] Vec<u8>); + +#[derive(Debug, Clone, Get)] +pub struct Request { + method: Method, + path: Path, + headers: HashMap<String, Vec<u8>>, +} + +impl Request { + pub async fn parse<T: AsyncBufRead + Unpin>( + mut reader: T, + buf: &mut Vec<u8>, + line: &mut Vec<u8>, + ) -> Result<Option<Self>, Error> { + buf.clear(); + + loop { + line.clear(); + + if reader.read_until(b'\n', line).await? == 0 { + return Ok(None); + } + + if line == b"\r\n" { + break; + } + + buf.extend_from_slice(line); + } + + let (method, path, headers) = match parse().parse_finished(InputIter::new(buf)) { + Ok(((method, path), headers)) => (method, Path(path), headers), + Err(mon::ParserFinishedError::Err(e) | mon::ParserFinishedError::Unfinished(e)) => { + return Err(Error::Parse(e.position())); + } + }; + + let headers = headers + .into_iter() + .map(|(key, value)| { + Ok::<(String, Vec<u8>), str::Utf8Error>(( + str::from_utf8(&key)?.to_lowercase(), + value, + )) + }) + .collect::<Result<HashMap<String, Vec<u8>>, _>>()?; + + Ok(Some(Self { + method, + path, + headers, + })) + } +} + +impl UserData for Request { + fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) { + fields.add_field_method_get("method", |_, this| Ok(this.method().to_string())); + + fields.add_field_method_get("path", |_, this| { + Ok(mlua::BString::new(this.path().inner().clone())) + }); + + fields.add_field_method_get("headers", |lua, this| { + let table = lua.create_table()?; + + for (key, value) in this.headers() { + table.set(key.clone(), value.clone())?; + } + + Ok(table) + }) + } +} + +fn method<'a>() -> impl Parser<&'a [u8], Output = Method> { + tag(b"GET".as_slice()) + .map(|_| Method::Get) + .or(tag(b"HEAD".as_slice()).map(|_| Method::Head)) +} + +fn path<'a>() -> impl Parser<&'a [u8], Output = Vec<u8>> { + any().and_not(whitespace()).repeated().at_least(1) +} + +fn header<'a>() -> impl Parser<&'a [u8], Output = (Vec<u8>, Vec<u8>)> { + let key = ascii_alphanumeric() + .followed_by( + ascii_alphanumeric() + .or(one_of(b"-".iter().copied())) + .repeated() + .many(), + ) + .recognize() + .map(|output: &[u8]| output.to_vec()); + + let value = any() + .and_not(tag(b"\r\n".as_slice())) + .repeated() + .at_least(1); + + key.and(value.preceded_by(tag(b": ".as_slice()))) +} + +#[allow(clippy::type_complexity)] +fn parse<'a>() -> impl Parser<&'a [u8], Output = ((Method, Vec<u8>), Vec<(Vec<u8>, Vec<u8>)>)> { + method() + .followed_by(ascii_whitespace()) + .and(path()) + .followed_by(ascii_whitespace()) + .followed_by(tag(b"HTTP/1.1\r\n".as_slice())) + .and( + header() + .separated_by_with_trailing(tag(b"\r\n".as_slice())) + .many(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_header() { + let input = b"Content-Length: 100"; + + header().parse_finished(InputIter::new(input)).unwrap(); + } + + #[tokio::test] + async fn test_parse_get() { + let mut buf = Vec::new(); + let mut line = Vec::new(); + let input = b"GET /path HTTP/1.1\r\nContent-Length: 100\r\n"; + + match Request::parse(input.as_slice(), &mut buf, &mut line).await { + Ok(_) => (), + Err(Error::Parse(position)) => { + panic!("{}", &str::from_utf8(input).unwrap()[position..]) + } + Err(e) => panic!("{e}"), + } + } +} |
