diff options
| -rw-r--r-- | src/handler.rs | 48 | ||||
| -rw-r--r-- | src/handler/lua.rs | 75 | ||||
| -rw-r--r-- | src/handler/staticfile.rs | 75 | ||||
| -rw-r--r-- | src/lua.rs | 176 | ||||
| -rw-r--r-- | src/main.rs | 190 | ||||
| -rw-r--r-- | src/response.rs | 51 |
6 files changed, 271 insertions, 344 deletions
diff --git a/src/handler.rs b/src/handler.rs deleted file mode 100644 index eb4d64a..0000000 --- a/src/handler.rs +++ /dev/null @@ -1,48 +0,0 @@ -mod lua; -mod staticfile; - -use mlua::{FromLua, Value}; - -use crate::{ - handler::{lua::LuaResponse, staticfile::StaticFile}, - request::Request, - response::Response, -}; - -pub trait Handle { - fn handle(self, request: Request) -> impl Future<Output = Result<Response, Error>>; -} - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("unsupported method")] - Unsupported, - - #[error("static file handler error: {0}")] - StaticFile(#[from] staticfile::Error), -} - -#[derive(Debug)] -pub enum Handler { - StaticFile(StaticFile), - Lua(LuaResponse), -} - -impl FromLua for Handler { - fn from_lua(value: Value, lua: &mlua::Lua) -> mlua::Result<Self> { - match value { - Value::Table(table) => match table.get::<String>("handler")?.as_str() { - "staticfile" => Ok(Self::StaticFile(StaticFile::from_lua( - Value::Table(table.clone()), - lua, - )?)), - "lua" => Ok(Self::Lua(LuaResponse::from_lua( - Value::Table(table.clone()), - lua, - )?)), - _ => Err(mlua::Error::runtime("unknown handler")), - }, - _ => Err(mlua::Error::runtime("expected table")), - } - } -} diff --git a/src/handler/lua.rs b/src/handler/lua.rs deleted file mode 100644 index 377111f..0000000 --- a/src/handler/lua.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::{collections::HashMap, str::FromStr}; - -use mlua::{FromLua, Lua, Value}; - -use crate::{ - handler::{self, Handle}, - request::Request, - response::{Body, Response, Status}, -}; - -#[derive(Debug, Clone)] -pub struct LuaResponse { - content: Option<Vec<u8>>, - status: Status, - headers: HashMap<String, Vec<u8>>, -} - -impl Handle for LuaResponse { - async fn handle(self, _: Request) -> Result<Response, handler::Error> { - Ok(Response::builder() - .status(self.status) - .headers(self.headers) - .body(match self.content { - Some(content) => Body::Buffer(content), - None => Body::Empty, - })) - } -} - -impl FromLua for LuaResponse { - fn from_lua(value: Value, _: &Lua) -> mlua::Result<Self> { - match value { - Value::Table(table) => { - let content = match table.get("content")? { - Value::String(string) => Some(string.as_bytes().to_vec()), - _ => None, - }; - - let status = match table.get("status")? { - Value::String(string) => Status::from_str(&string.to_str()?).map_err(|e| { - mlua::Error::RuntimeError(format!( - "failed to parse status from string: {e}" - )) - })?, - Value::Integer(i) => Status::from_repr(i as usize).ok_or_else(|| { - mlua::Error::runtime(format!("failed to parse status from integer: {i}")) - })?, - _ => return Err(mlua::Error::runtime("invalid type when reading status")), - }; - - let headers: HashMap<String, Vec<u8>> = match table.get::<Value>("headers")? { - Value::Table(table) => { - let mut hashmap = HashMap::<String, Vec<u8>>::new(); - - for result in table.pairs() { - let (key, value): (String, mlua::BString) = result?; - - hashmap.insert(key, value.to_vec()); - } - - hashmap - } - _ => HashMap::new(), - }; - - Ok(Self { - content, - status, - headers, - }) - } - _ => Err(mlua::Error::runtime("expected table")), - } - } -} diff --git a/src/handler/staticfile.rs b/src/handler/staticfile.rs deleted file mode 100644 index f208528..0000000 --- a/src/handler/staticfile.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::{ - handler::{self, Handle}, - request::{Method, Request}, - response::{Body, Response, Status}, -}; - -use std::{ffi::OsString, os::unix::ffi::OsStringExt, path::PathBuf}; - -use mlua::{FromLua, Value}; - -use tokio::{ - fs::File, - io::{self}, -}; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("io error: {0}")] - Io(#[from] io::Error), -} - -#[derive(Debug, Clone)] -pub struct StaticFile { - path: PathBuf, - mime: String, -} - -impl Handle for StaticFile { - async fn handle(self, request: Request) -> Result<Response, handler::Error> { - match request.method() { - Method::Get | Method::Head => match File::open(&self.path).await { - Ok(file) => { - let metadata = file.metadata().await.map_err(Error::Io)?; - - if metadata.is_file() { - Ok(Response::builder() - .status(Status::Ok) - .headers([ - ("content-length", format!("{}", metadata.len())), - ("content-type", self.mime), - ]) - .body(match request.method() { - Method::Get => Body::File(file), - Method::Head => Body::Empty, - })) - } else { - Ok(Response::builder() - .status(Status::NotFound) - .headers([("content-length", "0")]) - .body(Body::Empty)) - } - } - Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Response::builder() - .status(Status::NotFound) - .headers([("content-length", "0")]) - .body(Body::Empty)), - Err(e) => Err(Error::Io(e))?, - }, - } - } -} - -impl FromLua for StaticFile { - fn from_lua(value: mlua::Value, _: &mlua::Lua) -> mlua::Result<Self> { - match value { - Value::Table(table) => Ok(Self { - path: PathBuf::from(OsString::from_vec( - table.get::<mlua::String>("path")?.as_bytes().to_vec(), - )), - mime: table.get::<String>("mime")?, - }), - _ => Err(mlua::Error::runtime("expected table")), - } - } -} diff --git a/src/lua.rs b/src/lua.rs new file mode 100644 index 0000000..d116433 --- /dev/null +++ b/src/lua.rs @@ -0,0 +1,176 @@ +use std::{ + ffi::OsString, + os::unix::ffi::OsStringExt, + path::{Component, Path, PathBuf}, +}; + +use mlua::{AnyUserData, Lua, UserData, Value}; + +use tokio::{fs, io}; + +#[derive(Debug, thiserror::Error)] +pub enum InitLuaError { + #[error("failed to create variable: {0}")] + CreateTable(mlua::Error), + + #[error("failed to create helper function: {0}")] + Function(mlua::Error), + + #[error("failed to set variable {1}: {0}")] + SetVar(mlua::Error, String), + + #[error("failed to load variable {1}: {0}")] + LoadVar(mlua::Error, String), + + #[error("failed to load config: {0}")] + LoadConfig(std::io::Error), + + #[error("failed to eval config: {0}")] + EvalConfig(mlua::Error), +} + +#[derive(Debug, thiserror::Error)] +enum JoinPathError { + #[error("unsafe path")] + Unsafe, + + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} + +#[derive(Debug)] +pub struct File(Option<fs::File>); + +#[derive(Debug)] +struct Metadata(std::fs::Metadata); + +impl File { + pub fn take(&mut self) -> Option<fs::File> { + self.0.take() + } +} + +impl UserData for File {} + +impl UserData for Metadata { + fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) { + fields.add_field_method_get("len", |_, this| Ok(this.0.len())); + } +} + +fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> { + let mut pathbuf = root.as_ref().to_path_buf(); + + for component in rest.as_ref().components() { + match component { + Component::Normal(path) => { + pathbuf.push(path); + } + Component::ParentDir => { + pathbuf.pop(); + + if !pathbuf.starts_with(root.as_ref()) { + return Err(JoinPathError::Unsafe); + } + } + _ => continue, + } + } + + let canon = match pathbuf.canonicalize() { + Ok(canon) => canon, + Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf, + Err(e) => return Err(JoinPathError::Io(e)), + }; + + if !canon.starts_with(root.as_ref()) { + return Err(JoinPathError::Unsafe); + } + + Ok(canon) +} + +pub(super) fn init(lua: Lua) -> Result<(), InitLuaError> { + let http = lua.create_table().map_err(InitLuaError::CreateTable)?; + + lua.globals() + .set("http", http.clone()) + .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?; + + http.set( + "join_paths", + lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| { + let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec())); + let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec())); + + join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}"))) + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?; + + let io = lua.create_table().map_err(InitLuaError::CreateTable)?; + + http.set("io", io.clone()) + .map_err(|e| InitLuaError::SetVar(e, "http.io".to_string()))?; + + io.set( + "fopen", + lua.create_async_function(|lua, path: mlua::String| async move { + let pathbuf = PathBuf::from(OsString::from_vec(path.as_bytes().to_vec())); + + match fs::File::open(&pathbuf).await { + Ok(f) => { + let data = lua.create_any_userdata(File(Some(f)))?; + + Ok(Value::UserData(data)) + } + Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Value::Nil), + Err(e) => Err(mlua::Error::runtime(format!("failed to open file: {e}"))), + } + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.io.fopen".to_string()))?; + + io.set( + "metadata", + lua.create_async_function(|_, file: AnyUserData| async move { + let metadata = match file.borrow::<File>() { + Ok(f) => f.0.as_ref().unwrap().metadata().await.map_err(|e| { + mlua::Error::runtime(format!("failed to read file metadata: {e}")) + })?, + Err(_) => { + return Err(mlua::Error::runtime( + "failed to read file metadata, expected file type", + )); + } + }; + + Ok(Metadata(metadata)) + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.io.stat".to_string()))?; + + let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?); + + chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?; + + Ok(()) +} + +#[cfg(test)] +mod test { + + use super::*; + + #[test] + fn test_path_traversal() { + let root = "/var/www"; + let get = "/../../bin/sh"; + let joined = join_path(root, get); + + assert!(matches!(joined, Err(JoinPathError::Unsafe))); + } +} diff --git a/src/main.rs b/src/main.rs index 401af02..cc19487 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,8 @@ #![allow(dead_code)] -use std::{ - ffi::OsString, - os::unix::ffi::OsStringExt, - path::{Component, Path, PathBuf}, - process::ExitCode, -}; +use std::process::ExitCode; -use mlua::{Function, Lua, Table}; +use mlua::{FromLua, Lua, Table, Value}; use tokio::{ io::{self, BufReader, BufWriter}, @@ -16,13 +11,12 @@ use tokio::{ use crate::{ client::Client, - handler::{Handle, Handler}, request::Request, response::{Response, Status}, }; mod client; -mod handler; +mod lua; mod request; mod response; @@ -36,75 +30,44 @@ macro_rules! exit { } } -#[derive(Debug, thiserror::Error)] -enum HandleError { - #[error("error reading handler function: {0}")] - Function(mlua::Error), - - #[error("error calling handler: {0}")] - InvokeHandler(mlua::Error), - - #[error("handler error: {0}")] - Handler(#[from] handler::Error), -} +static INTERNAL_SERVER_ERROR_TEXT: &str = "internal server error"; #[derive(Debug, thiserror::Error)] enum ResponseError { - #[error("error reading request: {0}")] - Request(request::Error), - - #[error("error sending response: {0}")] - Response(io::Error), -} - -#[derive(Debug, thiserror::Error)] -enum InitLuaError { - #[error("failed to create variable: {0}")] - CreateTable(mlua::Error), - - #[error("failed to create helper function: {0}")] - Function(mlua::Error), + #[error("failed to interpet lua handler return value as response: {0}")] + ResponseFromLua(mlua::Error), - #[error("failed to set variable {1}: {0}")] - SetVar(mlua::Error, String), - - #[error("failed to load variable {1}: {0}")] - LoadVar(mlua::Error, String), - - #[error("failed to load config: {0}")] - LoadConfig(std::io::Error), - - #[error("failed to eval config: {0}")] - EvalConfig(mlua::Error), + #[error("failed to invoke handler: {0}")] + Handler(mlua::Error), } #[derive(Debug, thiserror::Error)] -enum JoinPathError { - #[error("unsafe path")] - Unsafe, +enum HandleError { + #[error("error reading request: {0}")] + ReadRequest(request::Error), - #[error("io error: {0}")] - Io(#[from] std::io::Error), + #[error("error sending response: {0}")] + SendResponse(io::Error), } -async fn handle(handlers: Table, request: Request) -> Result<Response, HandleError> { - let method = request.method().to_string(); - - let function = handlers - .get::<Function>(method.as_str()) - .map_err(HandleError::Function)?; - - let handler = function - .call::<Handler>(request.clone()) - .map_err(HandleError::InvokeHandler)?; +async fn response(lua: Lua, handlers: Table, request: Request) -> Result<Response, ResponseError> { + match handlers + .get::<Value>(request.method().to_string().as_str()) + .map_err(ResponseError::Handler)? + { + Value::Function(function) => { + let ret = function + .call_async(request) + .await + .map_err(ResponseError::Handler)?; - match handler { - Handler::StaticFile(staticfile) => Ok(staticfile.handle(request).await?), - Handler::Lua(lua_response) => Ok(lua_response.handle(request).await?), + Response::from_lua(ret, &lua).map_err(ResponseError::ResponseFromLua) + } + _ => todo!(), } } -async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseError> { +async fn handle(lua: Lua, handlers: Table, stream: TcpStream) -> Result<(), HandleError> { let mut client = { let (r, w) = stream.into_split(); @@ -114,93 +77,43 @@ async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseErro while let Some(request) = client .read_request() .await - .map_err(ResponseError::Request)? + .map_err(HandleError::ReadRequest)? { - let response = match handle(handlers.clone(), request).await { - Ok(response) => response, + let response = match response(lua.clone(), handlers.clone(), request).await { + Ok(r) => r, Err(e) => { - eprintln!("failed to handle request: {e:?}"); + eprintln!("{e}"); Response::builder() .status(Status::InternalServerError) - .headers([("content-length", "0")]) - .body(response::Body::Empty) + .headers([ + ( + "content-length", + INTERNAL_SERVER_ERROR_TEXT.len().to_string(), + ), + ("content-type", "text/plain".to_string()), + ]) + .body(response::Body::Buffer(Vec::from( + INTERNAL_SERVER_ERROR_TEXT.to_string(), + ))) } }; client .send_response(response) .await - .map_err(ResponseError::Response)?; + .map_err(HandleError::SendResponse)?; } Ok(()) } -fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> { - let mut pathbuf = root.as_ref().to_path_buf(); - - for component in rest.as_ref().components() { - match component { - Component::Normal(path) => { - pathbuf.push(path); - } - Component::ParentDir => { - pathbuf.pop(); - - if !pathbuf.starts_with(root.as_ref()) { - return Err(JoinPathError::Unsafe); - } - } - _ => continue, - } - } - - let canon = match pathbuf.canonicalize() { - Ok(canon) => canon, - Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf, - Err(e) => return Err(JoinPathError::Io(e)), - }; - - if !canon.starts_with(root.as_ref()) { - return Err(JoinPathError::Unsafe); - } - - Ok(canon) -} - -fn init_lua(lua: Lua) -> Result<(), InitLuaError> { - let http = lua.create_table().map_err(InitLuaError::CreateTable)?; - - lua.globals() - .set("http", http.clone()) - .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?; - - http.set( - "join_paths", - lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| { - let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec())); - let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec())); - - join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}"))) - }) - .map_err(InitLuaError::Function)?, - ) - .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?; - - let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?); - - chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?; - - Ok(()) -} - #[allow(unexpected_cfgs)] #[tokio::main(flavor = "local")] async fn main() -> ExitCode { let lua = Lua::new(); - if let Err(e) = init_lua(lua.clone()) { + if let Err(e) = lua::init(lua.clone()) { exit!("failed to init lua: {}", e); } @@ -235,7 +148,7 @@ async fn main() -> ExitCode { eprintln!("accepted connection from {addr}"); - let future = response(handlers.clone(), stream); + let future = handle(lua.clone(), handlers.clone(), stream); tokio::task::spawn_local(async { if let Err(e) = future.await { @@ -244,18 +157,3 @@ async fn main() -> ExitCode { }); } } - -#[cfg(test)] -mod test { - - use super::*; - - #[test] - fn test_path_traversal() { - let root = "/var/www"; - let get = "/../../bin/sh"; - let joined = join_path(root, get); - - assert!(matches!(joined, Err(JoinPathError::Unsafe))); - } -} diff --git a/src/response.rs b/src/response.rs index 68189fb..f6f0fa0 100644 --- a/src/response.rs +++ b/src/response.rs @@ -2,6 +2,8 @@ mod builder; use get::Get; +use mlua::{FromLua, Lua, Value}; + use std::collections::HashMap; use tokio::{ @@ -13,6 +15,8 @@ use strum::{Display, EnumString, FromRepr}; use builder::Builder; +use crate::lua; + #[allow(unreachable_patterns)] #[derive(Debug, Clone, Copy, Display, EnumString, FromRepr)] pub enum Status { @@ -80,3 +84,50 @@ impl Response { Ok(()) } } + +impl FromLua for Response { + fn from_lua(value: Value, _: &Lua) -> mlua::Result<Self> { + match value { + Value::Table(table) => { + let status = match table.get("status")? { + Value::Integer(i) => match Status::from_repr(i as usize) { + Some(status) => status, + None => return Err(mlua::Error::runtime(format!("invalid status: {i}"))), + }, + _ => return Err(mlua::Error::runtime("invalid status")), + }; + + let headers = match table.get("headers")? { + Value::Table(table) => { + let mut headers = HashMap::new(); + + for result in table.pairs() { + let (key, value): (String, mlua::String) = result?; + + headers.insert(key, value.as_bytes().to_vec()); + } + + headers + } + Value::Nil => HashMap::new(), + _ => return Err(mlua::Error::runtime("invalid headers")), + }; + + let body = match table.get("body")? { + Value::String(string) => Body::Buffer(string.as_bytes().to_vec()), + Value::UserData(userdata) if userdata.is::<lua::File>() => { + Body::File(userdata.borrow_mut::<lua::File>().unwrap().take().unwrap()) + } + _ => return Err(mlua::Error::runtime("invalid body")), + }; + + Ok(Self { + status, + headers, + body, + }) + } + _ => Err(mlua::Error::runtime("invalid response, expected a table")), + } + } +} |
