Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion rust_core/src/dag.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use petgraph::Direction;
use rustworkx_core::petgraph::graph::{DiGraph, NodeIndex};
use std::collections::{HashMap, HashSet, VecDeque};
use crate::graph_role::{GraphError, GraphRoles};
use std::hash::{Hash, Hasher};
use crate::graph::Graph;

/// Directed Acyclic Graph (DAG) with optional latent variables.
///
Expand All @@ -24,6 +27,89 @@ pub struct RustDAG {
pub node_map: HashMap<String, NodeIndex>,
pub reverse_node_map: HashMap<NodeIndex, String>,
pub latents: HashSet<String>,
pub roles: HashMap<String, HashSet<String>>, // New: role -> set of nodes
}

impl PartialEq for RustDAG {
fn eq(&self, other: &Self) -> bool {
// Compare nodes
let self_nodes: HashSet<&String> = self.node_map.keys().collect();
let other_nodes: HashSet<&String> = other.node_map.keys().collect();
if self_nodes != other_nodes {
return false;
}

// Compare edges
let self_edges: HashSet<(String, String)> = self.edges().into_iter().collect();
let other_edges: HashSet<(String, String)> = other.edges().into_iter().collect();
if self_edges != other_edges {
return false;
}

// Compare latents
if self.latents != other.latents {
return false;
}

// Compare roles
let mut self_roles: Vec<(String, Vec<String>)> = self
.get_roles()
.into_iter()
.map(|role| {
let mut nodes = self.get_role(&role);
nodes.sort();
(role, nodes)
})
.collect();
self_roles.sort_by(|a, b| a.0.cmp(&b.0));

let mut other_roles: Vec<(String, Vec<String>)> = other
.get_roles()
.into_iter()
.map(|role| {
let mut nodes = other.get_role(&role);
nodes.sort();
(role, nodes)
})
.collect();
other_roles.sort_by(|a, b| a.0.cmp(&b.0));

self_roles == other_roles
}
}

impl Eq for RustDAG {}

impl Hash for RustDAG {
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash nodes
let mut nodes: Vec<&String> = self.node_map.keys().collect();
nodes.sort();
nodes.hash(state);

// Hash edges
let mut edges: Vec<(String, String)> = self.edges();
edges.sort();
edges.hash(state);

// Hash latents
let mut latents: Vec<&String> = self.latents.iter().collect();
latents.sort();
latents.hash(state);

// Hash roles
let mut roles: Vec<(String, Vec<String>)> = self
.get_roles()
.into_iter()
.map(|role| {
let mut nodes: Vec<String> = self.get_role(&role);
nodes.sort();
(role, nodes)
})
.collect();
roles.sort_by(|a, b| a.0.cmp(&b.0));
roles.hash(state);
}
}

impl RustDAG {
Expand All @@ -37,6 +123,7 @@ impl RustDAG {
node_map: HashMap::new(),
reverse_node_map: HashMap::new(),
latents: HashSet::new(),
roles: HashMap::new(),
}
}

Expand Down Expand Up @@ -111,7 +198,6 @@ impl RustDAG {
Ok(())
}


/// Add multiple directed edges.
///
/// # Parameters
Expand Down Expand Up @@ -593,3 +679,34 @@ impl RustDAG {
self.graph.edge_count()
}
}

impl Graph for RustDAG {
fn nodes(&self) -> Vec<String> {
self.node_map.keys().cloned().collect()
}

fn parents(&self, node: &str) -> Result<Vec<String>, GraphError> {
self.get_parents(node)
.map_err(|e| GraphError::NodeNotFound(e))
}

fn ancestors(&self, nodes: Vec<String>) -> Result<HashSet<String>, GraphError> {
self.get_ancestors_of(nodes)
.map_err(|e| GraphError::NodeNotFound(e))
}
}

impl GraphRoles for RustDAG {
fn has_node(&self, node: &str) -> bool {
self.node_map.contains_key(node)
}

fn get_roles_map(&self) -> &HashMap<String, HashSet<String>> {
&self.roles
}

fn get_roles_map_mut(&mut self) -> &mut HashMap<String, HashSet<String>> {
&mut self.roles
}
}

14 changes: 14 additions & 0 deletions rust_core/src/graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::graph_role::GraphError;
use std::collections::HashSet;

/// Trait for core graph operations required by causal graphs.
pub trait Graph {
/// Get all nodes in the graph.
fn nodes(&self) -> Vec<String>;

/// Get the parents of a node.
fn parents(&self, node: &str) -> Result<Vec<String>, GraphError>;

/// Get the ancestors of a set of nodes (including the nodes themselves).
fn ancestors(&self, nodes: Vec<String>) -> Result<HashSet<String>, GraphError>;
}
139 changes: 139 additions & 0 deletions rust_core/src/graph_role.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::collections::{HashMap, HashSet};

/// Custom error type for graph operations.
#[derive(Debug)]
pub enum GraphError {
NodeNotFound(String),
InvalidOperation(String),
}

impl std::fmt::Display for GraphError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
GraphError::NodeNotFound(node) => write!(f, "Node '{}' not found in the graph", node),
GraphError::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg),
}
}
}

impl std::error::Error for GraphError {}

/// Trait for handling roles in graphs (similar to Python mixin).
pub trait GraphRoles: Clone {
/// Check if a node exists in the graph.
fn has_node(&self, node: &str) -> bool;

/// Get immutable reference to the roles map.
fn get_roles_map(&self) -> &HashMap<String, HashSet<String>>;

/// Get mutable reference to the roles map.
fn get_roles_map_mut(&mut self) -> &mut HashMap<String, HashSet<String>>;

/// Get nodes with a specific role.
fn get_role(&self, role: &str) -> Vec<String> {
self.get_roles_map()
.get(role)
.cloned()
.unwrap_or_default()
.into_iter()
.collect()
}

/// Get list of all roles.
fn get_roles(&self) -> Vec<String> {
self.get_roles_map().keys().cloned().collect()
}

/// Get dict of roles to nodes.
fn get_role_dict(&self) -> HashMap<String, Vec<String>> {
self.get_roles_map()
.iter()
.map(|(k, v)| (k.clone(), v.iter().cloned().collect()))
.collect()
}

/// Check if a role exists and has nodes.
fn has_role(&self, role: &str) -> bool {
self.get_roles_map()
.get(role)
.map(|set| !set.is_empty())
.unwrap_or(false)
}

/// Assign role to variables. Modifies in place if `inplace=true`, otherwise returns a new graph.
fn with_role(&mut self, role: String, variables: Vec<String>, inplace: bool) -> Result<Self, GraphError> {
if inplace {
// Modify self directly
for var in &variables {
if !self.has_node(var) {
return Err(GraphError::NodeNotFound(var.clone()));
}
}
let roles_map = self.get_roles_map_mut();
let entry = roles_map.entry(role).or_insert(HashSet::new());
for var in variables {
entry.insert(var);
}
Ok(self.clone()) // Return self.clone() for consistency, but self is modified
} else {
// Create and modify a new graph
let mut new_graph = self.clone();
for var in &variables {
if !new_graph.has_node(var) {
return Err(GraphError::NodeNotFound(var.clone()));
}
}
let roles_map = new_graph.get_roles_map_mut();
let entry = roles_map.entry(role).or_insert(HashSet::new());
for var in variables {
entry.insert(var);
}
Ok(new_graph)
}
}

/// Remove role from variables (or all if None). Modifies in place if `inplace=true`, otherwise returns a new graph.
fn without_role(&mut self, role: &str, variables: Option<Vec<String>>, inplace: bool) -> Self {
if inplace {
if let Some(set) = self.get_roles_map_mut().get_mut(role) {
if let Some(vars) = variables {
for var in vars {
set.remove(&var);
}
} else {
set.clear();
}
}
self.clone() // Return self.clone() for consistency
} else {
let mut new_graph = self.clone();
if let Some(set) = new_graph.get_roles_map_mut().get_mut(role) {
if let Some(vars) = variables {
for var in vars {
set.remove(&var);
}
} else {
set.clear();
}
}
new_graph
}
}

/// Validate causal structure (has exposure and outcome).
fn is_valid_causal_structure(&self) -> Result<bool, GraphError> {
let has_exposure = self.has_role("exposure");
let has_outcome = self.has_role("outcome");
if !has_exposure || !has_outcome {
let mut problems = Vec::new();
if !has_exposure {
problems.push("no 'exposure' role was defined");
}
if !has_outcome {
problems.push("no 'outcome' role was defined");
}
return Err(GraphError::InvalidOperation(problems.join(", and ")));
}
Ok(true)
}
}
20 changes: 20 additions & 0 deletions rust_core/src/identification/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::graph::Graph;
use crate::graph_role::{GraphError, GraphRoles};

/// Trait for causal identification algorithms, mirroring Python's BaseIdentification.
pub trait BaseIdentification {
/// Internal identification method to be implemented by specific algorithms.
fn _identify<T: Graph + GraphRoles>(
&self,
causal_graph: &T,
) -> Result<(T, bool), GraphError>;

/// Run the identification algorithm on a causal graph.
fn identify<T: Graph + GraphRoles>(
&self,
causal_graph: &T,
) -> Result<(T, bool), GraphError> {
causal_graph.is_valid_causal_structure()?;
self._identify(causal_graph)
}
}
3 changes: 3 additions & 0 deletions rust_core/src/identification/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod base;

pub use base::BaseIdentification;
3 changes: 3 additions & 0 deletions rust_core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Re-export modules/structs from your core logic
pub mod dag;
pub mod independencies;
pub mod identification;
pub mod pdag; // Add PDAG.rs later if needed
pub mod graph_role;
pub mod graph;

pub use dag::RustDAG;
pub use pdag::RustPDAG;
Expand Down
Loading
Loading