Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion nemo_automodel/components/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,86 @@ def to_dict(self):
k: self._unwrap(v) for k, v in self.__dict__.items() if k not in ("raise_on_missing_attr", "_raw_config")
}

def _to_dotted_path(self, obj):
"""
Convert a callable/class/method object to a dotted path string.

Best-effort normalization for a few common cases to produce concise, user-friendly paths.
"""
# Bound method on a class (e.g., Class.from_pretrained)
try:
import inspect as _inspect # local alias to avoid confusion with top-level import

if _inspect.ismethod(obj):
owner = getattr(obj, "__self__", None)
if _inspect.isclass(owner):
method_name = getattr(obj, "__name__", "unknown")
module_name = getattr(owner, "__module__", None) or ""
class_name = getattr(owner, "__name__", "UnknownClass")
# Prefer shortened top-level for NeMoAutoModel* classes if possible
if class_name.startswith("NeMoAutoModel"):
module_name = "nemo_automodel"
dotted = f"{module_name}.{class_name}.{method_name}".lstrip(".")
else:
# Bound to instance – fall back to module + qualname
module_name = getattr(obj, "__module__", None) or ""
qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", "unknown"))
dotted = f"{module_name}.{qualname}".lstrip(".")
elif _inspect.isfunction(obj):
module_name = getattr(obj, "__module__", None) or ""
qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", "unknown"))
dotted = f"{module_name}.{qualname}".lstrip(".")
elif _inspect.isclass(obj):
module_name = getattr(obj, "__module__", None) or ""
class_name = getattr(obj, "__name__", "UnknownClass")
dotted = f"{module_name}.{class_name}".lstrip(".")
else:
module_name = getattr(obj, "__module__", None) or ""
qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", str(obj)))
dotted = f"{module_name}.{qualname}".lstrip(".")
except Exception:
# Fallback to repr if anything goes wrong
return repr(obj)
return dotted

def to_yaml_dict(self):
"""
Convert configuration to a YAML-ready dictionary:
- Preserves typed scalars (ints, floats, bools)
- Converts callables/classes/methods (e.g., _target_, *_fn) to dotted path strings
- Recurses through nested ConfigNodes and lists
"""

def _convert(key, value):
# Nested config
if isinstance(value, ConfigNode):
return value.to_yaml_dict()
# Lists
if isinstance(value, list):
return [_convert(None, v) for v in value]
# Dicts (shouldn't normally appear because we wrap into ConfigNode, but handle defensively)
if isinstance(value, dict):
return {k: _convert(k, v) for k, v in value.items()}
# Convert targets/functions to dotted path strings
is_target_like = key == "_target_" or (isinstance(key, str) and key.endswith("_fn")) or key == "collate_fn"
try:
import inspect as _inspect

if is_target_like and (callable(value) or _inspect.ismethod(value) or _inspect.isclass(value)):
return self._to_dotted_path(value)
# Even if the key isn't target-like, convert bare callables to dotted path to avoid <function ...> repr
if callable(value) or _inspect.ismethod(value) or _inspect.isclass(value):
return self._to_dotted_path(value)
except Exception:
pass
# Primitive – already typed via translate_value/_wrap
return value

# Walk live attributes to preserve translated scalars
return {
k: _convert(k, v) for k, v in self.__dict__.items() if k not in ("raise_on_missing_attr", "_raw_config")
}

def _unwrap(self, v):
"""
Recursively convert wrapped configuration values to basic Python types.
Expand Down Expand Up @@ -508,7 +588,7 @@ def __repr__(self, level=0):
for key, value in self.__dict__.items()
if key not in ("raise_on_missing_attr", "_raw_config")
]
return "\n#path: " + "\n".join(lines) + f"\n{indent}"
return "\n".join(lines) + f"\n{indent}"

def _repr_value(self, value, level):
"""
Expand Down
41 changes: 14 additions & 27 deletions nemo_automodel/recipes/base_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.optim import Optimizer
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers.processing_utils import ProcessorMixin
Expand All @@ -36,13 +37,6 @@
from nemo_automodel.components.training.rng import StatefulRNG
from nemo_automodel.components.training.step_scheduler import StepScheduler

try:
import yaml as _yaml
except Exception:
_yaml = None
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizerBase


def has_load_restore_state(object):
"""
Expand Down Expand Up @@ -384,32 +378,25 @@ def _log_experiment_details(self):
and getattr(self.cfg.model, "pretrained_model_name_or_path", None),
}
try:
if _yaml is not None:
details_yaml = _yaml.safe_dump(details, sort_keys=False, default_flow_style=False).strip()
else:
details_yaml = "\n".join(f"{k}: {v}" for k, v in details.items())
list(map(logging.info, ("Experiment_details:\n" + details_yaml).splitlines()))
details_yaml = yaml.safe_dump(details, sort_keys=False, default_flow_style=False).strip()
for line in ("Experiment_details:\n" + details_yaml).splitlines():
logging.info(line)
except Exception:
logging.info(f"Experiment details: {details}")
# Resolved config
try:
cfg_obj = getattr(self, "cfg", None)
cfg_dict = (
cfg_obj.to_dict() if hasattr(cfg_obj, "to_dict") else (dict(cfg_obj) if cfg_obj is not None else {})
)
# Prefer YAML-ready dict that converts callables/classes to dotted paths and preserves typed scalars
if hasattr(cfg_obj, "to_yaml_dict"):
cfg_dict = cfg_obj.to_yaml_dict()
elif hasattr(cfg_obj, "to_dict"):
cfg_dict = cfg_obj.to_dict()
else:
cfg_dict = dict(cfg_obj) if cfg_obj is not None else {}

def rec_print(log_fn, cfg_dict: dict | None, indent: int = 2):
if cfg_dict is None:
return
for k, v in cfg_dict.items():
if isinstance(v, dict):
log_fn(f"{' ' * indent}{k}:")
rec_print(log_fn, v, indent + 2)
else:
log_fn(f"{' ' * indent}{k}: {v}")

logging.info("Recipe config:")
rec_print(logging.info, cfg_dict)
# Print as clean YAML on stdout for easy copy/paste and readability
cfg_yaml = yaml.safe_dump(cfg_dict, sort_keys=False, default_flow_style=False).strip()
print(cfg_yaml, flush=True)
except Exception:
logging.info("Recipe config: <unavailable>")

Expand Down
Loading