diff --git a/launch_testing_ros/launch_testing_ros/wait_for_topics.py b/launch_testing_ros/launch_testing_ros/wait_for_topics.py index 1f3743ca2..107765df6 100644 --- a/launch_testing_ros/launch_testing_ros/wait_for_topics.py +++ b/launch_testing_ros/launch_testing_ros/wait_for_topics.py @@ -19,6 +19,8 @@ from threading import Thread import rclpy +from rclpy.event_handler import QoSSubscriptionMatchedInfo +from rclpy.event_handler import SubscriptionEventCallbacks from rclpy.executors import SingleThreadedExecutor from rclpy.node import Node @@ -50,12 +52,32 @@ def method_2(): print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'} print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...] wait_for_topics.shutdown() + + # Method3, calling a trigger function before the wait. The trigger function takes + # the WaitForTopics node object as the first argument. Any additional arguments have + # to be passed to the wait(*args, **kwargs) method directly. + def trigger_function(node, arg=""): + node.get_logger().info('Trigger function called with argument: ' + arg) + + def method_3(): + topic_list = [('topic_1', String), ('topic_2', String)] + wait_for_topics = WaitForTopics(topic_list, timeout=5.0, trigger=trigger_function) + # The trigger function will be called inside the wait() method after the + # subscribers are created and before the publishers are connected. + assert wait_for_topics.wait("Hello World!") + print('Given topics are receiving messages !') + wait_for_topics.shutdown() """ - def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10) -> None: + def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10, + trigger=None, node_namespace=None) -> None: self.topic_tuples = topic_tuples self.timeout = timeout self.messages_received_buffer_length = messages_received_buffer_length + self.trigger = trigger + self.node_namespace = node_namespace + if self.trigger is not None and not callable(self.trigger): + raise TypeError('The passed trigger is not callable') self.__ros_context = rclpy.Context() rclpy.init(context=self.__ros_context) self.__ros_executor = SingleThreadedExecutor(context=self.__ros_context) @@ -80,11 +102,15 @@ def _prepare_ros_node(self): name=node_name, node_context=self.__ros_context, messages_received_buffer_length=self.messages_received_buffer_length, + node_namespace=self.node_namespace ) self.__ros_executor.add_node(self.__ros_node) - def wait(self): + def wait(self, *args, **kwargs): self.__ros_node.start_subscribers(self.topic_tuples) + if self.trigger: + self.trigger(self.__ros_node, *args, **kwargs) + self.__ros_node.any_publisher_connected.wait(self.timeout) return self.__ros_node.msg_event_object.wait(self.timeout) def shutdown(self): @@ -121,9 +147,12 @@ class _WaitForTopicsNode(Node): """Internal node used for subscribing to a set of topics.""" def __init__( - self, name='test_node', node_context=None, messages_received_buffer_length=None + self, name='test_node', + node_context=None, + messages_received_buffer_length=None, + node_namespace=None ) -> None: - super().__init__(node_name=name, context=node_context) # type: ignore + super().__init__(node_name=name, context=node_context, namespace=node_namespace) self.msg_event_object = Event() self.messages_received_buffer_length = messages_received_buffer_length self.subscriber_list = [] @@ -131,6 +160,13 @@ def __init__( self.expected_topics = set() self.received_topics = set() self.received_messages_buffer = {} + self.any_publisher_connected = Event() + + def _sub_matched_event_callback(self, info: QoSSubscriptionMatchedInfo): + if info.current_count != 0: + self.any_publisher_connected.set() + else: + self.any_publisher_connected.clear() def _reset(self): self.msg_event_object.clear() @@ -149,12 +185,16 @@ def start_subscribers(self, topic_tuples): maxlen=self.messages_received_buffer_length ) # Create a subscriber + sub_event_callback = SubscriptionEventCallbacks( + matched=self._sub_matched_event_callback + ) self.subscriber_list.append( self.create_subscription( topic_type, topic_name, self.callback_template(topic_name), - 10 + 10, + event_callbacks=sub_event_callback, ) ) diff --git a/launch_testing_ros/test/examples/repeater.py b/launch_testing_ros/test/examples/repeater.py new file mode 100644 index 000000000..4f41e01f0 --- /dev/null +++ b/launch_testing_ros/test/examples/repeater.py @@ -0,0 +1,52 @@ +# Copyright 2025 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import rclpy +from rclpy.node import Node + +from std_msgs.msg import String + + +class Repeater(Node): + + def __init__(self): + super().__init__('repeater') + self.subscription = self.create_subscription( + String, 'input', self.callback, 10 + ) + self.publisher = self.create_publisher(String, 'output', 10) + + def callback(self, input_msg): + self.get_logger().info(f'I heard: [{input_msg.data}]') + output_msg_data = input_msg.data + self.get_logger().info(f'Publishing: "{output_msg_data}"') + self.publisher.publish(String(data=output_msg_data)) + + +def main(args=None): + rclpy.init(args=args) + + node = Repeater() + + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/launch_testing_ros/test/examples/wait_for_topic_inject_trigger_test.py b/launch_testing_ros/test/examples/wait_for_topic_inject_trigger_test.py new file mode 100644 index 000000000..abde4d324 --- /dev/null +++ b/launch_testing_ros/test/examples/wait_for_topic_inject_trigger_test.py @@ -0,0 +1,78 @@ +# Copyright 2025 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import unittest + +import launch +import launch_ros.actions +import launch_testing.actions +import launch_testing.markers +from launch_testing_ros import WaitForTopics +import pytest +from std_msgs.msg import String + + +def generate_node(): + """Return node.""" + path_to_test = os.path.dirname(__file__) + return launch_ros.actions.Node( + executable=sys.executable, + arguments=[os.path.join(path_to_test, 'repeater.py')], + name='demo_node', + additional_env={'PYTHONUNBUFFERED': '1'}, + ) + + +def trigger_function(node): + if not hasattr(node, 'my_publisher'): + node.my_publisher = node.create_publisher(String, 'input', 10) + while node.my_publisher.get_subscription_count() == 0: + time.sleep(0.1) + msg = String() + msg.data = 'Hello World' + node.my_publisher.publish(msg) + print('Published message') + + +@pytest.mark.launch_test +@launch_testing.markers.keep_alive +def generate_test_description(): + description = [generate_node(), launch_testing.actions.ReadyToTest()] + return launch.LaunchDescription(description) + + +# TODO: Test cases fail on Windows debug builds +# https://github.com/ros2/launch_ros/issues/292 +if sys.platform.startswith('win'): + pytest.skip( + 'CLI tests can block for a pathological amount of time on Windows.', + allow_module_level=True) + + +class TestFixture(unittest.TestCase): + + def test_topics_successful(self): + """All the supplied topics should be read successfully.""" + topic_list = [('output', String)] + expected_topics = {'output'} + + # Method 1 : Using the magic methods and 'with' keyword + with WaitForTopics( + topic_list, timeout=10.0, trigger=trigger_function + ) as wait_for_node_object_1: + assert wait_for_node_object_1.topics_received() == expected_topics + assert wait_for_node_object_1.topics_not_received() == set() diff --git a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py index c32ca3613..4bc775248 100644 --- a/launch_testing_ros/test/examples/wait_for_topic_launch_test.py +++ b/launch_testing_ros/test/examples/wait_for_topic_launch_test.py @@ -104,3 +104,22 @@ def test_topics_unsuccessful(self, count: int): assert wait_for_node_object.topics_received() == expected_topics assert wait_for_node_object.topics_not_received() == {'invalid_topic'} wait_for_node_object.shutdown() + + def test_trigger_function(self, count): + topic_list = [('chatter_' + str(i), String) for i in range(count)] + expected_topics = {'chatter_' + str(i) for i in range(count)} + + # Method 3 : Using a trigger function + + # Using a list to store the trigger function's argument as it is mutable + is_trigger_called = [False] + + def trigger_function(node, arg): + node.get_logger().info(f'Trigger function called with argument: {arg[0]}') + arg[0] = True + + wait_for_node_object = WaitForTopics(topic_list, timeout=2.0, trigger=trigger_function) + assert wait_for_node_object.wait(is_trigger_called) + assert wait_for_node_object.topics_received() == expected_topics + assert wait_for_node_object.topics_not_received() == set() + assert is_trigger_called[0]