Skip to content

[tx] General implementation of trainable Hyper Connections#1008

Open
tanmaysachan wants to merge 16 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/mhc
Open

[tx] General implementation of trainable Hyper Connections#1008
tanmaysachan wants to merge 16 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/mhc

Conversation

@tanmaysachan
Copy link
Contributor

@tanmaysachan tanmaysachan commented Feb 2, 2026

Addresses #952

This PR is a general implementation of Hyper connections.

This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.

Default case - Trainable is false. Expansion rate is 1.

  1. H_res is a single value matrix [1]
  2. H_pre and H_post are vectors of [1, 1, 1, ...] that result in no-op matmuls

For expansion rate > 1

  1. H_res is initialized as identity of size nxn (n is the expansion rate)
  2. H_pre is [1/n, 1/n, ...]
  3. H_post is [1, 1, 1, ...]

These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.

Todos

  • simplify rms integration - added elementwise_affine as a flag
  • Benchmark/ensure no regression for expansion_rate = 1 - minimal difference in step time when expansion rate is 1 and untrainable.

Future work

  • Fine tune on custom data with mHC + LoRA to see perf gains

Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.

My review found a couple of issues:

  • An unused trainable parameter in the Connector class which should be removed for clarity.
  • A bug in DeepseekV3Model when handling intermediate hidden states for expansion_rate > 1, where squeeze() is used incorrectly.

Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.

A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.

Suggested change
all_hidden_states.append(hidden_states.squeeze())
all_hidden_states.append(hidden_states.mean(axis=-2))

hidden_dim: int,
expansion_rate: int,
*,
trainable: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trainable parameter is defined but it is not used anywhere in the Connector class. This could be misleading for developers using this module. Consider removing it from the method signature, and also the assignment self.trainable = trainable on line 27, to improve code clarity.

@pcmoritz pcmoritz added the tx label Feb 2, 2026
self.eps = eps
self.weight = Param(
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary, testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 5, 2026

This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :)

@tanmaysachan
Copy link
Contributor Author

Just waiting for the weekend to give it a spin 😅

I'll give Qwen0.6B a shot on an A/H100

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 5, 2026

Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :)

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 3 potential issues.

View 6 additional findings in Devin Review.

Open in Devin Review

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 squeeze() without axis removes batch dimension when batch_size=1

In both Qwen3Model and DeepseekV3Model, hidden_states.squeeze() is called without specifying an axis. This is intended to remove the expansion dimension (size n) added by the Connector, but jnp.squeeze() removes all size-1 dimensions. When batch_size=1 (common during inference/generation), this also removes the batch dimension, corrupting the hidden state shapes stored in all_hidden_states.

Root Cause and Impact

For expansion_rate=1 (the default), hidden_states before layers 1+ has shape (B, S, 1, H). Calling .squeeze() on this:

  • When B > 1: produces (B, S, H) — correct.
  • When B = 1: produces (S, H)batch dimension lost.

Similarly for expansion_rate > 1, hidden_states before layer 0 is (B, S, H) and .squeeze() with B=1 produces (S, H), while before layer 1+ it's (B, S, n, H) and .squeeze() with B=1 produces (S, n, H), resulting in inconsistent shapes across layers.

Impact: Any downstream consumer of output_hidden_states (e.g., probing, analysis, reward models) would receive incorrectly shaped tensors when batch_size=1, likely causing crashes or silent data corruption.

Prompt for agents
In skyrl-tx/tx/models/qwen3.py line 372 and skyrl-tx/tx/models/deepseekv3.py line 525, replace hidden_states.squeeze() with a shape-aware reduction. For layers after layer 0 (where the expansion dimension exists), use hidden_states.sum(axis=-2) or hidden_states.squeeze(axis=-2). For layer 0 (where hidden_states is still 3D from embed_tokens), no squeezing is needed. Consider checking hidden_states.ndim before deciding. A simple approach: use hidden_states.squeeze(axis=-2) if hidden_states.ndim == 4 else hidden_states, or simply always record without squeezing and let consumers handle the shape.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 squeeze() without axis removes batch dimension when batch_size=1 (DeepseekV3)

Same issue as in Qwen3Model but in DeepseekV3Model. hidden_states.squeeze() is called without specifying an axis, which removes all size-1 dimensions including the batch dimension when batch_size=1.

Root Cause and Impact

At skyrl-tx/tx/models/deepseekv3.py:525, hidden_states.squeeze() is called identically to the Qwen3 case. When batch_size=1, the batch dimension is removed, producing tensors of shape (S, H) instead of (1, S, H). This corrupts the all_hidden_states list returned via ModelOutput.hidden_states.

Prompt for agents
In skyrl-tx/tx/models/deepseekv3.py line 525, replace hidden_states.squeeze() with a shape-aware reduction that only removes the expansion dimension (axis=-2) when it exists (ndim==4), and preserves the batch dimension. For example: all_hidden_states.append(hidden_states.squeeze(axis=-2) if hidden_states.ndim == 4 else hidden_states)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 11, 2026

Did some analysis on the step times for each on Qwen 0.6B (on a 5060Ti)

Expansion rate as 1 does cause a hit to the average step time (about 0.3s slower, baseline has a step time of 2.1s vs 2.4s). An easy fix would be to just short circuit the entire thing for expansion rate = 1.

For expansion rate = 4, the step time was around 3.17s, so about 46% slower.

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 11, 2026

qwen_expansion_4_loss

Loss plot for Qwen0.6B with an expansion rate = 4 max_lora_adapters=2, max_lora_rank=1.

(some more analysis todo)

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 14 additional findings in Devin Review.

Open in Devin Review

Comment on lines 531 to 536
"""Compute full gradients, apply optimizer update, and reset accumulated grads."""
optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index))
return accumulated_grads.reset_adapter(adapter_index)
if global_optimizer is not None and self.has_global_trainables:
global_optimizer.update(global_params, global_accumulated_grads.get_mean())
global_accumulated_grads = global_accumulated_grads.reset()
return accumulated_grads.reset_adapter(adapter_index), global_accumulated_grads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Global optimizer updated with zero gradients on second adapter's optim_step

When multiple LoRA adapters are active, the shared global optimizer receives spurious zero-gradient updates, corrupting its Adam state.

Root Cause

In compute_grads_and_update (jax.py:531-536), the global optimizer is updated and the global accumulated gradients are reset unconditionally on every call:

if global_optimizer is not None and self.has_global_trainables:
    global_optimizer.update(global_params, global_accumulated_grads.get_mean())
    global_accumulated_grads = global_accumulated_grads.reset()

Since optim_step is called once per adapter (jax.py:773-809), with two adapters the sequence is:

  1. optim_step(adapter_1) → updates global optimizer with real mean gradients, resets global_accumulated_grads to zero
  2. optim_step(adapter_2) → updates global optimizer again with get_mean() of the now-zeroed gradients (all zeros), resets again

The second zero-gradient update corrupts Adam's internal state:

  • First moments decay: m_t = β₁ · m_{t-1} + (1-β₁) · 0 — momentum decays toward zero
  • Second moments decay: v_t = β₂ · v_{t-1} + (1-β₂) · 0 — variance estimate shrinks
  • Step counter increments, affecting bias correction

Impact: Global trainable parameters (connectors) receive incorrect optimizer updates that degrade training quality, with severity proportional to the number of adapters.

Prompt for agents
The global optimizer should only be updated once per training iteration, not once per adapter. Currently in compute_grads_and_update (jax.py:531-536), the global optimizer is updated and global accumulated gradients are reset on every call, but optim_step is called once per adapter. Fix this by either: (1) tracking whether global grads have already been applied in this iteration and skipping if already done (e.g., check global_accumulated_grads.count > 0 before updating), or (2) decoupling the global optimizer step from the per-adapter optim_step so it runs exactly once per training iteration. Option (1) is simpler: guard the global optimizer update with a check like `if global_accumulated_grads.count > 0` before calling global_optimizer.update.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 17 additional findings in Devin Review.

Open in Devin Review

Comment on lines +64 to +67
def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array:
if adapter_indices is None:
return jnp.zeros((batch_size,), dtype=jnp.int32)
return adapter_indices.astype(jnp.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 LoRAConnector broken when max_lora_adapters=0 — indexing into 0-sized parameter arrays returns wrong values

When a model is created with max_lora_adapters=0 (e.g., tx/run/train.py:80), the LoRAConnector creates all parameter arrays with a first dimension of 0. When pre() or post() is called, _get_adapter_indices returns jnp.zeros((B,), dtype=jnp.int32), and _get_params indexes into these 0-sized arrays, producing zero-filled results instead of the identity-preserving values.

Detailed Explanation

Unlike LoRAMixin.apply_lora which short-circuits when max_lora_adapters == 0 (lora.py:85), LoRAConnector has no such guard. When max_lora_adapters=0:

  • self.b_pre has shape (0, n), self.b_res has shape (0, n, n), etc.
  • _get_adapter_indices(B, None) returns jnp.zeros((B,)) at connectors.py:66
  • _get_params indexes into 0-sized arrays at connectors.py:71-80 — JAX clips out-of-bounds indices and returns zeros
  • In pre(): b_pre=0H_pre = sigmoid(0) = 0.5 instead of 1/n
  • In post(): b_res=0M = sinkhorn(zeros) produces a uniform 1/n matrix instead of identity

For the default expansion_rate=1, the impact on pre is masked by RMSNorm (the 0.5 scale cancels during normalization), and post still produces the correct residual + output. So the default case is approximately correct. However, for expansion_rate > 1 with max_lora_adapters=0, the connector would produce completely wrong outputs (uniform mixing instead of identity passthrough).

This path is exercised in production via tx/run/train.py:80 which uses max_lora_adapters=0.

Prompt for agents
Add a guard in LoRAConnector to handle the max_lora_adapters=0 case. The simplest approach is to add a check at the start of pre() and post() methods that bypasses the connector logic when max_lora_adapters is 0, falling back to identity behavior: pre() should return x.sum(axis=-2) / n (or equivalently the mean), and post() should return residual + output[..., None, :] (broadcasting output into the expansion dimension). Alternatively, ensure the constructor always creates at least 1 adapter slot (with identity initialization) even when max_lora_adapters=0, similar to how the default adapter_index=0 is used when adapter_indices is None.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 20 additional findings in Devin Review.

Open in Devin Review

Comment on lines 402 to +405
key = path[-2].key
normalized_path = tuple(p.key if hasattr(p, "key") else p.name for p in path)
if "connector" in normalized_path:
return value.at[adapter_index].set(0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 clear_lora_adapter zeros out connector params, destroying identity mapping

clear_lora_adapter blanket-sets all connector parameters to 0.0 via value.at[adapter_index].set(0.0), but several connector parameters require specific non-zero values to maintain the identity-mapping invariant.

Root Cause and Impact

The connector's identity mapping depends on:

  • input_norm_weight = 1.0 (zeroing it makes _norm at tx/layers/connectors.py:96-97 return all-zeros)
  • b_pre = inv_sigmoid(1/n) (zeroing it makes sigmoid(0)=0.5 instead of 1/n)
  • b_res = 10 * I (zeroing it makes sinkhorn_knopp(zeros) produce a uniform 1/n matrix instead of identity)

After clearing, the connector no longer acts as a residual connection. In particular, since _get_adapter_indices at tx/layers/connectors.py:64-67 defaults to adapter index 0 when adapter_indices is None, clearing adapter slot 0 will break all non-adapter inference — the model will produce garbled outputs because input_norm_weight=0 zeroes out the normalization, and the residual mixing matrix becomes uniform instead of identity.

Compare with init_lora_adapter at tx/layers/lora.py:349-371, which correctly re-initializes each parameter to its identity value. clear_lora_adapter should do the same rather than blanket-zeroing.

Prompt for agents
In skyrl-tx/tx/layers/lora.py, the clear_adapter function at lines 401-410 needs to reset connector parameters to their identity-mapping values instead of blanket-zeroing them. Specifically, within the `if "connector" in normalized_path:` block, replicate the per-key logic from init_lora_adapter (lines 349-371):

- alpha_pre, alpha_post, alpha_res → set to 0.0 (correct, keeps phi contribution off)
- input_norm_weight → set to 1.0
- phi_pre, phi_post, phi_res → set to 0.0 (fine since alpha is 0)
- b_pre → set to inv_sigmoid(1/n), where n = value.shape[1] for b_pre
- b_post → set to 0.0 (correct identity value since 2*sigmoid(0)=1.0)
- b_res → set to 10.0 * jnp.eye(n) where n = value.shape[1]

You need to extract key_name = path[-2].key and branch on it, similar to how init_lora_adapter handles connector parameters.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 12, 2026

qwen_loss_comparison

1.7B Qwen, without expansion rate and with rate = 4 (roughly identical loss plots)
mHC times with training are about 93% higher than regular per step.

@tanmaysachan
Copy link
Contributor Author

The loss differences are in a similar scale as to what is observed in the mHC paper.
image

Ground truth mHC analysis -
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants