summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs190
1 files changed, 44 insertions, 146 deletions
diff --git a/src/main.rs b/src/main.rs
index 401af02..cc19487 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,13 +1,8 @@
#![allow(dead_code)]
-use std::{
- ffi::OsString,
- os::unix::ffi::OsStringExt,
- path::{Component, Path, PathBuf},
- process::ExitCode,
-};
+use std::process::ExitCode;
-use mlua::{Function, Lua, Table};
+use mlua::{FromLua, Lua, Table, Value};
use tokio::{
io::{self, BufReader, BufWriter},
@@ -16,13 +11,12 @@ use tokio::{
use crate::{
client::Client,
- handler::{Handle, Handler},
request::Request,
response::{Response, Status},
};
mod client;
-mod handler;
+mod lua;
mod request;
mod response;
@@ -36,75 +30,44 @@ macro_rules! exit {
}
}
-#[derive(Debug, thiserror::Error)]
-enum HandleError {
- #[error("error reading handler function: {0}")]
- Function(mlua::Error),
-
- #[error("error calling handler: {0}")]
- InvokeHandler(mlua::Error),
-
- #[error("handler error: {0}")]
- Handler(#[from] handler::Error),
-}
+static INTERNAL_SERVER_ERROR_TEXT: &str = "internal server error";
#[derive(Debug, thiserror::Error)]
enum ResponseError {
- #[error("error reading request: {0}")]
- Request(request::Error),
-
- #[error("error sending response: {0}")]
- Response(io::Error),
-}
-
-#[derive(Debug, thiserror::Error)]
-enum InitLuaError {
- #[error("failed to create variable: {0}")]
- CreateTable(mlua::Error),
-
- #[error("failed to create helper function: {0}")]
- Function(mlua::Error),
+ #[error("failed to interpet lua handler return value as response: {0}")]
+ ResponseFromLua(mlua::Error),
- #[error("failed to set variable {1}: {0}")]
- SetVar(mlua::Error, String),
-
- #[error("failed to load variable {1}: {0}")]
- LoadVar(mlua::Error, String),
-
- #[error("failed to load config: {0}")]
- LoadConfig(std::io::Error),
-
- #[error("failed to eval config: {0}")]
- EvalConfig(mlua::Error),
+ #[error("failed to invoke handler: {0}")]
+ Handler(mlua::Error),
}
#[derive(Debug, thiserror::Error)]
-enum JoinPathError {
- #[error("unsafe path")]
- Unsafe,
+enum HandleError {
+ #[error("error reading request: {0}")]
+ ReadRequest(request::Error),
- #[error("io error: {0}")]
- Io(#[from] std::io::Error),
+ #[error("error sending response: {0}")]
+ SendResponse(io::Error),
}
-async fn handle(handlers: Table, request: Request) -> Result<Response, HandleError> {
- let method = request.method().to_string();
-
- let function = handlers
- .get::<Function>(method.as_str())
- .map_err(HandleError::Function)?;
-
- let handler = function
- .call::<Handler>(request.clone())
- .map_err(HandleError::InvokeHandler)?;
+async fn response(lua: Lua, handlers: Table, request: Request) -> Result<Response, ResponseError> {
+ match handlers
+ .get::<Value>(request.method().to_string().as_str())
+ .map_err(ResponseError::Handler)?
+ {
+ Value::Function(function) => {
+ let ret = function
+ .call_async(request)
+ .await
+ .map_err(ResponseError::Handler)?;
- match handler {
- Handler::StaticFile(staticfile) => Ok(staticfile.handle(request).await?),
- Handler::Lua(lua_response) => Ok(lua_response.handle(request).await?),
+ Response::from_lua(ret, &lua).map_err(ResponseError::ResponseFromLua)
+ }
+ _ => todo!(),
}
}
-async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseError> {
+async fn handle(lua: Lua, handlers: Table, stream: TcpStream) -> Result<(), HandleError> {
let mut client = {
let (r, w) = stream.into_split();
@@ -114,93 +77,43 @@ async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseErro
while let Some(request) = client
.read_request()
.await
- .map_err(ResponseError::Request)?
+ .map_err(HandleError::ReadRequest)?
{
- let response = match handle(handlers.clone(), request).await {
- Ok(response) => response,
+ let response = match response(lua.clone(), handlers.clone(), request).await {
+ Ok(r) => r,
Err(e) => {
- eprintln!("failed to handle request: {e:?}");
+ eprintln!("{e}");
Response::builder()
.status(Status::InternalServerError)
- .headers([("content-length", "0")])
- .body(response::Body::Empty)
+ .headers([
+ (
+ "content-length",
+ INTERNAL_SERVER_ERROR_TEXT.len().to_string(),
+ ),
+ ("content-type", "text/plain".to_string()),
+ ])
+ .body(response::Body::Buffer(Vec::from(
+ INTERNAL_SERVER_ERROR_TEXT.to_string(),
+ )))
}
};
client
.send_response(response)
.await
- .map_err(ResponseError::Response)?;
+ .map_err(HandleError::SendResponse)?;
}
Ok(())
}
-fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> {
- let mut pathbuf = root.as_ref().to_path_buf();
-
- for component in rest.as_ref().components() {
- match component {
- Component::Normal(path) => {
- pathbuf.push(path);
- }
- Component::ParentDir => {
- pathbuf.pop();
-
- if !pathbuf.starts_with(root.as_ref()) {
- return Err(JoinPathError::Unsafe);
- }
- }
- _ => continue,
- }
- }
-
- let canon = match pathbuf.canonicalize() {
- Ok(canon) => canon,
- Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf,
- Err(e) => return Err(JoinPathError::Io(e)),
- };
-
- if !canon.starts_with(root.as_ref()) {
- return Err(JoinPathError::Unsafe);
- }
-
- Ok(canon)
-}
-
-fn init_lua(lua: Lua) -> Result<(), InitLuaError> {
- let http = lua.create_table().map_err(InitLuaError::CreateTable)?;
-
- lua.globals()
- .set("http", http.clone())
- .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?;
-
- http.set(
- "join_paths",
- lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| {
- let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec()));
- let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec()));
-
- join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}")))
- })
- .map_err(InitLuaError::Function)?,
- )
- .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?;
-
- let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?);
-
- chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?;
-
- Ok(())
-}
-
#[allow(unexpected_cfgs)]
#[tokio::main(flavor = "local")]
async fn main() -> ExitCode {
let lua = Lua::new();
- if let Err(e) = init_lua(lua.clone()) {
+ if let Err(e) = lua::init(lua.clone()) {
exit!("failed to init lua: {}", e);
}
@@ -235,7 +148,7 @@ async fn main() -> ExitCode {
eprintln!("accepted connection from {addr}");
- let future = response(handlers.clone(), stream);
+ let future = handle(lua.clone(), handlers.clone(), stream);
tokio::task::spawn_local(async {
if let Err(e) = future.await {
@@ -244,18 +157,3 @@ async fn main() -> ExitCode {
});
}
}
-
-#[cfg(test)]
-mod test {
-
- use super::*;
-
- #[test]
- fn test_path_traversal() {
- let root = "/var/www";
- let get = "/../../bin/sh";
- let joined = join_path(root, get);
-
- assert!(matches!(joined, Err(JoinPathError::Unsafe)));
- }
-}