@@ -412,9 +412,79 @@ def wrapper(g):
412412 return wrapper
413413
414414
415+ RUNNING_INIT_HINT = """
416+ Hint: A common mistake is to use hk.cond(..) or `hk.switch(..)` at init time and
417+ create module parameters in one of the branches. At init time you should
418+ unconditionally create the parameters of all modules you might want to use
419+ at apply.
420+
421+ For hk.cond():
422+
423+ if hk.running_init():
424+ # At init time unconditionally create parameters in my_module.
425+ my_other_module(x)
426+ out = my_module(x)
427+ else:
428+ out = hk.cond(pred, my_module, my_other_module)
429+
430+ For hk.switch():
431+
432+ branches = [my_module, lambda x: x]
433+ if hk.running_init():
434+ # At init time unconditionally create parameters in all branches.
435+ for branch in branches:
436+ out = my_module(x)
437+ else:
438+ out = hk.switch(idx, branches, x)
439+ """ .strip ()
440+
441+
442+ def with_output_structure_hint (f ):
443+ """Adds a helpful hint to branch structure errors."""
444+ @functools .wraps (f )
445+ def wrapper (* args , ** kwargs ):
446+ try :
447+ return f (* args , ** kwargs )
448+ except TypeError as e :
449+ if not base .params_frozen () and "must have same type structure" in str (e ):
450+ raise TypeError (RUNNING_INIT_HINT ) from e
451+ else :
452+ raise e
453+ return wrapper
454+
455+
456+ # pylint: disable=g-doc-args
415457@functools .wraps (jax .lax .cond )
458+ @with_output_structure_hint
416459def cond (* args , ** kwargs ):
417- """Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out."""
460+ """Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out.
461+
462+ >>> true_fn = hk.nets.ResNet50(10)
463+ >>> false_fn = hk.Sequential([hk.Flatten(), hk.nets.MLP([300, 100, 10])])
464+ >>> x = jnp.ones([1, 224, 224, 3])
465+ >>> if hk.running_init():
466+ ... # At `init` run both branches to create parameters everywhere.
467+ ... true_fn(x)
468+ ... out = false_fn(x)
469+ ... else:
470+ ... # At `apply` conditionally call one of the modules.
471+ ... i = jax.random.randint(hk.next_rng_key(), [], 0, 100)
472+ ... out = hk.cond(i > 50, true_fn, false_fn, x)
473+
474+ Args:
475+ pred: Boolean scalar type.
476+ true_fun: Function (A -> B), to be applied if ``pred`` is ``True``.
477+ false_fun: Function (A -> B), to be applied if ``pred`` is ``False``.
478+ operands: Operands (A) input to either branch depending on ``pred``. The
479+ type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
480+ thereof.
481+
482+ Returns:
483+ Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
484+ depending on the value of ``pred``. The type can be a scalar, array, or any
485+ pytree (nested Python tuple/list/dict) thereof.
486+ """
487+ # pylint: enable=g-doc-args
418488 if not base .inside_transform ():
419489 raise ValueError ("hk.cond() should not be used outside of hk.transform(). "
420490 "Use jax.cond() instead." )
@@ -453,8 +523,34 @@ def cond(*args, **kwargs):
453523 return out
454524
455525
526+ @with_output_structure_hint
456527def switch (index , branches , operand ):
457- """Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out."""
528+ """Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out.
529+
530+ Note that creating parameters inside a switch branch is not supported, as such
531+ at init time we recommend you unconditionally evaluate all branches of your
532+ switch and only use the switch at apply. For example:
533+
534+ >>> experts = [hk.nets.MLP([300, 100, 10]) for _ in range(5)]
535+ >>> x = jnp.ones([1, 28 * 28])
536+ >>> if hk.running_init():
537+ ... # During init unconditionally create params/state for all experts.
538+ ... for expert in experts:
539+ ... out = expert(x)
540+ ... else:
541+ ... # During apply conditionally apply (and update) only one expert.
542+ ... index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
543+ ... out = hk.switch(index, experts, x)
544+
545+ Args:
546+ index: Integer scalar type, indicating which branch function to apply.
547+ branches: Sequence of functions (A -> B) to be applied based on index.
548+ operand: Operands (A) input to whichever branch is applied.
549+
550+ Returns:
551+ Value (B) of branch(*operands) for the branch that was selected based on
552+ index.
553+ """
458554 if not base .inside_transform ():
459555 raise ValueError (
460556 "hk.switch() should not be used outside of hk.transform(). "
0 commit comments