Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tasktiger/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def run_eager_task(self, task: "Task") -> None:
"""
raise NotImplementedError("Eager tasks are not supported.")

def on_permanent_error(self, task: "Task", execution: Dict[str, Any] | None) -> None:
def on_permanent_error(
self, task: "Task", execution: Dict[str, Any] | None
) -> None:
"""
Called if the task fails permanently.

Expand Down
6 changes: 6 additions & 0 deletions tasktiger/tasktiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,19 @@ def _get_current_serialized_func(self) -> str:
raise RuntimeError("Must be accessed from within a task.")
return g["current_tasks"][0].serialized_func

def _get_current_task_is_batch(self) -> bool:
if g["current_task_is_batch"] is None:
raise RuntimeError("Must be accessed from within a task.")
return g["current_task_is_batch"]

"""
Properties to access the currently processing task (or tasks, in case of a
batch task) from within the task. They must be invoked from within a task.
"""
current_task = property(_get_current_task)
current_tasks = property(_get_current_tasks)
current_serialized_func = property(_get_current_serialized_func)
current_task_is_batch = property(_get_current_task_is_batch)

@classproperty
def current_instance(self) -> "TaskTiger":
Expand Down
13 changes: 13 additions & 0 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ def verify_current_serialized_func_batch(tasks):
conn.set("serialized_func", serialized_func)


def verify_current_task_is_batch():
with redis.Redis(host=REDIS_HOST, db=TEST_DB, decode_responses=True) as conn:
is_batch = tiger.current_task_is_batch
conn.set("current_task_is_batch", str(is_batch))


@tiger.task(batch=True, queue="batch")
def verify_current_task_is_batch_batch(tasks):
with redis.Redis(host=REDIS_HOST, db=TEST_DB, decode_responses=True) as conn:
is_batch = tiger.current_task_is_batch
conn.set("current_task_is_batch", str(is_batch))


@tiger.task()
def verify_tasktiger_instance():
# Not necessarily the same object, but the same configuration.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
verify_current_serialized_func,
verify_current_serialized_func_batch,
verify_current_task,
verify_current_task_is_batch,
verify_current_task_is_batch_batch,
verify_current_tasks,
verify_tasktiger_instance,
)
Expand Down Expand Up @@ -1172,6 +1174,31 @@ def test_current_serialized_func_batch(self, always_eager):
)


class TestCurrentTaskIsBatch(BaseTestCase):
"""
Ensure current_task_is_batch is set.
"""

@pytest.mark.parametrize("always_eager", [False, True])
def test_current_task_is_batch(self, always_eager):
self.tiger.config["ALWAYS_EAGER"] = always_eager
task = Task(self.tiger, verify_current_task_is_batch)
task.delay()
Worker(self.tiger).run(once=True)
assert not self.conn.exists("runtime_error")
assert self.conn.get("current_task_is_batch") == "False"

@pytest.mark.parametrize("always_eager", [False, True])
def test_current_task_is_batch_batch(self, always_eager):
self.tiger.config["ALWAYS_EAGER"] = always_eager
task1 = Task(self.tiger, verify_current_task_is_batch_batch)
task1.delay()
task2 = Task(self.tiger, verify_current_task_is_batch_batch)
task2.delay()
Worker(self.tiger).run(once=True)
assert self.conn.get("current_task_is_batch") == "True"


class TestTaskTigerGlobal(BaseTestCase):
"""
Ensure TaskTiger.current_instance is set.
Expand Down