Skip to content

Commit aefcfc4

Browse files
tomhennigancopybara-github
authored andcommitted
Provide a more helpful message for common error in hk.{cond,switch}.
PiperOrigin-RevId: 457708404
1 parent 3f31e27 commit aefcfc4

File tree

2 files changed

+114
-2
lines changed

2 files changed

+114
-2
lines changed

haiku/_src/stateful.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,79 @@ def wrapper(g):
412412
return wrapper
413413

414414

415+
RUNNING_INIT_HINT = """
416+
Hint: A common mistake is to use hk.cond(..) or `hk.switch(..)` at init time and
417+
create module parameters in one of the branches. At init time you should
418+
unconditionally create the parameters of all modules you might want to use
419+
at apply.
420+
421+
For hk.cond():
422+
423+
if hk.running_init():
424+
# At init time unconditionally create parameters in my_module.
425+
my_other_module(x)
426+
out = my_module(x)
427+
else:
428+
out = hk.cond(pred, my_module, my_other_module)
429+
430+
For hk.switch():
431+
432+
branches = [my_module, lambda x: x]
433+
if hk.running_init():
434+
# At init time unconditionally create parameters in all branches.
435+
for branch in branches:
436+
out = my_module(x)
437+
else:
438+
out = hk.switch(idx, branches, x)
439+
""".strip()
440+
441+
442+
def with_output_structure_hint(f):
443+
"""Adds a helpful hint to branch structure errors."""
444+
@functools.wraps(f)
445+
def wrapper(*args, **kwargs):
446+
try:
447+
return f(*args, **kwargs)
448+
except TypeError as e:
449+
if not base.params_frozen() and "must have same type structure" in str(e):
450+
raise TypeError(RUNNING_INIT_HINT) from e
451+
else:
452+
raise e
453+
return wrapper
454+
455+
456+
# pylint: disable=g-doc-args
415457
@functools.wraps(jax.lax.cond)
458+
@with_output_structure_hint
416459
def cond(*args, **kwargs):
417-
"""Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out."""
460+
"""Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out.
461+
462+
>>> true_fn = hk.nets.ResNet50(10)
463+
>>> false_fn = hk.Sequential([hk.Flatten(), hk.nets.MLP([300, 100, 10])])
464+
>>> x = jnp.ones([1, 224, 224, 3])
465+
>>> if hk.running_init():
466+
... # At `init` run both branches to create parameters everywhere.
467+
... true_fn(x)
468+
... out = false_fn(x)
469+
... else:
470+
... # At `apply` conditionally call one of the modules.
471+
... i = jax.random.randint(hk.next_rng_key(), [], 0, 100)
472+
... out = hk.cond(i > 50, true_fn, false_fn, x)
473+
474+
Args:
475+
pred: Boolean scalar type.
476+
true_fun: Function (A -> B), to be applied if ``pred`` is ``True``.
477+
false_fun: Function (A -> B), to be applied if ``pred`` is ``False``.
478+
operands: Operands (A) input to either branch depending on ``pred``. The
479+
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
480+
thereof.
481+
482+
Returns:
483+
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
484+
depending on the value of ``pred``. The type can be a scalar, array, or any
485+
pytree (nested Python tuple/list/dict) thereof.
486+
"""
487+
# pylint: enable=g-doc-args
418488
if not base.inside_transform():
419489
raise ValueError("hk.cond() should not be used outside of hk.transform(). "
420490
"Use jax.cond() instead.")
@@ -453,8 +523,34 @@ def cond(*args, **kwargs):
453523
return out
454524

455525

526+
@with_output_structure_hint
456527
def switch(index, branches, operand):
457-
"""Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out."""
528+
"""Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out.
529+
530+
Note that creating parameters inside a switch branch is not supported, as such
531+
at init time we recommend you unconditionally evaluate all branches of your
532+
switch and only use the switch at apply. For example:
533+
534+
>>> experts = [hk.nets.MLP([300, 100, 10]) for _ in range(5)]
535+
>>> x = jnp.ones([1, 28 * 28])
536+
>>> if hk.running_init():
537+
... # During init unconditionally create params/state for all experts.
538+
... for expert in experts:
539+
... out = expert(x)
540+
... else:
541+
... # During apply conditionally apply (and update) only one expert.
542+
... index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
543+
... out = hk.switch(index, experts, x)
544+
545+
Args:
546+
index: Integer scalar type, indicating which branch function to apply.
547+
branches: Sequence of functions (A -> B) to be applied based on index.
548+
operand: Operands (A) input to whichever branch is applied.
549+
550+
Returns:
551+
Value (B) of branch(*operands) for the branch that was selected based on
552+
index.
553+
"""
458554
if not base.inside_transform():
459555
raise ValueError(
460556
"hk.switch() should not be used outside of hk.transform(). "

haiku/_src/stateful_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,22 @@ def f(i, x):
332332
self.assertEqual(state, {"square_module": {"y": y}})
333333
self.assertEqual(out, y)
334334

335+
@test_utils.transform_and_run(run_apply=False)
336+
def test_cond_branch_structure_error(self):
337+
true_fn = lambda x: base.get_parameter("w", x.shape, x.dtype, init=jnp.ones)
338+
false_fn = lambda x: x
339+
with self.assertRaisesRegex(TypeError, "Hint: A common mistake"):
340+
stateful.cond(False, true_fn, false_fn, 0)
341+
342+
@test_utils.transform_and_run(run_apply=False)
343+
def test_switch_branch_structure_error(self):
344+
branches = [
345+
lambda x: base.get_parameter("w", x.shape, x.dtype, init=jnp.ones),
346+
lambda x: x,
347+
]
348+
with self.assertRaisesRegex(TypeError, "Hint: A common mistake"):
349+
stateful.switch(0, branches, 0)
350+
335351
@parameterized.parameters(1, 2, 4, 8)
336352
@test_utils.transform_and_run
337353
def test_switch_traces_cases_with_same_id_once(self, n):

0 commit comments

Comments
 (0)