[tx] General implementation of trainable Hyper Connections#1008
[tx] General implementation of trainable Hyper Connections#1008tanmaysachan wants to merge 16 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.
My review found a couple of issues:
- An unused
trainableparameter in theConnectorclass which should be removed for clarity. - A bug in
DeepseekV3Modelwhen handling intermediate hidden states forexpansion_rate > 1, wheresqueeze()is used incorrectly.
Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.
A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.
| all_hidden_states.append(hidden_states.squeeze()) | |
| all_hidden_states.append(hidden_states.mean(axis=-2)) |
| hidden_dim: int, | ||
| expansion_rate: int, | ||
| *, | ||
| trainable: bool = False, |
There was a problem hiding this comment.
| self.eps = eps | ||
| self.weight = Param( | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs |
There was a problem hiding this comment.
Temporary, testing
There was a problem hiding this comment.
https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html
Torch also initalizes to one by default
|
This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :) |
|
Just waiting for the weekend to give it a spin 😅 I'll give Qwen0.6B a shot on an A/H100 |
|
Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :) |
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
🔴 squeeze() without axis removes batch dimension when batch_size=1
In both Qwen3Model and DeepseekV3Model, hidden_states.squeeze() is called without specifying an axis. This is intended to remove the expansion dimension (size n) added by the Connector, but jnp.squeeze() removes all size-1 dimensions. When batch_size=1 (common during inference/generation), this also removes the batch dimension, corrupting the hidden state shapes stored in all_hidden_states.
Root Cause and Impact
For expansion_rate=1 (the default), hidden_states before layers 1+ has shape (B, S, 1, H). Calling .squeeze() on this:
- When
B > 1: produces(B, S, H)— correct. - When
B = 1: produces(S, H)— batch dimension lost.
Similarly for expansion_rate > 1, hidden_states before layer 0 is (B, S, H) and .squeeze() with B=1 produces (S, H), while before layer 1+ it's (B, S, n, H) and .squeeze() with B=1 produces (S, n, H), resulting in inconsistent shapes across layers.
Impact: Any downstream consumer of output_hidden_states (e.g., probing, analysis, reward models) would receive incorrectly shaped tensors when batch_size=1, likely causing crashes or silent data corruption.
Prompt for agents
In skyrl-tx/tx/models/qwen3.py line 372 and skyrl-tx/tx/models/deepseekv3.py line 525, replace hidden_states.squeeze() with a shape-aware reduction. For layers after layer 0 (where the expansion dimension exists), use hidden_states.sum(axis=-2) or hidden_states.squeeze(axis=-2). For layer 0 (where hidden_states is still 3D from embed_tokens), no squeezing is needed. Consider checking hidden_states.ndim before deciding. A simple approach: use hidden_states.squeeze(axis=-2) if hidden_states.ndim == 4 else hidden_states, or simply always record without squeezing and let consumers handle the shape.
Was this helpful? React with 👍 or 👎 to provide feedback.
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
🔴 squeeze() without axis removes batch dimension when batch_size=1 (DeepseekV3)
Same issue as in Qwen3Model but in DeepseekV3Model. hidden_states.squeeze() is called without specifying an axis, which removes all size-1 dimensions including the batch dimension when batch_size=1.
Root Cause and Impact
At skyrl-tx/tx/models/deepseekv3.py:525, hidden_states.squeeze() is called identically to the Qwen3 case. When batch_size=1, the batch dimension is removed, producing tensors of shape (S, H) instead of (1, S, H). This corrupts the all_hidden_states list returned via ModelOutput.hidden_states.
Prompt for agents
In skyrl-tx/tx/models/deepseekv3.py line 525, replace hidden_states.squeeze() with a shape-aware reduction that only removes the expansion dimension (axis=-2) when it exists (ndim==4), and preserves the batch dimension. For example: all_hidden_states.append(hidden_states.squeeze(axis=-2) if hidden_states.ndim == 4 else hidden_states)
Was this helpful? React with 👍 or 👎 to provide feedback.
|
Did some analysis on the step times for each on Qwen 0.6B (on a 5060Ti) Expansion rate as 1 does cause a hit to the average step time (about 0.3s slower, baseline has a step time of 2.1s vs 2.4s). An easy fix would be to just short circuit the entire thing for expansion rate = 1. For expansion rate = 4, the step time was around 3.17s, so about 46% slower. |
skyrl-tx/tx/tinker/backends/jax.py
Outdated
| """Compute full gradients, apply optimizer update, and reset accumulated grads.""" | ||
| optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index)) | ||
| return accumulated_grads.reset_adapter(adapter_index) | ||
| if global_optimizer is not None and self.has_global_trainables: | ||
| global_optimizer.update(global_params, global_accumulated_grads.get_mean()) | ||
| global_accumulated_grads = global_accumulated_grads.reset() | ||
| return accumulated_grads.reset_adapter(adapter_index), global_accumulated_grads |
There was a problem hiding this comment.
🔴 Global optimizer updated with zero gradients on second adapter's optim_step
When multiple LoRA adapters are active, the shared global optimizer receives spurious zero-gradient updates, corrupting its Adam state.
Root Cause
In compute_grads_and_update (jax.py:531-536), the global optimizer is updated and the global accumulated gradients are reset unconditionally on every call:
if global_optimizer is not None and self.has_global_trainables:
global_optimizer.update(global_params, global_accumulated_grads.get_mean())
global_accumulated_grads = global_accumulated_grads.reset()Since optim_step is called once per adapter (jax.py:773-809), with two adapters the sequence is:
optim_step(adapter_1)→ updates global optimizer with real mean gradients, resetsglobal_accumulated_gradsto zerooptim_step(adapter_2)→ updates global optimizer again withget_mean()of the now-zeroed gradients (all zeros), resets again
The second zero-gradient update corrupts Adam's internal state:
- First moments decay:
m_t = β₁ · m_{t-1} + (1-β₁) · 0— momentum decays toward zero - Second moments decay:
v_t = β₂ · v_{t-1} + (1-β₂) · 0— variance estimate shrinks - Step counter increments, affecting bias correction
Impact: Global trainable parameters (connectors) receive incorrect optimizer updates that degrade training quality, with severity proportional to the number of adapters.
Prompt for agents
The global optimizer should only be updated once per training iteration, not once per adapter. Currently in compute_grads_and_update (jax.py:531-536), the global optimizer is updated and global accumulated gradients are reset on every call, but optim_step is called once per adapter. Fix this by either: (1) tracking whether global grads have already been applied in this iteration and skipping if already done (e.g., check global_accumulated_grads.count > 0 before updating), or (2) decoupling the global optimizer step from the per-adapter optim_step so it runs exactly once per training iteration. Option (1) is simpler: guard the global optimizer update with a check like `if global_accumulated_grads.count > 0` before calling global_optimizer.update.
Was this helpful? React with 👍 or 👎 to provide feedback.
| def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array: | ||
| if adapter_indices is None: | ||
| return jnp.zeros((batch_size,), dtype=jnp.int32) | ||
| return adapter_indices.astype(jnp.int32) |
There was a problem hiding this comment.
🟡 LoRAConnector broken when max_lora_adapters=0 — indexing into 0-sized parameter arrays returns wrong values
When a model is created with max_lora_adapters=0 (e.g., tx/run/train.py:80), the LoRAConnector creates all parameter arrays with a first dimension of 0. When pre() or post() is called, _get_adapter_indices returns jnp.zeros((B,), dtype=jnp.int32), and _get_params indexes into these 0-sized arrays, producing zero-filled results instead of the identity-preserving values.
Detailed Explanation
Unlike LoRAMixin.apply_lora which short-circuits when max_lora_adapters == 0 (lora.py:85), LoRAConnector has no such guard. When max_lora_adapters=0:
self.b_prehas shape(0, n),self.b_reshas shape(0, n, n), etc._get_adapter_indices(B, None)returnsjnp.zeros((B,))atconnectors.py:66_get_paramsindexes into 0-sized arrays atconnectors.py:71-80— JAX clips out-of-bounds indices and returns zeros- In
pre():b_pre=0→H_pre = sigmoid(0) = 0.5instead of1/n - In
post():b_res=0→M = sinkhorn(zeros)produces a uniform1/nmatrix instead of identity
For the default expansion_rate=1, the impact on pre is masked by RMSNorm (the 0.5 scale cancels during normalization), and post still produces the correct residual + output. So the default case is approximately correct. However, for expansion_rate > 1 with max_lora_adapters=0, the connector would produce completely wrong outputs (uniform mixing instead of identity passthrough).
This path is exercised in production via tx/run/train.py:80 which uses max_lora_adapters=0.
Prompt for agents
Add a guard in LoRAConnector to handle the max_lora_adapters=0 case. The simplest approach is to add a check at the start of pre() and post() methods that bypasses the connector logic when max_lora_adapters is 0, falling back to identity behavior: pre() should return x.sum(axis=-2) / n (or equivalently the mean), and post() should return residual + output[..., None, :] (broadcasting output into the expansion dimension). Alternatively, ensure the constructor always creates at least 1 adapter slot (with identity initialization) even when max_lora_adapters=0, similar to how the default adapter_index=0 is used when adapter_indices is None.
Was this helpful? React with 👍 or 👎 to provide feedback.
| key = path[-2].key | ||
| normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path) | ||
| if "connector" in normalized_path: | ||
| return value.at[adapter_index].set(0.0) |
There was a problem hiding this comment.
🔴 clear_lora_adapter zeros out connector params, destroying identity mapping
clear_lora_adapter blanket-sets all connector parameters to 0.0 via value.at[adapter_index].set(0.0), but several connector parameters require specific non-zero values to maintain the identity-mapping invariant.
Root Cause and Impact
The connector's identity mapping depends on:
input_norm_weight= 1.0 (zeroing it makes_normattx/layers/connectors.py:96-97return all-zeros)b_pre=inv_sigmoid(1/n)(zeroing it makessigmoid(0)=0.5instead of1/n)b_res=10 * I(zeroing it makessinkhorn_knopp(zeros)produce a uniform 1/n matrix instead of identity)
After clearing, the connector no longer acts as a residual connection. In particular, since _get_adapter_indices at tx/layers/connectors.py:64-67 defaults to adapter index 0 when adapter_indices is None, clearing adapter slot 0 will break all non-adapter inference — the model will produce garbled outputs because input_norm_weight=0 zeroes out the normalization, and the residual mixing matrix becomes uniform instead of identity.
Compare with init_lora_adapter at tx/layers/lora.py:349-371, which correctly re-initializes each parameter to its identity value. clear_lora_adapter should do the same rather than blanket-zeroing.
Prompt for agents
In skyrl-tx/tx/layers/lora.py, the clear_adapter function at lines 401-410 needs to reset connector parameters to their identity-mapping values instead of blanket-zeroing them. Specifically, within the `if "connector" in normalized_path:` block, replicate the per-key logic from init_lora_adapter (lines 349-371):
- alpha_pre, alpha_post, alpha_res → set to 0.0 (correct, keeps phi contribution off)
- input_norm_weight → set to 1.0
- phi_pre, phi_post, phi_res → set to 0.0 (fine since alpha is 0)
- b_pre → set to inv_sigmoid(1/n), where n = value.shape[1] for b_pre
- b_post → set to 0.0 (correct identity value since 2*sigmoid(0)=1.0)
- b_res → set to 10.0 * jnp.eye(n) where n = value.shape[1]
You need to extract key_name = path[-2].key and branch on it, similar to how init_lora_adapter handles connector parameters.
Was this helpful? React with 👍 or 👎 to provide feedback.




Addresses #952
This PR is a general implementation of Hyper connections.
This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.
Default case - Trainable is false. Expansion rate is 1.
For expansion rate > 1
These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.
Todos
Future work