diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 83055e36..61a8d32d 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1843,9 +1843,35 @@ def inputs(self) -> Sequence[Value | None]: @inputs.setter def inputs(self, _: Any) -> None: raise AttributeError( - "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." + "Node.inputs cannot be assigned to. Please use 'resize_inputs' and " + "'replace_input_with' instead." ) + def resize_inputs(self, new_size: int, /) -> None: + """Resize the inputs of the node. + + If the new size is greater than the current size, new inputs are added as None. + If the new size is less than the current size, the extra inputs are removed. + + After ``inputs`` is resized, you can use :meth:`replace_input_with` to set the new inputs. + + .. versionadded:: 0.1.13 + + Args: + new_size: The new number of inputs. + """ + current_size = len(self._inputs) + if new_size == current_size: + return + if new_size < current_size: + # Remove extra inputs + for i in range(new_size, current_size): + self.replace_input_with(i, None) + self._inputs = self._inputs[:new_size] + else: + # Add new inputs as None + self._inputs = self._inputs + (None,) * (new_size - current_size) + def predecessors(self) -> Sequence[Node]: """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes @@ -1920,15 +1946,54 @@ def append(self, /, nodes: Node | Iterable[Node]) -> None: def outputs(self) -> Sequence[Value]: """The output values of the node. - The outputs are immutable. To change the outputs, create a new node and - replace the inputs of the using nodes of this node's outputs by calling - :meth:`replace_input_with` on the using nodes of this node's outputs. + The outputs are always attached to this node once initialized (immutable), + except that the list can be resized to remove or add outputs. + + Use :meth:`resize_outputs` to change the number of outputs of the node. """ return self._outputs @outputs.setter def outputs(self, _: Sequence[Value]) -> None: - raise AttributeError("outputs is immutable. Please create a new node instead.") + raise AttributeError( + "Node.outputs cannot be assigned to. Please use 'resize_outputs' or create a new node instead." + ) + + def resize_outputs(self, new_size: int, /) -> None: + """Resize the outputs of the node. + + If the new size is greater than the current size, new output values are created. + If the new size is less than the current size, the extra output values are removed. + The removed output values must not have any uses. + + .. versionadded:: 0.1.13 + + Args: + new_size: The new number of outputs. + + Raises: + ValueError: If the new size is less than the current size and + the removed outputs have uses. + """ + current_size = len(self._outputs) + if new_size == current_size: + return + if new_size < current_size: + # Check that the removed outputs have no uses + for output in self._outputs[new_size:]: + if output.uses(): + raise ValueError( + f"Cannot remove output {output} because it has uses: {output.uses()}" + ) + for output in self._outputs[new_size:]: + # Detach the output from this node + output._producer = None # pylint: disable=protected-access + output._index = -1 # pylint: disable=protected-access + self._outputs = self._outputs[:new_size] + else: + # Create new outputs + new_outputs = [Value(self, index=i) for i in range(current_size, new_size)] + self._outputs = self._outputs + tuple(new_outputs) @property def attributes(self) -> _graph_containers.Attributes: diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 5c45f561..51a6eda3 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -1507,6 +1507,207 @@ def test_attributes_get_tensors(self): node.attributes.get_tensors("non_existent_attr", [tensor1]), [tensor1] ) + def test_resize_inputs_increase_size(self): + """Test that resize_inputs increases the number of inputs by adding None values.""" + v0 = _core.Value(name="v0") + v1 = _core.Value(name="v1") + node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1) + + self.assertEqual(len(node.inputs), 2) + self.assertIs(node.inputs[0], v0) + self.assertIs(node.inputs[1], v1) + + # Resize to 4 inputs + node.resize_inputs(4) + + self.assertEqual(len(node.inputs), 4) + self.assertIs(node.inputs[0], v0) + self.assertIs(node.inputs[1], v1) + self.assertIsNone(node.inputs[2]) + self.assertIsNone(node.inputs[3]) + + def test_resize_inputs_decrease_size(self): + """Test that resize_inputs decreases the number of inputs and removes uses.""" + v0 = _core.Value(name="v0") + v1 = _core.Value(name="v1") + v2 = _core.Value(name="v2") + node = _core.Node("", "TestOp", inputs=(v0, v1, v2), num_outputs=1) + + self.assertEqual(len(node.inputs), 3) + # Check that node is in v2's uses + self.assertEqual(len(v2.uses()), 1) + self.assertIn(_core.Usage(node, 2), v2.uses()) + + # Resize to 2 inputs (remove v2) + node.resize_inputs(2) + + self.assertEqual(len(node.inputs), 2) + self.assertIs(node.inputs[0], v0) + self.assertIs(node.inputs[1], v1) + # Check that node is no longer in v2's uses + self.assertEqual(len(v2.uses()), 0) + + def test_resize_inputs_same_size(self): + """Test that resize_inputs does nothing when size is unchanged.""" + v0 = _core.Value(name="v0") + v1 = _core.Value(name="v1") + node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1) + + # Resize to same size + node.resize_inputs(2) + + self.assertEqual(len(node.inputs), 2) + self.assertIs(node.inputs[0], v0) + self.assertIs(node.inputs[1], v1) + + def test_resize_inputs_to_zero(self): + """Test that resize_inputs can reduce inputs to zero.""" + v0 = _core.Value(name="v0") + v1 = _core.Value(name="v1") + node = _core.Node("", "TestOp", inputs=(v0, v1), num_outputs=1) + + node.resize_inputs(0) + + self.assertEqual(len(node.inputs), 0) + self.assertEqual(node.inputs, ()) + # Check that uses are removed + self.assertEqual(len(v0.uses()), 0) + self.assertEqual(len(v1.uses()), 0) + + def test_resize_inputs_from_zero(self): + """Test that resize_inputs can increase from zero inputs.""" + node = _core.Node("", "TestOp", inputs=(), num_outputs=1) + + self.assertEqual(len(node.inputs), 0) + + node.resize_inputs(3) + + self.assertEqual(len(node.inputs), 3) + self.assertIsNone(node.inputs[0]) + self.assertIsNone(node.inputs[1]) + self.assertIsNone(node.inputs[2]) + + def test_resize_inputs_preserves_none_inputs(self): + """Test that resize_inputs preserves None inputs when decreasing size.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0, None, None), num_outputs=1) + + node.resize_inputs(2) + + self.assertEqual(len(node.inputs), 2) + self.assertIs(node.inputs[0], v0) + self.assertIsNone(node.inputs[1]) + + def test_resize_outputs_increase_size(self): + """Test that resize_outputs increases the number of outputs.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2) + + self.assertEqual(len(node.outputs), 2) + old_output_0 = node.outputs[0] + old_output_1 = node.outputs[1] + + # Resize to 4 outputs + node.resize_outputs(4) + + self.assertEqual(len(node.outputs), 4) + # Verify old outputs are preserved + self.assertIs(node.outputs[0], old_output_0) + self.assertIs(node.outputs[1], old_output_1) + # Verify new outputs are created + self.assertIsNotNone(node.outputs[2]) + self.assertIsNotNone(node.outputs[3]) + # Verify new outputs have correct producer and index + self.assertIs(node.outputs[2].producer(), node) + self.assertIs(node.outputs[3].producer(), node) + self.assertEqual(node.outputs[2].index(), 2) + self.assertEqual(node.outputs[3].index(), 3) + + def test_resize_outputs_decrease_size(self): + """Test that resize_outputs decreases the number of outputs when they have no uses.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=3) + + self.assertEqual(len(node.outputs), 3) + old_output_0 = node.outputs[0] + + # Resize to 1 output + node.resize_outputs(1) + + self.assertEqual(len(node.outputs), 1) + self.assertIs(node.outputs[0], old_output_0) + + def test_resize_outputs_decrease_size_raises_when_output_has_uses(self): + """Test that resize_outputs raises ValueError when removing outputs with uses.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=3) + # Create a consumer for the third output + _consumer = _core.Node("", "Consumer", inputs=(node.outputs[2],), num_outputs=1) + + self.assertEqual(len(node.outputs[2].uses()), 1) + + # Try to resize to 2 outputs (remove the third one) + with self.assertRaisesRegex(ValueError, "Cannot remove output.*because it has uses"): + node.resize_outputs(2) + + # Verify outputs are unchanged + self.assertEqual(len(node.outputs), 3) + + def test_resize_outputs_same_size(self): + """Test that resize_outputs does nothing when size is unchanged.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2) + + old_outputs = node.outputs + + # Resize to same size + node.resize_outputs(2) + + self.assertEqual(len(node.outputs), 2) + self.assertIs(node.outputs[0], old_outputs[0]) + self.assertIs(node.outputs[1], old_outputs[1]) + + def test_resize_outputs_to_zero(self): + """Test that resize_outputs can reduce outputs to zero.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=2) + + node.resize_outputs(0) + + self.assertEqual(len(node.outputs), 0) + self.assertEqual(node.outputs, ()) + + def test_resize_outputs_from_zero(self): + """Test that resize_outputs can increase from zero outputs.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=0) + + self.assertEqual(len(node.outputs), 0) + + node.resize_outputs(2) + + self.assertEqual(len(node.outputs), 2) + self.assertIsNotNone(node.outputs[0]) + self.assertIsNotNone(node.outputs[1]) + self.assertIs(node.outputs[0].producer(), node) + self.assertIs(node.outputs[1].producer(), node) + self.assertEqual(node.outputs[0].index(), 0) + self.assertEqual(node.outputs[1].index(), 1) + + def test_resize_outputs_decrease_with_middle_output_having_uses(self): + """Test that resize_outputs raises when removing a middle output with uses.""" + v0 = _core.Value(name="v0") + node = _core.Node("", "TestOp", inputs=(v0,), num_outputs=4) + # Create a consumer for the second output (index 1) + _consumer = _core.Node("", "Consumer", inputs=(node.outputs[1],), num_outputs=1) + + # Try to resize to 1 output (remove outputs at indices 1, 2, 3) + with self.assertRaisesRegex(ValueError, "Cannot remove output.*because it has uses"): + node.resize_outputs(1) + + # Verify outputs are unchanged + self.assertEqual(len(node.outputs), 4) + # TODO(justinchuby): Test all methods diff --git a/src/onnx_ir/passes/common/unused_removal.py b/src/onnx_ir/passes/common/unused_removal.py index aebc40eb..57d75b51 100644 --- a/src/onnx_ir/passes/common/unused_removal.py +++ b/src/onnx_ir/passes/common/unused_removal.py @@ -64,6 +64,26 @@ def is_used_output(i: int) -> bool: if out not in graph_outputs and (not out.uses()) and optional_info[i] is True: out.name = "" + # Remove trailing outputs with empty names by counting backwards + new_output_count = len(node.outputs) + for i in reversed(range(len(node.outputs))): + if not node.outputs[i].name: + new_output_count -= 1 + else: + break + node.resize_outputs(new_output_count) + + +def _remove_trailing_empty_inputs(node: ir.Node) -> None: + # Remove trailing None inputs + new_input_count = len(node.inputs) + for i in reversed(range(len(node.inputs))): + if node.inputs[i] is None: + new_input_count -= 1 + else: + break + node.resize_inputs(new_input_count) + def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int: graph_outputs = frozenset(function_or_graph.outputs) @@ -79,6 +99,7 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph function_or_graph.remove(node, safe=True) count += 1 else: + _remove_trailing_empty_inputs(node) if onnx_opset_version is not None: _remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version) for attr in node.attributes.values():