diff --git a/avulto.pyi b/avulto.pyi index 657bdce..e91c86e 100644 --- a/avulto.pyi +++ b/avulto.pyi @@ -7,6 +7,8 @@ class Path: """The final part of the path.""" stem: str + """The parent path.""" + parent: Path def __init__(self, value): """Returns a new path.""" @@ -122,8 +124,10 @@ class DME: This is slower than the default but provides more reflection information. """ - def paths_prefixed(self, prefix: Path | str) -> list[str]: - """Returns a list of paths with the given `prefix`.""" + def typesof(self, prefix: Path | str) -> list[str]: + """Returns a list of type paths with the given `prefix`.""" + def subtypesof(self, prefix: Path | str) -> list[str]: + """Returns a list of type paths with the given `prefix`, excluding `prefix` itself.""" def typedecl(self, path: Path | str) -> TypeDecl: """Return the type declaration of the given `path`.""" @@ -201,4 +205,4 @@ class DMI: Iterates over all icon states. """ def data_rgba8(self, rect:Rect) -> bytes: - """Return the byte data of the spritesheet in 8-bit RGBA.""" \ No newline at end of file + """Return the byte data of the spritesheet in 8-bit RGBA.""" diff --git a/src/dme.rs b/src/dme.rs index ca26347..657f46d 100644 --- a/src/dme.rs +++ b/src/dme.rs @@ -1,6 +1,5 @@ extern crate dreammaker; -use itertools::Itertools; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, @@ -24,6 +23,19 @@ pub struct Dme { filepath: Py, } +impl Dme { + fn collect_child_paths(&self, needle: &Path, strict: bool, out: &mut Vec) { + for ty in self.objtree.iter_types() { + if needle.internal_parent_of_string(&ty.path, strict) { + out.push(Path(ty.path.clone())); + } + } + + out.sort(); + out.dedup(); + } +} + #[pymethods] impl Dme { #[staticmethod] @@ -93,28 +105,36 @@ impl Dme { } } - fn paths_prefixed(&self, prefix: &Bound, py: Python<'_>) -> PyResult { + fn typesof(&self, prefix: &Bound, py: Python<'_>) -> PyResult { let mut out: Vec = Vec::new(); - if let Ok(path) = prefix.extract::() { - for ty in self.objtree.iter_types() { - if ty.path.starts_with(&path.0) { - out.push(Path(ty.path.clone())); - } - } + let prefix_path = if let Ok(path) = prefix.extract::() { + path } else if let Ok(pystr) = prefix.downcast::() { - for ty in self.objtree.iter_types() { - if ty.path.starts_with(&pystr.to_string()) { - out.push(Path(ty.path.clone())); - } - } - } + Path(pystr.to_string()) + } else { + return Err(PyRuntimeError::new_err(format!("invalid path {}", prefix))); + }; + self.collect_child_paths(&prefix_path, false, &mut out); - let mut x = out.into_iter().unique().collect::>(); - x.sort(); - Ok(PyList::new_bound(py, x.into_iter().map(|m| m.into_py(py))).to_object(py)) + Ok(PyList::new_bound(py, out.into_iter().map(|m| m.into_py(py))).to_object(py)) } + fn subtypesof(&self, prefix: &Bound, py: Python<'_>) -> PyResult { + let mut out: Vec = Vec::new(); + + let prefix_path = if let Ok(path) = prefix.extract::() { + path + } else if let Ok(pystr) = prefix.downcast::() { + Path(pystr.to_string()) + } else { + return Err(PyRuntimeError::new_err(format!("invalid path {}", prefix))); + }; + self.collect_child_paths(&prefix_path, true, &mut out); + + Ok(PyList::new_bound(py, out.into_iter().map(|m| m.into_py(py))).to_object(py)) + } + fn walk_proc( &self, path: &Bound, @@ -139,7 +159,7 @@ impl Dme { "cannot coerce proc name to string".to_string(), )); }; - + if let Some(ty) = objtree.find(&objpath) { if let Some(p) = ty.get_proc(&procname) { if let Some(ref code) = p.get().code { diff --git a/src/dme/walker.rs b/src/dme/walker.rs index e44b00f..0e1b056 100644 --- a/src/dme/walker.rs +++ b/src/dme/walker.rs @@ -59,9 +59,6 @@ impl Dme { dreammaker::ast::Term::Float(_) => { visit_name = "visit_Constant"; } - dreammaker::ast::Term::Ident(_) => { - visit_name = "visit_Constant"; - } dreammaker::ast::Term::String(_) => { visit_name = "visit_Constant"; } @@ -180,6 +177,9 @@ impl Dme { if let Some(in_list_expr) = &l.in_list { self.walk_expr(in_list_expr, walker, py)?; } + for stmt in l.block.iter() { + self.walk_stmt(&stmt.elem, walker, py)?; + } } } dreammaker::ast::Statement::ForRange(f) => { @@ -246,27 +246,27 @@ impl Dme { cases, default, } => { - if walker.hasattr("walk_Switch").unwrap() { - walker.call_method1("walk_Switch", (from_statement_to_node(stmt, py)?,))?; + if walker.hasattr("visit_Switch").unwrap() { + walker.call_method1("visit_Switch", (from_statement_to_node(stmt, py)?,))?; } else { self.walk_expr(input, walker, py)?; for (case_types, block) in cases.iter() { - if walker.hasattr("walk_Expr").unwrap() { + if walker.hasattr("visit_Expr").unwrap() { for case_elem in &case_types.elem { match case_elem { dreammaker::ast::Case::Exact(e) => { walker.call_method1( - "walk_Expr", + "visit_Expr", (from_expression_to_node(e, py)?,), )?; } dreammaker::ast::Case::Range(s, e) => { walker.call_method1( - "walk_Expr", + "visit_Expr", (from_expression_to_node(s, py)?,), )?; walker.call_method1( - "walk_Expr", + "visit_Expr", (from_expression_to_node(e, py)?,), )?; } diff --git a/src/path.rs b/src/path.rs index 6dfb626..46d89ca 100644 --- a/src/path.rs +++ b/src/path.rs @@ -5,7 +5,7 @@ use std::{ }; use pyo3::{ - exceptions::{PyRuntimeError, PyValueError}, pyclass::CompareOp, pymethods, types::{PyAnyMethods, PyString}, Bound, PyAny, PyErr, PyResult + exceptions::{PyRuntimeError, PyValueError}, pyclass::CompareOp, pymethods, types::{PyAnyMethods, PyString, PyStringMethods}, Bound, PyAny, PyErr, PyResult }; use pyo3::pyclass; @@ -25,6 +25,50 @@ impl From for String { } } +impl Path { + pub fn internal_child_of_string(&self, rhs: &String, strict: bool) -> bool { + if self.0.eq(rhs) { + return !strict; + } + if rhs == "/" { + return true; + } + let parts: Vec<&str> = self.0.split('/').collect(); + let oparts: Vec<&str> = rhs.split('/').collect(); + if parts.len() < oparts.len() { + return false; + } + for (a, b) in parts.iter().zip(oparts) { + if !(*a).eq(b) { + return false; + } + } + + return true; + } + + pub fn internal_parent_of_string(&self, rhs: &String, strict: bool) -> bool { + if self.0.eq(rhs){ + return !strict; + } + if self.0 == "/" { + return true; + } + let parts: Vec<&str> = self.0.split('/').collect(); + let oparts: Vec<&str> = rhs.split('/').collect(); + if parts.len() > oparts.len() { + return false; + } + for (a, b) in parts.iter().zip(oparts) { + if !(*a).eq(b) { + return false; + } + } + + return true; + } +} + #[pymethods] impl Path { #[new] @@ -38,44 +82,9 @@ impl Path { #[pyo3(signature = (other, strict=false))] fn child_of(&self, other: &Bound, strict: bool) -> PyResult { if let Ok(rhs) = other.extract::() { - if self.0 == rhs.0 { - return Ok(!strict); - } - if rhs.0 == "/" { - return Ok(true); - } - let parts: Vec<&str> = self.0.split('/').collect(); - let oparts: Vec<&str> = rhs.0.split('/').collect(); - if parts.len() < oparts.len() { - return Ok(false); - } - for (a, b) in parts.iter().zip(oparts) { - if !(*a).eq(b) { - return Ok(false); - } - } - - return Ok(true); + return Ok(self.internal_child_of_string(&rhs.0, strict)); } else if let Ok(rhs) = other.downcast::() { - let rs = rhs.to_string(); - if self.0 == rs { - return Ok(!strict); - } - if rs == "/" { - return Ok(true); - } - let sparts: Vec<&str> = self.0.split('/').collect(); - let soparts: Vec<&str> = rs.split('/').collect(); - if sparts.len() < soparts.len() { - return Ok(false); - } - for (a, b) in sparts.iter().zip(soparts) { - if *a != b { - return Ok(false); - } - } - - return Ok(true); + return Ok(self.internal_child_of_string(&rhs.to_cow().unwrap().to_string(), strict)); } Err(PyErr::new::("not a valid path")) @@ -84,46 +93,25 @@ impl Path { #[pyo3(signature = (other, strict=false))] fn parent_of(&self, other: &Bound, strict: bool) -> PyResult { if let Ok(rhs) = other.extract::() { - if self.0 == rhs.0 { - return Ok(!strict); - } - if self.0 == "/" { - return Ok(true); - } - let parts: Vec<&str> = self.0.split('/').collect(); - let oparts: Vec<&str> = rhs.0.split('/').collect(); - if parts.len() > oparts.len() { - return Ok(false); - } - for (a, b) in parts.iter().zip(oparts) { - if !(*a).eq(b) { - return Ok(false); - } - } - - return Ok(true); + return Ok(self.internal_parent_of_string(&rhs.0, strict)); } else if let Ok(rhs) = other.downcast::() { - let rs = rhs.to_string(); - if self.0 == rs { - return Ok(!strict); - } - let sparts: Vec<&str> = self.0.split('/').collect(); - let soparts: Vec<&str> = rs.split('/').collect(); - if sparts.len() > soparts.len() { - return Ok(false); - } - for (a, b) in sparts.iter().zip(soparts) { - if *a != b { - return Ok(false); - } - } - - return Ok(true); + return Ok(self.internal_parent_of_string(&rhs.to_cow().unwrap().to_string(), strict)); } Err(PyErr::new::("not a valid path")) } + #[getter] + fn get_parent(&self) -> PyResult { + if self.0 == "/" { + return Ok(self.clone()); + } + let mut parts: Vec<&str> = self.0.split('/').collect(); + let _ = parts.split_off(parts.len() - 1); + let parent = parts.join("/"); + Ok(Path(parent)) + } + #[getter] fn get_stem(&self) -> PyResult { let parts: Vec<&str> = self.0.split('/').collect(); diff --git a/tests/test_dme.py b/tests/test_dme.py index 93dc6e8..7bbad3b 100644 --- a/tests/test_dme.py +++ b/tests/test_dme.py @@ -14,13 +14,18 @@ def dme() -> DME: return DME.from_file(get_fixture_path("testenv.dme")) -def test_dme_paths_prefixed(dme: DME): - assert dme.paths_prefixed("/obj/foo") == [ +def test_dme_typesof(dme: DME): + assert dme.typesof("/obj/foo") == [ "/obj/foo", "/obj/foo/bar", "/obj/foo/baz", ] + assert dme.subtypesof("/obj/foo") == [ + "/obj/foo/bar", + "/obj/foo/baz", + ] + def test_missing_type(dme: DME): with pytest.raises(RuntimeError) as ex: