-
Notifications
You must be signed in to change notification settings - Fork 254
[tx] Per-layer gradient checkpointing with stacked decoder layers #996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 <[email protected]>
…ze<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's activations are held at a time during backward pass, reducing peak memory by ~num_layers factor. - Add gradient_checkpointing config to ModelConfig - Apply jax.checkpoint per-layer when is_training=True - Rename compute_logits to is_training (controls both logits and checkpointing) Co-Authored-By: Claude Opus 4.5 <[email protected]>
…euse Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. Only enabled when gradient_checkpointing=True. Without checkpointing, activations are stored anyway, so fori_loop's buffer reuse doesn't help and its weight stacking overhead makes it worse. Co-Authored-By: Claude Opus 4.5 <[email protected]>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match - test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError. Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Resolve conflicts in llama3.py and qwen3.py - Integrate LogitsProcessor from main - Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 <[email protected]>
Co-Authored-By: Claude Opus 4.5 <[email protected]>
Co-Authored-By: Claude Opus 4.5 <[email protected]>
Co-Authored-By: Claude Opus 4.5 <[email protected]>
Address Gemini code review feedback: KVCache uses list[jax.Array] format, not stacked arrays. Update tests to: - Use list format in DummyModel initialization - Check list length instead of array shape - Update comments to reflect list format Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces a significant and well-executed architectural change to enable per-layer gradient checkpointing. The introduction of StackedDecoderLayers and MultiStackedDecoderLayers is a clean way to handle both homogeneous and heterogeneous model architectures while enabling efficient jax.lax.scan operations. The ArrayRef write-through view is a clever solution for transparently loading per-layer checkpoints into the new stacked format. The refactoring of model implementations and test utilities is thorough. The new tests for gradient checkpointing, which verify not only outputs but also gradients, are excellent and provide strong confidence in the correctness of this change. I have one suggestion regarding code structure to resolve a circular dependency, but overall this is a high-quality contribution that should significantly improve training memory efficiency.
| filter_fn: Callable[[tuple], bool] | None = None, | ||
| ) -> None: | ||
| """Load safetensors weights into a model with stacked layers.""" | ||
| from tx.layers.stacked import unstack_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local import of unstack_state here is a workaround for a circular dependency: utils.models -> layers.stacked -> utils.generator -> utils.models. This cycle is caused by utils.generator importing utils.models solely for the round_up_seq_len function.
To improve the code structure and break this cycle, consider moving round_up_seq_len to a more fundamental utility module (e.g., a new file like tx/utils/padding.py or tx.utils.misc.py). Both generator.py and models.py could then import it from the new location. This would eliminate the circular dependency and allow unstack_state to be imported at the top level of this file, leading to a cleaner and more maintainable module structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually a good comment, we can also do it in a separate PR
skyrl-tx/tx/layers/stacked.py
Outdated
|
|
||
| def __getitem__(self, key): | ||
| parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") | ||
| return parent[idx] if key is Ellipsis else parent[idx][key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a special case for the Ellipsis here and can just do
return parent[idx][key]
right? Since for key = Ellipsis that's equivalent to return parent[idx].
skyrl-tx/tx/layers/stacked.py
Outdated
| def __setitem__(self, key, value): | ||
| """Write through to parent when value is set via indexing.""" | ||
| parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx") | ||
| if key is Ellipsis: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above, I don't think we need to two different cases here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this actually only supports the ... key. See the latest comment.
skyrl-tx/tx/layers/stacked.py
Outdated
| for group in self.layer_groups: | ||
| yield from group | ||
|
|
||
| def get_stacked_layers_list(self) -> list[StackedDecoderLayers]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I belive this is not used any more and can be removed
skyrl-tx/tx/models/deepseekv3.py
Outdated
|
|
||
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) | ||
|
|
||
| def get_stacked_layers_list(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is not used any more and can be removed
skyrl-tx/tx/utils/models.py
Outdated
| assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}" | ||
| idx = get_adapter_idx(path, adapter_index) | ||
| if key == "lora_A": | ||
| return p[idx + (..., slice(None, rank))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a little more readable to write this as p[*idx, ..., :, :rank] and similar for the others and also below
skyrl-tx/tx/models/deepseekv3.py
Outdated
| self.config = config | ||
| self.num_dense_layers = config.first_k_dense_replace | ||
| self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace | ||
| self.num_layers = config.num_hidden_layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.num_layers isn't actually used, right? Same for the other models
| is_leaf=lambda x: isinstance(x, PartitionSpec), | ||
| ) | ||
|
|
||
| def make_ep_spec(spec, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this function pretty confusing, clearer is: Rename is_stacked_lora_path in the other file to is_stacked_path, and then replace jax.tree.map with jax.tree.map_with_path, and then write this function like
def make_ep_spec(path, s):
if not isinstance(s, PartitionSpec):
return s
# Strip leading stacking dimension if path is stacked
dims = s[1:] if is_stacked_path(path) else s
# Extract only 'ep' dims from PartitionSpecs, replacing others with None
return PartitionSpec(*(p if p == "ep" else None for p in dims))
state_specs = jax.tree_util.tree_map_with_path(
make_ep_spec, nnx.get_partition_spec(state), is_leaf=lambda x: isinstance(x, PartitionSpec)
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried the is_stacked_path + tree_map_with_path approach but it doesn't work for shard_map_ep. The issue is that shard_map_ep receives a sub-module (e.g. self.experts) that was extracted from a stacked layer during scan. The paths in that sub-module's state don't contain _stacked — that's the parent module's attribute name. So is_stacked_path returns False and the leading stacking dim isn't stripped, causing a spec/tensor rank mismatch.
skyrl-tx/tests/models/test_qwen3.py
Outdated
| jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank) | ||
|
|
||
|
|
||
| def load_stacked_lora_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't needed any more with the automatic view of stacked layers as unstacked
|
You can revert most of the changes to the test (except for the new tests and factoring out the common methods) -- see #1018 where the tests just pass with the existing code by using the stacked abstraction |
- Simplify ArrayRef __getitem__/__setitem__ by removing Ellipsis special cases - Remove unreachable IndexError in MultiStackedDecoderLayers.__getitem__ - Remove unused get_stacked_layers_list from DeepseekV3Model and MultiStackedDecoderLayers - Remove unused self.num_layers from DeepseekV3Model, Llama3Model, Qwen3Model - Use extended unpacking syntax (p[*idx, ..., :rank]) for readability - Rewrite make_ep_spec to use is_stacked_path + tree_map_with_path - Rename is_stacked_lora_path to is_stacked_path (not LoRA-specific) - Remove load_stacked_lora_weights in test_qwen3, use unstacked view instead Co-Authored-By: Claude Opus 4.6 <[email protected]>
Use model.model.layers[0] instead of accessing stacked format directly. ArrayRef provides transparent unstacked views, so tests don't need explicit stacked indexing. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Instead of extracting a GraphState and passing it to save/load_safetensors (which expect an nnx.Module for unstack_state), pass the full model with adapter_index/rank params to slice LoRA weights inline. Extract shared adapter slicing logic into get_lora_adapter_slice helper. Co-Authored-By: Claude Opus 4.6 <[email protected]>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is an excellent pull request that introduces a significant architectural improvement for memory efficiency during training. The implementation of per-layer gradient checkpointing through stacked layer weights is well-executed and thoughtfully designed.
Key highlights:
- Stacked Layer Architecture: The
StackedDecoderLayersandMultiStackedDecoderLayersprovide a clean and efficient way to manage layer parameters, enabling the use ofjax.lax.scanfor performance. The handling of heterogeneous layers inMultiStackedDecoderLayersis also a great addition. - Transparent Checkpointing: The
ArrayRefclass is a clever solution for providing a write-through, unstacked view of the parameters, which greatly simplifies checkpoint loading and saving logic. - Performance Optimizations: The choice to use a Python loop for decoding to enable buffer donation for the KV cache shows great attention to performance details.
- Test Coverage: The PR includes comprehensive tests for the new functionality, including correctness checks for gradient checkpointing across different models, and unit tests for the new utility functions. The refactoring of existing tests into
lora_test_utils.pyalso improves code quality. - Code Cleanup: The removal of a duplicated line in
tx/utils/generator.pyis a nice small cleanup.
The code is well-structured, clearly commented, and the changes are easy to follow despite the complexity of the refactoring. I've found one potential issue in save_safetensors which I've commented on. Overall, this is a high-quality contribution.
| _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) | ||
|
|
||
| adapter_lora_params = extract_adapter_state(adapter_index, lora_params, adapter_config.rank) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we used to call extract_adapter_state and pass the result to save_safetensors.
This is broken because extract_adapter_state returns a GraphState rather than a Module, which doesn't support the ArrayRef write through.
The lora_A/lora_B guard in filter_lora caused init_lora_adapter to set effective_rank=0 for lora_ranks and lora_scaling paths (which don't contain lora_A/lora_B). Move the guard to the save/load call sites. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Stacked layers use scan which changes LoRA initialization order, producing slightly different he_uniform values. With bf16 precision, 1e-4 learning rate needs >10 steps to register a loss change. Bump to 1e-3 for reliable convergence within 10 steps. Co-Authored-By: Claude Opus 4.6 <[email protected]>
| init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1)) | ||
|
|
||
| optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param) | ||
| optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=model.is_lora_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stacked layers use scan which changes LoRA initialization order,
producing slightly different he_uniform values. With bf16 precision,
1e-4 learning rate needs >10 steps to register a loss change.
Bump to 1e-3 for reliable convergence within 10 steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
closing in favor of #1083 and stacked PRs |
This PR implements per-layer gradient checkpointing to reduce training memory, enabled by a stacked layer weights architecture that allows using
jax.lax.scanwithjax.checkpointover transformer layers.Key Changes
1. Per-Layer Gradient Checkpointing
Checkpoint each transformer layer individually using
jax.lax.scanwithjax.checkpoint. Memory grows O(1) with layer count instead of O(n) - only one layer's activations held in memory during backward pass.2. Stacked Layer Weights
Store all layer parameters as stacked arrays
(num_layers, ...)instead of separate layer modules. This enables scan over uniform data structures.StackedDecoderLayers: Manages stacked parameters with scan-based forward pass
MultiStackedDecoderLayers: Supports heterogeneous architectures (e.g., DeepSeek with dense + MoE layer groups)
3. Checkpoint Loading
unstack_state()transforms stacked format to per-layer view usingArrayRefwrite-through variables for transparentcheckpoint loading.
Architecture Support
StackedDecoderLayersMultiStackedDecoderLayersBenchmark