diff --git a/crates/bevy_ecs/src/schedule/auto_insert_apply_deferred.rs b/crates/bevy_ecs/src/schedule/auto_insert_apply_deferred.rs index 64dc05c8feccb..00eae83a75a7f 100644 --- a/crates/bevy_ecs/src/schedule/auto_insert_apply_deferred.rs +++ b/crates/bevy_ecs/src/schedule/auto_insert_apply_deferred.rs @@ -1,9 +1,9 @@ use alloc::{boxed::Box, collections::BTreeSet, vec::Vec}; -use bevy_platform::collections::HashMap; +use bevy_platform::collections::{HashMap, HashSet}; use crate::{ - schedule::{SystemKey, SystemSetKey}, + schedule::{graph::Dag, SystemKey, SystemSetKey}, system::{IntoSystem, System}, world::World, }; @@ -72,11 +72,11 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass { &mut self, _world: &mut World, graph: &mut ScheduleGraph, - dependency_flattened: &mut DiGraph, + dependency_flattened: &mut Dag, ) -> Result<(), ScheduleBuildError> { - let mut sync_point_graph = dependency_flattened.clone(); - let topo = dependency_flattened - .toposort() + let mut sync_point_graph = dependency_flattened.graph().clone(); + let (topo, flat_dependency) = dependency_flattened + .toposort_and_graph() .map_err(ScheduleBuildError::FlatDependencySort)?; fn set_has_conditions(graph: &ScheduleGraph, set: SystemSetKey) -> bool { @@ -124,7 +124,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass { let mut distance_to_explicit_sync_node: HashMap = HashMap::default(); // Determine the distance for every node and collect the explicit sync points. - for &key in &topo { + for &key in topo { let (node_distance, mut node_needs_sync) = distances_and_pending_sync .get(&key) .copied() @@ -146,7 +146,7 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass { node_needs_sync = graph.systems[key].has_deferred(); } - for target in dependency_flattened.neighbors_directed(key, Direction::Outgoing) { + for target in flat_dependency.neighbors_directed(key, Direction::Outgoing) { let (target_distance, target_pending_sync) = distances_and_pending_sync.entry(target).or_default(); @@ -179,13 +179,13 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass { // Find any edges which have a different number of sync points between them and make sure // there is a sync point between them. - for &key in &topo { + for &key in topo { let (node_distance, _) = distances_and_pending_sync .get(&key) .copied() .unwrap_or_default(); - for target in dependency_flattened.neighbors_directed(key, Direction::Outgoing) { + for target in flat_dependency.neighbors_directed(key, Direction::Outgoing) { let (target_distance, _) = distances_and_pending_sync .get(&target) .copied() @@ -215,14 +215,14 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass { } } - *dependency_flattened = sync_point_graph; + **dependency_flattened = sync_point_graph; Ok(()) } fn collapse_set( &mut self, set: SystemSetKey, - systems: &[SystemKey], + systems: &HashSet, dependency_flattening: &DiGraph, ) -> impl Iterator { if systems.is_empty() { diff --git a/crates/bevy_ecs/src/schedule/error.rs b/crates/bevy_ecs/src/schedule/error.rs index cb566f7606dad..48826630c8964 100644 --- a/crates/bevy_ecs/src/schedule/error.rs +++ b/crates/bevy_ecs/src/schedule/error.rs @@ -1,13 +1,18 @@ use alloc::{format, string::String, vec::Vec}; +use bevy_platform::collections::HashSet; use core::fmt::Write as _; use thiserror::Error; use crate::{ - component::{ComponentId, Components}, + component::Components, schedule::{ - graph::{DiGraphToposortError, GraphNodeId}, - NodeId, ScheduleGraph, SystemKey, SystemSetKey, + graph::{ + DagCrossDependencyError, DagOverlappingGroupError, DagRedundancyError, + DiGraphToposortError, GraphNodeId, + }, + AmbiguousSystemConflictsWarning, ConflictingSystems, NodeId, ScheduleGraph, SystemKey, + SystemSetKey, SystemTypeSetAmbiguityError, }, world::World, }; @@ -26,14 +31,14 @@ pub enum ScheduleBuildError { #[error("Failed to topologically sort the flattened dependency graph: {0}")] FlatDependencySort(DiGraphToposortError), /// Tried to order a system (set) relative to a system set it belongs to. - #[error("`{0:?}` and `{1:?}` have both `in_set` and `before`-`after` relationships (these might be transitive). This combination is unsolvable as a system cannot run before or after a set it belongs to.")] - CrossDependency(NodeId, NodeId), + #[error("`{:?}` and `{:?}` have both `in_set` and `before`-`after` relationships (these might be transitive). This combination is unsolvable as a system cannot run before or after a set it belongs to.", .0.0, .0.1)] + CrossDependency(#[from] DagCrossDependencyError), /// Tried to order system sets that share systems. - #[error("`{0:?}` and `{1:?}` have a `before`-`after` relationship (which may be transitive) but share systems.")] - SetsHaveOrderButIntersect(SystemSetKey, SystemSetKey), + #[error("`{:?}` and `{:?}` have a `before`-`after` relationship (which may be transitive) but share systems.", .0.0, .0.1)] + SetsHaveOrderButIntersect(#[from] DagOverlappingGroupError), /// Tried to order a system (set) relative to all instances of some system function. - #[error("Tried to order against `{0:?}` in a schedule that has more than one `{0:?}` instance. `{0:?}` is a `SystemTypeSet` and cannot be used for ordering if ambiguous. Use a different set without this restriction.")] - SystemTypeSetAmbiguity(SystemSetKey), + #[error(transparent)] + SystemTypeSetAmbiguity(#[from] SystemTypeSetAmbiguityError), /// Tried to run a schedule before all of its systems have been initialized. #[error("Tried to run a schedule before all of its systems have been initialized.")] Uninitialized, @@ -56,7 +61,7 @@ pub enum ScheduleBuildWarning { /// [`LogLevel::Ignore`]: crate::schedule::LogLevel::Ignore /// [`LogLevel::Error`]: crate::schedule::LogLevel::Error #[error("The hierarchy of system sets contains redundant edges: {0:?}")] - HierarchyRedundancy(Vec<(NodeId, NodeId)>), + HierarchyRedundancy(#[from] DagRedundancyError), /// Systems with conflicting access have indeterminate run order. /// /// This warning is **disabled** by default, but can be enabled by setting @@ -66,8 +71,8 @@ pub enum ScheduleBuildWarning { /// [`ScheduleBuildSettings::ambiguity_detection`]: crate::schedule::ScheduleBuildSettings::ambiguity_detection /// [`LogLevel::Warn`]: crate::schedule::LogLevel::Warn /// [`LogLevel::Error`]: crate::schedule::LogLevel::Error - #[error("Systems with conflicting access have indeterminate run order: {0:?}")] - Ambiguity(Vec<(SystemKey, SystemKey, Vec)>), + #[error(transparent)] + Ambiguity(#[from] AmbiguousSystemConflictsWarning), } impl ScheduleBuildError { @@ -101,13 +106,13 @@ impl ScheduleBuildError { ScheduleBuildError::FlatDependencySort(DiGraphToposortError::Cycle(cycles)) => { Self::dependency_cycle_to_string(cycles, graph) } - ScheduleBuildError::CrossDependency(a, b) => { - Self::cross_dependency_to_string(a, b, graph) + ScheduleBuildError::CrossDependency(error) => { + Self::cross_dependency_to_string(error, graph) } - ScheduleBuildError::SetsHaveOrderButIntersect(a, b) => { + ScheduleBuildError::SetsHaveOrderButIntersect(DagOverlappingGroupError(a, b)) => { Self::sets_have_order_but_intersect_to_string(a, b, graph) } - ScheduleBuildError::SystemTypeSetAmbiguity(set) => { + ScheduleBuildError::SystemTypeSetAmbiguity(SystemTypeSetAmbiguityError(set)) => { Self::system_type_set_ambiguity_to_string(set, graph) } ScheduleBuildError::Uninitialized => Self::uninitialized_to_string(), @@ -144,7 +149,7 @@ impl ScheduleBuildError { } fn hierarchy_redundancy_to_string( - transitive_edges: &[(NodeId, NodeId)], + transitive_edges: &HashSet<(NodeId, NodeId)>, graph: &ScheduleGraph, ) -> String { let mut message = String::from("hierarchy contains redundant edge(s)"); @@ -195,7 +200,11 @@ impl ScheduleBuildError { message } - fn cross_dependency_to_string(a: &NodeId, b: &NodeId, graph: &ScheduleGraph) -> String { + fn cross_dependency_to_string( + error: &DagCrossDependencyError, + graph: &ScheduleGraph, + ) -> String { + let DagCrossDependencyError(a, b) = error; format!( "{} `{}` and {} `{}` have both `in_set` and `before`-`after` relationships (these might be transitive). \ This combination is unsolvable as a system cannot run before or after a set it belongs to.", @@ -227,7 +236,7 @@ impl ScheduleBuildError { } pub(crate) fn ambiguity_to_string( - ambiguities: &[(SystemKey, SystemKey, Vec)], + ambiguities: &ConflictingSystems, graph: &ScheduleGraph, components: &Components, ) -> String { @@ -236,7 +245,7 @@ impl ScheduleBuildError { "{n_ambiguities} pairs of systems with conflicting data access have indeterminate execution order. \ Consider adding `before`, `after`, or `ambiguous_with` relationships between these:\n", ); - let ambiguities = graph.conflicts_to_string(ambiguities, components); + let ambiguities = ambiguities.to_string(graph, components); for (name_a, name_b, conflicts) in ambiguities { writeln!(message, " -- {name_a} and {name_b}").unwrap(); @@ -261,10 +270,10 @@ impl ScheduleBuildWarning { /// replaced with their names. pub fn to_string(&self, graph: &ScheduleGraph, world: &World) -> String { match self { - ScheduleBuildWarning::HierarchyRedundancy(transitive_edges) => { + ScheduleBuildWarning::HierarchyRedundancy(DagRedundancyError(transitive_edges)) => { ScheduleBuildError::hierarchy_redundancy_to_string(transitive_edges, graph) } - ScheduleBuildWarning::Ambiguity(ambiguities) => { + ScheduleBuildWarning::Ambiguity(AmbiguousSystemConflictsWarning(ambiguities)) => { ScheduleBuildError::ambiguity_to_string(ambiguities, graph, world.components()) } } diff --git a/crates/bevy_ecs/src/schedule/graph/dag.rs b/crates/bevy_ecs/src/schedule/graph/dag.rs new file mode 100644 index 0000000000000..3913b057e799c --- /dev/null +++ b/crates/bevy_ecs/src/schedule/graph/dag.rs @@ -0,0 +1,995 @@ +use alloc::vec::Vec; +use core::{ + fmt::{self, Debug}, + hash::{BuildHasher, Hash}, + ops::{Deref, DerefMut}, +}; + +use bevy_platform::{ + collections::{HashMap, HashSet}, + hash::FixedHasher, +}; +use fixedbitset::FixedBitSet; +use thiserror::Error; + +use crate::{ + error::Result, + schedule::graph::{ + index, row_col, DiGraph, DiGraphToposortError, + Direction::{Incoming, Outgoing}, + GraphNodeId, UnGraph, + }, +}; + +/// A directed acyclic graph structure. +#[derive(Clone)] +pub struct Dag { + /// The underlying directed graph. + graph: DiGraph, + /// A cached topological ordering of the graph. This is recomputed when the + /// graph is modified, and is not valid when `dirty` is true. + toposort: Vec, + /// Whether the graph has been modified since the last topological sort. + dirty: bool, +} + +impl Dag { + /// Creates a new directed acyclic graph. + pub fn new() -> Self + where + S: Default, + { + Self::default() + } + + /// Read-only access to the underlying directed graph. + #[must_use] + pub fn graph(&self) -> &DiGraph { + &self.graph + } + + /// Mutable access to the underlying directed graph. Marks the graph as dirty. + #[must_use = "This function marks the graph as dirty, so it should be used."] + pub fn graph_mut(&mut self) -> &mut DiGraph { + self.dirty = true; + &mut self.graph + } + + /// Returns whether the graph is dirty (i.e., has been modified since the + /// last topological sort). + #[must_use] + pub fn is_dirty(&self) -> bool { + self.dirty + } + + /// Returns whether the graph is topologically sorted (i.e., not dirty). + #[must_use] + pub fn is_toposorted(&self) -> bool { + !self.dirty + } + + /// Ensures the graph is topologically sorted, recomputing the toposort if + /// the graph is dirty. + /// + /// # Errors + /// + /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be + /// topologically sorted. + pub fn ensure_toposorted(&mut self) -> Result<(), DiGraphToposortError> { + if self.dirty { + // recompute the toposort, reusing the existing allocation + self.toposort = self.graph.toposort(core::mem::take(&mut self.toposort))?; + self.dirty = false; + } + Ok(()) + } + + /// Returns the cached toposort if the graph is not dirty, otherwise returns + /// `None`. + #[must_use = "This method only returns a cached value and does not compute anything."] + pub fn get_toposort(&self) -> Option<&[N]> { + if self.dirty { + None + } else { + Some(&self.toposort) + } + } + + /// Returns a topological ordering of the graph, computing it if the graph + /// is dirty. + /// + /// # Errors + /// + /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be + /// topologically sorted. + pub fn toposort(&mut self) -> Result<&[N], DiGraphToposortError> { + self.ensure_toposorted()?; + Ok(&self.toposort) + } + + /// Returns both the topological ordering and the underlying graph, + /// computing the toposort if the graph is dirty. + /// + /// This function is useful to avoid multiple borrow issues when both + /// the graph and the toposort are needed. + /// + /// # Errors + /// + /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be + /// topologically sorted. + pub fn toposort_and_graph( + &mut self, + ) -> Result<(&[N], &DiGraph), DiGraphToposortError> { + self.ensure_toposorted()?; + Ok((&self.toposort, &self.graph)) + } + + /// Processes a DAG and computes its: + /// - transitive reduction (along with the set of removed edges) + /// - transitive closure + /// - reachability matrix (as a bitset) + /// - pairs of nodes connected by a path + /// - pairs of nodes not connected by a path + /// + /// The algorithm implemented comes from + /// ["On the calculation of transitive reduction-closure of orders"][1] by Habib, Morvan and Rampon. + /// + /// # Note + /// + /// If the DAG is dirty, this method will first attempt to topologically sort it. + /// + /// # Errors + /// + /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be + /// topologically sorted. + /// + /// [1]: https://doi.org/10.1016/0012-365X(93)90164-O + pub fn analyze(&mut self) -> Result, DiGraphToposortError> + where + S: Default, + { + let (toposort, graph) = self.toposort_and_graph()?; + Ok(DagAnalysis::new(graph, toposort)) + } + + /// Replaces the current graph with its transitive reduction based on the + /// provided analysis. + /// + /// # Note + /// + /// The given [`DagAnalysis`] must have been generated from this DAG. + pub fn remove_redundant_edges(&mut self, analysis: &DagAnalysis) + where + S: Clone, + { + // We don't need to mark the graph as dirty, since transitive reduction + // is guaranteed to have the same topological ordering as the original graph. + self.graph = analysis.transitive_reduction.clone(); + } + + /// Groups nodes in this DAG by a key type `K`, collecting value nodes `V` + /// under all of their ancestor key nodes. `num_groups` hints at the + /// expected number of groups, for memory allocation optimization. + /// + /// The node type `N` must be convertible into either a key type `K` or + /// a value type `V` via the [`TryInto`] trait. + /// + /// # Errors + /// + /// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be + /// topologically sorted. + pub fn group_by( + &mut self, + num_groups: usize, + ) -> Result, DiGraphToposortError> + where + N: TryInto, + K: Eq + Hash, + V: Clone + Eq + Hash, + S: BuildHasher + Default, + { + let (toposort, graph) = self.toposort_and_graph()?; + Ok(DagGroups::with_capacity(num_groups, graph, toposort)) + } + + /// Converts from one [`GraphNodeId`] type to another. If the conversion fails, + /// it returns the error from the target type's [`TryFrom`] implementation. + /// + /// Nodes must uniquely convert from `N` to `T` (i.e. no two `N` can convert + /// to the same `T`). The resulting DAG must be re-topologically sorted. + /// + /// # Errors + /// + /// If the conversion fails, it returns an error of type `N::Error`. + pub fn try_convert(self) -> Result, N::Error> + where + N: TryInto, + T: GraphNodeId, + S: Default, + { + Ok(Dag { + graph: self.graph.try_convert()?, + toposort: Vec::new(), + dirty: true, + }) + } +} + +impl Deref for Dag { + type Target = DiGraph; + + fn deref(&self) -> &Self::Target { + self.graph() + } +} + +impl DerefMut for Dag { + fn deref_mut(&mut self) -> &mut Self::Target { + self.graph_mut() + } +} + +impl Default for Dag { + fn default() -> Self { + Self { + graph: Default::default(), + toposort: Default::default(), + dirty: false, + } + } +} + +impl Debug for Dag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.dirty { + f.debug_struct("Dag") + .field("graph", &self.graph) + .field("dirty", &self.dirty) + .finish() + } else { + f.debug_struct("Dag") + .field("graph", &self.graph) + .field("toposort", &self.toposort) + .finish() + } + } +} + +/// Stores the results of a call to [`Dag::analyze`]. +pub struct DagAnalysis { + /// Boolean reachability matrix for the graph. + reachable: FixedBitSet, + /// Pairs of nodes that have a path connecting them. + connected: HashSet<(N, N), S>, + /// Pairs of nodes that don't have a path connecting them. + disconnected: HashSet<(N, N), S>, + /// Edges that are redundant because a longer path exists. + transitive_edges: HashSet<(N, N), S>, + /// Variant of the graph with no transitive edges. + transitive_reduction: DiGraph, + /// Variant of the graph with all possible transitive edges. + transitive_closure: DiGraph, +} + +impl DagAnalysis { + /// Analyzes the given DAG and computes various properties about it. + pub fn new(graph: &DiGraph, topological_order: &[N]) -> Self + where + S: Default, + { + if graph.node_count() == 0 { + return DagAnalysis::default(); + } + let n = graph.node_count(); + + // build a copy of the graph where the nodes and edges appear in topsorted order + let mut map = >::with_capacity_and_hasher(n, Default::default()); + let mut topsorted = DiGraph::::default(); + // iterate nodes in topological order + for (i, &node) in topological_order.iter().enumerate() { + map.insert(node, i); + topsorted.add_node(node); + // insert nodes as successors to their predecessors + for pred in graph.neighbors_directed(node, Incoming) { + topsorted.add_edge(pred, node); + } + } + + let mut reachable = FixedBitSet::with_capacity(n * n); + let mut connected = HashSet::default(); + let mut disconnected = HashSet::default(); + let mut transitive_edges = HashSet::default(); + let mut transitive_reduction = DiGraph::default(); + let mut transitive_closure = DiGraph::default(); + + let mut visited = FixedBitSet::with_capacity(n); + + // iterate nodes in topological order + for node in topsorted.nodes() { + transitive_reduction.add_node(node); + transitive_closure.add_node(node); + } + + // iterate nodes in reverse topological order + for a in topsorted.nodes().rev() { + let index_a = *map.get(&a).unwrap(); + // iterate their successors in topological order + for b in topsorted.neighbors_directed(a, Outgoing) { + let index_b = *map.get(&b).unwrap(); + debug_assert!(index_a < index_b); + if !visited[index_b] { + // edge is not redundant + transitive_reduction.add_edge(a, b); + transitive_closure.add_edge(a, b); + reachable.insert(index(index_a, index_b, n)); + + let successors = transitive_closure + .neighbors_directed(b, Outgoing) + .collect::>(); + for c in successors { + let index_c = *map.get(&c).unwrap(); + debug_assert!(index_b < index_c); + if !visited[index_c] { + visited.insert(index_c); + transitive_closure.add_edge(a, c); + reachable.insert(index(index_a, index_c, n)); + } + } + } else { + // edge is redundant + transitive_edges.insert((a, b)); + } + } + + visited.clear(); + } + + // partition pairs of nodes into "connected by path" and "not connected by path" + for i in 0..(n - 1) { + // reachable is upper triangular because the nodes were topsorted + for index in index(i, i + 1, n)..=index(i, n - 1, n) { + let (a, b) = row_col(index, n); + let pair = (topological_order[a], topological_order[b]); + if reachable[index] { + connected.insert(pair); + } else { + disconnected.insert(pair); + } + } + } + + // fill diagonal (nodes reach themselves) + // for i in 0..n { + // reachable.set(index(i, i, n), true); + // } + + DagAnalysis { + reachable, + connected, + disconnected, + transitive_edges, + transitive_reduction, + transitive_closure, + } + } + + /// Returns the reachability matrix. + pub fn reachable(&self) -> &FixedBitSet { + &self.reachable + } + + /// Returns the set of node pairs that are connected by a path. + pub fn connected(&self) -> &HashSet<(N, N), S> { + &self.connected + } + + /// Returns the list of node pairs that are not connected by a path. + pub fn disconnected(&self) -> &HashSet<(N, N), S> { + &self.disconnected + } + + /// Returns the list of redundant edges because a longer path exists. + pub fn transitive_edges(&self) -> &HashSet<(N, N), S> { + &self.transitive_edges + } + + /// Returns the transitive reduction of the graph. + pub fn transitive_reduction(&self) -> &DiGraph { + &self.transitive_reduction + } + + /// Returns the transitive closure of the graph. + pub fn transitive_closure(&self) -> &DiGraph { + &self.transitive_closure + } + + /// Checks if the graph has any redundant (transitive) edges. + /// + /// # Errors + /// + /// If there are redundant edges, returns a [`DagRedundancyError`] + /// containing the list of redundant edges. + pub fn check_for_redundant_edges(&self) -> Result<(), DagRedundancyError> + where + S: Clone, + { + if self.transitive_edges.is_empty() { + Ok(()) + } else { + Err(DagRedundancyError(self.transitive_edges.clone())) + } + } + + /// Checks if there are any pairs of nodes that have a path in both this + /// graph and another graph. + /// + /// # Errors + /// + /// Returns [`DagCrossDependencyError`] if any node pair is connected in + /// both graphs. + pub fn check_for_cross_dependencies( + &self, + other: &Self, + ) -> Result<(), DagCrossDependencyError> { + for &(a, b) in &self.connected { + if other.connected.contains(&(a, b)) || other.connected.contains(&(b, a)) { + return Err(DagCrossDependencyError(a, b)); + } + } + + Ok(()) + } + + /// Checks if any connected node pairs that are both keys have overlapping + /// groups. + /// + /// # Errors + /// + /// If there are overlapping groups, returns a [`DagOverlappingGroupError`] + /// containing the first pair of keys that have overlapping groups. + pub fn check_for_overlapping_groups( + &self, + groups: &DagGroups, + ) -> Result<(), DagOverlappingGroupError> + where + N: TryInto, + K: Eq + Hash, + V: Eq + Hash, + { + for &(a, b) in &self.connected { + let (Ok(a_key), Ok(b_key)) = (a.try_into(), b.try_into()) else { + continue; + }; + let a_group = groups.get(&a_key).unwrap(); + let b_group = groups.get(&b_key).unwrap(); + if !a_group.is_disjoint(b_group) { + return Err(DagOverlappingGroupError(a_key, b_key)); + } + } + Ok(()) + } +} + +impl Default for DagAnalysis { + fn default() -> Self { + Self { + reachable: Default::default(), + connected: Default::default(), + disconnected: Default::default(), + transitive_edges: Default::default(), + transitive_reduction: Default::default(), + transitive_closure: Default::default(), + } + } +} + +impl Debug for DagAnalysis { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DagAnalysis") + .field("reachable", &self.reachable) + .field("connected", &self.connected) + .field("disconnected", &self.disconnected) + .field("transitive_edges", &self.transitive_edges) + .field("transitive_reduction", &self.transitive_reduction) + .field("transitive_closure", &self.transitive_closure) + .finish() + } +} + +/// A mapping of keys to groups of values in a [`Dag`]. +pub struct DagGroups(HashMap, S>); + +impl DagGroups { + /// Groups nodes in this DAG by a key type `K`, collecting value nodes `V` + /// under all of their ancestor key nodes. + /// + /// The node type `N` must be convertible into either a key type `K` or + /// a value type `V` via the [`TryInto`] trait. + pub fn new(graph: &DiGraph, toposort: &[N]) -> Self + where + N: GraphNodeId + TryInto, + { + Self::with_capacity(0, graph, toposort) + } + + /// Groups nodes in this DAG by a key type `K`, collecting value nodes `V` + /// under all of their ancestor key nodes. `capacity` hints at the + /// expected number of groups, for memory allocation optimization. + /// + /// The node type `N` must be convertible into either a key type `K` or + /// a value type `V` via the [`TryInto`] trait. + pub fn with_capacity(capacity: usize, graph: &DiGraph, toposort: &[N]) -> Self + where + N: GraphNodeId + TryInto, + { + let mut groups: HashMap, S> = + HashMap::with_capacity_and_hasher(capacity, Default::default()); + + // Iterate in reverse topological order (bottom-up) so we hit children before parents. + for &id in toposort.iter().rev() { + let Ok(key) = id.try_into() else { + continue; + }; + + let mut children = HashSet::default(); + + for node in graph.neighbors_directed(id, Outgoing) { + match node.try_into() { + Ok(key) => { + // If the child is a key, this key inherits all of its children. + let key_children = groups.get(&key).unwrap(); + children.extend(key_children.iter().cloned()); + } + Err(value) => { + // If the child is a value, add it directly. + children.insert(value); + } + } + } + + groups.insert(key, children); + } + + Self(groups) + } +} + +impl DagGroups { + /// Converts the given [`Dag`] into a flattened version where key nodes + /// (`K`) are replaced by their associated value nodes (`V`). Edges to/from + /// key nodes are redirected to connect their value nodes instead. + /// + /// The `collapse_group` function is called for each key node to customize + /// how its group is collapsed. + /// + /// The resulting [`Dag`] will have only value nodes (`V`). + pub fn flatten( + &self, + dag: Dag, + mut collapse_group: impl FnMut(K, &HashSet, &Dag, &mut Vec<(N, N)>), + ) -> Dag + where + N: GraphNodeId + TryInto + From + From, + { + let mut flattening = dag; + let mut temp = Vec::new(); + + for (&key, values) in self.iter() { + // Call the user-provided function to handle collapsing the group. + collapse_group(key, values, &flattening, &mut temp); + + if values.is_empty() { + // Replace connections to the key node with connections between its neighbors. + for a in flattening.neighbors_directed(N::from(key), Incoming) { + for b in flattening.neighbors_directed(N::from(key), Outgoing) { + temp.push((a, b)); + } + } + } else { + // Redirect edges to/from the key node to connect to its value nodes. + for a in flattening.neighbors_directed(N::from(key), Incoming) { + for &value in values { + temp.push((a, N::from(value))); + } + } + for b in flattening.neighbors_directed(N::from(key), Outgoing) { + for &value in values { + temp.push((N::from(value), b)); + } + } + } + + // Remove the key node from the graph. + flattening.remove_node(N::from(key)); + // Add all previously collected edges. + for (a, b) in temp.drain(..) { + flattening.add_edge(a, b); + } + } + + // By this point, we should have removed all keys from the graph, + // so this conversion should never fail. + flattening + .try_convert::() + .unwrap_or_else(|n| unreachable!("Flattened graph has a leftover key {n:?}")) + } + + /// Converts an undirected graph by replacing key nodes (`K`) with their + /// associated value nodes (`V`). Edges connected to key nodes are + /// redirected to connect their value nodes instead. + /// + /// The resulting undirected graph will have only value nodes (`V`). + pub fn flatten_undirected(&self, graph: &UnGraph) -> UnGraph + where + N: GraphNodeId + TryInto, + { + let mut flattened = UnGraph::default(); + + for (lhs, rhs) in graph.all_edges() { + match (lhs.try_into(), rhs.try_into()) { + (Ok(lhs), Ok(rhs)) => { + // Normal edge between two value nodes + flattened.add_edge(lhs, rhs); + } + (Err(lhs_key), Ok(rhs)) => { + // Edge from a key node to a value node, expand to all values in the key's group + for &lhs in self.get(&lhs_key).into_iter().flatten() { + flattened.add_edge(lhs, rhs); + } + } + (Ok(lhs), Err(rhs_key)) => { + // Edge from a value node to a key node, expand to all values in the key's group + for &rhs in self.get(&rhs_key).into_iter().flatten() { + flattened.add_edge(lhs, rhs); + } + } + (Err(lhs_key), Err(rhs_key)) => { + // Edge between two key nodes, expand to all combinations of their value nodes + for &lhs in self.get(&lhs_key).into_iter().flatten() { + for &rhs in self.get(&rhs_key).into_iter().flatten() { + flattened.add_edge(lhs, rhs); + } + } + } + } + } + + flattened + } +} + +impl Deref for DagGroups { + type Target = HashMap, S>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DagGroups { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Default for DagGroups +where + S: BuildHasher + Default, +{ + fn default() -> Self { + Self(Default::default()) + } +} + +impl Debug for DagGroups { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("DagGroups").field(&self.0).finish() + } +} + +/// Error indicating that the graph has redundant edges. +#[derive(Error, Debug)] +#[error("DAG has redundant edges: {0:?}")] +pub struct DagRedundancyError(pub HashSet<(N, N), S>); + +/// Error indicating that two graphs both have a dependency between the same nodes. +#[derive(Error, Debug)] +#[error("DAG has a cross-dependency between nodes {0:?} and {1:?}")] +pub struct DagCrossDependencyError(pub N, pub N); + +/// Error indicating that the graph has overlapping groups between two keys. +#[derive(Error, Debug)] +#[error("DAG has overlapping groups between keys {0:?} and {1:?}")] +pub struct DagOverlappingGroupError(pub K, pub K); + +#[cfg(test)] +mod tests { + use core::ops::DerefMut; + + use crate::schedule::graph::{index, Dag, Direction, GraphNodeId, UnGraph}; + + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct TestNode(u32); + + impl GraphNodeId for TestNode { + type Adjacent = (TestNode, Direction); + type Edge = (TestNode, TestNode); + + fn kind(&self) -> &'static str { + "test node" + } + } + + #[test] + fn test_dirty_on_deref_mut() { + { + let mut dag = Dag::::new(); + dag.add_node(TestNode(1)); + assert!(dag.is_dirty()); + } + { + let mut dag = Dag::::new(); + dag.add_edge(TestNode(1), TestNode(2)); + assert!(dag.is_dirty()); + } + { + let mut dag = Dag::::new(); + dag.deref_mut(); + assert!(dag.is_dirty()); + } + { + let mut dag = Dag::::new(); + let _ = dag.graph_mut(); + assert!(dag.is_dirty()); + } + } + + #[test] + fn test_toposort() { + let mut dag = Dag::::new(); + dag.add_edge(TestNode(1), TestNode(2)); + dag.add_edge(TestNode(2), TestNode(3)); + dag.add_edge(TestNode(1), TestNode(3)); + + assert_eq!( + dag.toposort().unwrap(), + &[TestNode(1), TestNode(2), TestNode(3)] + ); + assert_eq!( + dag.get_toposort().unwrap(), + &[TestNode(1), TestNode(2), TestNode(3)] + ); + } + + #[test] + fn test_analyze() { + let mut dag1 = Dag::::new(); + dag1.add_edge(TestNode(1), TestNode(2)); + dag1.add_edge(TestNode(2), TestNode(3)); + dag1.add_edge(TestNode(1), TestNode(3)); // redundant edge + + let analysis1 = dag1.analyze().unwrap(); + + assert!(analysis1.reachable().contains(index(0, 1, 3))); + assert!(analysis1.reachable().contains(index(1, 2, 3))); + assert!(analysis1.reachable().contains(index(0, 2, 3))); + + assert!(analysis1.connected().contains(&(TestNode(1), TestNode(2)))); + assert!(analysis1.connected().contains(&(TestNode(2), TestNode(3)))); + assert!(analysis1.connected().contains(&(TestNode(1), TestNode(3)))); + + assert!(!analysis1 + .disconnected() + .contains(&(TestNode(2), TestNode(1)))); + assert!(!analysis1 + .disconnected() + .contains(&(TestNode(3), TestNode(2)))); + assert!(!analysis1 + .disconnected() + .contains(&(TestNode(3), TestNode(1)))); + + assert!(analysis1 + .transitive_edges() + .contains(&(TestNode(1), TestNode(3)))); + + assert!(analysis1.check_for_redundant_edges().is_err()); + + let mut dag2 = Dag::::new(); + dag2.add_edge(TestNode(3), TestNode(4)); + + let analysis2 = dag2.analyze().unwrap(); + + assert!(analysis2.check_for_redundant_edges().is_ok()); + assert!(analysis1.check_for_cross_dependencies(&analysis2).is_ok()); + + let mut dag3 = Dag::::new(); + dag3.add_edge(TestNode(1), TestNode(2)); + + let analysis3 = dag3.analyze().unwrap(); + + assert!(analysis1.check_for_cross_dependencies(&analysis3).is_err()); + + dag1.remove_redundant_edges(&analysis1); + let analysis1 = dag1.analyze().unwrap(); + assert!(analysis1.check_for_redundant_edges().is_ok()); + } + + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + enum Node { + Key(Key), + Value(Value), + } + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct Key(u32); + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct Value(u32); + + impl GraphNodeId for Node { + type Adjacent = (Node, Direction); + type Edge = (Node, Node); + + fn kind(&self) -> &'static str { + "node" + } + } + + impl TryInto for Node { + type Error = Value; + + fn try_into(self) -> Result { + match self { + Node::Key(k) => Ok(k), + Node::Value(v) => Err(v), + } + } + } + + impl TryInto for Node { + type Error = Key; + + fn try_into(self) -> Result { + match self { + Node::Value(v) => Ok(v), + Node::Key(k) => Err(k), + } + } + } + + impl GraphNodeId for Key { + type Adjacent = (Key, Direction); + type Edge = (Key, Key); + + fn kind(&self) -> &'static str { + "key" + } + } + + impl GraphNodeId for Value { + type Adjacent = (Value, Direction); + type Edge = (Value, Value); + + fn kind(&self) -> &'static str { + "value" + } + } + + impl From for Node { + fn from(key: Key) -> Self { + Node::Key(key) + } + } + + impl From for Node { + fn from(value: Value) -> Self { + Node::Value(value) + } + } + + #[test] + fn test_group_by() { + let mut dag = Dag::::new(); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10))); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20))); + dag.add_edge(Node::Key(Key(2)), Node::Key(Key(1))); + dag.add_edge(Node::Value(Value(10)), Node::Value(Value(11))); + + let groups = dag.group_by::(2).unwrap(); + assert_eq!(groups.len(), 2); + + let group_key1 = groups.get(&Key(1)).unwrap(); + assert!(group_key1.contains(&Value(10))); + assert!(group_key1.contains(&Value(11))); + + let group_key2 = groups.get(&Key(2)).unwrap(); + assert!(group_key2.contains(&Value(10))); + assert!(group_key2.contains(&Value(11))); + assert!(group_key2.contains(&Value(20))); + } + + #[test] + fn test_flatten() { + let mut dag = Dag::::new(); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10))); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21))); + dag.add_edge(Node::Value(Value(30)), Node::Key(Key(1))); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(40))); + + let groups = dag.group_by::(2).unwrap(); + let flattened = groups.flatten(dag, |_key, _values, _dag, _temp| {}); + + assert!(flattened.contains_node(Value(10))); + assert!(flattened.contains_node(Value(11))); + assert!(flattened.contains_node(Value(20))); + assert!(flattened.contains_node(Value(21))); + assert!(flattened.contains_node(Value(30))); + assert!(flattened.contains_node(Value(40))); + + assert!(flattened.contains_edge(Value(30), Value(10))); + assert!(flattened.contains_edge(Value(30), Value(11))); + assert!(flattened.contains_edge(Value(10), Value(40))); + assert!(flattened.contains_edge(Value(11), Value(40))); + } + + #[test] + fn test_flatten_undirected() { + let mut dag = Dag::::new(); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10))); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21))); + + let groups = dag.group_by::(2).unwrap(); + + let mut ungraph = UnGraph::::default(); + ungraph.add_edge(Node::Value(Value(10)), Node::Value(Value(11))); + ungraph.add_edge(Node::Key(Key(1)), Node::Value(Value(30))); + ungraph.add_edge(Node::Value(Value(40)), Node::Key(Key(2))); + ungraph.add_edge(Node::Key(Key(1)), Node::Key(Key(2))); + + let flattened = groups.flatten_undirected(&ungraph); + + assert!(flattened.contains_edge(Value(10), Value(11))); + assert!(flattened.contains_edge(Value(10), Value(30))); + assert!(flattened.contains_edge(Value(11), Value(30))); + assert!(flattened.contains_edge(Value(40), Value(20))); + assert!(flattened.contains_edge(Value(40), Value(21))); + assert!(flattened.contains_edge(Value(10), Value(20))); + assert!(flattened.contains_edge(Value(10), Value(21))); + assert!(flattened.contains_edge(Value(11), Value(20))); + assert!(flattened.contains_edge(Value(11), Value(21))); + } + + #[test] + fn test_overlapping_groups() { + let mut dag = Dag::::new(); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10))); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(11))); // overlap + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20))); + dag.add_edge(Node::Key(Key(1)), Node::Key(Key(2))); + + let groups = dag.group_by::(2).unwrap(); + let analysis = dag.analyze().unwrap(); + + let result = analysis.check_for_overlapping_groups(&groups); + assert!(result.is_err()); + } + + #[test] + fn test_disjoint_groups() { + let mut dag = Dag::::new(); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10))); + dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20))); + dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21))); + + let groups = dag.group_by::(2).unwrap(); + let analysis = dag.analyze().unwrap(); + + let result = analysis.check_for_overlapping_groups(&groups); + assert!(result.is_ok()); + } +} diff --git a/crates/bevy_ecs/src/schedule/graph/graph_map.rs b/crates/bevy_ecs/src/schedule/graph/graph_map.rs index 414ecaef9cf5d..4af139674a445 100644 --- a/crates/bevy_ecs/src/schedule/graph/graph_map.rs +++ b/crates/bevy_ecs/src/schedule/graph/graph_map.rs @@ -287,15 +287,17 @@ impl Graph /// /// # Errors /// - /// If the conversion fails, it returns an error of type `T::Error`. - pub fn try_into>(self) -> Result, T::Error> + /// If the conversion fails, it returns an error of type `N::Error`. + pub fn try_convert(self) -> Result, N::Error> where + N: TryInto, + T: GraphNodeId, S: Default, { // Converts the node key and every adjacency list entry from `N` to `T`. - fn try_convert_node>( + fn try_convert_node, T: GraphNodeId>( (key, adj): (N, Vec), - ) -> Result<(T, Vec), T::Error> { + ) -> Result<(T, Vec), N::Error> { let key = key.try_into()?; let adj = adj .into_iter() @@ -303,13 +305,13 @@ impl Graph let (id, dir) = node.into(); Ok(T::Adjacent::from((id.try_into()?, dir))) }) - .collect::>()?; + .collect::>()?; Ok((key, adj)) } // Unpacks the edge pair, converts the nodes from `N` to `T`, and repacks them. - fn try_convert_edge>( + fn try_convert_edge, T: GraphNodeId>( edge: N::Edge, - ) -> Result { + ) -> Result { let (a, b) = edge.into(); Ok(T::Edge::from((a.try_into()?, b.try_into()?))) } @@ -318,12 +320,12 @@ impl Graph .nodes .into_iter() .map(try_convert_node::) - .collect::>()?; + .collect::>()?; let edges = self .edges .into_iter() .map(try_convert_edge::) - .collect::>()?; + .collect::>()?; Ok(Graph { nodes, edges }) } } @@ -351,7 +353,7 @@ impl DiGraph { /// /// - If the graph contains a self-loop, returns [`DiGraphToposortError::Loop`]. /// - If the graph contains cycles, returns [`DiGraphToposortError::Cycle`]. - pub fn toposort(&self) -> Result, DiGraphToposortError> { + pub fn toposort(&self, mut scratch: Vec) -> Result, DiGraphToposortError> { // Check explicitly for self-edges. // `iter_sccs` won't report them as cycles because they still form components of one node. if let Some((node, _)) = self.all_edges().find(|(left, right)| left == right) { @@ -359,7 +361,9 @@ impl DiGraph { } // Tarjan's SCC algorithm returns elements in *reverse* topological order. - let mut top_sorted_nodes = Vec::with_capacity(self.node_count()); + scratch.clear(); + scratch.reserve_exact(self.node_count().saturating_sub(scratch.capacity())); + let mut top_sorted_nodes = scratch; let mut sccs_with_cycles = Vec::new(); for scc in self.iter_sccs() { diff --git a/crates/bevy_ecs/src/schedule/graph/mod.rs b/crates/bevy_ecs/src/schedule/graph/mod.rs index 292954b979bec..f0da67d88b3bd 100644 --- a/crates/bevy_ecs/src/schedule/graph/mod.rs +++ b/crates/bevy_ecs/src/schedule/graph/mod.rs @@ -4,16 +4,15 @@ use core::{ fmt::Debug, }; -use bevy_platform::collections::{HashMap, HashSet}; use bevy_utils::TypeIdMap; -use fixedbitset::FixedBitSet; - -use crate::schedule::set::*; +use crate::schedule::InternedSystemSet; +mod dag; mod graph_map; mod tarjan_scc; +pub use dag::*; pub use graph_map::{DiGraph, DiGraphToposortError, Direction, GraphNodeId, UnGraph}; /// Specifies what kind of edge should be added to the dependency graph. @@ -77,147 +76,3 @@ pub(crate) fn index(row: usize, col: usize, num_cols: usize) -> usize { pub(crate) fn row_col(index: usize, num_cols: usize) -> (usize, usize) { (index / num_cols, index % num_cols) } - -/// Stores the results of the graph analysis. -pub(crate) struct CheckGraphResults { - /// Boolean reachability matrix for the graph. - pub(crate) reachable: FixedBitSet, - /// Pairs of nodes that have a path connecting them. - pub(crate) connected: HashSet<(N, N)>, - /// Pairs of nodes that don't have a path connecting them. - pub(crate) disconnected: Vec<(N, N)>, - /// Edges that are redundant because a longer path exists. - pub(crate) transitive_edges: Vec<(N, N)>, - /// Variant of the graph with no transitive edges. - pub(crate) transitive_reduction: DiGraph, - /// Variant of the graph with all possible transitive edges. - // TODO: this will very likely be used by "if-needed" ordering - #[expect(dead_code, reason = "See the TODO above this attribute.")] - pub(crate) transitive_closure: DiGraph, -} - -impl Default for CheckGraphResults { - fn default() -> Self { - Self { - reachable: FixedBitSet::new(), - connected: HashSet::default(), - disconnected: Vec::new(), - transitive_edges: Vec::new(), - transitive_reduction: DiGraph::default(), - transitive_closure: DiGraph::default(), - } - } -} - -/// Processes a DAG and computes its: -/// - transitive reduction (along with the set of removed edges) -/// - transitive closure -/// - reachability matrix (as a bitset) -/// - pairs of nodes connected by a path -/// - pairs of nodes not connected by a path -/// -/// The algorithm implemented comes from -/// ["On the calculation of transitive reduction-closure of orders"][1] by Habib, Morvan and Rampon. -/// -/// [1]: https://doi.org/10.1016/0012-365X(93)90164-O -pub(crate) fn check_graph( - graph: &DiGraph, - topological_order: &[N], -) -> CheckGraphResults { - if graph.node_count() == 0 { - return CheckGraphResults::default(); - } - - let n = graph.node_count(); - - // build a copy of the graph where the nodes and edges appear in topsorted order - let mut map = >::with_capacity_and_hasher(n, Default::default()); - let mut topsorted = DiGraph::::default(); - // iterate nodes in topological order - for (i, &node) in topological_order.iter().enumerate() { - map.insert(node, i); - topsorted.add_node(node); - // insert nodes as successors to their predecessors - for pred in graph.neighbors_directed(node, Direction::Incoming) { - topsorted.add_edge(pred, node); - } - } - - let mut reachable = FixedBitSet::with_capacity(n * n); - let mut connected = >::default(); - let mut disconnected = Vec::new(); - - let mut transitive_edges = Vec::new(); - let mut transitive_reduction = DiGraph::default(); - let mut transitive_closure = DiGraph::default(); - - let mut visited = FixedBitSet::with_capacity(n); - - // iterate nodes in topological order - for node in topsorted.nodes() { - transitive_reduction.add_node(node); - transitive_closure.add_node(node); - } - - // iterate nodes in reverse topological order - for a in topsorted.nodes().rev() { - let index_a = *map.get(&a).unwrap(); - // iterate their successors in topological order - for b in topsorted.neighbors_directed(a, Direction::Outgoing) { - let index_b = *map.get(&b).unwrap(); - debug_assert!(index_a < index_b); - if !visited[index_b] { - // edge is not redundant - transitive_reduction.add_edge(a, b); - transitive_closure.add_edge(a, b); - reachable.insert(index(index_a, index_b, n)); - - let successors = transitive_closure - .neighbors_directed(b, Direction::Outgoing) - .collect::>(); - for c in successors { - let index_c = *map.get(&c).unwrap(); - debug_assert!(index_b < index_c); - if !visited[index_c] { - visited.insert(index_c); - transitive_closure.add_edge(a, c); - reachable.insert(index(index_a, index_c, n)); - } - } - } else { - // edge is redundant - transitive_edges.push((a, b)); - } - } - - visited.clear(); - } - - // partition pairs of nodes into "connected by path" and "not connected by path" - for i in 0..(n - 1) { - // reachable is upper triangular because the nodes were topsorted - for index in index(i, i + 1, n)..=index(i, n - 1, n) { - let (a, b) = row_col(index, n); - let pair = (topological_order[a], topological_order[b]); - if reachable[index] { - connected.insert(pair); - } else { - disconnected.push(pair); - } - } - } - - // fill diagonal (nodes reach themselves) - // for i in 0..n { - // reachable.set(index(i, i, n), true); - // } - - CheckGraphResults { - reachable, - connected, - disconnected, - transitive_edges, - transitive_reduction, - transitive_closure, - } -} diff --git a/crates/bevy_ecs/src/schedule/mod.rs b/crates/bevy_ecs/src/schedule/mod.rs index f66549921bdae..f3b687dfc5ba7 100644 --- a/crates/bevy_ecs/src/schedule/mod.rs +++ b/crates/bevy_ecs/src/schedule/mod.rs @@ -744,7 +744,7 @@ mod tests { let result = schedule.initialize(&mut world); assert!(matches!( result, - Err(ScheduleBuildError::CrossDependency(_, _)) + Err(ScheduleBuildError::CrossDependency(_)) )); } @@ -769,7 +769,7 @@ mod tests { // `foo` can't be in both `A` and `C` because they can't run at the same time. assert!(matches!( result, - Err(ScheduleBuildError::SetsHaveOrderButIntersect(_, _)) + Err(ScheduleBuildError::SetsHaveOrderButIntersect(_)) )); } @@ -1151,7 +1151,8 @@ mod tests { let ambiguities: Vec<_> = schedule .graph() - .conflicts_to_string(schedule.graph().conflicting_systems(), world.components()) + .conflicting_systems() + .to_string(schedule.graph(), world.components()) .map(|item| { ( item.0, @@ -1209,7 +1210,8 @@ mod tests { let ambiguities: Vec<_> = schedule .graph() - .conflicts_to_string(schedule.graph().conflicting_systems(), world.components()) + .conflicting_systems() + .to_string(schedule.graph(), world.components()) .map(|item| { ( item.0, diff --git a/crates/bevy_ecs/src/schedule/node.rs b/crates/bevy_ecs/src/schedule/node.rs index 25ce6f55ac757..f73d8f9d49c0f 100644 --- a/crates/bevy_ecs/src/schedule/node.rs +++ b/crates/bevy_ecs/src/schedule/node.rs @@ -1,22 +1,29 @@ -use alloc::{boxed::Box, vec::Vec}; -use bevy_utils::prelude::DebugName; +use alloc::{boxed::Box, collections::BTreeSet, string::String, vec::Vec}; use core::{ any::TypeId, fmt::{self, Debug}, - ops::{Index, IndexMut, Range}, + ops::{Deref, Index, IndexMut, Range}, }; -use bevy_platform::collections::HashMap; +use bevy_platform::collections::{HashMap, HashSet}; +use bevy_utils::prelude::DebugName; use slotmap::{new_key_type, Key, KeyData, SecondaryMap, SlotMap}; +use thiserror::Error; use crate::{ change_detection::{CheckChangeTicks, Tick}, + component::{ComponentId, Components}, prelude::{SystemIn, SystemSet}, - query::FilteredAccessSet, + query::{AccessConflicts, FilteredAccessSet}, schedule::{ - graph::{Direction, GraphNodeId}, - BoxedCondition, InternedSystemSet, + graph::{ + DagAnalysis, DagGroups, DiGraph, + Direction::{self, Incoming, Outgoing}, + GraphNodeId, UnGraph, + }, + BoxedCondition, InternedSystemSet, ScheduleGraph, }, + storage::SparseSetIndex, system::{ ReadOnlySystem, RunSystemError, ScheduleSystem, System, SystemParamValidationError, SystemStateFlags, @@ -24,7 +31,7 @@ use crate::{ world::{unsafe_world_cell::UnsafeWorldCell, DeferredWorld, World}, }; -/// A [`SystemWithAccess`] stored in a [`ScheduleGraph`](crate::schedule::ScheduleGraph). +/// A [`SystemWithAccess`] stored in a [`ScheduleGraph`]. pub(crate) struct SystemNode { pub(crate) inner: Option, } @@ -590,6 +597,60 @@ impl Systems { } } } + + /// Calculates the list of systems that conflict with each other based on + /// their access patterns. + /// + /// If the `Box<[ComponentId]>` is empty for a given pair of systems, then the + /// systems conflict on [`World`] access in general (e.g. one of them is + /// exclusive, or both systems have `Query`). + pub fn get_conflicting_systems( + &self, + flat_dependency_analysis: &DagAnalysis, + flat_ambiguous_with: &UnGraph, + ambiguous_with_all: &HashSet, + ignored_ambiguities: &BTreeSet, + ) -> ConflictingSystems { + let mut conflicting_systems: Vec<(_, _, Box<[_]>)> = Vec::new(); + for &(a, b) in flat_dependency_analysis.disconnected() { + if flat_ambiguous_with.contains_edge(a, b) + || ambiguous_with_all.contains(&NodeId::System(a)) + || ambiguous_with_all.contains(&NodeId::System(b)) + { + continue; + } + + let system_a = &self[a]; + let system_b = &self[b]; + if system_a.is_exclusive() || system_b.is_exclusive() { + conflicting_systems.push((a, b, Box::new([]))); + } else { + let access_a = &system_a.access; + let access_b = &system_b.access; + if !access_a.is_compatible(access_b) { + match access_a.get_conflicts(access_b) { + AccessConflicts::Individual(conflicts) => { + let conflicts: Box<[_]> = conflicts + .ones() + .map(ComponentId::get_sparse_set_index) + .filter(|id| !ignored_ambiguities.contains(id)) + .collect(); + if !conflicts.is_empty() { + conflicting_systems.push((a, b, conflicts)); + } + } + AccessConflicts::All => { + // there is no specific component conflicting, but the systems are overall incompatible + // for example 2 systems with `Query` + conflicting_systems.push((a, b, Box::new([]))); + } + } + } + } + } + + ConflictingSystems(conflicting_systems) + } } impl Index for Systems { @@ -610,6 +671,58 @@ impl IndexMut for Systems { } } +/// Pairs of systems that conflict with each other along with the components +/// they conflict on, which prevents them from running in parallel. If the +/// component list is empty, the systems conflict on [`World`] access in general +/// (e.g. one of them is exclusive, or both systems have `Query`). +#[derive(Clone, Debug, Default)] +pub struct ConflictingSystems(pub Vec<(SystemKey, SystemKey, Box<[ComponentId]>)>); + +impl ConflictingSystems { + /// Checks if there are any conflicting systems, returning [`Ok`] if there + /// are none, or an [`AmbiguousSystemConflictsWarning`] if there are. + pub fn check_if_not_empty(&self) -> Result<(), AmbiguousSystemConflictsWarning> { + if self.0.is_empty() { + Ok(()) + } else { + Err(AmbiguousSystemConflictsWarning(self.clone())) + } + } + + /// Converts the conflicting systems into an iterator of their system names + /// and the names of the components they conflict on. + pub fn to_string( + &self, + graph: &ScheduleGraph, + components: &Components, + ) -> impl Iterator)> { + self.iter().map(move |(system_a, system_b, conflicts)| { + let name_a = graph.get_node_name(&NodeId::System(*system_a)); + let name_b = graph.get_node_name(&NodeId::System(*system_b)); + + let conflict_names: Box<[_]> = conflicts + .iter() + .map(|id| components.get_name(*id).unwrap()) + .collect(); + + (name_a, name_b, conflict_names) + }) + } +} + +impl Deref for ConflictingSystems { + type Target = Vec<(SystemKey, SystemKey, Box<[ComponentId]>)>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// Error returned when there are ambiguous system conflicts detected. +#[derive(Error, Debug)] +#[error("Systems with conflicting access have indeterminate run order: {:?}", .0.0)] +pub struct AmbiguousSystemConflictsWarning(pub ConflictingSystems); + /// Container for system sets in a schedule. #[derive(Default)] pub struct SystemSets { @@ -764,6 +877,30 @@ impl SystemSets { } } } + + /// Ensures that there are no edges to system-type sets that have multiple + /// instances. + pub fn check_type_set_ambiguity( + &self, + set_systems: &DagGroups, + ambiguous_with: &UnGraph, + dependency: &DiGraph, + ) -> Result<(), SystemTypeSetAmbiguityError> { + for (&key, systems) in set_systems.iter() { + let set = &self[key]; + if set.system_type().is_some() { + let instances = systems.len(); + let ambiguous_with = ambiguous_with.edges(NodeId::Set(key)); + let before = dependency.edges_directed(NodeId::Set(key), Incoming); + let after = dependency.edges_directed(NodeId::Set(key), Outgoing); + let relations = before.count() + after.count() + ambiguous_with.count(); + if instances > 1 && relations > 0 { + return Err(SystemTypeSetAmbiguityError(key)); + } + } + } + Ok(()) + } } impl Index for SystemSets { @@ -780,6 +917,11 @@ impl Index for SystemSets { } } +/// Error returned when calling [`SystemSets::check_type_set_ambiguity`]. +#[derive(Error, Debug)] +#[error("Tried to order against `{0:?}` in a schedule that has more than one `{0:?}` instance. `{0:?}` is a `SystemTypeSet` and cannot be used for ordering if ambiguous. Use a different set without this restriction.")] +pub struct SystemTypeSetAmbiguityError(pub SystemSetKey); + #[cfg(test)] mod tests { use alloc::{boxed::Box, vec}; diff --git a/crates/bevy_ecs/src/schedule/pass.rs b/crates/bevy_ecs/src/schedule/pass.rs index 9072c3ae65c97..5a3d52f0767d3 100644 --- a/crates/bevy_ecs/src/schedule/pass.rs +++ b/crates/bevy_ecs/src/schedule/pass.rs @@ -1,9 +1,10 @@ use alloc::{boxed::Box, vec::Vec}; +use bevy_platform::collections::HashSet; use core::any::{Any, TypeId}; use super::{DiGraph, NodeId, ScheduleBuildError, ScheduleGraph}; use crate::{ - schedule::{SystemKey, SystemSetKey}, + schedule::{graph::Dag, SystemKey, SystemSetKey}, world::World, }; use bevy_utils::TypeIdMap; @@ -23,7 +24,7 @@ pub trait ScheduleBuildPass: Send + Sync + Debug + 'static { fn collapse_set( &mut self, set: SystemSetKey, - systems: &[SystemKey], + systems: &HashSet, dependency_flattening: &DiGraph, ) -> impl Iterator; @@ -32,7 +33,7 @@ pub trait ScheduleBuildPass: Send + Sync + Debug + 'static { &mut self, world: &mut World, graph: &mut ScheduleGraph, - dependency_flattened: &mut DiGraph, + dependency_flattened: &mut Dag, ) -> Result<(), ScheduleBuildError>; } @@ -42,13 +43,13 @@ pub(super) trait ScheduleBuildPassObj: Send + Sync + Debug { &mut self, world: &mut World, graph: &mut ScheduleGraph, - dependency_flattened: &mut DiGraph, + dependency_flattened: &mut Dag, ) -> Result<(), ScheduleBuildError>; fn collapse_set( &mut self, set: SystemSetKey, - systems: &[SystemKey], + systems: &HashSet, dependency_flattening: &DiGraph, dependencies_to_add: &mut Vec<(NodeId, NodeId)>, ); @@ -60,14 +61,14 @@ impl ScheduleBuildPassObj for T { &mut self, world: &mut World, graph: &mut ScheduleGraph, - dependency_flattened: &mut DiGraph, + dependency_flattened: &mut Dag, ) -> Result<(), ScheduleBuildError> { self.build(world, graph, dependency_flattened) } fn collapse_set( &mut self, set: SystemSetKey, - systems: &[SystemKey], + systems: &HashSet, dependency_flattening: &DiGraph, dependencies_to_add: &mut Vec<(NodeId, NodeId)>, ) { diff --git a/crates/bevy_ecs/src/schedule/schedule.rs b/crates/bevy_ecs/src/schedule/schedule.rs index b21ec09984b5d..1a8688306da26 100644 --- a/crates/bevy_ecs/src/schedule/schedule.rs +++ b/crates/bevy_ecs/src/schedule/schedule.rs @@ -11,7 +11,7 @@ use alloc::{ vec::Vec, }; use bevy_platform::collections::{HashMap, HashSet}; -use bevy_utils::{default, prelude::DebugName, TypeIdMap}; +use bevy_utils::{default, TypeIdMap}; use core::{ any::{Any, TypeId}, fmt::{Debug, Write}, @@ -33,7 +33,6 @@ use crate::{ world::World, }; -use crate::{query::AccessConflicts, storage::SparseSetIndex}; pub use stepping::Stepping; use Direction::{Incoming, Outgoing}; @@ -675,44 +674,6 @@ impl Schedule { } } -/// A directed acyclic graph structure. -pub struct Dag { - /// A directed graph. - graph: DiGraph, - /// A cached topological ordering of the graph. - topsort: Vec, -} - -impl Dag { - fn new() -> Self { - Self { - graph: DiGraph::default(), - topsort: Vec::new(), - } - } - - /// The directed graph of the stored systems, connected by their ordering dependencies. - pub fn graph(&self) -> &DiGraph { - &self.graph - } - - /// A cached topological ordering of the graph. - /// - /// The order is determined by the ordering dependencies between systems. - pub fn cached_topsort(&self) -> &[N] { - &self.topsort - } -} - -impl Default for Dag { - fn default() -> Self { - Self { - graph: Default::default(), - topsort: Default::default(), - } - } -} - /// Metadata for a [`Schedule`]. /// /// The order isn't optimized; calling `ScheduleGraph::build_schedule` will return a @@ -728,11 +689,11 @@ pub struct ScheduleGraph { /// Directed acyclic graph of the dependency (which systems/sets have to run before which other systems/sets) dependency: Dag, /// Map of systems in each set - set_systems: HashMap>, + set_systems: DagGroups, ambiguous_with: UnGraph, /// Nodes that are allowed to have ambiguous ordering relationship with any other systems. pub ambiguous_with_all: HashSet, - conflicting_systems: Vec<(SystemKey, SystemKey, Vec)>, + conflicting_systems: ConflictingSystems, anonymous_sets: usize, changed: bool, settings: ScheduleBuildSettings, @@ -747,10 +708,10 @@ impl ScheduleGraph { system_sets: SystemSets::default(), hierarchy: Dag::new(), dependency: Dag::new(), - set_systems: HashMap::new(), + set_systems: DagGroups::default(), ambiguous_with: UnGraph::default(), ambiguous_with_all: HashSet::default(), - conflicting_systems: Vec::new(), + conflicting_systems: ConflictingSystems::default(), anonymous_sets: 0, changed: false, settings: default(), @@ -778,7 +739,7 @@ impl ScheduleGraph { /// /// If the `Vec` is empty, the systems conflict on [`World`] access. /// Must be called after [`ScheduleGraph::build_schedule`] to be non-empty. - pub fn conflicting_systems(&self) -> &[(SystemKey, SystemKey, Vec)] { + pub fn conflicting_systems(&self) -> &ConflictingSystems { &self.conflicting_systems } @@ -883,9 +844,7 @@ impl ScheduleGraph { for previous_node in previous_nodes { for current_node in current_nodes { - self.dependency - .graph - .add_edge(*previous_node, *current_node); + self.dependency.add_edge(*previous_node, *current_node); for pass in self.passes.values_mut() { pass.add_dependency( @@ -959,7 +918,7 @@ impl ScheduleGraph { pub fn systems_in_set( &self, system_set: InternedSystemSet, - ) -> Result<&[SystemKey], ScheduleError> { + ) -> Result<&HashSet, ScheduleError> { if self.changed { return Err(ScheduleError::Uninitialized); } @@ -969,42 +928,25 @@ impl ScheduleGraph { .ok_or(ScheduleError::SetNotFound)?; self.set_systems .get(&system_set_id) - .map(Vec::as_slice) .ok_or(ScheduleError::SetNotFound) } fn add_edges_for_transitive_dependencies(&mut self, node: NodeId) { - let in_nodes: Vec<_> = self - .hierarchy - .graph - .neighbors_directed(node, Incoming) - .collect(); - let out_nodes: Vec<_> = self - .hierarchy - .graph - .neighbors_directed(node, Outgoing) - .collect(); + let in_nodes: Vec<_> = self.hierarchy.neighbors_directed(node, Incoming).collect(); + let out_nodes: Vec<_> = self.hierarchy.neighbors_directed(node, Outgoing).collect(); for &in_node in &in_nodes { for &out_node in &out_nodes { - self.hierarchy.graph.add_edge(in_node, out_node); + self.hierarchy.add_edge(in_node, out_node); } } - let in_nodes: Vec<_> = self - .dependency - .graph - .neighbors_directed(node, Incoming) - .collect(); - let out_nodes: Vec<_> = self - .dependency - .graph - .neighbors_directed(node, Outgoing) - .collect(); + let in_nodes: Vec<_> = self.dependency.neighbors_directed(node, Incoming).collect(); + let out_nodes: Vec<_> = self.dependency.neighbors_directed(node, Outgoing).collect(); for &in_node in &in_nodes { for &out_node in &out_nodes { - self.dependency.graph.add_edge(in_node, out_node); + self.dependency.add_edge(in_node, out_node); } } } @@ -1018,7 +960,7 @@ impl ScheduleGraph { let set = system_set.into_system_set(); let interned = set.intern(); // clone the keys out of the schedule as the systems are getting removed from self - let keys = self.systems_in_set(interned)?.to_vec(); + let keys = self.systems_in_set(interned)?.clone(); self.changed = true; @@ -1066,12 +1008,12 @@ impl ScheduleGraph { } } - fn remove_systems_by_keys(&mut self, keys: &[SystemKey]) { + fn remove_systems_by_keys(&mut self, keys: &HashSet) { for &key in keys { self.systems.remove(key); - self.hierarchy.graph.remove_node(key.into()); - self.dependency.graph.remove_node(key.into()); + self.hierarchy.remove_node(key.into()); + self.dependency.remove_node(key.into()); self.ambiguous_with.remove_node(key.into()); self.ambiguous_with_all.remove(&NodeId::from(key)); } @@ -1080,8 +1022,8 @@ impl ScheduleGraph { fn remove_set_by_key(&mut self, key: SystemSetKey) { self.system_sets.remove(key); self.set_systems.remove(&key); - self.hierarchy.graph.remove_node(key.into()); - self.dependency.graph.remove_node(key.into()); + self.hierarchy.remove_node(key.into()); + self.dependency.remove_node(key.into()); self.ambiguous_with.remove_node(key.into()); self.ambiguous_with_all.remove(&NodeId::from(key)); } @@ -1097,17 +1039,17 @@ impl ScheduleGraph { .. } = graph_info; - self.hierarchy.graph.add_node(id); - self.dependency.graph.add_node(id); + self.hierarchy.add_node(id); + self.dependency.add_node(id); for key in sets .into_iter() .map(|set| self.system_sets.get_key_or_insert(set)) { - self.hierarchy.graph.add_edge(NodeId::Set(key), id); + self.hierarchy.add_edge(NodeId::Set(key), id); // ensure set also appears in dependency graph - self.dependency.graph.add_node(NodeId::Set(key)); + self.dependency.add_node(NodeId::Set(key)); } for (kind, key, options) in @@ -1121,13 +1063,13 @@ impl ScheduleGraph { DependencyKind::Before => (id, NodeId::Set(key)), DependencyKind::After => (NodeId::Set(key), id), }; - self.dependency.graph.add_edge(lhs, rhs); + self.dependency.add_edge(lhs, rhs); for pass in self.passes.values_mut() { pass.add_dependency(lhs, rhs, &options); } // ensure set also appears in hierarchy graph - self.hierarchy.graph.add_node(NodeId::Set(key)); + self.hierarchy.add_node(NodeId::Set(key)); } match ambiguous_with { @@ -1167,272 +1109,114 @@ impl ScheduleGraph { ) -> Result<(SystemSchedule, Vec), ScheduleBuildError> { let mut warnings = Vec::new(); - // check hierarchy for cycles - self.hierarchy.topsort = self + // Check system set memberships for cycles. + let hierarchy_analysis = self .hierarchy - .graph - .toposort() + .analyze() .map_err(ScheduleBuildError::HierarchySort)?; - let hier_results = check_graph(&self.hierarchy.graph, &self.hierarchy.topsort); - if let Some(warning) = - self.optionally_check_hierarchy_conflicts(&hier_results.transitive_edges)? + // Check for redundant system set memberships, logging warnings or + // returning errors as configured. + if self.settings.hierarchy_detection != LogLevel::Ignore + && let Err(e) = hierarchy_analysis.check_for_redundant_edges() { - warnings.push(warning); + match self.settings.hierarchy_detection { + LogLevel::Error => return Err(ScheduleBuildWarning::HierarchyRedundancy(e).into()), + LogLevel::Warn => warnings.push(ScheduleBuildWarning::HierarchyRedundancy(e)), + LogLevel::Ignore => unreachable!(), + } } + // Remove redundant system set memberships. + self.hierarchy.remove_redundant_edges(&hierarchy_analysis); - // remove redundant edges - self.hierarchy.graph = hier_results.transitive_reduction; - - // check dependencies for cycles - self.dependency.topsort = self + // Check system and system set ordering dependencies for cycles. + let dependency_analysis = self .dependency - .graph - .toposort() + .analyze() .map_err(ScheduleBuildError::DependencySort)?; - // check for systems or system sets depending on sets they belong to - let dep_results = check_graph(&self.dependency.graph, &self.dependency.topsort); - self.check_for_cross_dependencies(&dep_results, &hier_results.connected)?; - - // map all system sets to their systems - // go in reverse topological order (bottom-up) for efficiency - let (set_systems, set_system_bitsets) = - self.map_sets_to_systems(&self.hierarchy.topsort, &self.hierarchy.graph); - self.check_order_but_intersect(&dep_results.connected, &set_system_bitsets)?; + // Check for systems or system sets with ordering dependencies on sets they belong to. + dependency_analysis.check_for_cross_dependencies(&hierarchy_analysis)?; - // check that there are no edges to system-type sets that have multiple instances - self.check_system_type_set_ambiguity(&set_systems)?; - - let mut dependency_flattened = self.get_dependency_flattened(&set_systems); + // Group all systems by the system sets they belong to. + self.set_systems = self + .hierarchy + .group_by(self.system_sets.len()) + .map_err(ScheduleBuildError::HierarchySort)?; + // Check for system sets that share systems but have an ordering dependency. + dependency_analysis.check_for_overlapping_groups(&self.set_systems)?; + + // Check that there are no edges to system-type sets that have multiple instances. + self.system_sets.check_type_set_ambiguity( + &self.set_systems, + &self.ambiguous_with, + &self.dependency, + )?; + + // Flatten system ordering dependencies by collapsing system sets. This + // means that if a system set has ordering dependencies, those + // dependencies are applied to all systems in the set. + let mut flat_dependency = + self.set_systems + .flatten(self.dependency.clone(), |set, systems, flattening, temp| { + for pass in self.passes.values_mut() { + pass.collapse_set(set, systems, flattening, temp); + } + }); - // modify graph with build passes + // Allow modification of the schedule graph by build passes. let mut passes = core::mem::take(&mut self.passes); for pass in passes.values_mut() { - pass.build(world, self, &mut dependency_flattened)?; + pass.build(world, self, &mut flat_dependency)?; } self.passes = passes; - // topsort - let mut dependency_flattened_dag = Dag { - topsort: dependency_flattened - .toposort() - .map_err(ScheduleBuildError::FlatDependencySort)?, - graph: dependency_flattened, - }; - - let flat_results = check_graph( - &dependency_flattened_dag.graph, - &dependency_flattened_dag.topsort, - ); - - // remove redundant edges - dependency_flattened_dag.graph = flat_results.transitive_reduction; - - // flatten: combine `in_set` with `ambiguous_with` information - let ambiguous_with_flattened = self.get_ambiguous_with_flattened(&set_systems); - self.set_systems = set_systems; - - // check for conflicts - let conflicting_systems = self.get_conflicting_systems( - &flat_results.disconnected, - &ambiguous_with_flattened, + // Check system ordering dependencies for cycles after collapsing sets + // and applying build passes. + let flat_dependency_analysis = flat_dependency + .analyze() + .map_err(ScheduleBuildError::FlatDependencySort)?; + + // Remove redundant system ordering dependencies. + flat_dependency.remove_redundant_edges(&flat_dependency_analysis); + + // Flatten accepted system ordering ambiguities by collapsing system sets. + // This means that if a system set is allowed to have ambiguous ordering + // with another set, all systems in the first set are allowed to have + // ambiguous ordering with all systems in the second set. + let flat_ambiguous_with = self.set_systems.flatten_undirected(&self.ambiguous_with); + + // Find all system ordering ambiguities, ignoring those that are accepted. + self.conflicting_systems = self.systems.get_conflicting_systems( + &flat_dependency_analysis, + &flat_ambiguous_with, + &self.ambiguous_with_all, ignored_ambiguities, ); - if let Some(warning) = self.optionally_check_conflicts(&conflicting_systems)? { - warnings.push(warning); + // If there are any ambiguities, log warnings or return errors as configured. + if self.settings.ambiguity_detection != LogLevel::Ignore + && let Err(e) = self.conflicting_systems.check_if_not_empty() + { + match self.settings.ambiguity_detection { + LogLevel::Error => return Err(ScheduleBuildWarning::Ambiguity(e).into()), + LogLevel::Warn => warnings.push(ScheduleBuildWarning::Ambiguity(e)), + LogLevel::Ignore => unreachable!(), + } } - self.conflicting_systems = conflicting_systems; // build the schedule Ok(( - self.build_schedule_inner(dependency_flattened_dag, hier_results.reachable), + self.build_schedule_inner(flat_dependency, hierarchy_analysis), warnings, )) } - /// Return a map from system set `NodeId` to a list of system `NodeId`s that are included in the set. - /// Also return a map from system set `NodeId` to a `FixedBitSet` of system `NodeId`s that are included in the set, - /// where the bitset order is the same as `self.systems` - fn map_sets_to_systems( - &self, - hierarchy_topsort: &[NodeId], - hierarchy_graph: &DiGraph, - ) -> ( - HashMap>, - HashMap>, - ) { - let mut set_systems: HashMap> = - HashMap::with_capacity_and_hasher(self.system_sets.len(), Default::default()); - let mut set_system_sets: HashMap> = - HashMap::with_capacity_and_hasher(self.system_sets.len(), Default::default()); - for &id in hierarchy_topsort.iter().rev() { - let NodeId::Set(set_key) = id else { - continue; - }; - - let mut systems = Vec::new(); - let mut system_set = HashSet::with_capacity(self.systems.len()); - - for child in hierarchy_graph.neighbors_directed(id, Outgoing) { - match child { - NodeId::System(key) => { - systems.push(key); - system_set.insert(key); - } - NodeId::Set(key) => { - let child_systems = set_systems.get(&key).unwrap(); - let child_system_set = set_system_sets.get(&key).unwrap(); - systems.extend_from_slice(child_systems); - system_set.extend(child_system_set.iter()); - } - } - } - - set_systems.insert(set_key, systems); - set_system_sets.insert(set_key, system_set); - } - (set_systems, set_system_sets) - } - - fn get_dependency_flattened( - &mut self, - set_systems: &HashMap>, - ) -> DiGraph { - // flatten: combine `in_set` with `before` and `after` information - // have to do it like this to preserve transitivity - let mut dependency_flattening = self.dependency.graph.clone(); - let mut temp = Vec::new(); - for (&set, systems) in set_systems { - for pass in self.passes.values_mut() { - pass.collapse_set(set, systems, &dependency_flattening, &mut temp); - } - if systems.is_empty() { - // collapse dependencies for empty sets - for a in dependency_flattening.neighbors_directed(NodeId::Set(set), Incoming) { - for b in dependency_flattening.neighbors_directed(NodeId::Set(set), Outgoing) { - temp.push((a, b)); - } - } - } else { - for a in dependency_flattening.neighbors_directed(NodeId::Set(set), Incoming) { - for &sys in systems { - temp.push((a, NodeId::System(sys))); - } - } - - for b in dependency_flattening.neighbors_directed(NodeId::Set(set), Outgoing) { - for &sys in systems { - temp.push((NodeId::System(sys), b)); - } - } - } - - dependency_flattening.remove_node(NodeId::Set(set)); - for (a, b) in temp.drain(..) { - dependency_flattening.add_edge(a, b); - } - } - - // By this point, we should have removed all system sets from the graph, - // so this conversion should never fail. - dependency_flattening - .try_into::() - .unwrap_or_else(|n| { - unreachable!( - "Flattened dependency graph has a leftover system set {}", - self.get_node_name(&NodeId::Set(n)) - ) - }) - } - - fn get_ambiguous_with_flattened( - &self, - set_systems: &HashMap>, - ) -> UnGraph { - let mut ambiguous_with_flattened = UnGraph::default(); - for (lhs, rhs) in self.ambiguous_with.all_edges() { - match (lhs, rhs) { - (NodeId::System(_), NodeId::System(_)) => { - ambiguous_with_flattened.add_edge(lhs, rhs); - } - (NodeId::Set(lhs), NodeId::System(_)) => { - for &lhs_ in set_systems.get(&lhs).unwrap_or(&Vec::new()) { - ambiguous_with_flattened.add_edge(NodeId::System(lhs_), rhs); - } - } - (NodeId::System(_), NodeId::Set(rhs)) => { - for &rhs_ in set_systems.get(&rhs).unwrap_or(&Vec::new()) { - ambiguous_with_flattened.add_edge(lhs, NodeId::System(rhs_)); - } - } - (NodeId::Set(lhs), NodeId::Set(rhs)) => { - for &lhs_ in set_systems.get(&lhs).unwrap_or(&Vec::new()) { - for &rhs_ in set_systems.get(&rhs).unwrap_or(&vec![]) { - ambiguous_with_flattened - .add_edge(NodeId::System(lhs_), NodeId::System(rhs_)); - } - } - } - } - } - - ambiguous_with_flattened - } - - fn get_conflicting_systems( - &self, - flat_results_disconnected: &Vec<(SystemKey, SystemKey)>, - ambiguous_with_flattened: &UnGraph, - ignored_ambiguities: &BTreeSet, - ) -> Vec<(SystemKey, SystemKey, Vec)> { - let mut conflicting_systems = Vec::new(); - for &(a, b) in flat_results_disconnected { - if ambiguous_with_flattened.contains_edge(NodeId::System(a), NodeId::System(b)) - || self.ambiguous_with_all.contains(&NodeId::System(a)) - || self.ambiguous_with_all.contains(&NodeId::System(b)) - { - continue; - } - - let system_a = &self.systems[a]; - let system_b = &self.systems[b]; - if system_a.is_exclusive() || system_b.is_exclusive() { - conflicting_systems.push((a, b, Vec::new())); - } else { - let access_a = &system_a.access; - let access_b = &system_b.access; - if !access_a.is_compatible(access_b) { - match access_a.get_conflicts(access_b) { - AccessConflicts::Individual(conflicts) => { - let conflicts: Vec<_> = conflicts - .ones() - .map(ComponentId::get_sparse_set_index) - .filter(|id| !ignored_ambiguities.contains(id)) - .collect(); - if !conflicts.is_empty() { - conflicting_systems.push((a, b, conflicts)); - } - } - AccessConflicts::All => { - // there is no specific component conflicting, but the systems are overall incompatible - // for example 2 systems with `Query` - conflicting_systems.push((a, b, Vec::new())); - } - } - } - } - } - - conflicting_systems - } - fn build_schedule_inner( &self, - dependency_flattened_dag: Dag, - hier_results_reachable: FixedBitSet, + flat_dependency: Dag, + hierarchy_analysis: DagAnalysis, ) -> SystemSchedule { - let dg_system_ids = dependency_flattened_dag.topsort; + let dg_system_ids = flat_dependency.get_toposort().unwrap().to_vec(); let dg_system_idx_map = dg_system_ids .iter() .cloned() @@ -1440,18 +1224,14 @@ impl ScheduleGraph { .map(|(i, id)| (id, i)) .collect::>(); - let hg_systems = self - .hierarchy - .topsort + let hierarchy_toposort = self.hierarchy.get_toposort().unwrap(); + let hg_systems = hierarchy_toposort .iter() .cloned() .enumerate() .filter_map(|(i, id)| Some((i, id.as_system()?))) .collect::>(); - - let (hg_set_with_conditions_idxs, hg_set_ids): (Vec<_>, Vec<_>) = self - .hierarchy - .topsort + let (hg_set_with_conditions_idxs, hg_set_ids): (Vec<_>, Vec<_>) = hierarchy_toposort .iter() .cloned() .enumerate() @@ -1465,20 +1245,18 @@ impl ScheduleGraph { let sys_count = self.systems.len(); let set_with_conditions_count = hg_set_ids.len(); - let hg_node_count = self.hierarchy.graph.node_count(); + let hg_node_count = self.hierarchy.node_count(); // get the number of dependencies and the immediate dependents of each system // (needed by multi_threaded executor to run systems in the correct order) let mut system_dependencies = Vec::with_capacity(sys_count); let mut system_dependents = Vec::with_capacity(sys_count); for &sys_key in &dg_system_ids { - let num_dependencies = dependency_flattened_dag - .graph + let num_dependencies = flat_dependency .neighbors_directed(sys_key, Incoming) .count(); - let dependents = dependency_flattened_dag - .graph + let dependents = flat_dependency .neighbors_directed(sys_key, Outgoing) .map(|dep_id| dg_system_idx_map[&dep_id]) .collect::>(); @@ -1495,7 +1273,7 @@ impl ScheduleGraph { let bitset = &mut systems_in_sets_with_conditions[i]; for &(col, sys_key) in &hg_systems { let idx = dg_system_idx_map[&sys_key]; - let is_descendant = hier_results_reachable[index(row, col, hg_node_count)]; + let is_descendant = hierarchy_analysis.reachable()[index(row, col, hg_node_count)]; bitset.set(idx, is_descendant); } } @@ -1510,7 +1288,7 @@ impl ScheduleGraph { .enumerate() .take_while(|&(_idx, &row)| row < col) { - let is_ancestor = hier_results_reachable[index(row, col, hg_node_count)]; + let is_ancestor = hierarchy_analysis.reachable()[index(row, col, hg_node_count)]; bitset.set(idx, is_ancestor); } } @@ -1697,7 +1475,6 @@ impl ScheduleGraph { format!( "({})", self.hierarchy - .graph .edges_directed(*id, Outgoing) // never get the sets of the members or this will infinite recurse when the report_sets setting is on. .map(|(_, member_id)| self.get_node_name_inner(&member_id, false)) @@ -1706,126 +1483,8 @@ impl ScheduleGraph { ) } - /// If [`ScheduleBuildSettings::hierarchy_detection`] is [`LogLevel::Ignore`] this check - /// is skipped. - fn optionally_check_hierarchy_conflicts( - &self, - transitive_edges: &[(NodeId, NodeId)], - ) -> Result, ScheduleBuildError> { - match ( - self.settings.hierarchy_detection, - !transitive_edges.is_empty(), - ) { - (LogLevel::Warn, true) => Ok(Some(ScheduleBuildWarning::HierarchyRedundancy( - transitive_edges.to_vec(), - ))), - (LogLevel::Error, true) => { - Err(ScheduleBuildWarning::HierarchyRedundancy(transitive_edges.to_vec()).into()) - } - _ => Ok(None), - } - } - - fn check_for_cross_dependencies( - &self, - dep_results: &CheckGraphResults, - hier_results_connected: &HashSet<(NodeId, NodeId)>, - ) -> Result<(), ScheduleBuildError> { - for &(a, b) in &dep_results.connected { - if hier_results_connected.contains(&(a, b)) || hier_results_connected.contains(&(b, a)) - { - return Err(ScheduleBuildError::CrossDependency(a, b)); - } - } - - Ok(()) - } - - fn check_order_but_intersect( - &self, - dep_results_connected: &HashSet<(NodeId, NodeId)>, - set_system_sets: &HashMap>, - ) -> Result<(), ScheduleBuildError> { - // check that there is no ordering between system sets that intersect - for &(a, b) in dep_results_connected { - let (NodeId::Set(a_key), NodeId::Set(b_key)) = (a, b) else { - continue; - }; - - let a_systems = set_system_sets.get(&a_key).unwrap(); - let b_systems = set_system_sets.get(&b_key).unwrap(); - - if !a_systems.is_disjoint(b_systems) { - return Err(ScheduleBuildError::SetsHaveOrderButIntersect(a_key, b_key)); - } - } - - Ok(()) - } - - fn check_system_type_set_ambiguity( - &self, - set_systems: &HashMap>, - ) -> Result<(), ScheduleBuildError> { - for (&key, systems) in set_systems { - let set = &self.system_sets[key]; - if set.system_type().is_some() { - let instances = systems.len(); - let ambiguous_with = self.ambiguous_with.edges(NodeId::Set(key)); - let before = self - .dependency - .graph - .edges_directed(NodeId::Set(key), Incoming); - let after = self - .dependency - .graph - .edges_directed(NodeId::Set(key), Outgoing); - let relations = before.count() + after.count() + ambiguous_with.count(); - if instances > 1 && relations > 0 { - return Err(ScheduleBuildError::SystemTypeSetAmbiguity(key)); - } - } - } - Ok(()) - } - - /// if [`ScheduleBuildSettings::ambiguity_detection`] is [`LogLevel::Ignore`], this check is skipped - fn optionally_check_conflicts( - &self, - conflicts: &[(SystemKey, SystemKey, Vec)], - ) -> Result, ScheduleBuildError> { - match (self.settings.ambiguity_detection, !conflicts.is_empty()) { - (LogLevel::Warn, true) => Ok(Some(ScheduleBuildWarning::Ambiguity(conflicts.to_vec()))), - (LogLevel::Error, true) => { - Err(ScheduleBuildWarning::Ambiguity(conflicts.to_vec()).into()) - } - _ => Ok(None), - } - } - - /// convert conflicts to human readable format - pub fn conflicts_to_string<'a>( - &'a self, - ambiguities: &'a [(SystemKey, SystemKey, Vec)], - components: &'a Components, - ) -> impl Iterator)> + 'a { - ambiguities - .iter() - .map(move |(system_a, system_b, conflicts)| { - let name_a = self.get_node_name(&NodeId::System(*system_a)); - let name_b = self.get_node_name(&NodeId::System(*system_b)); - - let conflict_names: Vec<_> = conflicts - .iter() - .map(|id| components.get_name(*id).unwrap()) - .collect(); - - (name_a, name_b, conflict_names) - }) - } - fn traverse_sets_containing_node(&self, id: NodeId, f: &mut impl FnMut(SystemSetKey) -> bool) { - for (set_id, _) in self.hierarchy.graph.edges_directed(id, Incoming) { + for (set_id, _) in self.hierarchy.edges_directed(id, Incoming) { let NodeId::Set(set_key) = set_id else { continue; }; @@ -2792,7 +2451,7 @@ mod tests { .graph() .systems_in_set(test_system.into_system_set().intern()) .unwrap(); - assert_ne!(keys[0], keys[1]); + assert_eq!(keys.len(), 2); } #[test] diff --git a/release-content/migration-guides/schedule_cleanup.md b/release-content/migration-guides/schedule_cleanup.md index 19ab4dc6d87e4..0b1281519d677 100644 --- a/release-content/migration-guides/schedule_cleanup.md +++ b/release-content/migration-guides/schedule_cleanup.md @@ -1,13 +1,24 @@ --- title: "Schedule cleanup" -pull_requests: [21608] +pull_requests: [21608, 21817] --- -- `ScheduleGraph::topsort_graph` has been moved to `DiGraph::toposort`. +- `ScheduleGraph::topsort_graph` has been moved to `DiGraph::toposort`, and now takes a `Vec` parameter for allocation reuse. - `ReportCycles` was removed: instead, `DiGraphToposortError`s should be immediately wrapped into hierarchy graph or dependency graph `ScheduleBuildError` variants. - `ScheduleBuildError::HierarchyLoop` variant was removed, use `ScheduleBuildError::HierarchySort(DiGraphToposortError::Loop())` instead. - `ScheduleBuildError::HierarchyCycle` variant was removed, use `ScheduleBuildError::HierarchySort(DiGraphToposortError::Cycle())` instead. - `ScheduleBuildError::DependencyLoop` variant was removed, use `ScheduleBuildError::DependencySort(DiGraphToposortError::Loop())` instead. - `ScheduleBuildError::DependencyCycle` variant was removed, use `ScheduleBuildError::DependencySort(DiGraphToposortError::Cycle())` instead. +- `ScheduleBuildError::CrossDependency` now wraps a `DagCrossDependencyError` instead of directly holding two `NodeId`s. Fetch them from the wrapped struct instead. +- `ScheduleBuildError::SetsHaveOrderButIntersect` now wraps a `DagOverlappingGroupError` instead of directly holding two `SystemSetKey`s. Fetch them from the wrapped struct instead. +- `ScheduleBuildError::SystemTypeSetAmbiguity` now wraps a `SystemTypeSetAmbiguityError` instead of directly holding a `SystemSetKey`. Fetch them from the wrapped struct instead. +- `ScheduleBuildWarning::HierarchyRedundancy` now wraps a `DagRedundancyError` instead of directly holding a `Vec<(NodeId, NodeId)>`. Fetch them from the wrapped struct instead. +- `ScheduleBuildWarning::Ambiguity` now wraps a `AmbiguousSystemConflictsWarning` instead of directly holding a `Vec`. Fetch them from the wrapped struct instead. +- `ScheduleGraph::conflicting_systems` now returns a `&ConflictingSystems` instead of a slice. Fetch conflicts from the wrapped struct instead. +- `ScheduleGraph::systems_in_set` now returns a `&HashSet` instead of a slice, to reduce redundant allocations. +- `ScheduleGraph::conflicts_to_string` functionality has been replaced with `ConflictingSystems::to_string`. +- `ScheduleBuildPass::build` now takes `&mut Dag` instead of `&mut DiGraph`, to allow reusing previous toposorts. +- `ScheduleBuildPass::collapse_set` now takes `&HashSet` instead of a slice, to reduce redundant allocations. - `simple_cycles_in_component` has been changed from a free function into a method on `DiGraph`. +- `DiGraph::try_into`/`UnGraph::try_into` was renamed to `DiGraph::try_convert`/`UnGraph::try_convert` to prevent overlap with the `TryInto` trait, and now makes use of `TryInto` instead of `TryFrom` for conversions.