From c6d95a3fd8653a25811c1da70c11e58d761f0ca2 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Wed, 27 Aug 2025 16:13:50 +0530 Subject: [PATCH 1/2] Graph Role core impl --- rust_core/src/dag.rs | 102 +++++++++++++++++++++++++- rust_core/src/graph_role.rs | 139 ++++++++++++++++++++++++++++++++++++ rust_core/src/lib.rs | 1 + rust_core/tests/test_dag.rs | 129 ++++++++++++++++++++++++++++++++- 4 files changed, 368 insertions(+), 3 deletions(-) create mode 100644 rust_core/src/graph_role.rs diff --git a/rust_core/src/dag.rs b/rust_core/src/dag.rs index bf4d31d..05d5b09 100644 --- a/rust_core/src/dag.rs +++ b/rust_core/src/dag.rs @@ -1,6 +1,8 @@ use petgraph::Direction; use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex}; use std::collections::{HashMap, HashSet, VecDeque}; +use crate::graph_role::GraphRoles; +use std::hash::{Hash, Hasher}; /// Directed Acyclic Graph (DAG) with optional latent variables. /// @@ -24,6 +26,89 @@ pub struct RustDAG { pub node_map: HashMap, pub reverse_node_map: HashMap, pub latents: HashSet, + pub roles: HashMap>, // New: role -> set of nodes +} + +impl PartialEq for RustDAG { + fn eq(&self, other: &Self) -> bool { + // Compare nodes + let self_nodes: HashSet<&String> = self.node_map.keys().collect(); + let other_nodes: HashSet<&String> = other.node_map.keys().collect(); + if self_nodes != other_nodes { + return false; + } + + // Compare edges + let self_edges: HashSet<(String, String)> = self.edges().into_iter().collect(); + let other_edges: HashSet<(String, String)> = other.edges().into_iter().collect(); + if self_edges != other_edges { + return false; + } + + // Compare latents + if self.latents != other.latents { + return false; + } + + // Compare roles + let mut self_roles: Vec<(String, Vec)> = self + .get_roles() + .into_iter() + .map(|role| { + let mut nodes = self.get_role(&role); + nodes.sort(); + (role, nodes) + }) + .collect(); + self_roles.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut other_roles: Vec<(String, Vec)> = other + .get_roles() + .into_iter() + .map(|role| { + let mut nodes = other.get_role(&role); + nodes.sort(); + (role, nodes) + }) + .collect(); + other_roles.sort_by(|a, b| a.0.cmp(&b.0)); + + self_roles == other_roles + } +} + +impl Eq for RustDAG {} + +impl Hash for RustDAG { + fn hash(&self, state: &mut H) { + // Hash nodes + let mut nodes: Vec<&String> = self.node_map.keys().collect(); + nodes.sort(); + nodes.hash(state); + + // Hash edges + let mut edges: Vec<(String, String)> = self.edges(); + edges.sort(); + edges.hash(state); + + // Hash latents + let mut latents: Vec<&String> = self.latents.iter().collect(); + latents.sort(); + latents.hash(state); + + // Hash roles + let mut roles: Vec<(String, Vec)> = self + .get_roles() + .into_iter() + .map(|role| { + let mut nodes: Vec = self.get_role(&role); + nodes.sort(); + (role, nodes) + }) + .collect(); + roles.sort_by(|a, b| a.0.cmp(&b.0)); + roles.hash(state); + } } impl RustDAG { @@ -37,6 +122,7 @@ impl RustDAG { node_map: HashMap::new(), reverse_node_map: HashMap::new(), latents: HashSet::new(), + roles: HashMap::new(), } } @@ -111,7 +197,6 @@ impl RustDAG { Ok(()) } - /// Add multiple directed edges. /// /// # Parameters @@ -593,3 +678,18 @@ impl RustDAG { self.graph.edge_count() } } + + +impl GraphRoles for RustDAG { + fn has_node(&self, node: &str) -> bool { + self.node_map.contains_key(node) + } + + fn get_roles_map(&self) -> &HashMap> { + &self.roles + } + + fn get_roles_map_mut(&mut self) -> &mut HashMap> { + &mut self.roles + } +} \ No newline at end of file diff --git a/rust_core/src/graph_role.rs b/rust_core/src/graph_role.rs new file mode 100644 index 0000000..122cb70 --- /dev/null +++ b/rust_core/src/graph_role.rs @@ -0,0 +1,139 @@ +use std::collections::{HashMap, HashSet}; + +/// Custom error type for graph operations. +#[derive(Debug)] +pub enum GraphError { + NodeNotFound(String), + InvalidOperation(String), +} + +impl std::fmt::Display for GraphError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GraphError::NodeNotFound(node) => write!(f, "Node '{}' not found in the graph", node), + GraphError::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg), + } + } +} + +impl std::error::Error for GraphError {} + +/// Trait for handling roles in graphs (similar to Python mixin). +pub trait GraphRoles: Clone { + /// Check if a node exists in the graph. + fn has_node(&self, node: &str) -> bool; + + /// Get immutable reference to the roles map. + fn get_roles_map(&self) -> &HashMap>; + + /// Get mutable reference to the roles map. + fn get_roles_map_mut(&mut self) -> &mut HashMap>; + + /// Get nodes with a specific role. + fn get_role(&self, role: &str) -> Vec { + self.get_roles_map() + .get(role) + .cloned() + .unwrap_or_default() + .into_iter() + .collect() + } + + /// Get list of all roles. + fn get_roles(&self) -> Vec { + self.get_roles_map().keys().cloned().collect() + } + + /// Get dict of roles to nodes. + fn get_role_dict(&self) -> HashMap> { + self.get_roles_map() + .iter() + .map(|(k, v)| (k.clone(), v.iter().cloned().collect())) + .collect() + } + + /// Check if a role exists and has nodes. + fn has_role(&self, role: &str) -> bool { + self.get_roles_map() + .get(role) + .map(|set| !set.is_empty()) + .unwrap_or(false) + } + + /// Assign role to variables. Modifies in place if `inplace=true`, otherwise returns a new graph. + fn with_role(&mut self, role: String, variables: Vec, inplace: bool) -> Result { + if inplace { + // Modify self directly + for var in &variables { + if !self.has_node(var) { + return Err(GraphError::NodeNotFound(var.clone())); + } + } + let roles_map = self.get_roles_map_mut(); + let entry = roles_map.entry(role).or_insert(HashSet::new()); + for var in variables { + entry.insert(var); + } + Ok(self.clone()) // Return self.clone() for consistency, but self is modified + } else { + // Create and modify a new graph + let mut new_graph = self.clone(); + for var in &variables { + if !new_graph.has_node(var) { + return Err(GraphError::NodeNotFound(var.clone())); + } + } + let roles_map = new_graph.get_roles_map_mut(); + let entry = roles_map.entry(role).or_insert(HashSet::new()); + for var in variables { + entry.insert(var); + } + Ok(new_graph) + } + } + + /// Remove role from variables (or all if None). Modifies in place if `inplace=true`, otherwise returns a new graph. + fn without_role(&mut self, role: &str, variables: Option>, inplace: bool) -> Self { + if inplace { + if let Some(set) = self.get_roles_map_mut().get_mut(role) { + if let Some(vars) = variables { + for var in vars { + set.remove(&var); + } + } else { + set.clear(); + } + } + self.clone() // Return self.clone() for consistency + } else { + let mut new_graph = self.clone(); + if let Some(set) = new_graph.get_roles_map_mut().get_mut(role) { + if let Some(vars) = variables { + for var in vars { + set.remove(&var); + } + } else { + set.clear(); + } + } + new_graph + } + } + + /// Validate causal structure (has exposure and outcome). + fn is_valid_causal_structure(&self) -> Result { + let has_exposure = self.has_role("exposure"); + let has_outcome = self.has_role("outcome"); + if !has_exposure || !has_outcome { + let mut problems = Vec::new(); + if !has_exposure { + problems.push("no 'exposure' role was defined"); + } + if !has_outcome { + problems.push("no 'outcome' role was defined"); + } + return Err(GraphError::InvalidOperation(problems.join(", and "))); + } + Ok(true) + } +} \ No newline at end of file diff --git a/rust_core/src/lib.rs b/rust_core/src/lib.rs index bcf918f..85d8fc4 100644 --- a/rust_core/src/lib.rs +++ b/rust_core/src/lib.rs @@ -2,6 +2,7 @@ pub mod dag; pub mod independencies; pub mod pdag; // Add PDAG.rs later if needed +pub mod graph_role; pub use dag::RustDAG; pub use pdag::RustPDAG; diff --git a/rust_core/tests/test_dag.rs b/rust_core/tests/test_dag.rs index 01ca853..74eceb4 100644 --- a/rust_core/tests/test_dag.rs +++ b/rust_core/tests/test_dag.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; - -use rust_core::RustDAG; +use rust_core::graph_role::GraphError; +use std::hash::{Hash, Hasher}; +use rust_core::{graph_role::GraphRoles, RustDAG}; #[test] fn test_add_nodes_and_edges() { @@ -257,3 +258,127 @@ fn test_minimal_dseparator_adjacent_error() { assert!(result.is_err()); assert!(result.unwrap_err().contains("adjacent")); } + +#[test] +fn test_role_hash_equality() { + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ("A".to_string(), "E".to_string()), + ("E".to_string(), "D".to_string()), + ], + None, + ) + .unwrap(); + + let dag1 = dag.clone(); + let dag2 = dag.clone(); + + assert_eq!(dag1, dag2); +} + +// Helper function to calculate hash value as u64 + fn calculate_hash(t: &T) -> u64 { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + t.hash(&mut hasher); + hasher.finish() + } + +#[test] + fn test_hash() { + // Create two identical M-bias DAGs: A -> B <- C, B -> D, E -> D + let mut dag1 = RustDAG::new(); + dag1.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ("B".to_string(), "D".to_string()), + ("E".to_string(), "D".to_string()), + ], + None, + ) + .unwrap(); + + let mut dag2 = RustDAG::new(); + dag2.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("C".to_string(), "B".to_string()), + ("B".to_string(), "D".to_string()), + ("E".to_string(), "D".to_string()), + ], + None, + ) + .unwrap(); + + // Test identical DAGs have same hash + assert_eq!(dag1, dag2); + assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); + + // Add exposure role to dag1 + dag1.with_role("exposure".to_string(), vec!["E".to_string()], true).unwrap(); + assert_ne!(dag1, dag2); + assert_ne!(calculate_hash(&dag1), calculate_hash(&dag2)); + + // Add exposure role to dag2 + dag2.with_role("exposure".to_string(), vec!["E".to_string()], true).unwrap(); + assert_eq!(dag1, dag2); + assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); + + // Add outcome role to dag1 + dag1.with_role("outcome".to_string(), vec!["D".to_string()], true).unwrap(); + assert_ne!(dag1, dag2); + assert_ne!(calculate_hash(&dag1), calculate_hash(&dag2)); + + // Add outcome role to dag2 + dag2.with_role("outcome".to_string(), vec!["D".to_string()], true).unwrap(); + assert_eq!(dag1, dag2); + assert_eq!(calculate_hash(&dag1), calculate_hash(&dag2)); + } + + + +#[test] +fn test_roles() { + let mut dag = RustDAG::new(); + dag.add_edges_from( + vec![ + ("A".to_string(), "B".to_string()), + ("B".to_string(), "C".to_string()), + ("C".to_string(), "D".to_string()), + ("A".to_string(), "E".to_string()), + ("E".to_string(), "D".to_string()), + ], + None, + ) + .unwrap(); + + // Test assigning role to existing node + let result = dag.with_role("exposure".to_string(), vec!["A".to_string()], true); + assert!(result.is_ok()); + assert!(dag.has_role("exposure")); + assert_eq!(dag.get_role("exposure"), vec!["A".to_string()]); + + // Test assigning role to non-existent node (should fail) + let result = dag.with_role("exposure".to_string(), vec!["Z".to_string()], true); + assert!(matches!(result, Err(GraphError::NodeNotFound(ref s)) if s == "Z")); + // Verify exposure role still contains only "A" + assert_eq!(dag.get_role("exposure"), vec!["A".to_string()]); + + // Test assigning outcome role + let result = dag.with_role("outcome".to_string(), vec!["D".to_string()], true); + assert!(result.is_ok()); + assert!(dag.has_role("outcome")); + assert_eq!(dag.get_role("outcome"), vec!["D".to_string()]); + + // Test valid causal structure + assert!(dag.is_valid_causal_structure().is_ok()); + + // Test removing role + dag.without_role("exposure", None, true); + assert!(!dag.has_role("exposure")); + assert!(dag.is_valid_causal_structure().is_err()); +} From 51c5d755478a2ff140a25d88587947fa54f96128 Mon Sep 17 00:00:00 2001 From: Mohammed Razak Date: Sat, 30 Aug 2025 22:55:42 +0530 Subject: [PATCH 2/2] base identification class impl --- rust_core/src/dag.rs | 21 +++- rust_core/src/graph.rs | 14 +++ rust_core/src/identification/base.rs | 20 ++++ rust_core/src/identification/mod.rs | 3 + rust_core/src/lib.rs | 2 + rust_core/tests/base_tests.rs | 168 +++++++++++++++++++++++++++ 6 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 rust_core/src/graph.rs create mode 100644 rust_core/src/identification/base.rs create mode 100644 rust_core/src/identification/mod.rs create mode 100644 rust_core/tests/base_tests.rs diff --git a/rust_core/src/dag.rs b/rust_core/src/dag.rs index 05d5b09..463c672 100644 --- a/rust_core/src/dag.rs +++ b/rust_core/src/dag.rs @@ -1,8 +1,9 @@ use petgraph::Direction; use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex}; use std::collections::{HashMap, HashSet, VecDeque}; -use crate::graph_role::GraphRoles; +use crate::graph_role::{GraphError, GraphRoles}; use std::hash::{Hash, Hasher}; +use crate::graph::Graph; /// Directed Acyclic Graph (DAG) with optional latent variables. /// @@ -679,6 +680,21 @@ impl RustDAG { } } +impl Graph for RustDAG { + fn nodes(&self) -> Vec { + self.node_map.keys().cloned().collect() + } + + fn parents(&self, node: &str) -> Result, GraphError> { + self.get_parents(node) + .map_err(|e| GraphError::NodeNotFound(e)) + } + + fn ancestors(&self, nodes: Vec) -> Result, GraphError> { + self.get_ancestors_of(nodes) + .map_err(|e| GraphError::NodeNotFound(e)) + } +} impl GraphRoles for RustDAG { fn has_node(&self, node: &str) -> bool { @@ -692,4 +708,5 @@ impl GraphRoles for RustDAG { fn get_roles_map_mut(&mut self) -> &mut HashMap> { &mut self.roles } -} \ No newline at end of file +} + diff --git a/rust_core/src/graph.rs b/rust_core/src/graph.rs new file mode 100644 index 0000000..488df17 --- /dev/null +++ b/rust_core/src/graph.rs @@ -0,0 +1,14 @@ +use crate::graph_role::GraphError; +use std::collections::HashSet; + +/// Trait for core graph operations required by causal graphs. +pub trait Graph { + /// Get all nodes in the graph. + fn nodes(&self) -> Vec; + + /// Get the parents of a node. + fn parents(&self, node: &str) -> Result, GraphError>; + + /// Get the ancestors of a set of nodes (including the nodes themselves). + fn ancestors(&self, nodes: Vec) -> Result, GraphError>; +} \ No newline at end of file diff --git a/rust_core/src/identification/base.rs b/rust_core/src/identification/base.rs new file mode 100644 index 0000000..0e34a85 --- /dev/null +++ b/rust_core/src/identification/base.rs @@ -0,0 +1,20 @@ +use crate::graph::Graph; +use crate::graph_role::{GraphError, GraphRoles}; + +/// Trait for causal identification algorithms, mirroring Python's BaseIdentification. +pub trait BaseIdentification { + /// Internal identification method to be implemented by specific algorithms. + fn _identify( + &self, + causal_graph: &T, + ) -> Result<(T, bool), GraphError>; + + /// Run the identification algorithm on a causal graph. + fn identify( + &self, + causal_graph: &T, + ) -> Result<(T, bool), GraphError> { + causal_graph.is_valid_causal_structure()?; + self._identify(causal_graph) + } +} \ No newline at end of file diff --git a/rust_core/src/identification/mod.rs b/rust_core/src/identification/mod.rs new file mode 100644 index 0000000..1d63ce9 --- /dev/null +++ b/rust_core/src/identification/mod.rs @@ -0,0 +1,3 @@ +pub mod base; + +pub use base::BaseIdentification; \ No newline at end of file diff --git a/rust_core/src/lib.rs b/rust_core/src/lib.rs index 85d8fc4..28d3c0f 100644 --- a/rust_core/src/lib.rs +++ b/rust_core/src/lib.rs @@ -1,8 +1,10 @@ // Re-export modules/structs from your core logic pub mod dag; pub mod independencies; +pub mod identification; pub mod pdag; // Add PDAG.rs later if needed pub mod graph_role; +pub mod graph; pub use dag::RustDAG; pub use pdag::RustPDAG; diff --git a/rust_core/tests/base_tests.rs b/rust_core/tests/base_tests.rs new file mode 100644 index 0000000..98ded7e --- /dev/null +++ b/rust_core/tests/base_tests.rs @@ -0,0 +1,168 @@ +use std::collections::{HashMap, HashSet}; +use rust_core::{graph::Graph, graph_role::{GraphError, GraphRoles}, identification::base, RustDAG}; + +use base::BaseIdentification; + + +/// A simple identification method that assigns the "adjustment" role to either +/// the first or last non-exposure, non-outcome node (alphabetically sorted), +/// based on the variant parameter. +#[derive(Debug, Clone)] +struct DummyIdentification { + variant: Option, +} +impl DummyIdentification { + fn new(variant: Option<&str>) -> Self { + DummyIdentification { + variant: variant.map(|s| s.to_string()), + } + } +} +impl BaseIdentification for DummyIdentification { + fn _identify(&self, causal_graph: &T) -> Result<(T, bool), GraphError> { + let mut mutable_graph = causal_graph.clone(); + match self.variant.as_deref() { + Some("first") => { + let non_role_nodes: HashSet = causal_graph + .nodes() + .into_iter() + .collect::>() + .difference( + &causal_graph + .get_role("exposure") + .into_iter() + .chain(causal_graph.get_role("outcome").into_iter()) + .collect::>(), + ) + .cloned() + .collect(); + let mut sorted_nodes: Vec = non_role_nodes.into_iter().collect(); + sorted_nodes.sort(); + if let Some(adjustment_node) = sorted_nodes.first() { + let identified_cg = mutable_graph.with_role( + "adjustment".to_string(), + vec![adjustment_node.clone()], + false, + )?; + Ok((identified_cg, true)) + } else { + Ok((causal_graph.clone(), false)) + } + } + Some("last") => { + let non_role_nodes: HashSet = causal_graph + .nodes() + .into_iter() + .collect::>() + .difference( + &causal_graph + .get_role("exposure") + .into_iter() + .chain(causal_graph.get_role("outcome").into_iter()) + .collect::>(), + ) + .cloned() + .collect(); + let mut sorted_nodes: Vec = non_role_nodes.into_iter().collect(); + sorted_nodes.sort(); + if let Some(adjustment_node) = sorted_nodes.last() { + let identified_cg = mutable_graph.with_role( + "adjustment".to_string(), + vec![adjustment_node.clone()], + false, + )?; + Ok((identified_cg, true)) + } else { + Ok((causal_graph.clone(), false)) + } + } + _ => Ok((causal_graph.clone(), false)), + } + } +} + + +#[test] +fn test_base_identification_first() { + let mut cg = RustDAG::new(); + cg.add_edges_from( + vec![ + ("U".to_string(), "X".to_string()), + ("X".to_string(), "M".to_string()), + ("M".to_string(), "Y".to_string()), + ("U".to_string(), "Y".to_string()), + ], + None, + ) + .unwrap(); + + cg.with_role("exposure".to_string(), vec!["X".to_string()], true).unwrap(); + cg.with_role("outcome".to_string(), vec!["Y".to_string()], true).unwrap(); + let identifier = DummyIdentification::new(Some("first")); + let (identified_cg, is_identified) = identifier.identify(&cg).unwrap(); + assert!(is_identified); + let expected_roles: HashMap> = [ + ("exposure".to_string(), vec!["X".to_string()]), + ("outcome".to_string(), vec!["Y".to_string()]), + ("adjustment".to_string(), vec!["M".to_string()]), + ] + .into_iter() + .collect(); + assert_eq!(identified_cg.get_role_dict(), expected_roles); +} + + +#[test] +fn test_base_identification_last() { + let mut cg = RustDAG::new(); + cg.add_edges_from( + vec![ + ("U".to_string(), "X".to_string()), + ("X".to_string(), "M".to_string()), + ("M".to_string(), "Y".to_string()), + ("U".to_string(), "Y".to_string()), + ], + None, + ) + .unwrap(); + cg.with_role("exposure".to_string(), vec!["X".to_string()], true).unwrap(); + cg.with_role("outcome".to_string(), vec!["Y".to_string()], true).unwrap(); + let identifier = DummyIdentification::new(Some("last")); + let (identified_cg, is_identified) = identifier.identify(&cg).unwrap(); + assert!(is_identified); + let expected_roles: HashMap> = [ + ("exposure".to_string(), vec!["X".to_string()]), + ("outcome".to_string(), vec!["Y".to_string()]), + ("adjustment".to_string(), vec!["U".to_string()]), + ] + .into_iter() + .collect(); + assert_eq!(identified_cg.get_role_dict(), expected_roles); +} + +#[test] +fn test_base_identification_gibberish() { + let mut cg = RustDAG::new(); + cg.add_edges_from( + vec![ + ("U".to_string(), "X".to_string()), + ("X".to_string(), "M".to_string()), + ("M".to_string(), "Y".to_string()), + ("U".to_string(), "Y".to_string()), + ], + None, + ) + .unwrap(); + cg.with_role("exposure".to_string(), vec!["X".to_string()], true).unwrap(); + cg.with_role("outcome".to_string(), vec!["Y".to_string()], true).unwrap(); + let identifier = DummyIdentification::new(Some("gibberish")); + let (identified_cg, is_identified) = identifier.identify(&cg).unwrap(); + assert!(!is_identified); + let expected_roles: HashMap> = [ + ("exposure".to_string(), vec!["X".to_string()]), + ("outcome".to_string(), vec!["Y".to_string()]), + ] + .into_iter() + .collect(); + assert_eq!(identified_cg.get_role_dict(), expected_roles); +} \ No newline at end of file