-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fix/checkpoint static argnames #34077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix/checkpoint static argnames #34077
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @Indrapal-70, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
1d4ea80 to
3da5d66
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully adds support for static_argnames to jax.checkpoint, aligning its interface with jax.jit. The use of inspect.signature for argument resolution is a good approach. However, the implementation of _remat_static_args and its helpers is very messy, with a lot of commented-out code and development notes that should be removed to improve maintainability. Additionally, the handling of static_argnames within **kwargs is incomplete and should be either implemented or documented as a limitation.
| def _remat_static_args(fun, static_argnums, static_argnames, args, kwargs): | ||
| if not static_argnums and not static_argnames: | ||
| return fun, args, kwargs | ||
|
|
||
| static_argnums = static_argnums or () | ||
| static_argnames = static_argnames or () | ||
|
|
||
| sig = inspect.signature(fun) | ||
| bound = sig.bind(*args, **kwargs) | ||
| bound.apply_defaults() | ||
| arg_names = list(bound.arguments.keys()) | ||
|
|
||
| # Normalize static_argnums to be a set of indices | ||
| if isinstance(static_argnums, int): | ||
| static_argnums = (static_argnums,) | ||
| elif not (type(static_argnums) is tuple and | ||
| all(type(d) is int for d in static_argnums)): | ||
| raise TypeError("the `static_argnums` argument to `jax.checkpoint` / " | ||
| "`jax.remat` must be an int, tuple of ints or, bool, but " | ||
| f"got value {static_argnums}") | ||
|
|
||
| if not all(-len(args) <= d < len(args) for d in static_argnums): | ||
| raise ValueError("the `static_argnums` argument to `jax.checkpoint` / " | ||
| "`jax.remat` can only take integer values greater than or " | ||
| "equal to `-len(args)` and less than `len(args)`, but got " | ||
| f"{static_argnums}, while `len(args)` = {len(args)}") | ||
|
|
||
| if not static_argnums: | ||
| return fun, args | ||
| nargs = len(args) | ||
| static_argnums_ = frozenset(d % len(args) for d in static_argnums) | ||
| dyn_args, static_args = [], [] | ||
| for i, x in enumerate(args): | ||
| if i in static_argnums_: static_args.append(WrapHashably(x)) | ||
| else: dyn_args.append(x) | ||
| new_fun = _dyn_args_fun(fun, static_argnums_, tuple(static_args), nargs) | ||
| return new_fun, dyn_args | ||
| static_argnums_set = set() | ||
| for i in static_argnums: | ||
| if not -len(arg_names) <= i < len(arg_names): | ||
| raise ValueError(f"static_argnums index {i} is out of bounds for arguments {arg_names}") | ||
| static_argnums_set.add(i % len(arg_names)) | ||
|
|
||
| # Normalize static_argnames to be a set of indices | ||
| if isinstance(static_argnames, str): | ||
| static_argnames = (static_argnames,) | ||
| for name in static_argnames: | ||
| try: | ||
| idx = arg_names.index(name) | ||
| static_argnums_set.add(idx) | ||
| except ValueError: | ||
| # For **kwargs which are not in signature but passed, we can't easily handle them as static by name | ||
| # if they are part of **kwargs. JIT handles this by keeping them in separate dict? | ||
| # For now, let's assume static_argnames must refer to named arguments in signature OR | ||
| # if the function has **kwargs, we might need to handle it. | ||
| # Current api_util.resolve_argnums logic is more complex. | ||
| # Simplified approach: if name not in arg_names (explicit args), maybe its in kwargs? | ||
| if name in bound.kwargs.values(): # Wait, bound.arguments contains everything? | ||
| # bound.arguments contains mapping from arg name to value. | ||
| # if 'kwargs' is the name of **kwargs param, it holds a dict. | ||
| pass | ||
| raise ValueError(f"static_argname '{name}' not found in function signature or arguments.") | ||
|
|
||
| static_args = [] | ||
| dyn_args = [] | ||
| dyn_kwargs = {} | ||
|
|
||
| # Iterate over bound arguments and separate static/dynamic | ||
| # We basically reconstruct args/kwargs for the new function | ||
| # But the new function will receive dynamic args/kwargs. | ||
| # The issue is that we need to match the signature of the wrapper dynamic function | ||
| # to what tree_flatten expects? | ||
| # The wrapper effectively becomes: | ||
| # def new_fun(*dyn_args_list, **dyn_kwargs_dict): ... | ||
|
|
||
| # Let's collect static values to close over. | ||
| # And create a mapping for dynamic values. | ||
|
|
||
| static_refs = {} # index -> value | ||
| dyn_vals = [] | ||
|
|
||
| for i, (name, val) in enumerate(bound.arguments.items()): | ||
| # Handle potentially var_positional (*args) or var_keyword (**kwargs) | ||
| param = sig.parameters.get(name) | ||
| if param and param.kind == inspect.Parameter.VAR_POSITIONAL: | ||
| # val is a tuple. | ||
| # We probably don't support partial static-ness inside *args yet with this simple logic, | ||
| # unless we explode it? jax.jit usually treats *args as a unit or explodes it? | ||
| # Let's assume standard args for now. | ||
| # If the user passed *args, bound.arguments has 'args': (val1, val2...) | ||
| # This complicates numerical indexing. | ||
| pass | ||
|
|
||
| if i in static_argnums_set: | ||
| static_refs[i] = WrapHashably(val) | ||
| else: | ||
| dyn_vals.append(val) | ||
|
|
||
| # Check: does checkpoint/trace_to_jaxpr support *args/**kwargs in the partial? | ||
| # trace_to_jaxpr calls flatten_fun(lu.wrap_init(fun, ...)) | ||
| # flatten_fun uses tree_flatten((args, kwargs)). | ||
| # So we just need to return fun_, dyn_args, dyn_kwargs such that | ||
| # fun_(*dyn_args, **dyn_kwargs) calls fun(*args, **kwargs). | ||
|
|
||
| # BUT: trace_to_jaxpr expects the function it traces to look like it takes the pytree leaves. | ||
| # The `_dyn_args_fun` logic in unmodified code constructed a function that takes *dyn_args only. | ||
| # And `fun_remat` flattened `(args, kwargs)`. | ||
|
|
||
| # If we allow arbitrary args/kwargs structure, `tree_flatten((dyn_args, dyn_kwargs))` | ||
| # will produce leaves. | ||
|
|
||
| # A robust way is: identifying which args (by index in the flattened list or by name) are static. This is hard. | ||
|
|
||
| # Let's look at `_remat_static_argnums` logic again. | ||
| # It takes `args` (positional). It filters `static_argnums` out. | ||
| # It creates `new_fun` that takes `*dyn_args`. | ||
| # `fun_remat` then does `tree_flatten((dyn_args, kwargs))`? No. | ||
| # `fun_remat` does `fun_, args = _remat_static_argnums(fun, static_argnums, args)` (so args is now only dynamic positionals) | ||
| # Then `args_flat, in_tree = tree_flatten((args, kwargs))`. | ||
| # So all original kwargs are treated as dynamic. | ||
|
|
||
| # To support static_argnames (which might be in kwargs) or static_argnums referring to kwargs (if user passed them as kwargs): | ||
| # We need to process both args and kwargs, extract statics, and leave the rest as dynamic. | ||
|
|
||
| # Let's stick to the behavior: `fun_` will take `(*dyn_pos_args, **dyn_kwargs)`. | ||
| # `fun_remat` will call `tree_flatten((dyn_pos_args, dyn_kwargs))`. | ||
|
|
||
| # Re-implementing logic with inspect: | ||
|
|
||
| static_val_map = {} # name -> value (Hashable) | ||
| dyn_args_list = [] | ||
| dyn_kwargs_dict = {} | ||
|
|
||
| # We need to map static_argnums to names if possible or just use indices in the bound arguments? | ||
| # `bound.arguments` is an OrderedDict. | ||
|
|
||
| arg_names = list(bound.arguments.keys()) | ||
|
|
||
| # Determine which names are static | ||
| static_names = set(static_argnames) | ||
| for i in static_argnums: | ||
| if 0 <= i < len(arg_names): | ||
| static_names.add(arg_names[i]) | ||
| elif -len(arg_names) <= i < 0: | ||
| static_names.add(arg_names[len(arg_names) + i]) | ||
|
|
||
| for name, val in bound.arguments.items(): | ||
| # Note: bound.arguments collapses *args and **kwargs into single entries if they exist. | ||
| # E.g. def f(a, *args): ... call f(1, 2, 3) -> bound.arguments={'a':1, 'args':(2,3)} | ||
| # If user said static_argnums=1, they probably meant the first element of *args? | ||
| # jax.jit doesn't support static_argnums inside *args easily unless you use explicit naming? | ||
| # Actually jax.jit documentation says "arguments that are not array-like ... must be marked as static". | ||
|
|
||
| if name in static_names: | ||
| static_val_map[name] = WrapHashably(val) | ||
| else: | ||
| # It's dynamic. | ||
| # We need to know if it was passed as pos or kwarg to reconstruct correctly? | ||
| # bound.arguments doesn't preserve passed-as-kwarg distinction fully if apply_defaults is used? | ||
| # `bound.arguments` reflects the values of the parameters. | ||
| # If we use `fun(**bound.arguments)`, it should work for most cases. | ||
| if name == 'kwargs' and sig.parameters[name].kind == inspect.Parameter.VAR_KEYWORD: | ||
| # This is the **kwargs dict. | ||
| # If static_argnames refer to keys inside here? | ||
| # For now, let's assume we don't peer inside **kwargs for staticness unless needed. | ||
| if isinstance(val, dict): | ||
| # We can allow capturing partial kwargs if we want, but simpler is: | ||
| # if 'kwargs' is not static, then all of it is dynamic. | ||
| dyn_kwargs_dict.update(val) | ||
| elif name == 'args' and sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL: | ||
| # This is *args tuple. | ||
| dyn_args_list.extend(val) | ||
| else: | ||
| # Regular argument. | ||
| # We can just put it in dyn_kwargs_dict for simplicity? | ||
| # Or keep it positionally if it was positional? | ||
| # Since bound.arguments is ordered, we can try to respect that? | ||
| # But we can just pass everything as kwargs to the wrapper if we are lazy, | ||
| # UNLESS there are positional-only args. | ||
|
|
||
| param = sig.parameters[name] | ||
| if param.kind == inspect.Parameter.POSITIONAL_ONLY: | ||
| dyn_args_list.append(val) | ||
| else: | ||
| dyn_kwargs_dict[name] = val | ||
|
|
||
| # Wait, mixing dyn_args_list and dyn_kwargs_dict might be tricky if preserving order matters for some reason | ||
| # but binding should handle it. | ||
|
|
||
| # Let's refine the "reconstruct" part. | ||
| # The `new_fun` will take mixed args. | ||
| # Actually, `trace_to_jaxpr` flattens `(args, kwargs)`. | ||
| # So we just need `new_fun` to accept whatever `tree_unflatten` produces. | ||
| # The `in_tree` will be created from `(dyn_args_list, dyn_kwargs_dict)`. | ||
|
|
||
| # So `new_fun` will be called with `(*dyn_args_list, **dyn_kwargs_dict)`. | ||
| # Inside `new_fun`, we need to merge `static_val_map` and these dynamic args to call `fun`. | ||
|
|
||
| return _dyn_args_fun_inspect(fun, sig, tuple(sorted(static_val_map.items()))), dyn_args_list, dyn_kwargs_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @weakref_lru_cache | ||
| def _dyn_args_fun_inspect(fun, sig, static_val_map): # static_val_map is frozendict/tuple | ||
| def new_fun(*args, **kwargs): | ||
| # args are the dynamic positional arguments (corresponding to POSITIONAL_ONLY or VAR_POSITIONAL) | ||
| # kwargs are the dynamic keyword arguments | ||
|
|
||
| # Merging is tricky without exact mapping meta-data. | ||
| # But wait, we simplified: we put everything that CAN be kwarg into kwargs in `_remat_static_args`. | ||
| # So `args` only contains POSITIONAL_ONLY and exploded VAR_POSITIONAL? | ||
|
|
||
| # If we have `def f(a, /, b)`: | ||
| # If `a` is dynamic, it's in `args[0]`. | ||
| # If `b` is dynamic, it's in `kwargs['b']`. | ||
|
|
||
| # If `a` is static, `args` is empty (if no other pos only). | ||
|
|
||
| # We need to construct the call. | ||
| # We can build a dict of arguments and then ordered args list. | ||
|
|
||
| # final_args = [] | ||
| # final_kwargs = {} | ||
|
|
||
| # Iterate sig parameters. | ||
| # If name in static_val_map: use that. | ||
| # Else if kind is POSITIONAL_ONLY: pop from args? | ||
| # Else: pop from kwargs? | ||
|
|
||
| # This requires tracking state of `args` consumption. | ||
|
|
||
| arg_iter = iter(args) | ||
| ba_args = [] | ||
| ba_kwargs = {} | ||
|
|
||
| for name, param in sig.parameters.items(): | ||
| if name in static_val_map: | ||
| val = static_val_map[name].val | ||
| if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): | ||
| ba_args.append(val) # Wait, bind/apply_defaults puts everything in structure. | ||
| # But we are calling `fun`. `fun` expects args/kwargs matching signature. | ||
|
|
||
| # Actually, simpler: just prepare arguments for `fun(*ba_args, **ba_kwargs)`. | ||
|
|
||
| if param.kind == inspect.Parameter.POSITIONAL_ONLY: | ||
| ba_args.append(val) | ||
| elif param.kind == inspect.Parameter.VAR_POSITIONAL: | ||
| ba_args.extend(val) # static *args? | ||
| elif param.kind == inspect.Parameter.VAR_KEYWORD: | ||
| ba_kwargs.update(val) # static **kwargs? | ||
| else: | ||
| # POSITIONAL_OR_KEYWORD or KEYWORD_ONLY | ||
| # We can pass as kwarg usually, unless it's already satisfied? | ||
| # If we pass as kwarg, it works. | ||
| ba_kwargs[name] = val | ||
|
|
||
| else: | ||
| # Dynamic. Look in args/kwargs. | ||
| if param.kind == inspect.Parameter.POSITIONAL_ONLY: | ||
| ba_args.append(next(arg_iter)) | ||
| elif param.kind == inspect.Parameter.VAR_POSITIONAL: | ||
| # Consumes rest of args? | ||
| ba_args.extend(list(arg_iter)) | ||
| elif param.kind == inspect.Parameter.VAR_KEYWORD: | ||
| ba_kwargs.update(kwargs) # The rest of kwargs | ||
| else: | ||
| # POSITIONAL_OR_KEYWORD or KEYWORD_ONLY | ||
| if name in kwargs: | ||
| ba_kwargs[name] = kwargs[name] | ||
| else: | ||
| # Should have been provided? | ||
| pass | ||
|
|
||
| return fun(*ba_args, **ba_kwargs) | ||
| return new_fun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function definition appears to be an incomplete and buggy draft that is immediately overwritten on line 749. It should be removed. For instance, it incorrectly handles POSITIONAL_OR_KEYWORD static arguments, which could lead to them being passed twice. The actual implementation is in _dyn_args_fun_inspect_uncached.
| for name in static_argnames: | ||
| try: | ||
| idx = arg_names.index(name) | ||
| static_argnums_set.add(idx) | ||
| except ValueError: | ||
| # For **kwargs which are not in signature but passed, we can't easily handle them as static by name | ||
| # if they are part of **kwargs. JIT handles this by keeping them in separate dict? | ||
| # For now, let's assume static_argnames must refer to named arguments in signature OR | ||
| # if the function has **kwargs, we might need to handle it. | ||
| # Current api_util.resolve_argnums logic is more complex. | ||
| # Simplified approach: if name not in arg_names (explicit args), maybe its in kwargs? | ||
| if name in bound.kwargs.values(): # Wait, bound.arguments contains everything? | ||
| # bound.arguments contains mapping from arg name to value. | ||
| # if 'kwargs' is the name of **kwargs param, it holds a dict. | ||
| pass | ||
| raise ValueError(f"static_argname '{name}' not found in function signature or arguments.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The handling of static_argnames that are part of **kwargs seems incomplete. The except ValueError block is entered when an argument name is not found in the function's explicit signature, which is the case for arguments passed via **kwargs.
The current logic:
except ValueError:
# ... (comments)
if name in bound.kwargs.values(): # This check is incorrect and has no effect
pass
raise ValueError(...)This will always raise a ValueError, meaning static_argnames in **kwargs are not supported. The check if name in bound.kwargs.values() is also incorrect as bound.kwargs is not an attribute of inspect.BoundArguments, and even if it were, it would likely contain values, not keys.
This seems to be a known limitation, as the test test_static_argnames_in_var_kwargs is empty.
To align with jax.jit's behavior, this should be implemented. If not, this limitation should be clearly documented in the checkpoint docstring.
This update aligns jax.checkpoint with jax.jit by allowing arguments to be marked static by name.
Key updates:
static_argnames Support: Enables @jax.checkpoint(static_argnames=(...)) for clearer intent.
Signature Resolution: Uses inspect.signature to correctly identify static parameters whether passed via position or keyword.
Decorator Factory: Refactored checkpoint to support optional fun arguments for cleaner decorator usage.
Testing: Added tests/checkpoint_args_test.py to verify mixed-argument handling and prevent regressions.