Skip to content

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 30, 2026

This PR implements per-layer gradient checkpointing to reduce training memory, enabled by a stacked layer weights architecture that allows using jax.lax.scan with jax.checkpoint over transformer layers.

Key Changes

1. Per-Layer Gradient Checkpointing

Checkpoint each transformer layer individually using jax.lax.scan with jax.checkpoint. Memory grows O(1) with layer count instead of O(n) - only one layer's activations held in memory during backward pass.

2. Stacked Layer Weights

Store all layer parameters as stacked arrays (num_layers, ...) instead of separate layer modules. This enables scan over uniform data structures.

StackedDecoderLayers: Manages stacked parameters with scan-based forward pass
MultiStackedDecoderLayers: Supports heterogeneous architectures (e.g., DeepSeek with dense + MoE layer groups)

3. Checkpoint Loading

unstack_state() transforms stacked format to per-layer view using ArrayRef write-through variables for transparent
checkpoint loading.

Architecture Support

  • Homogeneous models (Llama3, Qwen3): Single StackedDecoderLayers
  • Heterogeneous models (DeepSeek): Multiple groups via MultiStackedDecoderLayers

Benchmark

  ┌────────────────────────┬───────────────┬─────────────┬──────────┬───────────────┐
  │          Mode          │    Branch     │ Peak Memory │ JIT Time │ Post-JIT Time │
  ├────────────────────────┼───────────────┼─────────────┼──────────┼───────────────┤
  │ Sample (bs=16, seq=4k) │ main          │ 6,737 MiB   │ 331.5s   │ 21.28s        │
  ├────────────────────────┼───────────────┼─────────────┼──────────┼───────────────┤
  │                        │ stack-weights │ 6,727 MiB   │ 406.2s   │ 24.18s        │
  ├────────────────────────┼───────────────┼─────────────┼──────────┼───────────────┤
  │ Train (bs=1, seq=4k)   │ main          │ 22,451 MiB  │ 331.1s   │ 0.85s         │
  ├────────────────────────┼───────────────┼─────────────┼──────────┼───────────────┤
  │                        │ stack-weights │ 6,699 MiB   │ 227.2s   │ 0.51s         │
  └────────────────────────┴───────────────┴─────────────┴──────────┴───────────────┘

Open with Devin

raulchen and others added 30 commits January 20, 2026 18:55
Compute lm_head projection in chunks to avoid materializing the full
[B*T, V] logits tensor. Key changes:

- Add compute_logits flag to model.__call__ (skip lm_head when False)
- Add lm_head weight to CausalLMOutput for external computation
- Implement chunked logprobs with jax.lax.map (default chunk_size=1024)
- Add loss_chunk_size config option

Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor.
For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
…ze<=0

The chunked cross-entropy path computes logits via direct matmul with
lm_head weight, bypassing LoRA adapters. This is incorrect when
train_unembed=True since LoRA should be applied to lm_head.

Changes:
- Rename is_training to skip_logits for clarity
- Add _use_chunked_loss flag to backend
- Automatically switch to non-chunked mode when:
  - train_unembed=True (requires LoRA on lm_head)
  - loss_chunk_size <= 0 (config-based disable)
- Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's
activations are held at a time during backward pass, reducing peak
memory by ~num_layers factor.

- Add gradient_checkpointing config to ModelConfig
- Apply jax.checkpoint per-layer when is_training=True
- Rename compute_logits to is_training (controls both logits and checkpointing)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
…euse

Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles
ONE loop body and reuses buffers during backward recomputation. With a
Python loop, XLA unrolls N separate checkpoint regions and can't optimize
buffer reuse across them.

Only enabled when gradient_checkpointing=True. Without checkpointing,
activations are stored anyway, so fori_loop's buffer reuse doesn't help
and its weight stacking overhead makes it worse.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match
- test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Resolve conflicts in llama3.py and qwen3.py
- Integrate LogitsProcessor from main
- Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths
- Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers
- Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods:
  compute_logits(), compute_logprobs(), logits_to_logprobs()
- Model forward() returns only hidden_states (removed logits computation)
- Simplified CausalLMOutput: removed logits and lm_head fields
- Generator uses LogitsProcessor for all logits/logprobs computation
- Backend uses LogitsProcessor.compute_logprobs() with chunking
- Updated tests to use new LogitsProcessor API

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Create CausalLMBase class with compute_logits/compute_logprobs methods
- Models expose wrapper methods instead of direct LogitsProcessor access
- Update generator and jax.py backend to use model methods
- LogitsProcessor is now internal implementation detail

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Replace _has_train_unembed flag with _train_unembed_mask array
- Check at runtime if any adapter in batch needs LoRA on lm_head
- Use jax.lax.cond to choose chunked vs non-chunked path
- Handle adapter reuse correctly (reset mask on delete)
- Remove unused _use_chunked_loss flag

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Replace abstract property with __init__(lm_head) in base class
- Subclasses explicitly call CausalLMBase.__init__(self, lm_head)
- Fix test to support multiple adapters for mixed train_unembed test

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Address Gemini code review feedback: KVCache uses list[jax.Array]
format, not stacked arrays. Update tests to:
- Use list format in DummyModel initialization
- Check list length instead of array shape
- Update comments to reflect list format

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
@raulchen
Copy link
Contributor Author

raulchen commented Feb 6, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a significant and well-executed architectural change to enable per-layer gradient checkpointing. The introduction of StackedDecoderLayers and MultiStackedDecoderLayers is a clean way to handle both homogeneous and heterogeneous model architectures while enabling efficient jax.lax.scan operations. The ArrayRef write-through view is a clever solution for transparently loading per-layer checkpoints into the new stacked format. The refactoring of model implementations and test utilities is thorough. The new tests for gradient checkpointing, which verify not only outputs but also gradients, are excellent and provide strong confidence in the correctness of this change. I have one suggestion regarding code structure to resolve a circular dependency, but overall this is a high-quality contribution that should significantly improve training memory efficiency.

filter_fn: Callable[[tuple], bool] | None = None,
) -> None:
"""Load safetensors weights into a model with stacked layers."""
from tx.layers.stacked import unstack_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local import of unstack_state here is a workaround for a circular dependency: utils.models -> layers.stacked -> utils.generator -> utils.models. This cycle is caused by utils.generator importing utils.models solely for the round_up_seq_len function.

To improve the code structure and break this cycle, consider moving round_up_seq_len to a more fundamental utility module (e.g., a new file like tx/utils/padding.py or tx.utils.misc.py). Both generator.py and models.py could then import it from the new location. This would eliminate the circular dependency and allow unstack_state to be imported at the top level of this file, leading to a cleaner and more maintainable module structure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a good comment, we can also do it in a separate PR


def __getitem__(self, key):
parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx")
return parent[idx] if key is Ellipsis else parent[idx][key]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a special case for the Ellipsis here and can just do

return parent[idx][key]

right? Since for key = Ellipsis that's equivalent to return parent[idx].

def __setitem__(self, key, value):
"""Write through to parent when value is set via indexing."""
parent, idx = self.get_metadata("_parent"), self.get_metadata("_idx")
if key is Ellipsis:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, I don't think we need to two different cases here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this actually only supports the ... key. See the latest comment.

for group in self.layer_groups:
yield from group

def get_stacked_layers_list(self) -> list[StackedDecoderLayers]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I belive this is not used any more and can be removed


self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)

def get_stacked_layers_list(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is not used any more and can be removed

assert p.ndim in {3, 4, 5}, f"LoRA parameters must have 3-5 dimensions, got shape {p.shape}"
idx = get_adapter_idx(path, adapter_index)
if key == "lora_A":
return p[idx + (..., slice(None, rank))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a little more readable to write this as p[*idx, ..., :, :rank] and similar for the others and also below

self.config = config
self.num_dense_layers = config.first_k_dense_replace
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_layers = config.num_hidden_layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_layers isn't actually used, right? Same for the other models

is_leaf=lambda x: isinstance(x, PartitionSpec),
)

def make_ep_spec(spec, value):
Copy link
Collaborator

@pcmoritz pcmoritz Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this function pretty confusing, clearer is: Rename is_stacked_lora_path in the other file to is_stacked_path, and then replace jax.tree.map with jax.tree.map_with_path, and then write this function like

def make_ep_spec(path, s):
        if not isinstance(s, PartitionSpec):
            return s
        # Strip leading stacking dimension if path is stacked
        dims = s[1:] if is_stacked_path(path) else s
        # Extract only 'ep' dims from PartitionSpecs, replacing others with None
        return PartitionSpec(*(p if p == "ep" else None for p in dims))

state_specs = jax.tree_util.tree_map_with_path(
        make_ep_spec, nnx.get_partition_spec(state), is_leaf=lambda x: isinstance(x, PartitionSpec)
    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the is_stacked_path + tree_map_with_path approach but it doesn't work for shard_map_ep. The issue is that shard_map_ep receives a sub-module (e.g. self.experts) that was extracted from a stacked layer during scan. The paths in that sub-module's state don't contain _stacked — that's the parent module's attribute name. So is_stacked_path returns False and the leading stacking dim isn't stripped, causing a spec/tensor rank mismatch.

jax_module.lora_ranks[...] = jax_module.lora_ranks[...].at[adapter_idx].set(rank)


def load_stacked_lora_weights(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't needed any more with the automatic view of stacked layers as unstacked

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 6, 2026

You can revert most of the changes to the test (except for the new tests and factoring out the common methods) -- see #1018 where the tests just pass with the existing code by using the stacked abstraction

raulchen and others added 8 commits February 6, 2026 10:03
- Simplify ArrayRef __getitem__/__setitem__ by removing Ellipsis special cases
- Remove unreachable IndexError in MultiStackedDecoderLayers.__getitem__
- Remove unused get_stacked_layers_list from DeepseekV3Model and MultiStackedDecoderLayers
- Remove unused self.num_layers from DeepseekV3Model, Llama3Model, Qwen3Model
- Use extended unpacking syntax (p[*idx, ..., :rank]) for readability
- Rewrite make_ep_spec to use is_stacked_path + tree_map_with_path
- Rename is_stacked_lora_path to is_stacked_path (not LoRA-specific)
- Remove load_stacked_lora_weights in test_qwen3, use unstacked view instead

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Use model.model.layers[0] instead of accessing stacked format directly.
ArrayRef provides transparent unstacked views, so tests don't need
explicit stacked indexing.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Instead of extracting a GraphState and passing it to save/load_safetensors
(which expect an nnx.Module for unstack_state), pass the full model with
adapter_index/rank params to slice LoRA weights inline. Extract shared
adapter slicing logic into get_lora_adapter_slice helper.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@raulchen
Copy link
Contributor Author

raulchen commented Feb 6, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an excellent pull request that introduces a significant architectural improvement for memory efficiency during training. The implementation of per-layer gradient checkpointing through stacked layer weights is well-executed and thoughtfully designed.

Key highlights:

  • Stacked Layer Architecture: The StackedDecoderLayers and MultiStackedDecoderLayers provide a clean and efficient way to manage layer parameters, enabling the use of jax.lax.scan for performance. The handling of heterogeneous layers in MultiStackedDecoderLayers is also a great addition.
  • Transparent Checkpointing: The ArrayRef class is a clever solution for providing a write-through, unstacked view of the parameters, which greatly simplifies checkpoint loading and saving logic.
  • Performance Optimizations: The choice to use a Python loop for decoding to enable buffer donation for the KV cache shows great attention to performance details.
  • Test Coverage: The PR includes comprehensive tests for the new functionality, including correctness checks for gradient checkpointing across different models, and unit tests for the new utility functions. The refactoring of existing tests into lora_test_utils.py also improves code quality.
  • Code Cleanup: The removal of a duplicated line in tx/utils/generator.py is a nice small cleanup.

The code is well-structured, clearly commented, and the changes are easy to follow despite the complexity of the refactoring. I've found one potential issue in save_safetensors which I've commented on. Overall, this is a high-quality contribution.

_, lora_params, _ = nnx.split(model, model.is_lora_param, ...)

adapter_lora_params = extract_adapter_state(adapter_index, lora_params, adapter_config.rank)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we used to call extract_adapter_state and pass the result to save_safetensors.
This is broken because extract_adapter_state returns a GraphState rather than a Module, which doesn't support the ArrayRef write through.

raulchen and others added 2 commits February 6, 2026 16:43
The lora_A/lora_B guard in filter_lora caused init_lora_adapter to set
effective_rank=0 for lora_ranks and lora_scaling paths (which don't
contain lora_A/lora_B). Move the guard to the save/load call sites.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Stacked layers use scan which changes LoRA initialization order,
producing slightly different he_uniform values. With bf16 precision,
1e-4 learning rate needs >10 steps to register a loss change.
Bump to 1e-3 for reliable convergence within 10 steps.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
init_lora_adapter(model, adapter_index=1, lora_config=LoraConfig(rank=8, alpha=8, seed=1))

optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=model.is_lora_param)
optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=model.is_lora_param)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stacked layers use scan which changes LoRA initialization order,
producing slightly different he_uniform values. With bf16 precision,
1e-4 learning rate needs >10 steps to register a loss change.
Bump to 1e-3 for reliable convergence within 10 steps.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 10 additional findings in Devin Review.

Open in Devin Review

@raulchen
Copy link
Contributor Author

closing in favor of #1083 and stacked PRs

@raulchen raulchen closed this Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants