Skip to content

Commit c5c6496

Browse files
feat(trainer): Add CustomTrainerContainer to create TrainJobs from image (#127)
* feat(trainer): Add CustomTrainerContainer to create TrainJobs from image Signed-off-by: Andrey Velichkevich <[email protected]> * Update kubeflow/trainer/types/types.py Co-authored-by: Anya Kramar <[email protected]> Signed-off-by: Andrey Velichkevich <[email protected]> * Update train docstring Signed-off-by: Andrey Velichkevich <[email protected]> --------- Signed-off-by: Andrey Velichkevich <[email protected]> Co-authored-by: Anya Kramar <[email protected]>
1 parent 7d682fd commit c5c6496

File tree

8 files changed

+166
-31
lines changed

8 files changed

+166
-31
lines changed

AGENTS.md

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
## Agent Behavior Policy
88

99
AI agents should:
10+
1011
- Make atomic, minimal, and reversible changes.
1112
- Prefer local analysis (`uv run`, `make verify`, `pytest`) before proposing commits.
1213
- NEVER modify configuration, CI/CD, or release automation unless explicitly requested.
1314
- Avoid non-deterministic code or random seeds without fixtures.
1415
- Use `AGENTS.md` and `Makefile` as the source of truth for development commands.
1516

1617
Agents must NOT:
18+
1719
- Bypass tests or linters
1820
- Introduce dependencies without updating `pyproject.toml`
1921
- Generate or commit large autogenerated files
2022

21-
2223
### Context Awareness
2324

2425
Before writing code, agents should:
@@ -27,7 +28,6 @@ Before writing code, agents should:
2728
- Match import patterns from neighboring files
2829
- Preserve existing logging and error-handling conventionso
2930

30-
3131
## Repository Map
3232

3333
```
@@ -53,17 +53,21 @@ Root files: AGENTS.md, README.md, pyproject.toml, Makefile, CI workflows
5353
## Quick Start
5454

5555
<!-- BEGIN: AGENT_COMMANDS -->
56+
5657
**Setup**:
58+
5759
```bash
5860
make install-dev # Install uv, create .venv, sync deps
5961
```
6062

6163
**Verify (CI parity)**:
64+
6265
```bash
6366
make verify # Runs ruff check --show-fixes and ruff format --check
6467
```
6568

6669
**Testing**:
70+
6771
```bash
6872
make test-python # All unit tests + coverage (HTML by default)
6973
make test-python report=xml # XML coverage report
@@ -73,38 +77,45 @@ uv run coverage run -m pytest <path> && uv run coverage report # Ad-hoc
7377
```
7478

7579
**Local lint/format**:
80+
7681
```bash
7782
uv run ruff check --fix . # Fix lint issues
7883
uv run ruff format kubeflow # Format code
7984
```
8085

8186
**Type checking**:
87+
8288
```bash
8389
uv run mypy kubeflow # Run type checker
8490
```
8591

8692
**Pre-commit**:
93+
8794
```bash
8895
uv run pre-commit install # Install hooks
8996
uv run pre-commit run --all-files # Run all hooks
9097
```
98+
9199
<!-- END: AGENT_COMMANDS -->
92100

93101
## Development Workflow for AI Agents
94102

95103
**Preferred commands**: use `uv run ...` to ensure tool consistency and `.venv` usage
96104

97105
**Before making changes**:
106+
98107
1. Read existing code patterns and docstrings for alignment
99108
2. Follow the Core Development Principles below
100109
3. Run validation commands before proposing changes
101110

102111
**Validation before proposing changes**:
112+
103113
- Lint/format: `make verify`
104114
- Tests: `make test-python` or targeted `pytest` invocations
105115
- Type checking: `uv run mypy kubeflow` (if available)
106116

107117
**Commit/PR hygiene**:
118+
108119
- Follow Conventional Commits in titles and messages
109120
- Include rationale ("why") in commit messages/PR descriptions
110121
- Do not push secrets or change git config
@@ -117,19 +128,22 @@ uv run pre-commit run --all-files # Run all hooks
117128
**Always attempt to preserve function signatures, argument positions, and names for exported/public methods.**
118129

119130
**Bad - Breaking Change:**
131+
120132
```python
121133
def train_model(id, verbose=False): # Changed from `model_id`
122134
pass
123135
```
124136

125137
**Good - Stable Interface:**
138+
126139
```python
127140
def train_model(model_id: str, verbose: bool = False) -> TrainingResult:
128141
"""Train model with optional verbose output."""
129142
pass
130143
```
131144

132145
**Before making ANY changes to public APIs:**
146+
133147
- Check if the function/class is exported in `__init__.py`
134148
- Look for existing usage patterns in tests and examples
135149
- Use keyword-only arguments for new parameters: `*, new_param: str = "default"`
@@ -140,27 +154,30 @@ def train_model(model_id: str, verbose: bool = False) -> TrainingResult:
140154
**All Python code MUST include type hints and return types.**
141155

142156
**Bad:**
157+
143158
```python
144159
def p(u, d):
145160
return [x for x in u if x not in d]
146161
```
147162

148163
**Good:**
164+
149165
```python
150166
def filter_completed_jobs(jobs: list[str], completed: set[str]) -> list[str]:
151167
"""Filter out jobs that are already completed.
152-
168+
153169
Args:
154170
jobs: List of job identifiers to filter.
155171
completed: Set of completed job identifiers.
156-
172+
157173
Returns:
158174
List of jobs that are not yet completed.
159175
"""
160176
return [job for job in jobs if job not in completed]
161177
```
162178

163179
**Style Requirements:**
180+
164181
- Line length 100, Python 3.9 target, double quotes, spaces indent
165182
- Imports: isort via ruff; first-party is `kubeflow`; prefer absolute imports
166183
- Naming: pep8-naming; functions/vars `snake_case`, classes `PascalCase`, constants `UPPER_SNAKE_CASE`; prefix private with `_`
@@ -173,18 +190,21 @@ def filter_completed_jobs(jobs: list[str], completed: set[str]) -> list[str]:
173190
**Every new feature or bugfix MUST be covered by unit tests.**
174191

175192
**Test Organization:**
193+
176194
- Unit tests: `kubeflow/trainer/**/*_test.py` (no network calls allowed)
177195
- Use `pytest` as the testing framework
178196
- See `kubeflow/trainer/test/common.py` for fixtures and patterns
179197
- Unit test structure must be consistent between each other (see `kubeflow/trainer/backends/kubernetes/backend_test.py` for reference)
180198

181199
**Test Structure Pattern** (following `backend_test.py`):
200+
182201
- Use `TestCase` dataclass for parametrized tests
183202
- Include `name`, `expected_status`, `config`, `expected_output/error` fields
184203
- Print test execution status for debugging
185204
- Handle both success and exception cases in the same test function
186205

187206
**Test Quality Checklist:**
207+
188208
- [ ] Tests fail when your new logic is broken
189209
- [ ] Happy path is covered
190210
- [ ] Edge cases and error conditions are tested
@@ -194,19 +214,21 @@ def filter_completed_jobs(jobs: list[str], completed: set[str]) -> list[str]:
194214
**Test Examples:**
195215

196216
Simple test:
217+
197218
```python
198219
def test_filter_completed_jobs():
199220
"""Test filtering completed jobs from a list."""
200221
jobs = ["job-1", "job-2", "job-3"]
201222
completed = {"job-1", "job-2"}
202-
223+
203224
result = filter_completed_jobs(jobs, completed)
204-
225+
205226
assert result == ["job-3"]
206227
assert len(result) == 1
207228
```
208229

209230
Parametrized test cases (preferred for multiple scenarios):
231+
210232
```python
211233
@pytest.mark.parametrize(
212234
"test_case",
@@ -234,20 +256,23 @@ def test_filter_jobs_parametrized(test_case):
234256
### 4. Security and Risk Assessment
235257

236258
**Security Checklist:**
259+
237260
- [ ] No `eval()`, `exec()`, or `pickle` on user-controlled input
238261
- [ ] Proper exception handling (no bare `except:`) and use descriptive error messages
239262
- [ ] Remove unreachable/commented code before committing
240263
- [ ] Ensure proper resource cleanup (file handles, connections)
241264
- [ ] No secrets in code, logs, or examples
242265

243266
**Bad:**
267+
244268
```python
245269
def load_config(path):
246270
with open(path) as f:
247271
return eval(f.read()) # ⚠️ Never eval user input
248272
```
249273

250274
**Good:**
275+
251276
```python
252277
import yaml
253278

@@ -262,31 +287,34 @@ def load_config(path: str) -> dict:
262287
**Use Google-style docstrings with Args section for all public functions.**
263288

264289
**Insufficient Documentation:**
290+
265291
```python
266292
def submit_job(name, config):
267293
"""Submit a job."""
268294
```
269295

270296
**Complete Documentation:**
297+
271298
```python
272299
def submit_job(name: str, config: dict, *, priority: str = "normal") -> str:
273300
"""Submit a training job with specified configuration.
274-
301+
275302
Args:
276303
name: The job name identifier.
277304
config: Job configuration dictionary.
278305
priority: Job priority level ('low', 'normal', 'high').
279-
306+
280307
Returns:
281308
Job ID string for tracking the submitted job.
282-
309+
283310
Raises:
284311
InvalidConfigError: If the configuration is invalid.
285312
ResourceUnavailableError: If required resources are not available.
286313
"""
287314
```
288315

289316
**Documentation Guidelines:**
317+
290318
- Types go in function signatures, NOT in docstrings
291319
- Focus on "why" rather than "what" in descriptions
292320
- Document all parameters, return values, and exceptions
@@ -298,6 +326,7 @@ def submit_job(name: str, config: dict, *, priority: str = "normal") -> str:
298326
**When you encounter code that could be improved, suggest better designs:**
299327

300328
**Poor Design:**
329+
301330
```python
302331
def process_training(data, k8s_client, storage, logger):
303332
# Function doing too many things
@@ -309,21 +338,22 @@ def process_training(data, k8s_client, storage, logger):
309338
```
310339

311340
**Better Design:**
341+
312342
```python
313343
@dataclass
314344
class TrainingJobResult:
315345
"""Result of training job submission."""
316346
job_id: str
317347
status: str
318348
created_at: datetime
319-
349+
320350
class TrainingJobManager:
321351
"""Handles training job lifecycle operations."""
322-
352+
323353
def __init__(self, k8s_client: KubernetesClient, storage: Storage):
324354
self.k8s = k8s_client
325355
self.storage = storage
326-
356+
327357
def submit_job(self, config: TrainingConfig) -> TrainingJobResult:
328358
"""Submit and track a new training job."""
329359
validated_config = self._validate_config(config)
@@ -343,27 +373,39 @@ class TrainingJobManager:
343373
**Trainer Types**:
344374

345375
**CustomTrainer** (`kubeflow.trainer.types.CustomTrainer`):
376+
346377
- **Purpose**: For custom, self-contained training functions that you write yourself
347378
- **Flexibility**: Complete control over the training process
348379
- **Use case**: "Bring your own training code" - maximum flexibility
349380
- **Key attributes**: `func` (your training function), `func_args`, `packages_to_install`, `pip_index_urls`, `num_nodes`, `resources_per_node`, `env`
350381

382+
**CustomTrainerContainer** (`kubeflow.trainer.types.CustomTrainerContainer`):
383+
384+
- **Purpose**: For custom, self-contained container image that you create yourself
385+
- **Flexibility**: Complete control over the training process
386+
- **Use case**: "Bring your own training image" - maximum flexibility
387+
- **Key attributes**: `num_nodes`, `resources_per_node`, `env`
388+
351389
**BuiltinTrainer** (`kubeflow.trainer.types.BuiltinTrainer`):
390+
352391
- **Purpose**: For pre-built training frameworks with existing fine-tuning logic
353392
- **Convenience**: Just configure parameters, training logic is already implemented
354393
- **Use case**: "Use our pre-built trainers" - convenience for common scenarios
355394
- **Key attributes**: `config` (currently only supports `TorchTuneConfig` for LLM fine-tuning with TorchTune)
356395

357396
**Backends**:
397+
358398
- `localprocess`: local execution for fast iteration
359399
- `kubernetes`: K8s-backed jobs, see `backends/kubernetes`
360400

361401
**Typical flow**:
402+
362403
1. Get runtime, define trainer, submit with `TrainerClient().train(...)`
363404
2. `wait_for_job_status(...)` then fetch logs with `get_job_logs(...)`
364405
3. For full example, see README "Run your first PyTorch distributed job"
365406

366407
**Integration patterns**:
408+
367409
- Follow existing patterns in `kubeflow.trainer.backends` for new backends
368410
- Use `kubeflow.trainer.types` for data models and type definitions
369411
- Implement proper error handling and resource cleanup
@@ -372,6 +414,7 @@ class TrainingJobManager:
372414
## CI & PRs
373415

374416
**PR Requirements**:
417+
375418
- Title must follow Conventional Commits:
376419
- Types: `chore`, `fix`, `feat`, `revert`
377420
- Scopes: `ci`, `docs`, `examples`, `scripts`, `test`, `trainer`
@@ -381,9 +424,11 @@ class TrainingJobManager:
381424
## Releasing
382425

383426
**Version management**:
427+
384428
```bash
385429
make release VERSION=X.Y.Z # Updates kubeflow/__init__.py and generates changelog
386430
```
431+
387432
- Do not commit secrets; verify coverage and lint pass before tagging
388433

389434
## Troubleshooting

kubeflow/trainer/api/trainer_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,26 @@ def train(
9595
self,
9696
runtime: Optional[types.Runtime] = None,
9797
initializer: Optional[types.Initializer] = None,
98-
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
98+
trainer: Optional[
99+
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
100+
] = None,
99101
) -> str:
100102
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:
101103
102104
- CustomTrainer: Runs training with a user-defined function that fully encapsulates the
103105
training process.
106+
- CustomTrainerContainer: Runs training with a user-defined image that fully encapsulates
107+
the training process.
104108
- BuiltinTrainer: Uses a predefined trainer with built-in post-training logic, requiring
105109
only parameter configuration.
106110
107111
Args:
108112
runtime: Optional reference to one of the existing runtimes. Defaults to the
109113
torch-distributed runtime if not provided.
110114
initializer: Optional configuration for the dataset and model initializers.
111-
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
112-
the TrainJob will use the runtime's default values.
115+
trainer: Optional configuration for a CustomTrainer, CustomTrainerContainer, or
116+
BuiltinTrainer. If not specified, the TrainJob will use the
117+
runtime's default values.
113118
114119
Returns:
115120
The unique name of the TrainJob that has been generated.

kubeflow/trainer/backends/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def train(
3838
self,
3939
runtime: Optional[types.Runtime] = None,
4040
initializer: Optional[types.Initializer] = None,
41-
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
41+
trainer: Optional[
42+
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
43+
] = None,
4244
) -> str:
4345
raise NotImplementedError()
4446

0 commit comments

Comments
 (0)