diff --git a/.gitignore b/.gitignore index ce3feeb8ff98..50f3a77710c0 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ jax.iml # virtualenv/venv directories /venv/ +/venv_new/ /bin/ /include/ /lib/ diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 62002cd8eb51..5d36b6afbe5d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -14,11 +14,12 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial import logging from typing import Any import types +import inspect import numpy as np @@ -201,9 +202,10 @@ def policy(prim, *args, **params): ### Main API @partial(api_boundary, repro_api_name="jax.checkpoint") -def checkpoint(fun: Callable, *, prevent_cse: bool = True, +def checkpoint(fun: Callable | None = None, *, prevent_cse: bool = True, policy: Callable[..., bool] | None = None, static_argnums: int | tuple[int, ...] = (), + static_argnames: str | Iterable[str] = (), concrete: bool | DeprecatedArg = DeprecatedArg()) -> Callable: """Make ``fun`` recompute internal linearization points when differentiated. @@ -252,6 +254,10 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True, caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads. See the example below. + static_argnames: Optional, string or collection of strings, a keyword-only + argument indicating named arguments to treat as static. These are + arguments that will be specialized on and traced with their concrete + values. policy: Optional, callable keyword-only argument. It should be one of the attributes of ``jax.checkpoint_policies``. The callable takes as input a type-level specification of a first-order primitive application and @@ -347,6 +353,11 @@ def foo(x, y): ``jax.ensure_compile_time_eval``), it may be easier to compute some values outside the :func:`jax.checkpoint`-decorated function and then close over them. """ + if fun is None: + return partial(checkpoint, prevent_cse=prevent_cse, policy=policy, + static_argnums=static_argnums, + static_argnames=static_argnames, concrete=concrete) + if not isinstance(concrete, DeprecatedArg): concrete_msg = ( "The `concrete` option to `jax.checkpoint` has been deprecated." @@ -359,6 +370,10 @@ def foo(x, y): if isinstance(static_argnums, int): static_argnums = static_argnums, + if isinstance(static_argnames, str): + static_argnames = (static_argnames,) + else: + static_argnames = tuple(static_argnames) if isinstance(prevent_cse, list): prevent_cse = tuple(prevent_cse) if not isinstance(prevent_cse, (tuple, bool)): @@ -370,8 +385,10 @@ def foo(x, y): def fun_remat(*args, **kwargs): debug = api_util.debug_info( "checkpoint / remat", fun, - args, kwargs, static_argnums=static_argnums) - fun_, args = _remat_static_argnums(fun, static_argnums, args) + args, kwargs, static_argnums=static_argnums, + static_argnames=static_argnames) + fun_, args, kwargs = _remat_static_args( + fun, static_argnums, static_argnames, args, kwargs) args_flat, in_tree = tree_flatten((args, kwargs)) in_avals = [core.shaped_abstractify(x) for x in args_flat] jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals), debug) @@ -390,10 +407,12 @@ def fun_remat(*args, **kwargs): def remat(fun: Callable, *, prevent_cse: bool = True, policy: Callable[..., bool] | None = None, static_argnums: int | tuple[int, ...] = (), + static_argnames: str | Iterable[str] = (), concrete: bool | DeprecatedArg = DeprecatedArg()) -> Callable: """Alias of :func:`jax.checkpoint`.""" return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, - static_argnums=static_argnums, concrete=concrete) + static_argnums=static_argnums, + static_argnames=static_argnames, concrete=concrete) # This function is similar to api_util.argnums_partial, except the error # messages are specific to jax.remat (and thus more actionable), the @@ -425,6 +444,109 @@ def _remat_static_argnums(fun, static_argnums, args): new_fun = _dyn_args_fun(fun, static_argnums_, tuple(static_args), nargs) return new_fun, dyn_args +def _remat_static_args(fun, static_argnums, static_argnames, args, kwargs): + if not static_argnums and not static_argnames: + return fun, args, kwargs + + static_argnums = static_argnums or () + static_argnames = static_argnames or () + + if isinstance(static_argnums, int): + static_argnums = (static_argnums,) + if isinstance(static_argnames, str): + static_argnames = (static_argnames,) + + sig = inspect.signature(fun) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + static_argnames_set = set(static_argnames) + arg_names = list(bound.arguments.keys()) + for i in static_argnums: + try: + if i < 0: i += len(bound.arguments) + if 0 <= i < len(arg_names): + static_argnames_set.add(arg_names[i]) + else: + raise ValueError(f"static_argnums index {i} is out of bounds for {len(arg_names)} arguments") + except IndexError: + raise ValueError(f"static_argnums index {i} is out of bounds") + + static_vals = {} + dyn_args = [] + dyn_kwargs = {} + + for name, val in bound.arguments.items(): + param = sig.parameters.get(name) + if name in static_argnames_set: + static_vals[name] = WrapHashably(val) + elif param and param.kind == inspect.Parameter.VAR_POSITIONAL: + dyn_args.extend(val) + elif param and param.kind == inspect.Parameter.VAR_KEYWORD: + if isinstance(val, dict): + for k, v in val.items(): + if k in static_argnames_set: + static_vals[k] = WrapHashably(v) + else: + dyn_kwargs[k] = v + else: + if param and param.kind == inspect.Parameter.POSITIONAL_ONLY: + dyn_args.append(val) + else: + dyn_kwargs[name] = val + + return _dyn_args_fun_inspect(fun, sig, tuple(sorted(static_vals.items()))), dyn_args, dyn_kwargs + + +def _dyn_args_fun_inspect(fun, sig, static_vals_items): + if any(isinstance(v.val, core.Tracer) for _, v in static_vals_items): + return _dyn_args_fun_inspect_uncached(fun, sig, static_vals_items) + return _dyn_args_fun_inspect_cached(fun, sig, static_vals_items) + + +def _dyn_args_fun_inspect_uncached(fun, sig, static_vals_items): + static_map = {k: v.val for k, v in static_vals_items} + + def new_fun(*args, **kwargs): + arg_iter = iter(args) + final_args = [] + final_kwargs = {} + + for name, param in sig.parameters.items(): + if name in static_map: + val = static_map[name] + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + final_args.append(val) + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + final_args.extend(val) + elif param.kind == inspect.Parameter.VAR_KEYWORD: + final_kwargs.update(val) + else: + final_kwargs[name] = val + elif param.kind == inspect.Parameter.VAR_KEYWORD: + final_kwargs.update(kwargs) + else: + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + final_args.append(next(arg_iter)) + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + final_args.extend(list(arg_iter)) + else: + if name in kwargs: + final_kwargs[name] = kwargs[name] + + for k, v in static_map.items(): + if k not in final_kwargs and k not in sig.parameters: + final_kwargs[k] = v + + return fun(*final_args, **final_kwargs) + return new_fun + +_dyn_args_fun_inspect_cached = weakref_lru_cache(_dyn_args_fun_inspect_uncached) + +# Remove old _remat_static_args and helpers +# (The replace block replaces the old function completely) + + class WrapHashably: val: Any hash: int @@ -484,9 +606,9 @@ def _trace_to_jaxpr(fun: Callable, msg, = e.args if 'for checkpoint' in msg: msg += "\n\n" + ( - "Consider using the `static_argnums` parameter for `jax.remat` or " - "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " - "involving `static_argnums`:\n" + "Consider using the `static_argnums` or `static_argnames` parameter " + "for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` " + "docstring and its example involving `static_argnums`:\n" "https://docs.jax.dev/en/latest/_autosummary/jax.checkpoint.html" "\n") e.args = msg,