summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/handler.rs48
-rw-r--r--src/handler/lua.rs75
-rw-r--r--src/handler/staticfile.rs75
-rw-r--r--src/lua.rs176
-rw-r--r--src/main.rs190
-rw-r--r--src/response.rs51
6 files changed, 271 insertions, 344 deletions
diff --git a/src/handler.rs b/src/handler.rs
deleted file mode 100644
index eb4d64a..0000000
--- a/src/handler.rs
+++ /dev/null
@@ -1,48 +0,0 @@
-mod lua;
-mod staticfile;
-
-use mlua::{FromLua, Value};
-
-use crate::{
- handler::{lua::LuaResponse, staticfile::StaticFile},
- request::Request,
- response::Response,
-};
-
-pub trait Handle {
- fn handle(self, request: Request) -> impl Future<Output = Result<Response, Error>>;
-}
-
-#[derive(Debug, thiserror::Error)]
-pub enum Error {
- #[error("unsupported method")]
- Unsupported,
-
- #[error("static file handler error: {0}")]
- StaticFile(#[from] staticfile::Error),
-}
-
-#[derive(Debug)]
-pub enum Handler {
- StaticFile(StaticFile),
- Lua(LuaResponse),
-}
-
-impl FromLua for Handler {
- fn from_lua(value: Value, lua: &mlua::Lua) -> mlua::Result<Self> {
- match value {
- Value::Table(table) => match table.get::<String>("handler")?.as_str() {
- "staticfile" => Ok(Self::StaticFile(StaticFile::from_lua(
- Value::Table(table.clone()),
- lua,
- )?)),
- "lua" => Ok(Self::Lua(LuaResponse::from_lua(
- Value::Table(table.clone()),
- lua,
- )?)),
- _ => Err(mlua::Error::runtime("unknown handler")),
- },
- _ => Err(mlua::Error::runtime("expected table")),
- }
- }
-}
diff --git a/src/handler/lua.rs b/src/handler/lua.rs
deleted file mode 100644
index 377111f..0000000
--- a/src/handler/lua.rs
+++ /dev/null
@@ -1,75 +0,0 @@
-use std::{collections::HashMap, str::FromStr};
-
-use mlua::{FromLua, Lua, Value};
-
-use crate::{
- handler::{self, Handle},
- request::Request,
- response::{Body, Response, Status},
-};
-
-#[derive(Debug, Clone)]
-pub struct LuaResponse {
- content: Option<Vec<u8>>,
- status: Status,
- headers: HashMap<String, Vec<u8>>,
-}
-
-impl Handle for LuaResponse {
- async fn handle(self, _: Request) -> Result<Response, handler::Error> {
- Ok(Response::builder()
- .status(self.status)
- .headers(self.headers)
- .body(match self.content {
- Some(content) => Body::Buffer(content),
- None => Body::Empty,
- }))
- }
-}
-
-impl FromLua for LuaResponse {
- fn from_lua(value: Value, _: &Lua) -> mlua::Result<Self> {
- match value {
- Value::Table(table) => {
- let content = match table.get("content")? {
- Value::String(string) => Some(string.as_bytes().to_vec()),
- _ => None,
- };
-
- let status = match table.get("status")? {
- Value::String(string) => Status::from_str(&string.to_str()?).map_err(|e| {
- mlua::Error::RuntimeError(format!(
- "failed to parse status from string: {e}"
- ))
- })?,
- Value::Integer(i) => Status::from_repr(i as usize).ok_or_else(|| {
- mlua::Error::runtime(format!("failed to parse status from integer: {i}"))
- })?,
- _ => return Err(mlua::Error::runtime("invalid type when reading status")),
- };
-
- let headers: HashMap<String, Vec<u8>> = match table.get::<Value>("headers")? {
- Value::Table(table) => {
- let mut hashmap = HashMap::<String, Vec<u8>>::new();
-
- for result in table.pairs() {
- let (key, value): (String, mlua::BString) = result?;
-
- hashmap.insert(key, value.to_vec());
- }
-
- hashmap
- }
- _ => HashMap::new(),
- };
-
- Ok(Self {
- content,
- status,
- headers,
- })
- }
- _ => Err(mlua::Error::runtime("expected table")),
- }
- }
-}
diff --git a/src/handler/staticfile.rs b/src/handler/staticfile.rs
deleted file mode 100644
index f208528..0000000
--- a/src/handler/staticfile.rs
+++ /dev/null
@@ -1,75 +0,0 @@
-use crate::{
- handler::{self, Handle},
- request::{Method, Request},
- response::{Body, Response, Status},
-};
-
-use std::{ffi::OsString, os::unix::ffi::OsStringExt, path::PathBuf};
-
-use mlua::{FromLua, Value};
-
-use tokio::{
- fs::File,
- io::{self},
-};
-
-#[derive(Debug, thiserror::Error)]
-pub enum Error {
- #[error("io error: {0}")]
- Io(#[from] io::Error),
-}
-
-#[derive(Debug, Clone)]
-pub struct StaticFile {
- path: PathBuf,
- mime: String,
-}
-
-impl Handle for StaticFile {
- async fn handle(self, request: Request) -> Result<Response, handler::Error> {
- match request.method() {
- Method::Get | Method::Head => match File::open(&self.path).await {
- Ok(file) => {
- let metadata = file.metadata().await.map_err(Error::Io)?;
-
- if metadata.is_file() {
- Ok(Response::builder()
- .status(Status::Ok)
- .headers([
- ("content-length", format!("{}", metadata.len())),
- ("content-type", self.mime),
- ])
- .body(match request.method() {
- Method::Get => Body::File(file),
- Method::Head => Body::Empty,
- }))
- } else {
- Ok(Response::builder()
- .status(Status::NotFound)
- .headers([("content-length", "0")])
- .body(Body::Empty))
- }
- }
- Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Response::builder()
- .status(Status::NotFound)
- .headers([("content-length", "0")])
- .body(Body::Empty)),
- Err(e) => Err(Error::Io(e))?,
- },
- }
- }
-}
-
-impl FromLua for StaticFile {
- fn from_lua(value: mlua::Value, _: &mlua::Lua) -> mlua::Result<Self> {
- match value {
- Value::Table(table) => Ok(Self {
- path: PathBuf::from(OsString::from_vec(
- table.get::<mlua::String>("path")?.as_bytes().to_vec(),
- )),
- mime: table.get::<String>("mime")?,
- }),
- _ => Err(mlua::Error::runtime("expected table")),
- }
- }
-}
diff --git a/src/lua.rs b/src/lua.rs
new file mode 100644
index 0000000..d116433
--- /dev/null
+++ b/src/lua.rs
@@ -0,0 +1,176 @@
+use std::{
+ ffi::OsString,
+ os::unix::ffi::OsStringExt,
+ path::{Component, Path, PathBuf},
+};
+
+use mlua::{AnyUserData, Lua, UserData, Value};
+
+use tokio::{fs, io};
+
+#[derive(Debug, thiserror::Error)]
+pub enum InitLuaError {
+ #[error("failed to create variable: {0}")]
+ CreateTable(mlua::Error),
+
+ #[error("failed to create helper function: {0}")]
+ Function(mlua::Error),
+
+ #[error("failed to set variable {1}: {0}")]
+ SetVar(mlua::Error, String),
+
+ #[error("failed to load variable {1}: {0}")]
+ LoadVar(mlua::Error, String),
+
+ #[error("failed to load config: {0}")]
+ LoadConfig(std::io::Error),
+
+ #[error("failed to eval config: {0}")]
+ EvalConfig(mlua::Error),
+}
+
+#[derive(Debug, thiserror::Error)]
+enum JoinPathError {
+ #[error("unsafe path")]
+ Unsafe,
+
+ #[error("io error: {0}")]
+ Io(#[from] std::io::Error),
+}
+
+#[derive(Debug)]
+pub struct File(Option<fs::File>);
+
+#[derive(Debug)]
+struct Metadata(std::fs::Metadata);
+
+impl File {
+ pub fn take(&mut self) -> Option<fs::File> {
+ self.0.take()
+ }
+}
+
+impl UserData for File {}
+
+impl UserData for Metadata {
+ fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
+ fields.add_field_method_get("len", |_, this| Ok(this.0.len()));
+ }
+}
+
+fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> {
+ let mut pathbuf = root.as_ref().to_path_buf();
+
+ for component in rest.as_ref().components() {
+ match component {
+ Component::Normal(path) => {
+ pathbuf.push(path);
+ }
+ Component::ParentDir => {
+ pathbuf.pop();
+
+ if !pathbuf.starts_with(root.as_ref()) {
+ return Err(JoinPathError::Unsafe);
+ }
+ }
+ _ => continue,
+ }
+ }
+
+ let canon = match pathbuf.canonicalize() {
+ Ok(canon) => canon,
+ Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf,
+ Err(e) => return Err(JoinPathError::Io(e)),
+ };
+
+ if !canon.starts_with(root.as_ref()) {
+ return Err(JoinPathError::Unsafe);
+ }
+
+ Ok(canon)
+}
+
+pub(super) fn init(lua: Lua) -> Result<(), InitLuaError> {
+ let http = lua.create_table().map_err(InitLuaError::CreateTable)?;
+
+ lua.globals()
+ .set("http", http.clone())
+ .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?;
+
+ http.set(
+ "join_paths",
+ lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| {
+ let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec()));
+ let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec()));
+
+ join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}")))
+ })
+ .map_err(InitLuaError::Function)?,
+ )
+ .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?;
+
+ let io = lua.create_table().map_err(InitLuaError::CreateTable)?;
+
+ http.set("io", io.clone())
+ .map_err(|e| InitLuaError::SetVar(e, "http.io".to_string()))?;
+
+ io.set(
+ "fopen",
+ lua.create_async_function(|lua, path: mlua::String| async move {
+ let pathbuf = PathBuf::from(OsString::from_vec(path.as_bytes().to_vec()));
+
+ match fs::File::open(&pathbuf).await {
+ Ok(f) => {
+ let data = lua.create_any_userdata(File(Some(f)))?;
+
+ Ok(Value::UserData(data))
+ }
+ Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Value::Nil),
+ Err(e) => Err(mlua::Error::runtime(format!("failed to open file: {e}"))),
+ }
+ })
+ .map_err(InitLuaError::Function)?,
+ )
+ .map_err(|e| InitLuaError::SetVar(e, "http.io.fopen".to_string()))?;
+
+ io.set(
+ "metadata",
+ lua.create_async_function(|_, file: AnyUserData| async move {
+ let metadata = match file.borrow::<File>() {
+ Ok(f) => f.0.as_ref().unwrap().metadata().await.map_err(|e| {
+ mlua::Error::runtime(format!("failed to read file metadata: {e}"))
+ })?,
+ Err(_) => {
+ return Err(mlua::Error::runtime(
+ "failed to read file metadata, expected file type",
+ ));
+ }
+ };
+
+ Ok(Metadata(metadata))
+ })
+ .map_err(InitLuaError::Function)?,
+ )
+ .map_err(|e| InitLuaError::SetVar(e, "http.io.stat".to_string()))?;
+
+ let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?);
+
+ chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?;
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod test {
+
+ use super::*;
+
+ #[test]
+ fn test_path_traversal() {
+ let root = "/var/www";
+ let get = "/../../bin/sh";
+ let joined = join_path(root, get);
+
+ assert!(matches!(joined, Err(JoinPathError::Unsafe)));
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 401af02..cc19487 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,13 +1,8 @@
#![allow(dead_code)]
-use std::{
- ffi::OsString,
- os::unix::ffi::OsStringExt,
- path::{Component, Path, PathBuf},
- process::ExitCode,
-};
+use std::process::ExitCode;
-use mlua::{Function, Lua, Table};
+use mlua::{FromLua, Lua, Table, Value};
use tokio::{
io::{self, BufReader, BufWriter},
@@ -16,13 +11,12 @@ use tokio::{
use crate::{
client::Client,
- handler::{Handle, Handler},
request::Request,
response::{Response, Status},
};
mod client;
-mod handler;
+mod lua;
mod request;
mod response;
@@ -36,75 +30,44 @@ macro_rules! exit {
}
}
-#[derive(Debug, thiserror::Error)]
-enum HandleError {
- #[error("error reading handler function: {0}")]
- Function(mlua::Error),
-
- #[error("error calling handler: {0}")]
- InvokeHandler(mlua::Error),
-
- #[error("handler error: {0}")]
- Handler(#[from] handler::Error),
-}
+static INTERNAL_SERVER_ERROR_TEXT: &str = "internal server error";
#[derive(Debug, thiserror::Error)]
enum ResponseError {
- #[error("error reading request: {0}")]
- Request(request::Error),
-
- #[error("error sending response: {0}")]
- Response(io::Error),
-}
-
-#[derive(Debug, thiserror::Error)]
-enum InitLuaError {
- #[error("failed to create variable: {0}")]
- CreateTable(mlua::Error),
-
- #[error("failed to create helper function: {0}")]
- Function(mlua::Error),
+ #[error("failed to interpet lua handler return value as response: {0}")]
+ ResponseFromLua(mlua::Error),
- #[error("failed to set variable {1}: {0}")]
- SetVar(mlua::Error, String),
-
- #[error("failed to load variable {1}: {0}")]
- LoadVar(mlua::Error, String),
-
- #[error("failed to load config: {0}")]
- LoadConfig(std::io::Error),
-
- #[error("failed to eval config: {0}")]
- EvalConfig(mlua::Error),
+ #[error("failed to invoke handler: {0}")]
+ Handler(mlua::Error),
}
#[derive(Debug, thiserror::Error)]
-enum JoinPathError {
- #[error("unsafe path")]
- Unsafe,
+enum HandleError {
+ #[error("error reading request: {0}")]
+ ReadRequest(request::Error),
- #[error("io error: {0}")]
- Io(#[from] std::io::Error),
+ #[error("error sending response: {0}")]
+ SendResponse(io::Error),
}
-async fn handle(handlers: Table, request: Request) -> Result<Response, HandleError> {
- let method = request.method().to_string();
-
- let function = handlers
- .get::<Function>(method.as_str())
- .map_err(HandleError::Function)?;
-
- let handler = function
- .call::<Handler>(request.clone())
- .map_err(HandleError::InvokeHandler)?;
+async fn response(lua: Lua, handlers: Table, request: Request) -> Result<Response, ResponseError> {
+ match handlers
+ .get::<Value>(request.method().to_string().as_str())
+ .map_err(ResponseError::Handler)?
+ {
+ Value::Function(function) => {
+ let ret = function
+ .call_async(request)
+ .await
+ .map_err(ResponseError::Handler)?;
- match handler {
- Handler::StaticFile(staticfile) => Ok(staticfile.handle(request).await?),
- Handler::Lua(lua_response) => Ok(lua_response.handle(request).await?),
+ Response::from_lua(ret, &lua).map_err(ResponseError::ResponseFromLua)
+ }
+ _ => todo!(),
}
}
-async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseError> {
+async fn handle(lua: Lua, handlers: Table, stream: TcpStream) -> Result<(), HandleError> {
let mut client = {
let (r, w) = stream.into_split();
@@ -114,93 +77,43 @@ async fn response(handlers: Table, stream: TcpStream) -> Result<(), ResponseErro
while let Some(request) = client
.read_request()
.await
- .map_err(ResponseError::Request)?
+ .map_err(HandleError::ReadRequest)?
{
- let response = match handle(handlers.clone(), request).await {
- Ok(response) => response,
+ let response = match response(lua.clone(), handlers.clone(), request).await {
+ Ok(r) => r,
Err(e) => {
- eprintln!("failed to handle request: {e:?}");
+ eprintln!("{e}");
Response::builder()
.status(Status::InternalServerError)
- .headers([("content-length", "0")])
- .body(response::Body::Empty)
+ .headers([
+ (
+ "content-length",
+ INTERNAL_SERVER_ERROR_TEXT.len().to_string(),
+ ),
+ ("content-type", "text/plain".to_string()),
+ ])
+ .body(response::Body::Buffer(Vec::from(
+ INTERNAL_SERVER_ERROR_TEXT.to_string(),
+ )))
}
};
client
.send_response(response)
.await
- .map_err(ResponseError::Response)?;
+ .map_err(HandleError::SendResponse)?;
}
Ok(())
}
-fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> {
- let mut pathbuf = root.as_ref().to_path_buf();
-
- for component in rest.as_ref().components() {
- match component {
- Component::Normal(path) => {
- pathbuf.push(path);
- }
- Component::ParentDir => {
- pathbuf.pop();
-
- if !pathbuf.starts_with(root.as_ref()) {
- return Err(JoinPathError::Unsafe);
- }
- }
- _ => continue,
- }
- }
-
- let canon = match pathbuf.canonicalize() {
- Ok(canon) => canon,
- Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf,
- Err(e) => return Err(JoinPathError::Io(e)),
- };
-
- if !canon.starts_with(root.as_ref()) {
- return Err(JoinPathError::Unsafe);
- }
-
- Ok(canon)
-}
-
-fn init_lua(lua: Lua) -> Result<(), InitLuaError> {
- let http = lua.create_table().map_err(InitLuaError::CreateTable)?;
-
- lua.globals()
- .set("http", http.clone())
- .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?;
-
- http.set(
- "join_paths",
- lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| {
- let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec()));
- let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec()));
-
- join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}")))
- })
- .map_err(InitLuaError::Function)?,
- )
- .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?;
-
- let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?);
-
- chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?;
-
- Ok(())
-}
-
#[allow(unexpected_cfgs)]
#[tokio::main(flavor = "local")]
async fn main() -> ExitCode {
let lua = Lua::new();
- if let Err(e) = init_lua(lua.clone()) {
+ if let Err(e) = lua::init(lua.clone()) {
exit!("failed to init lua: {}", e);
}
@@ -235,7 +148,7 @@ async fn main() -> ExitCode {
eprintln!("accepted connection from {addr}");
- let future = response(handlers.clone(), stream);
+ let future = handle(lua.clone(), handlers.clone(), stream);
tokio::task::spawn_local(async {
if let Err(e) = future.await {
@@ -244,18 +157,3 @@ async fn main() -> ExitCode {
});
}
}
-
-#[cfg(test)]
-mod test {
-
- use super::*;
-
- #[test]
- fn test_path_traversal() {
- let root = "/var/www";
- let get = "/../../bin/sh";
- let joined = join_path(root, get);
-
- assert!(matches!(joined, Err(JoinPathError::Unsafe)));
- }
-}
diff --git a/src/response.rs b/src/response.rs
index 68189fb..f6f0fa0 100644
--- a/src/response.rs
+++ b/src/response.rs
@@ -2,6 +2,8 @@ mod builder;
use get::Get;
+use mlua::{FromLua, Lua, Value};
+
use std::collections::HashMap;
use tokio::{
@@ -13,6 +15,8 @@ use strum::{Display, EnumString, FromRepr};
use builder::Builder;
+use crate::lua;
+
#[allow(unreachable_patterns)]
#[derive(Debug, Clone, Copy, Display, EnumString, FromRepr)]
pub enum Status {
@@ -80,3 +84,50 @@ impl Response {
Ok(())
}
}
+
+impl FromLua for Response {
+ fn from_lua(value: Value, _: &Lua) -> mlua::Result<Self> {
+ match value {
+ Value::Table(table) => {
+ let status = match table.get("status")? {
+ Value::Integer(i) => match Status::from_repr(i as usize) {
+ Some(status) => status,
+ None => return Err(mlua::Error::runtime(format!("invalid status: {i}"))),
+ },
+ _ => return Err(mlua::Error::runtime("invalid status")),
+ };
+
+ let headers = match table.get("headers")? {
+ Value::Table(table) => {
+ let mut headers = HashMap::new();
+
+ for result in table.pairs() {
+ let (key, value): (String, mlua::String) = result?;
+
+ headers.insert(key, value.as_bytes().to_vec());
+ }
+
+ headers
+ }
+ Value::Nil => HashMap::new(),
+ _ => return Err(mlua::Error::runtime("invalid headers")),
+ };
+
+ let body = match table.get("body")? {
+ Value::String(string) => Body::Buffer(string.as_bytes().to_vec()),
+ Value::UserData(userdata) if userdata.is::<lua::File>() => {
+ Body::File(userdata.borrow_mut::<lua::File>().unwrap().take().unwrap())
+ }
+ _ => return Err(mlua::Error::runtime("invalid body")),
+ };
+
+ Ok(Self {
+ status,
+ headers,
+ body,
+ })
+ }
+ _ => Err(mlua::Error::runtime("invalid response, expected a table")),
+ }
+ }
+}