diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/client.rs | 39 | ||||
| -rw-r--r-- | src/handler.rs | 48 | ||||
| -rw-r--r-- | src/handler/lua.rs | 75 | ||||
| -rw-r--r-- | src/handler/staticfile.rs | 75 | ||||
| -rw-r--r-- | src/main.rs | 216 | ||||
| -rw-r--r-- | src/request.rs | 183 | ||||
| -rw-r--r-- | src/response.rs | 82 | ||||
| -rw-r--r-- | src/response/builder.rs | 75 |
8 files changed, 793 insertions, 0 deletions
diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..f5ed2f9 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,39 @@ +use tokio::io::{self, AsyncBufRead, AsyncWrite, AsyncWriteExt}; + +use crate::{ + request::{self, Request}, + response::Response, +}; + +#[derive(Debug)] +pub struct Client<R, W> { + reader: R, + writer: W, + buf: Vec<u8>, + line: Vec<u8>, +} + +impl<R, W> Client<R, W> { + pub fn new(reader: R, writer: W) -> Self { + Self { + reader, + writer, + buf: Vec::new(), + line: Vec::new(), + } + } +} + +impl<R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin> Client<R, W> { + pub async fn read_request(&mut self) -> Result<Option<Request>, request::Error> { + Request::parse(&mut self.reader, &mut self.buf, &mut self.line).await + } + + pub async fn send_response(&mut self, response: Response) -> Result<(), io::Error> { + response.serialize(&mut self.writer).await?; + + self.writer.flush().await?; + + Ok(()) + } +} diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..eb4d64a --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,48 @@ +mod lua; +mod staticfile; + +use mlua::{FromLua, Value}; + +use crate::{ + handler::{lua::LuaResponse, staticfile::StaticFile}, + request::Request, + response::Response, +}; + +pub trait Handle { + fn handle(self, request: Request) -> impl Future<Output = Result<Response, Error>>; +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("unsupported method")] + Unsupported, + + #[error("static file handler error: {0}")] + StaticFile(#[from] staticfile::Error), +} + +#[derive(Debug)] +pub enum Handler { + StaticFile(StaticFile), + Lua(LuaResponse), +} + +impl FromLua for Handler { + fn from_lua(value: Value, lua: &mlua::Lua) -> mlua::Result<Self> { + match value { + Value::Table(table) => match table.get::<String>("handler")?.as_str() { + "staticfile" => Ok(Self::StaticFile(StaticFile::from_lua( + Value::Table(table.clone()), + lua, + )?)), + "lua" => Ok(Self::Lua(LuaResponse::from_lua( + Value::Table(table.clone()), + lua, + )?)), + _ => Err(mlua::Error::runtime("unknown handler")), + }, + _ => Err(mlua::Error::runtime("expected table")), + } + } +} diff --git a/src/handler/lua.rs b/src/handler/lua.rs new file mode 100644 index 0000000..377111f --- /dev/null +++ b/src/handler/lua.rs @@ -0,0 +1,75 @@ +use std::{collections::HashMap, str::FromStr}; + +use mlua::{FromLua, Lua, Value}; + +use crate::{ + handler::{self, Handle}, + request::Request, + response::{Body, Response, Status}, +}; + +#[derive(Debug, Clone)] +pub struct LuaResponse { + content: Option<Vec<u8>>, + status: Status, + headers: HashMap<String, Vec<u8>>, +} + +impl Handle for LuaResponse { + async fn handle(self, _: Request) -> Result<Response, handler::Error> { + Ok(Response::builder() + .status(self.status) + .headers(self.headers) + .body(match self.content { + Some(content) => Body::Buffer(content), + None => Body::Empty, + })) + } +} + +impl FromLua for LuaResponse { + fn from_lua(value: Value, _: &Lua) -> mlua::Result<Self> { + match value { + Value::Table(table) => { + let content = match table.get("content")? { + Value::String(string) => Some(string.as_bytes().to_vec()), + _ => None, + }; + + let status = match table.get("status")? { + Value::String(string) => Status::from_str(&string.to_str()?).map_err(|e| { + mlua::Error::RuntimeError(format!( + "failed to parse status from string: {e}" + )) + })?, + Value::Integer(i) => Status::from_repr(i as usize).ok_or_else(|| { + mlua::Error::runtime(format!("failed to parse status from integer: {i}")) + })?, + _ => return Err(mlua::Error::runtime("invalid type when reading status")), + }; + + let headers: HashMap<String, Vec<u8>> = match table.get::<Value>("headers")? { + Value::Table(table) => { + let mut hashmap = HashMap::<String, Vec<u8>>::new(); + + for result in table.pairs() { + let (key, value): (String, mlua::BString) = result?; + + hashmap.insert(key, value.to_vec()); + } + + hashmap + } + _ => HashMap::new(), + }; + + Ok(Self { + content, + status, + headers, + }) + } + _ => Err(mlua::Error::runtime("expected table")), + } + } +} diff --git a/src/handler/staticfile.rs b/src/handler/staticfile.rs new file mode 100644 index 0000000..f208528 --- /dev/null +++ b/src/handler/staticfile.rs @@ -0,0 +1,75 @@ +use crate::{ + handler::{self, Handle}, + request::{Method, Request}, + response::{Body, Response, Status}, +}; + +use std::{ffi::OsString, os::unix::ffi::OsStringExt, path::PathBuf}; + +use mlua::{FromLua, Value}; + +use tokio::{ + fs::File, + io::{self}, +}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("io error: {0}")] + Io(#[from] io::Error), +} + +#[derive(Debug, Clone)] +pub struct StaticFile { + path: PathBuf, + mime: String, +} + +impl Handle for StaticFile { + async fn handle(self, request: Request) -> Result<Response, handler::Error> { + match request.method() { + Method::Get | Method::Head => match File::open(&self.path).await { + Ok(file) => { + let metadata = file.metadata().await.map_err(Error::Io)?; + + if metadata.is_file() { + Ok(Response::builder() + .status(Status::Ok) + .headers([ + ("content-length", format!("{}", metadata.len())), + ("content-type", self.mime), + ]) + .body(match request.method() { + Method::Get => Body::File(file), + Method::Head => Body::Empty, + })) + } else { + Ok(Response::builder() + .status(Status::NotFound) + .headers([("content-length", "0")]) + .body(Body::Empty)) + } + } + Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Response::builder() + .status(Status::NotFound) + .headers([("content-length", "0")]) + .body(Body::Empty)), + Err(e) => Err(Error::Io(e))?, + }, + } + } +} + +impl FromLua for StaticFile { + fn from_lua(value: mlua::Value, _: &mlua::Lua) -> mlua::Result<Self> { + match value { + Value::Table(table) => Ok(Self { + path: PathBuf::from(OsString::from_vec( + table.get::<mlua::String>("path")?.as_bytes().to_vec(), + )), + mime: table.get::<String>("mime")?, + }), + _ => Err(mlua::Error::runtime("expected table")), + } + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..b422f47 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,216 @@ +#![allow(dead_code)] + +use std::{ + ffi::OsString, + os::unix::ffi::OsStringExt, + path::{Component, Path, PathBuf}, + process::ExitCode, +}; + +use mlua::{Function, Lua, Table}; + +use tokio::{ + io::{self, BufReader, BufWriter}, + net::{TcpListener, TcpStream}, +}; + +use crate::{ + client::Client, + handler::{Handle, Handler}, + request::Request, + response::{Response, Status}, +}; + +mod client; +mod handler; +mod request; +mod response; + +macro_rules! exit { + ($fmt:literal, $($s:expr),*) => { + { + eprintln!($fmt, $($s),*); + + return ::std::process::ExitCode::FAILURE + } + } +} + +#[derive(Debug, thiserror::Error)] +enum HandleError { + #[error("error reading handler function: {0}")] + Function(mlua::Error), + + #[error("error calling handler: {0}")] + InvokeHandler(mlua::Error), + + #[error("handler error: {0}")] + Handler(#[from] handler::Error), +} + +#[derive(Debug, thiserror::Error)] +enum ResponseError { + #[error("error reading request: {0}")] + Request(request::Error), + + #[error("error sending response: {0}")] + Response(io::Error), +} + +#[derive(Debug, thiserror::Error)] +enum InitLuaError { + #[error("failed to create variable: {0}")] + CreateTable(mlua::Error), + + #[error("failed to create helper function: {0}")] + Function(mlua::Error), + + #[error("failed to set variable {1}: {0}")] + SetVar(mlua::Error, String), + + #[error("failed to load variable {1}: {0}")] + LoadVar(mlua::Error, String), + + #[error("failed to load config: {0}")] + LoadConfig(std::io::Error), + + #[error("failed to eval config: {0}")] + EvalConfig(mlua::Error), +} + +async fn handle(handlers: Table, request: Request) -> Result<Response, HandleError> { + let method = request.method().to_string(); + + let function = handlers + .get::<Function>(method.as_str()) + .map_err(HandleError::Function)?; + + let handler = function + .call::<Handler>(request.clone()) + .map_err(HandleError::InvokeHandler)?; + + match handler { + Handler::StaticFile(staticfile) => Ok(staticfile.handle(request).await?), + Handler::Lua(lua_response) => Ok(lua_response.handle(request).await?), + } +} + +async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseError> { + let mut client = { + let (r, w) = stream.into_split(); + + Client::new(BufReader::new(r), BufWriter::new(w)) + }; + + while let Some(request) = client + .read_request() + .await + .map_err(ResponseError::Request)? + { + let response = match handle(handlers.clone(), request).await { + Ok(response) => response, + Err(e) => { + eprintln!("failed to handle request: {e:?}"); + + Response::builder() + .status(Status::InternalServerError) + .headers([("content-length", "0")]) + .body(response::Body::Empty) + } + }; + + client + .send_response(response) + .await + .map_err(ResponseError::Response)?; + } + + Ok(()) +} + +fn join_path(root: &Path, rest: &Path) -> Result<PathBuf, std::io::Error> { + root.components() + .chain( + rest.components() + .filter(|p| !matches!(p, Component::RootDir)), + ) + .filter(|p| !matches!(p, Component::CurDir | Component::ParentDir)) + .collect::<PathBuf>() + .canonicalize() +} + +fn init_lua(lua: Lua) -> Result<(), InitLuaError> { + let http = lua.create_table().map_err(InitLuaError::CreateTable)?; + + lua.globals() + .set("http", http.clone()) + .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?; + + http.set( + "join_paths", + lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| { + let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec())); + let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec())); + + join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}"))) + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?; + + let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?); + + chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?; + + Ok(()) +} + +#[allow(unexpected_cfgs)] +#[tokio::main(flavor = "local")] +async fn main() -> ExitCode { + let lua = Lua::new(); + + if let Err(e) = init_lua(lua.clone()) { + exit!("failed to init lua: {}", e); + } + + let http = match lua.globals().get::<Table>("http") { + Ok(http) => http, + Err(e) => exit!("failed to load table 'http': {}", e), + }; + + let bind = match http.get::<String>("bind") { + Ok(bind) => bind, + Err(e) => exit!("failed to load string 'http.bind': {}", e), + }; + + let handlers = match http.get::<Table>("handlers") { + Ok(handlers) => handlers, + Err(e) => exit!("failed to load 'http.handlers': {}", e), + }; + + let listener = match TcpListener::bind(&bind).await { + Ok(listener) => listener, + Err(e) => exit!("failed to bind to {}: {}", bind, e), + }; + + loop { + let (stream, addr) = match listener.accept().await { + Ok((stream, addr)) => (stream, addr), + Err(e) => { + eprintln!("failed to accept connection: {e}"); + continue; + } + }; + + eprintln!("accepted connection from {addr}"); + + let future = response(handlers.clone(), stream); + + tokio::task::spawn_local(async { + if let Err(e) = future.await { + eprintln!("response failure: {e:?}"); + } + }); + } +} diff --git a/src/request.rs b/src/request.rs new file mode 100644 index 0000000..99b7482 --- /dev/null +++ b/src/request.rs @@ -0,0 +1,183 @@ +use mlua::UserData; +use mon::{ + Parser, ParserIter, any, ascii_alphanumeric, ascii_whitespace, input::InputIter, one_of, tag, + whitespace, +}; + +use strum::Display; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +use std::{collections::HashMap, io, num::ParseIntError, str}; + +use get::Get; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("io error: {0}")] + Io(#[from] io::Error), + + #[error("parser error")] + Parse(usize), + + #[error("parse int error: {0}")] + ParseInt(ParseIntError), + + #[error("unicode error: {0}")] + Unicode(#[from] str::Utf8Error), +} + +#[derive(Debug, Clone, Copy, Display)] +pub enum Method { + #[strum(to_string = "GET")] + Get, + + #[strum(to_string = "HEAD")] + Head, +} + +#[derive(Debug, Clone, Get)] +pub struct Path(#[get(method = "inner")] Vec<u8>); + +#[derive(Debug, Clone, Get)] +pub struct Request { + method: Method, + path: Path, + headers: HashMap<String, Vec<u8>>, +} + +impl Request { + pub async fn parse<T: AsyncBufRead + Unpin>( + mut reader: T, + buf: &mut Vec<u8>, + line: &mut Vec<u8>, + ) -> Result<Option<Self>, Error> { + buf.clear(); + + loop { + line.clear(); + + if reader.read_until(b'\n', line).await? == 0 { + return Ok(None); + } + + if line == b"\r\n" { + break; + } + + buf.extend_from_slice(line); + } + + let (method, path, headers) = match parse().parse_finished(InputIter::new(buf)) { + Ok(((method, path), headers)) => (method, Path(path), headers), + Err(mon::ParserFinishedError::Err(e) | mon::ParserFinishedError::Unfinished(e)) => { + return Err(Error::Parse(e.position())); + } + }; + + let headers = headers + .into_iter() + .map(|(key, value)| { + Ok::<(String, Vec<u8>), str::Utf8Error>(( + str::from_utf8(&key)?.to_lowercase(), + value, + )) + }) + .collect::<Result<HashMap<String, Vec<u8>>, _>>()?; + + Ok(Some(Self { + method, + path, + headers, + })) + } +} + +impl UserData for Request { + fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) { + fields.add_field_method_get("method", |_, this| Ok(this.method().to_string())); + + fields.add_field_method_get("path", |_, this| { + Ok(mlua::BString::new(this.path().inner().clone())) + }); + + fields.add_field_method_get("headers", |lua, this| { + let table = lua.create_table()?; + + for (key, value) in this.headers() { + table.set(key.clone(), value.clone())?; + } + + Ok(table) + }) + } +} + +fn method<'a>() -> impl Parser<&'a [u8], Output = Method> { + tag(b"GET".as_slice()) + .map(|_| Method::Get) + .or(tag(b"HEAD".as_slice()).map(|_| Method::Head)) +} + +fn path<'a>() -> impl Parser<&'a [u8], Output = Vec<u8>> { + any().and_not(whitespace()).repeated().at_least(1) +} + +fn header<'a>() -> impl Parser<&'a [u8], Output = (Vec<u8>, Vec<u8>)> { + let key = ascii_alphanumeric() + .followed_by( + ascii_alphanumeric() + .or(one_of(b"-".iter().copied())) + .repeated() + .many(), + ) + .recognize() + .map(|output: &[u8]| output.to_vec()); + + let value = any() + .and_not(tag(b"\r\n".as_slice())) + .repeated() + .at_least(1); + + key.and(value.preceded_by(tag(b": ".as_slice()))) +} + +#[allow(clippy::type_complexity)] +fn parse<'a>() -> impl Parser<&'a [u8], Output = ((Method, Vec<u8>), Vec<(Vec<u8>, Vec<u8>)>)> { + method() + .followed_by(ascii_whitespace()) + .and(path()) + .followed_by(ascii_whitespace()) + .followed_by(tag(b"HTTP/1.1\r\n".as_slice())) + .and( + header() + .separated_by_with_trailing(tag(b"\r\n".as_slice())) + .many(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_header() { + let input = b"Content-Length: 100"; + + header().parse_finished(InputIter::new(input)).unwrap(); + } + + #[tokio::test] + async fn test_parse_get() { + let mut buf = Vec::new(); + let mut line = Vec::new(); + let input = b"GET /path HTTP/1.1\r\nContent-Length: 100\r\n"; + + match Request::parse(input.as_slice(), &mut buf, &mut line).await { + Ok(_) => (), + Err(Error::Parse(position)) => { + panic!("{}", &str::from_utf8(input).unwrap()[position..]) + } + Err(e) => panic!("{e}"), + } + } +} diff --git a/src/response.rs b/src/response.rs new file mode 100644 index 0000000..68189fb --- /dev/null +++ b/src/response.rs @@ -0,0 +1,82 @@ +mod builder; + +use get::Get; + +use std::collections::HashMap; + +use tokio::{ + fs::File, + io::{self, AsyncWrite, AsyncWriteExt}, +}; + +use strum::{Display, EnumString, FromRepr}; + +use builder::Builder; + +#[allow(unreachable_patterns)] +#[derive(Debug, Clone, Copy, Display, EnumString, FromRepr)] +pub enum Status { + #[strum(to_string = "200 OK", serialize = "OK")] + Ok = 200, + + #[strum(to_string = "404 Not Found", serialize = "Not Found")] + NotFound = 404, + + #[strum( + to_string = "500 Internal Server Error", + serialize = "Internal ServerError" + )] + InternalServerError = 500, +} + +#[derive(Debug)] +pub enum Body { + File(File), + Buffer(Vec<u8>), + Empty, +} + +#[derive(Debug, Get)] +pub struct Response { + status: Status, + headers: HashMap<String, Vec<u8>>, + body: Body, +} + +impl Response { + pub fn builder() -> Builder { + Builder::new() + } +} + +impl Response { + pub async fn serialize<W: AsyncWrite + Unpin>( + mut self, + mut writer: W, + ) -> Result<(), io::Error> { + writer.write_all(b"HTTP/1.1 ").await?; + writer.write_all(self.status.to_string().as_bytes()).await?; + writer.write_all(b"\r\n").await?; + + for (key, value) in &self.headers { + writer.write_all(key.as_bytes()).await?; + writer.write_all(b": ").await?; + writer.write_all(value).await?; + writer.write_all(b"\r\n").await?; + } + + writer.write_all(b"\r\n").await?; + + match &mut self.body { + Body::File(file) => { + io::copy(file, &mut writer).await?; + } + Body::Buffer(buf) => { + writer.write_all(buf).await?; + } + Body::Empty => (), + } + + Ok(()) + } +} diff --git a/src/response/builder.rs b/src/response/builder.rs new file mode 100644 index 0000000..b99b824 --- /dev/null +++ b/src/response/builder.rs @@ -0,0 +1,75 @@ +use std::collections::HashMap; + +use crate::response::{self, Response}; + +#[derive(Debug, Clone, Copy)] +pub struct Builder; + +#[derive(Debug, Clone, Copy)] +pub struct Status { + status: response::Status, +} + +#[derive(Debug, Clone)] +pub struct Header { + status: response::Status, + headers: HashMap<String, Vec<u8>>, +} + +impl Builder { + pub fn new() -> Self { + Self + } + + pub fn status(self, status: response::Status) -> Status { + Status { status } + } +} + +impl Status { + pub fn headers<K: Into<String>, V: Into<Vec<u8>>, T: IntoIterator<Item = (K, V)>>( + self, + headers: T, + ) -> Header { + Header { + status: self.status, + headers: HashMap::from_iter( + headers + .into_iter() + .map(|(key, value)| (key.into(), value.into())), + ), + } + } + + pub fn body(self, body: response::Body) -> Response { + Response { + status: self.status, + headers: HashMap::new(), + body, + } + } +} + +impl Header { + pub fn headers<K: Into<String>, V: Into<Vec<u8>>, T: IntoIterator<Item = (K, V)>>( + self, + headers: T, + ) -> Header { + Header { + status: self.status, + headers: HashMap::from_iter( + headers + .into_iter() + .map(|(key, value)| (key.into(), value.into())), + ), + } + } + + pub fn body(self, body: response::Body) -> Response { + Response { + status: self.status, + headers: self.headers, + body, + } + } +} |
