diff options
Diffstat (limited to 'src/main.rs')
| -rw-r--r-- | src/main.rs | 190 |
1 files changed, 44 insertions, 146 deletions
diff --git a/src/main.rs b/src/main.rs index 401af02..cc19487 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,8 @@ #![allow(dead_code)] -use std::{ - ffi::OsString, - os::unix::ffi::OsStringExt, - path::{Component, Path, PathBuf}, - process::ExitCode, -}; +use std::process::ExitCode; -use mlua::{Function, Lua, Table}; +use mlua::{FromLua, Lua, Table, Value}; use tokio::{ io::{self, BufReader, BufWriter}, @@ -16,13 +11,12 @@ use tokio::{ use crate::{ client::Client, - handler::{Handle, Handler}, request::Request, response::{Response, Status}, }; mod client; -mod handler; +mod lua; mod request; mod response; @@ -36,75 +30,44 @@ macro_rules! exit { } } -#[derive(Debug, thiserror::Error)] -enum HandleError { - #[error("error reading handler function: {0}")] - Function(mlua::Error), - - #[error("error calling handler: {0}")] - InvokeHandler(mlua::Error), - - #[error("handler error: {0}")] - Handler(#[from] handler::Error), -} +static INTERNAL_SERVER_ERROR_TEXT: &str = "internal server error"; #[derive(Debug, thiserror::Error)] enum ResponseError { - #[error("error reading request: {0}")] - Request(request::Error), - - #[error("error sending response: {0}")] - Response(io::Error), -} - -#[derive(Debug, thiserror::Error)] -enum InitLuaError { - #[error("failed to create variable: {0}")] - CreateTable(mlua::Error), - - #[error("failed to create helper function: {0}")] - Function(mlua::Error), + #[error("failed to interpet lua handler return value as response: {0}")] + ResponseFromLua(mlua::Error), - #[error("failed to set variable {1}: {0}")] - SetVar(mlua::Error, String), - - #[error("failed to load variable {1}: {0}")] - LoadVar(mlua::Error, String), - - #[error("failed to load config: {0}")] - LoadConfig(std::io::Error), - - #[error("failed to eval config: {0}")] - EvalConfig(mlua::Error), + #[error("failed to invoke handler: {0}")] + Handler(mlua::Error), } #[derive(Debug, thiserror::Error)] -enum JoinPathError { - #[error("unsafe path")] - Unsafe, +enum HandleError { + #[error("error reading request: {0}")] + ReadRequest(request::Error), - #[error("io error: {0}")] - Io(#[from] std::io::Error), + #[error("error sending response: {0}")] + SendResponse(io::Error), } -async fn handle(handlers: Table, request: Request) -> Result<Response, HandleError> { - let method = request.method().to_string(); - - let function = handlers - .get::<Function>(method.as_str()) - .map_err(HandleError::Function)?; - - let handler = function - .call::<Handler>(request.clone()) - .map_err(HandleError::InvokeHandler)?; +async fn response(lua: Lua, handlers: Table, request: Request) -> Result<Response, ResponseError> { + match handlers + .get::<Value>(request.method().to_string().as_str()) + .map_err(ResponseError::Handler)? + { + Value::Function(function) => { + let ret = function + .call_async(request) + .await + .map_err(ResponseError::Handler)?; - match handler { - Handler::StaticFile(staticfile) => Ok(staticfile.handle(request).await?), - Handler::Lua(lua_response) => Ok(lua_response.handle(request).await?), + Response::from_lua(ret, &lua).map_err(ResponseError::ResponseFromLua) + } + _ => todo!(), } } -async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseError> { +async fn handle(lua: Lua, handlers: Table, stream: TcpStream) -> Result<(), HandleError> { let mut client = { let (r, w) = stream.into_split(); @@ -114,93 +77,43 @@ async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseErro while let Some(request) = client .read_request() .await - .map_err(ResponseError::Request)? + .map_err(HandleError::ReadRequest)? { - let response = match handle(handlers.clone(), request).await { - Ok(response) => response, + let response = match response(lua.clone(), handlers.clone(), request).await { + Ok(r) => r, Err(e) => { - eprintln!("failed to handle request: {e:?}"); + eprintln!("{e}"); Response::builder() .status(Status::InternalServerError) - .headers([("content-length", "0")]) - .body(response::Body::Empty) + .headers([ + ( + "content-length", + INTERNAL_SERVER_ERROR_TEXT.len().to_string(), + ), + ("content-type", "text/plain".to_string()), + ]) + .body(response::Body::Buffer(Vec::from( + INTERNAL_SERVER_ERROR_TEXT.to_string(), + ))) } }; client .send_response(response) .await - .map_err(ResponseError::Response)?; + .map_err(HandleError::SendResponse)?; } Ok(()) } -fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> { - let mut pathbuf = root.as_ref().to_path_buf(); - - for component in rest.as_ref().components() { - match component { - Component::Normal(path) => { - pathbuf.push(path); - } - Component::ParentDir => { - pathbuf.pop(); - - if !pathbuf.starts_with(root.as_ref()) { - return Err(JoinPathError::Unsafe); - } - } - _ => continue, - } - } - - let canon = match pathbuf.canonicalize() { - Ok(canon) => canon, - Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf, - Err(e) => return Err(JoinPathError::Io(e)), - }; - - if !canon.starts_with(root.as_ref()) { - return Err(JoinPathError::Unsafe); - } - - Ok(canon) -} - -fn init_lua(lua: Lua) -> Result<(), InitLuaError> { - let http = lua.create_table().map_err(InitLuaError::CreateTable)?; - - lua.globals() - .set("http", http.clone()) - .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?; - - http.set( - "join_paths", - lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| { - let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec())); - let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec())); - - join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}"))) - }) - .map_err(InitLuaError::Function)?, - ) - .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?; - - let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?); - - chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?; - - Ok(()) -} - #[allow(unexpected_cfgs)] #[tokio::main(flavor = "local")] async fn main() -> ExitCode { let lua = Lua::new(); - if let Err(e) = init_lua(lua.clone()) { + if let Err(e) = lua::init(lua.clone()) { exit!("failed to init lua: {}", e); } @@ -235,7 +148,7 @@ async fn main() -> ExitCode { eprintln!("accepted connection from {addr}"); - let future = response(handlers.clone(), stream); + let future = handle(lua.clone(), handlers.clone(), stream); tokio::task::spawn_local(async { if let Err(e) = future.await { @@ -244,18 +157,3 @@ async fn main() -> ExitCode { }); } } - -#[cfg(test)] -mod test { - - use super::*; - - #[test] - fn test_path_traversal() { - let root = "/var/www"; - let get = "/../../bin/sh"; - let joined = join_path(root, get); - - assert!(matches!(joined, Err(JoinPathError::Unsafe))); - } -} |
