diff --git a/tasktiger/runner.py b/tasktiger/runner.py index 037c813..36b8b4d 100644 --- a/tasktiger/runner.py +++ b/tasktiger/runner.py @@ -42,7 +42,9 @@ def run_eager_task(self, task: "Task") -> None: """ raise NotImplementedError("Eager tasks are not supported.") - def on_permanent_error(self, task: "Task", execution: Dict[str, Any] | None) -> None: + def on_permanent_error( + self, task: "Task", execution: Dict[str, Any] | None + ) -> None: """ Called if the task fails permanently. diff --git a/tasktiger/tasktiger.py b/tasktiger/tasktiger.py index f0a6c39..ef8d3df 100644 --- a/tasktiger/tasktiger.py +++ b/tasktiger/tasktiger.py @@ -311,6 +311,11 @@ def _get_current_serialized_func(self) -> str: raise RuntimeError("Must be accessed from within a task.") return g["current_tasks"][0].serialized_func + def _get_current_task_is_batch(self) -> bool: + if g["current_task_is_batch"] is None: + raise RuntimeError("Must be accessed from within a task.") + return g["current_task_is_batch"] + """ Properties to access the currently processing task (or tasks, in case of a batch task) from within the task. They must be invoked from within a task. @@ -318,6 +323,7 @@ def _get_current_serialized_func(self) -> str: current_task = property(_get_current_task) current_tasks = property(_get_current_tasks) current_serialized_func = property(_get_current_serialized_func) + current_task_is_batch = property(_get_current_task_is_batch) @classproperty def current_instance(self) -> "TaskTiger": diff --git a/tests/tasks.py b/tests/tasks.py index 4eb2900..c6860c7 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -163,6 +163,19 @@ def verify_current_serialized_func_batch(tasks): conn.set("serialized_func", serialized_func) +def verify_current_task_is_batch(): + with redis.Redis(host=REDIS_HOST, db=TEST_DB, decode_responses=True) as conn: + is_batch = tiger.current_task_is_batch + conn.set("current_task_is_batch", str(is_batch)) + + +@tiger.task(batch=True, queue="batch") +def verify_current_task_is_batch_batch(tasks): + with redis.Redis(host=REDIS_HOST, db=TEST_DB, decode_responses=True) as conn: + is_batch = tiger.current_task_is_batch + conn.set("current_task_is_batch", str(is_batch)) + + @tiger.task() def verify_tasktiger_instance(): # Not necessarily the same object, but the same configuration. diff --git a/tests/test_base.py b/tests/test_base.py index a41a03e..df05ea1 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -49,6 +49,8 @@ verify_current_serialized_func, verify_current_serialized_func_batch, verify_current_task, + verify_current_task_is_batch, + verify_current_task_is_batch_batch, verify_current_tasks, verify_tasktiger_instance, ) @@ -1172,6 +1174,31 @@ def test_current_serialized_func_batch(self, always_eager): ) +class TestCurrentTaskIsBatch(BaseTestCase): + """ + Ensure current_task_is_batch is set. + """ + + @pytest.mark.parametrize("always_eager", [False, True]) + def test_current_task_is_batch(self, always_eager): + self.tiger.config["ALWAYS_EAGER"] = always_eager + task = Task(self.tiger, verify_current_task_is_batch) + task.delay() + Worker(self.tiger).run(once=True) + assert not self.conn.exists("runtime_error") + assert self.conn.get("current_task_is_batch") == "False" + + @pytest.mark.parametrize("always_eager", [False, True]) + def test_current_task_is_batch_batch(self, always_eager): + self.tiger.config["ALWAYS_EAGER"] = always_eager + task1 = Task(self.tiger, verify_current_task_is_batch_batch) + task1.delay() + task2 = Task(self.tiger, verify_current_task_is_batch_batch) + task2.delay() + Worker(self.tiger).run(once=True) + assert self.conn.get("current_task_is_batch") == "True" + + class TestTaskTigerGlobal(BaseTestCase): """ Ensure TaskTiger.current_instance is set.