use mlua::UserData; use mon::{ Parser, ParserIter, any, ascii_alphanumeric, ascii_whitespace, input::InputIter, one_of, tag, whitespace, }; use strum::Display; use tokio::io::{AsyncBufRead, AsyncBufReadExt}; use std::{collections::HashMap, io, num::ParseIntError, str}; use get::Get; #[derive(Debug, thiserror::Error)] pub enum Error { #[error("io error: {0}")] Io(#[from] io::Error), #[error("parser error")] Parse(usize), #[error("parse int error: {0}")] ParseInt(ParseIntError), #[error("unicode error: {0}")] Unicode(#[from] str::Utf8Error), } #[derive(Debug, Clone, Copy, Display)] pub enum Method { #[strum(to_string = "GET")] Get, #[strum(to_string = "HEAD")] Head, } #[derive(Debug, Clone, Get)] pub struct Path(#[get(method = "inner")] Vec); #[derive(Debug, Clone, Get)] pub struct Request { method: Method, path: Path, headers: HashMap>, } impl Request { pub async fn parse( mut reader: T, buf: &mut Vec, line: &mut Vec, ) -> Result, Error> { buf.clear(); loop { line.clear(); if reader.read_until(b'\n', line).await? == 0 { return Ok(None); } if line == b"\r\n" { break; } buf.extend_from_slice(line); } let (method, path, headers) = match parse().parse_finished(InputIter::new(buf)) { Ok(((method, path), headers)) => (method, Path(path), headers), Err(mon::ParserFinishedError::Err(e) | mon::ParserFinishedError::Unfinished(e)) => { return Err(Error::Parse(e.position())); } }; let headers = headers .into_iter() .map(|(key, value)| { Ok::<(String, Vec), str::Utf8Error>(( str::from_utf8(&key)?.to_lowercase(), value, )) }) .collect::>, _>>()?; Ok(Some(Self { method, path, headers, })) } } impl UserData for Request { fn add_fields>(fields: &mut F) { fields.add_field_method_get("method", |_, this| Ok(this.method().to_string())); fields.add_field_method_get("path", |_, this| { Ok(mlua::BString::new(this.path().inner().clone())) }); fields.add_field_method_get("headers", |lua, this| { let table = lua.create_table()?; for (key, value) in this.headers() { table.set(key.clone(), value.clone())?; } Ok(table) }) } } fn method<'a>() -> impl Parser<&'a [u8], Output = Method> { tag(b"GET".as_slice()) .map(|_| Method::Get) .or(tag(b"HEAD".as_slice()).map(|_| Method::Head)) } fn path<'a>() -> impl Parser<&'a [u8], Output = Vec> { any().and_not(whitespace()).repeated().at_least(1) } fn header<'a>() -> impl Parser<&'a [u8], Output = (Vec, Vec)> { let key = ascii_alphanumeric() .followed_by( ascii_alphanumeric() .or(one_of(b"-".iter().copied())) .repeated() .many(), ) .recognize() .map(|output: &[u8]| output.to_vec()); let value = any() .and_not(tag(b"\r\n".as_slice())) .repeated() .at_least(1); key.and(value.preceded_by(tag(b": ".as_slice()))) } #[allow(clippy::type_complexity)] fn parse<'a>() -> impl Parser<&'a [u8], Output = ((Method, Vec), Vec<(Vec, Vec)>)> { method() .followed_by(ascii_whitespace()) .and(path()) .followed_by(ascii_whitespace()) .followed_by(tag(b"HTTP/1.1\r\n".as_slice())) .and( header() .separated_by_with_trailing(tag(b"\r\n".as_slice())) .many(), ) } #[cfg(test)] mod test { use super::*; #[test] fn test_parse_header() { let input = b"Content-Length: 100"; header().parse_finished(InputIter::new(input)).unwrap(); } #[tokio::test] async fn test_parse_get() { let mut buf = Vec::new(); let mut line = Vec::new(); let input = b"GET /path HTTP/1.1\r\nContent-Length: 100\r\n"; match Request::parse(input.as_slice(), &mut buf, &mut line).await { Ok(_) => (), Err(Error::Parse(position)) => { panic!("{}", &str::from_utf8(input).unwrap()[position..]) } Err(e) => panic!("{e}"), } } }