Skip to content

Commit 04baf7f

Browse files
Sanitize profile filename (#21395)
* sanitize filename * add testing * changelog * fix comment --------- Co-authored-by: Bhimraj Yadav <[email protected]>
1 parent 716c2c6 commit 04baf7f

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
### Fixed
2828

29+
- Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395))
30+
31+
2932
- Fix `StochasticWeightAveraging` with infinite epochs ([#21396](https://github.com/Lightning-AI/pytorch-lightning/pull/21396))
3033

3134

src/lightning/pytorch/profilers/profiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import os
18+
import re
1819
from abc import ABC, abstractmethod
1920
from collections.abc import Generator
2021
from contextlib import contextmanager
@@ -80,7 +81,6 @@ def _prepare_filename(
8081
self,
8182
action_name: Optional[str] = None,
8283
extension: str = ".txt",
83-
split_token: str = "-", # noqa: S107
8484
) -> str:
8585
args = []
8686
if self._stage is not None:
@@ -91,7 +91,14 @@ def _prepare_filename(
9191
args.append(str(self._local_rank))
9292
if action_name is not None:
9393
args.append(action_name)
94-
return split_token.join(args) + extension
94+
base = "-".join(args)
95+
# Replace a set of path-unsafe characters across platforms with '_'
96+
base = re.sub(r"[\\/:*?\"<>|\n\r\t]", "_", base)
97+
base = re.sub(r"_+", "_", base)
98+
base = base.strip()
99+
if not base:
100+
base = "profile"
101+
return base + extension
95102

96103
def _prepare_streams(self) -> None:
97104
if self._write_stream is not None:

tests/tests_pytorch/profilers/test_profiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,33 @@ def test_advanced_profiler_dump_states(tmp_path):
322322
assert len(data) > 0
323323

324324

325+
@pytest.mark.parametrize("char", ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"])
326+
def test_advanced_profiler_dump_states_sanitizes_filename(tmp_path, char):
327+
"""Profiler should sanitize action names to produce filesystem-safe .prof filenames.
328+
329+
This guards against errors when callbacks or actions include path-unsafe characters (e.g., metric names with '/').
330+
331+
"""
332+
profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
333+
action_name = f"before{char}after"
334+
with profiler.profile(action_name):
335+
pass
336+
337+
profiler.describe()
338+
339+
prof_files = [f for f in os.listdir(tmp_path) if f.endswith(".prof")]
340+
assert len(prof_files) == 1
341+
prof_name = prof_files[0]
342+
343+
# Ensure none of the path-unsafe characters are present in the produced filename
344+
forbidden = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"]
345+
for bad in forbidden:
346+
assert bad not in prof_name
347+
348+
# File should be non-empty
349+
assert (tmp_path / prof_name).read_bytes()
350+
351+
325352
def test_advanced_profiler_value_errors(advanced_profiler):
326353
"""Ensure errors are raised where expected."""
327354
action = "test"

0 commit comments

Comments
 (0)