From 4a16841789604614bc495c36972236749e5f35b0 Mon Sep 17 00:00:00 2001 From: John Turner Date: Wed, 4 Mar 2026 20:53:24 -0500 Subject: roll our own http types --- src/client.rs | 95 ++++++------------------ src/handler.rs | 38 ++++++++++ src/handler/staticfile.rs | 70 ++++++++++++++++++ src/handlers/mod.rs | 38 ---------- src/handlers/staticfile.rs | 83 --------------------- src/main.rs | 30 +++----- src/request.rs | 178 ++++++++++++++++++++++++++++++++++++++------- src/response.rs | 71 ++++++++++++------ src/response/builder.rs | 75 +++++++++++++++++++ 9 files changed, 419 insertions(+), 259 deletions(-) create mode 100644 src/handler.rs create mode 100644 src/handler/staticfile.rs delete mode 100644 src/handlers/mod.rs delete mode 100644 src/handlers/staticfile.rs create mode 100644 src/response/builder.rs (limited to 'src') diff --git a/src/client.rs b/src/client.rs index c5c8e18..f5ed2f9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,88 +1,39 @@ -use httparse::EMPTY_HEADER; +use tokio::io::{self, AsyncBufRead, AsyncWrite, AsyncWriteExt}; -use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt}; - -use crate::{request::Request, response::Response}; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("unsupported version")] - Version, - - #[error("io error: {0}")] - Io(#[from] io::Error), - - #[error("invalid method: {0}")] - Method(#[from] http::method::InvalidMethod), - - #[error("http error: {0}")] - Http(#[from] http::Error), - - #[error("http parse error: {0}")] - Parse(#[from] httparse::Error), -} +use crate::{ + request::{self, Request}, + response::Response, +}; +#[derive(Debug)] pub struct Client { reader: R, writer: W, + buf: Vec, + line: Vec, } -impl Client -where - R: AsyncBufRead + Unpin, - W: AsyncWrite + Unpin, -{ +impl Client { pub fn new(reader: R, writer: W) -> Self { - Self { reader, writer } + Self { + reader, + writer, + buf: Vec::new(), + line: Vec::new(), + } } +} - pub async fn send_response(&mut self, response: Response) -> io::Result<()> { - response.to_wire(&mut self.writer).await?; - - self.writer.flush().await?; - - Ok(()) +impl Client { + pub async fn read_request(&mut self) -> Result, request::Error> { + Request::parse(&mut self.reader, &mut self.buf, &mut self.line).await } - pub async fn read_request(&mut self) -> Result>>, Error> { - let mut buf = Vec::new(); - let mut line = Vec::new(); + pub async fn send_response(&mut self, response: Response) -> Result<(), io::Error> { + response.serialize(&mut self.writer).await?; - loop { - line.clear(); - - if self.reader.read_until(b'\n', &mut line).await? == 0 { - return Ok(None); - } - - if line == b"\r\n" || line.is_empty() { - break; - } - - buf.extend_from_slice(&line); - buf.extend_from_slice(b"\r\n"); - } - - let mut headers = [EMPTY_HEADER; 64]; - let mut parsed = httparse::Request::new(&mut headers); - - parsed.parse(&buf)?; - - let mut builder = http::Request::builder(); - - builder = builder.method(http::Method::from_bytes(parsed.method.unwrap().as_bytes())?); - builder = builder.uri(parsed.path.unwrap()); - builder = builder.version(match parsed.version.unwrap() { - 1 => http::Version::HTTP_11, - _ => return Err(Error::Version), - }); - - for header in parsed.headers { - builder = builder.header(header.name, header.value); - } - - let body: Vec = Vec::new(); + self.writer.flush().await?; - Ok(Some(Request::new(builder.body(body)?))) + Ok(()) } } diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..45dc879 --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,38 @@ +mod staticfile; + +use mlua::{FromLua, Value}; + +use crate::{handler::staticfile::StaticFile, request::Request, response::Response}; + +pub trait Handle { + fn handle(self, request: Request) -> impl Future>; +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("unsupported method")] + Unsupported, + + #[error("static file handler error: {0}")] + StaticFile(#[from] staticfile::Error), +} + +#[derive(Debug)] +pub enum Handler { + StaticFile(StaticFile), +} + +impl FromLua for Handler { + fn from_lua(value: Value, lua: &mlua::Lua) -> mlua::Result { + match value { + Value::Table(table) => match table.get::("handler")?.as_str() { + "staticfile" => Ok(Self::StaticFile(StaticFile::from_lua( + Value::Table(table.clone()), + lua, + )?)), + _ => Err(mlua::Error::runtime("unknown handler")), + }, + _ => Err(mlua::Error::runtime("expected table")), + } + } +} diff --git a/src/handler/staticfile.rs b/src/handler/staticfile.rs new file mode 100644 index 0000000..871f414 --- /dev/null +++ b/src/handler/staticfile.rs @@ -0,0 +1,70 @@ +use crate::{ + handler::{self, Handle}, + request::{Method, Request}, + response::{Body, Response, Status}, +}; + +use std::{ffi::OsString, os::unix::ffi::OsStringExt, path::PathBuf}; + +use mlua::{FromLua, Value}; + +use tokio::{ + fs::File, + io::{self}, +}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("io error: {0}")] + Io(#[from] io::Error), +} + +#[derive(Debug, Clone)] +pub struct StaticFile { + path: PathBuf, + mime: String, +} + +impl Handle for StaticFile { + async fn handle(self, request: Request) -> Result { + match request.method() { + Method::Get | Method::Head => match File::open(&self.path).await { + Ok(file) => { + let metadata = file.metadata().await.map_err(Error::Io)?; + + if metadata.is_file() { + Ok(Response::builder() + .status(Status::Ok) + .headers([ + ("content-length", format!("{}", metadata.len())), + ("content-type", self.mime), + ]) + .body(Body::File(file))) + } else { + Ok(Response::builder() + .status(Status::NotFound) + .body(Body::Empty)) + } + } + Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Response::builder() + .status(Status::NotFound) + .body(Body::Empty)), + Err(e) => Err(Error::Io(e))?, + }, + } + } +} + +impl FromLua for StaticFile { + fn from_lua(value: mlua::Value, _: &mlua::Lua) -> mlua::Result { + match value { + Value::Table(table) => Ok(Self { + path: PathBuf::from(OsString::from_vec( + table.get::("path")?.as_bytes().to_vec(), + )), + mime: table.get::("mime")?, + }), + _ => Err(mlua::Error::runtime("expected table")), + } + } +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs deleted file mode 100644 index 800b61e..0000000 --- a/src/handlers/mod.rs +++ /dev/null @@ -1,38 +0,0 @@ -use mlua::{FromLua, Value}; - -use crate::{handlers::staticfile::StaticFile, request::Request, response::Response}; - -mod staticfile; - -pub(super) trait Handle { - async fn handle(&self, request: Request) -> Result; -} - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("unsupported method")] - Unsupported, - - #[error("static file handler error: {0}")] - StaticFile(#[from] staticfile::Error), -} - -#[derive(Debug, Clone)] -pub enum Handlers { - StaticFile(StaticFile), -} - -impl FromLua for Handlers { - fn from_lua(value: Value, lua: &mlua::Lua) -> mlua::Result { - match value { - Value::Table(table) => match table.get::("handler")?.as_str() { - "staticfile" => Ok(Self::StaticFile(StaticFile::from_lua( - Value::Table(table.clone()), - lua, - )?)), - _ => Err(mlua::Error::runtime("unknown handler")), - }, - _ => Err(mlua::Error::runtime("expected table")), - } - } -} diff --git a/src/handlers/staticfile.rs b/src/handlers/staticfile.rs deleted file mode 100644 index a765315..0000000 --- a/src/handlers/staticfile.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::{ffi::OsString, os::unix::ffi::OsStringExt, path::PathBuf, time}; - -use mlua::{FromLua, Value}; -use tokio::{ - fs::{self, File}, - io, -}; - -use crate::{ - Handle, handlers, - request::Request, - response::{self, Response}, -}; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("io error: {0}")] - Io(#[from] io::Error), - - #[error("http error: {0}")] - Http(#[from] http::Error), -} - -#[derive(Debug, Clone)] -pub struct StaticFile { - path: PathBuf, - mime: String, -} - -impl Handle for StaticFile { - async fn handle(&self, request: Request) -> Result { - if let http::Method::GET | http::Method::HEAD = request.inner().method().clone() { - if !fs::try_exists(&self.path).await.map_err(Error::Io)? { - return Ok(Response::new( - http::Response::builder() - .status(http::StatusCode::NOT_FOUND) - .body(response::Body::Empty) - .map_err(Error::Http)?, - )); - } - - let file = File::open(&self.path).await.map_err(Error::Io)?; - let metadata = file.metadata().await.map_err(Error::Io)?; - - let now = time::SystemTime::now(); - let date = httpdate::fmt_http_date(now); - - let response = http::Response::builder() - .status(http::StatusCode::OK) - .header("CONTENT-LENGTH", metadata.len()) - .header("CONTENT-TYPE", &self.mime) - .header("DATE", date); - - match request.inner().method().clone() { - http::Method::GET => Ok(Response::new( - response - .body(response::Body::File(file)) - .map_err(Error::Http)?, - )), - http::Method::HEAD => Ok(Response::new( - response.body(response::Body::Empty).map_err(Error::Http)?, - )), - _ => unreachable!(), - } - } else { - Err(handlers::Error::Unsupported) - } - } -} - -impl FromLua for StaticFile { - fn from_lua(value: mlua::Value, _: &mlua::Lua) -> mlua::Result { - match value { - Value::Table(table) => Ok(Self { - path: PathBuf::from(OsString::from_vec( - table.get::("path")?.as_bytes().to_vec(), - )), - mime: table.get::("mime")?, - }), - _ => Err(mlua::Error::runtime("expected table")), - } - } -} diff --git a/src/main.rs b/src/main.rs index ea459d9..3d69f36 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,13 +11,13 @@ use tokio::{ use crate::{ client::Client, - handlers::{Handle, Handlers}, + handler::{Handle, Handler}, request::Request, - response::Response, + response::{Response, Status}, }; mod client; -mod handlers; +mod handler; mod request; mod response; @@ -40,13 +40,13 @@ enum HandleError { InvokeHandler(mlua::Error), #[error("handler error: {0}")] - Handler(#[from] handlers::Error), + Handler(#[from] handler::Error), } #[derive(Debug, thiserror::Error)] enum ResponseError { #[error("error reading request: {0}")] - Request(client::Error), + Request(request::Error), #[error("error sending response: {0}")] Response(io::Error), @@ -70,22 +70,19 @@ enum InitLuaError { EvalConfig(mlua::Error), } -async fn handle( - handlers: Table, - request: Request, -) -> Result { - let method = request.inner().method().as_str().to_string(); +async fn handle(handlers: Table, request: Request) -> Result { + let method = request.method().to_string(); let function = handlers .get::(method.as_str()) .map_err(HandleError::Function)?; let handler = function - .call::(request.clone()) + .call::(request.clone()) .map_err(HandleError::InvokeHandler)?; match handler { - Handlers::StaticFile(staticfile) => Ok(staticfile.handle(request).await?), + Handler::StaticFile(staticfile) => Ok(staticfile.handle(request).await?), } } @@ -106,12 +103,9 @@ async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseErro Err(e) => { eprintln!("failed to handle request: {e:?}"); - Response::new( - http::Response::builder() - .status(http::StatusCode::INTERNAL_SERVER_ERROR) - .body(response::Body::Empty) - .unwrap(), - ) + Response::builder() + .status(Status::InternalServerError) + .body(response::Body::Empty) } }; diff --git a/src/request.rs b/src/request.rs index 2814a14..6da46ee 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,55 +1,181 @@ - use mlua::UserData; +use mon::{ + Parser, ParserIter, any, ascii_alphanumeric, ascii_whitespace, input::InputIter, one_of, tag, + whitespace, +}; + +use strum::Display; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +use std::{collections::HashMap, io, num::ParseIntError, str}; -use tokio::io::{self}; +use get::Get; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("error parsing request: {0}")] - Parse(#[from] httparse::Error), - #[error("io error: {0}")] Io(#[from] io::Error), - #[error("invalid method: {0}")] - Method(#[from] http::method::InvalidMethod), + #[error("parser error")] + Parse(usize), - #[error("invalid request: {0}")] - Request(#[from] http::Error), + #[error("parse int error: {0}")] + ParseInt(ParseIntError), - #[error("unsupported version")] - Version, + #[error("unicode error: {0}")] + Unicode(#[from] str::Utf8Error), } -#[derive(Debug, Clone)] -pub struct Request(http::Request); +#[derive(Debug, Clone, Copy, Display)] +pub enum Method { + #[strum(to_string = "GET")] + Get, -impl Request { - pub fn inner(&self) -> &http::Request { - &self.0 - } + #[strum(to_string = "HEAD")] + Head, +} + +#[derive(Debug, Clone, Get)] +pub struct Path(#[get(method = "inner")] Vec); + +#[derive(Debug, Clone, Get)] +pub struct Request { + method: Method, + path: Path, + headers: HashMap>, +} + +impl Request { + pub async fn parse( + mut reader: T, + buf: &mut Vec, + line: &mut Vec, + ) -> Result, Error> { + buf.clear(); + + loop { + line.clear(); + + if reader.read_until(b'\n', line).await? == 0 { + return Ok(None); + } + + if line == b"\r\n" { + break; + } + + buf.extend_from_slice(line); + } + + let (method, path, headers) = match parse().parse_finished(InputIter::new(buf)) { + Ok(((method, path), headers)) => (method, Path(path), headers), + Err(mon::ParserFinishedError::Err(e) | mon::ParserFinishedError::Unfinished(e)) => { + return Err(Error::Parse(e.position())); + } + }; + + let headers = headers + .into_iter() + .map(|(key, value)| { + Ok::<(String, Vec), str::Utf8Error>(( + str::from_utf8(&key)?.to_lowercase(), + value, + )) + }) + .collect::>, _>>()?; - pub fn new(request: http::Request) -> Self { - Self(request) + Ok(Some(Self { + method, + path, + headers, + })) } } -impl UserData for Request { +fn method<'a>() -> impl Parser<&'a [u8], Output = Method> { + tag(b"GET".as_slice()) + .map(|_| Method::Get) + .or(tag(b"HEAD".as_slice()).map(|_| Method::Head)) +} + +fn path<'a>() -> impl Parser<&'a [u8], Output = Vec> { + any().and_not(whitespace()).repeated().at_least(1) +} + +fn header<'a>() -> impl Parser<&'a [u8], Output = (Vec, Vec)> { + let key = ascii_alphanumeric() + .followed_by( + ascii_alphanumeric() + .or(one_of(b"-".iter().copied())) + .repeated() + .many(), + ) + .recognize() + .map(|output: &[u8]| output.to_vec()); + + let value = any() + .and_not(tag(b"\r\n".as_slice())) + .repeated() + .at_least(1); + + key.and(value.preceded_by(tag(b": ".as_slice()))) +} + +impl UserData for Request { fn add_fields>(fields: &mut F) { - fields.add_field_method_get("method", |_, this| { - Ok(this.inner().method().as_str().to_string()) - }); + fields.add_field_method_get("method", |_, this| Ok(this.method().to_string())); - fields.add_field_method_get("path", |_, this| Ok(this.inner().uri().path().to_string())); + fields.add_field_method_get("path", |_, this| Ok(this.path().0.clone())); fields.add_field_method_get("headers", |lua, this| { let table = lua.create_table()?; - for (key, value) in this.inner().headers() { - table.set(key.as_str(), value.as_bytes())?; + for (key, value) in this.headers() { + table.set(key.clone(), value.clone())?; } Ok(table) }) } } + +#[allow(clippy::type_complexity)] +fn parse<'a>() -> impl Parser<&'a [u8], Output = ((Method, Vec), Vec<(Vec, Vec)>)> { + method() + .followed_by(ascii_whitespace()) + .and(path()) + .followed_by(ascii_whitespace()) + .followed_by(tag(b"HTTP/1.1\r\n".as_slice())) + .and( + header() + .separated_by_with_trailing(tag(b"\r\n".as_slice())) + .many(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_header() { + let input = b"Content-Length: 100"; + + header().parse_finished(InputIter::new(input)).unwrap(); + } + + #[tokio::test] + async fn test_parse_get() { + let mut buf = Vec::new(); + let mut line = Vec::new(); + let input = b"GET /path HTTP/1.1\r\nContent-Length: 100\r\n"; + + match Request::parse(input.as_slice(), &mut buf, &mut line).await { + Ok(_) => (), + Err(Error::Parse(position)) => { + panic!("{}", &str::from_utf8(input).unwrap()[position..]) + } + Err(e) => panic!("{e}"), + } + } +} diff --git a/src/response.rs b/src/response.rs index 801e682..fa1a294 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,48 +1,75 @@ +mod builder; + +use get::Get; + +use std::collections::HashMap; + use tokio::{ fs::File, io::{self, AsyncWrite, AsyncWriteExt}, }; +use strum::Display; + +use builder::Builder; + +#[allow(unreachable_patterns)] +#[derive(Debug, Clone, Copy, Display)] +pub enum Status { + #[strum(to_string = "200 OK")] + Ok, + + #[strum(to_string = "404 Not Found")] + NotFound, + + #[strum(to_string = "500 Internal Server Error")] + InternalServerError, +} + #[derive(Debug)] pub enum Body { File(File), - Bytes(Vec), + Buffer(Vec), Empty, } -#[derive(Debug)] -pub struct Response(http::Response); +#[derive(Debug, Get)] +pub struct Response { + status: Status, + headers: HashMap>, + body: Body, +} impl Response { - pub fn new(inner: http::Response) -> Self { - Self(inner) - } - - pub fn inner(&self) -> &http::Response { - &self.0 + pub fn builder() -> Builder { + Builder::new() } } impl Response { - pub async fn to_wire(self, writer: &mut W) -> io::Result<()> { - writer - .write_all(format!("HTTP/1.1 {}\r\n", self.0.status()).as_bytes()) - .await?; - - for (key, value) in self.0.headers() { - writer.write_all(format!("{key}: ").as_bytes()).await?; - writer.write_all(value.as_bytes()).await?; + pub async fn serialize( + mut self, + mut writer: W, + ) -> Result<(), io::Error> { + writer.write_all(b"HTTP/1.1 ").await?; + writer.write_all(self.status.to_string().as_bytes()).await?; + writer.write_all(b"\r\n").await?; + + for (key, value) in &self.headers { + writer.write_all(key.as_bytes()).await?; + writer.write_all(b": ").await?; + writer.write_all(value).await?; writer.write_all(b"\r\n").await?; } writer.write_all(b"\r\n").await?; - match self.0.into_body() { - Body::File(mut file) => { - io::copy(&mut file, writer).await?; + match &mut self.body { + Body::File(file) => { + io::copy(file, &mut writer).await?; } - Body::Bytes(buf) => { - writer.write_all(&buf).await?; + Body::Buffer(buf) => { + writer.write_all(buf).await?; } Body::Empty => (), } diff --git a/src/response/builder.rs b/src/response/builder.rs new file mode 100644 index 0000000..b99b824 --- /dev/null +++ b/src/response/builder.rs @@ -0,0 +1,75 @@ +use std::collections::HashMap; + +use crate::response::{self, Response}; + +#[derive(Debug, Clone, Copy)] +pub struct Builder; + +#[derive(Debug, Clone, Copy)] +pub struct Status { + status: response::Status, +} + +#[derive(Debug, Clone)] +pub struct Header { + status: response::Status, + headers: HashMap>, +} + +impl Builder { + pub fn new() -> Self { + Self + } + + pub fn status(self, status: response::Status) -> Status { + Status { status } + } +} + +impl Status { + pub fn headers, V: Into>, T: IntoIterator>( + self, + headers: T, + ) -> Header { + Header { + status: self.status, + headers: HashMap::from_iter( + headers + .into_iter() + .map(|(key, value)| (key.into(), value.into())), + ), + } + } + + pub fn body(self, body: response::Body) -> Response { + Response { + status: self.status, + headers: HashMap::new(), + body, + } + } +} + +impl Header { + pub fn headers, V: Into>, T: IntoIterator>( + self, + headers: T, + ) -> Header { + Header { + status: self.status, + headers: HashMap::from_iter( + headers + .into_iter() + .map(|(key, value)| (key.into(), value.into())), + ), + } + } + + pub fn body(self, body: response::Body) -> Response { + Response { + status: self.status, + headers: self.headers, + body, + } + } +} -- cgit v1.2.3