diff options
Diffstat (limited to 'src/lua.rs')
| -rw-r--r-- | src/lua.rs | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/src/lua.rs b/src/lua.rs new file mode 100644 index 0000000..d116433 --- /dev/null +++ b/src/lua.rs @@ -0,0 +1,176 @@ +use std::{ + ffi::OsString, + os::unix::ffi::OsStringExt, + path::{Component, Path, PathBuf}, +}; + +use mlua::{AnyUserData, Lua, UserData, Value}; + +use tokio::{fs, io}; + +#[derive(Debug, thiserror::Error)] +pub enum InitLuaError { + #[error("failed to create variable: {0}")] + CreateTable(mlua::Error), + + #[error("failed to create helper function: {0}")] + Function(mlua::Error), + + #[error("failed to set variable {1}: {0}")] + SetVar(mlua::Error, String), + + #[error("failed to load variable {1}: {0}")] + LoadVar(mlua::Error, String), + + #[error("failed to load config: {0}")] + LoadConfig(std::io::Error), + + #[error("failed to eval config: {0}")] + EvalConfig(mlua::Error), +} + +#[derive(Debug, thiserror::Error)] +enum JoinPathError { + #[error("unsafe path")] + Unsafe, + + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} + +#[derive(Debug)] +pub struct File(Option<fs::File>); + +#[derive(Debug)] +struct Metadata(std::fs::Metadata); + +impl File { + pub fn take(&mut self) -> Option<fs::File> { + self.0.take() + } +} + +impl UserData for File {} + +impl UserData for Metadata { + fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) { + fields.add_field_method_get("len", |_, this| Ok(this.0.len())); + } +} + +fn join_path<A: AsRef<Path>, B: AsRef<Path>>(root: A, rest: B) -> Result<PathBuf, JoinPathError> { + let mut pathbuf = root.as_ref().to_path_buf(); + + for component in rest.as_ref().components() { + match component { + Component::Normal(path) => { + pathbuf.push(path); + } + Component::ParentDir => { + pathbuf.pop(); + + if !pathbuf.starts_with(root.as_ref()) { + return Err(JoinPathError::Unsafe); + } + } + _ => continue, + } + } + + let canon = match pathbuf.canonicalize() { + Ok(canon) => canon, + Err(e) if matches!(e.kind(), std::io::ErrorKind::NotFound) => pathbuf, + Err(e) => return Err(JoinPathError::Io(e)), + }; + + if !canon.starts_with(root.as_ref()) { + return Err(JoinPathError::Unsafe); + } + + Ok(canon) +} + +pub(super) fn init(lua: Lua) -> Result<(), InitLuaError> { + let http = lua.create_table().map_err(InitLuaError::CreateTable)?; + + lua.globals() + .set("http", http.clone()) + .map_err(|e| InitLuaError::SetVar(e, "http".to_string()))?; + + http.set( + "join_paths", + lua.create_function(|_, (root, rest): (mlua::String, mlua::String)| { + let a = PathBuf::from(OsString::from_vec(root.as_bytes().to_vec())); + let b = PathBuf::from(OsString::from_vec(rest.as_bytes().to_vec())); + + join_path(&a, &b).map_err(|e| mlua::Error::runtime(format!("failed to join path: {e}"))) + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.join_paths".to_string()))?; + + let io = lua.create_table().map_err(InitLuaError::CreateTable)?; + + http.set("io", io.clone()) + .map_err(|e| InitLuaError::SetVar(e, "http.io".to_string()))?; + + io.set( + "fopen", + lua.create_async_function(|lua, path: mlua::String| async move { + let pathbuf = PathBuf::from(OsString::from_vec(path.as_bytes().to_vec())); + + match fs::File::open(&pathbuf).await { + Ok(f) => { + let data = lua.create_any_userdata(File(Some(f)))?; + + Ok(Value::UserData(data)) + } + Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => Ok(Value::Nil), + Err(e) => Err(mlua::Error::runtime(format!("failed to open file: {e}"))), + } + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.io.fopen".to_string()))?; + + io.set( + "metadata", + lua.create_async_function(|_, file: AnyUserData| async move { + let metadata = match file.borrow::<File>() { + Ok(f) => f.0.as_ref().unwrap().metadata().await.map_err(|e| { + mlua::Error::runtime(format!("failed to read file metadata: {e}")) + })?, + Err(_) => { + return Err(mlua::Error::runtime( + "failed to read file metadata, expected file type", + )); + } + }; + + Ok(Metadata(metadata)) + }) + .map_err(InitLuaError::Function)?, + ) + .map_err(|e| InitLuaError::SetVar(e, "http.io.stat".to_string()))?; + + let chunk = lua.load(std::fs::read_to_string("config.lua").map_err(InitLuaError::LoadConfig)?); + + chunk.eval::<()>().map_err(InitLuaError::EvalConfig)?; + + Ok(()) +} + +#[cfg(test)] +mod test { + + use super::*; + + #[test] + fn test_path_traversal() { + let root = "/var/www"; + let get = "/../../bin/sh"; + let joined = join_path(root, get); + + assert!(matches!(joined, Err(JoinPathError::Unsafe))); + } +} |
