Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,9 +1843,35 @@ def inputs(self) -> Sequence[Value | None]:
@inputs.setter
def inputs(self, _: Any) -> None:
raise AttributeError(
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
"Node.inputs cannot be assigned to. Please use 'resize_inputs' and "
"'replace_input_with' instead."
)

def resize_inputs(self, new_size: int, /) -> None:
"""Resize the inputs of the node.

If the new size is greater than the current size, new inputs are added as None.
If the new size is less than the current size, the extra inputs are removed.

After ``inputs`` is resized, you can use :meth:`replace_input_with` to set the new inputs.

.. versionadded:: 0.1.13

Args:
new_size: The new number of inputs.
"""
current_size = len(self._inputs)
if new_size == current_size:
return
if new_size < current_size:
# Remove extra inputs
for i in range(new_size, current_size):
self.replace_input_with(i, None)
self._inputs = self._inputs[:new_size]
else:
# Add new inputs as None
self._inputs = self._inputs + (None,) * (new_size - current_size)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
Expand Down Expand Up @@ -1920,15 +1946,54 @@ def append(self, /, nodes: Node | Iterable[Node]) -> None:
def outputs(self) -> Sequence[Value]:
"""The output values of the node.

The outputs are immutable. To change the outputs, create a new node and
replace the inputs of the using nodes of this node's outputs by calling
:meth:`replace_input_with` on the using nodes of this node's outputs.
The outputs are always attached to this node once initialized (immutable),
except that the list can be resized to remove or add outputs.

Use :meth:`resize_outputs` to change the number of outputs of the node.
"""
return self._outputs

@outputs.setter
def outputs(self, _: Sequence[Value]) -> None:
raise AttributeError("outputs is immutable. Please create a new node instead.")
raise AttributeError(
"Node.outputs cannot be assigned to. Please use 'resize_outputs' or create a new node instead."
)

def resize_outputs(self, new_size: int, /) -> None:
"""Resize the outputs of the node.

If the new size is greater than the current size, new output values are created.
If the new size is less than the current size, the extra output values are removed.
The removed output values must not have any uses.

.. versionadded:: 0.1.13

Args:
new_size: The new number of outputs.

Raises:
ValueError: If the new size is less than the current size and
the removed outputs have uses.
"""
current_size = len(self._outputs)
if new_size == current_size:
return
if new_size < current_size:
# Check that the removed outputs have no uses
for output in self._outputs[new_size:]:
if output.uses():
raise ValueError(
f"Cannot remove output {output} because it has uses: {output.uses()}"
)
for output in self._outputs[new_size:]:
# Detach the output from this node
output._producer = None # pylint: disable=protected-access
output._index = -1 # pylint: disable=protected-access
self._outputs = self._outputs[:new_size]
else:
# Create new outputs
new_outputs = [Value(self, index=i) for i in range(current_size, new_size)]
self._outputs = self._outputs + tuple(new_outputs)

@property
def attributes(self) -> _graph_containers.Attributes:
Expand Down
201 changes: 201 additions & 0 deletions src/onnx_ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,207 @@ def test_attributes_get_tensors(self):
node.attributes.get_tensors("non_existent_attr", [tensor1]), [tensor1]
)

def test_resize_inputs_increase_size(self):
"""Test that resize_inputs increases the number of inputs by adding None values."""
v0 = _core.Value(name="v0")
v1 = _core.Value(name="v1")
node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1)

self.assertEqual(len(node.inputs), 2)
self.assertIs(node.inputs[0], v0)
self.assertIs(node.inputs[1], v1)

# Resize to 4 inputs
node.resize_inputs(4)

self.assertEqual(len(node.inputs), 4)
self.assertIs(node.inputs[0], v0)
self.assertIs(node.inputs[1], v1)
self.assertIsNone(node.inputs[2])
self.assertIsNone(node.inputs[3])

def test_resize_inputs_decrease_size(self):
"""Test that resize_inputs decreases the number of inputs and removes uses."""
v0 = _core.Value(name="v0")
v1 = _core.Value(name="v1")
v2 = _core.Value(name="v2")
node = _core.Node("", "TestOp", inputs=(v0, v1, v2), num_outputs=1)

self.assertEqual(len(node.inputs), 3)
# Check that node is in v2's uses
self.assertEqual(len(v2.uses()), 1)
self.assertIn(_core.Usage(node, 2), v2.uses())

# Resize to 2 inputs (remove v2)
node.resize_inputs(2)

self.assertEqual(len(node.inputs), 2)
self.assertIs(node.inputs[0], v0)
self.assertIs(node.inputs[1], v1)
# Check that node is no longer in v2's uses
self.assertEqual(len(v2.uses()), 0)

def test_resize_inputs_same_size(self):
"""Test that resize_inputs does nothing when size is unchanged."""
v0 = _core.Value(name="v0")
v1 = _core.Value(name="v1")
node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1)

# Resize to same size
node.resize_inputs(2)

self.assertEqual(len(node.inputs), 2)
self.assertIs(node.inputs[0], v0)
self.assertIs(node.inputs[1], v1)

def test_resize_inputs_to_zero(self):
"""Test that resize_inputs can reduce inputs to zero."""
v0 = _core.Value(name="v0")
v1 = _core.Value(name="v1")
node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1)

node.resize_inputs(0)

self.assertEqual(len(node.inputs), 0)
self.assertEqual(node.inputs, ())
# Check that uses are removed
self.assertEqual(len(v0.uses()), 0)
self.assertEqual(len(v1.uses()), 0)

def test_resize_inputs_from_zero(self):
"""Test that resize_inputs can increase from zero inputs."""
node = _core.Node("", "TestOp", inputs=(), num_outputs=1)

self.assertEqual(len(node.inputs), 0)

node.resize_inputs(3)

self.assertEqual(len(node.inputs), 3)
self.assertIsNone(node.inputs[0])
self.assertIsNone(node.inputs[1])
self.assertIsNone(node.inputs[2])

def test_resize_inputs_preserves_none_inputs(self):
"""Test that resize_inputs preserves None inputs when decreasing size."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0, None, None), num_outputs=1)

node.resize_inputs(2)

self.assertEqual(len(node.inputs), 2)
self.assertIs(node.inputs[0], v0)
self.assertIsNone(node.inputs[1])

def test_resize_outputs_increase_size(self):
"""Test that resize_outputs increases the number of outputs."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2)

self.assertEqual(len(node.outputs), 2)
old_output_0 = node.outputs[0]
old_output_1 = node.outputs[1]

# Resize to 4 outputs
node.resize_outputs(4)

self.assertEqual(len(node.outputs), 4)
# Verify old outputs are preserved
self.assertIs(node.outputs[0], old_output_0)
self.assertIs(node.outputs[1], old_output_1)
# Verify new outputs are created
self.assertIsNotNone(node.outputs[2])
self.assertIsNotNone(node.outputs[3])
# Verify new outputs have correct producer and index
self.assertIs(node.outputs[2].producer(), node)
self.assertIs(node.outputs[3].producer(), node)
self.assertEqual(node.outputs[2].index(), 2)
self.assertEqual(node.outputs[3].index(), 3)

def test_resize_outputs_decrease_size(self):
"""Test that resize_outputs decreases the number of outputs when they have no uses."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=3)

self.assertEqual(len(node.outputs), 3)
old_output_0 = node.outputs[0]

# Resize to 1 output
node.resize_outputs(1)

self.assertEqual(len(node.outputs), 1)
self.assertIs(node.outputs[0], old_output_0)

def test_resize_outputs_decrease_size_raises_when_output_has_uses(self):
"""Test that resize_outputs raises ValueError when removing outputs with uses."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=3)
# Create a consumer for the third output
_consumer = _core.Node("", "Consumer", inputs=(node.outputs[2],), num_outputs=1)

self.assertEqual(len(node.outputs[2].uses()), 1)

# Try to resize to 2 outputs (remove the third one)
with self.assertRaisesRegex(ValueError, "Cannot remove output.*because it has uses"):
node.resize_outputs(2)

# Verify outputs are unchanged
self.assertEqual(len(node.outputs), 3)

def test_resize_outputs_same_size(self):
"""Test that resize_outputs does nothing when size is unchanged."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2)

old_outputs = node.outputs

# Resize to same size
node.resize_outputs(2)

self.assertEqual(len(node.outputs), 2)
self.assertIs(node.outputs[0], old_outputs[0])
self.assertIs(node.outputs[1], old_outputs[1])

def test_resize_outputs_to_zero(self):
"""Test that resize_outputs can reduce outputs to zero."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2)

node.resize_outputs(0)

self.assertEqual(len(node.outputs), 0)
self.assertEqual(node.outputs, ())

def test_resize_outputs_from_zero(self):
"""Test that resize_outputs can increase from zero outputs."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=0)

self.assertEqual(len(node.outputs), 0)

node.resize_outputs(2)

self.assertEqual(len(node.outputs), 2)
self.assertIsNotNone(node.outputs[0])
self.assertIsNotNone(node.outputs[1])
self.assertIs(node.outputs[0].producer(), node)
self.assertIs(node.outputs[1].producer(), node)
self.assertEqual(node.outputs[0].index(), 0)
self.assertEqual(node.outputs[1].index(), 1)

def test_resize_outputs_decrease_with_middle_output_having_uses(self):
"""Test that resize_outputs raises when removing a middle output with uses."""
v0 = _core.Value(name="v0")
node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=4)
# Create a consumer for the second output (index 1)
_consumer = _core.Node("", "Consumer", inputs=(node.outputs[1],), num_outputs=1)

# Try to resize to 1 output (remove outputs at indices 1, 2, 3)
with self.assertRaisesRegex(ValueError, "Cannot remove output.*because it has uses"):
node.resize_outputs(1)

# Verify outputs are unchanged
self.assertEqual(len(node.outputs), 4)

# TODO(justinchuby): Test all methods


Expand Down
21 changes: 21 additions & 0 deletions src/onnx_ir/passes/common/unused_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,26 @@ def is_used_output(i: int) -> bool:
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
out.name = ""

# Remove trailing outputs with empty names by counting backwards
new_output_count = len(node.outputs)
for i in reversed(range(len(node.outputs))):
if not node.outputs[i].name:
new_output_count -= 1
else:
break
node.resize_outputs(new_output_count)


def _remove_trailing_empty_inputs(node: ir.Node) -> None:
# Remove trailing None inputs
new_input_count = len(node.inputs)
for i in reversed(range(len(node.inputs))):
if node.inputs[i] is None:
new_input_count -= 1
else:
break
node.resize_inputs(new_input_count)


def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
graph_outputs = frozenset(function_or_graph.outputs)
Expand All @@ -79,6 +99,7 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph
function_or_graph.remove(node, safe=True)
count += 1
else:
_remove_trailing_empty_inputs(node)
if onnx_opset_version is not None:
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
for attr in node.attributes.values():
Expand Down