Skip to content

Conversation

@Indrapal-70
Copy link

This update aligns jax.checkpoint with jax.jit by allowing arguments to be marked static by name.

Key updates:

static_argnames Support: Enables @jax.checkpoint(static_argnames=(...)) for clearer intent.

Signature Resolution: Uses inspect.signature to correctly identify static parameters whether passed via position or keyword.

Decorator Factory: Refactored checkpoint to support optional fun arguments for cleaner decorator usage.

Testing: Added tests/checkpoint_args_test.py to verify mixed-argument handling and prevent regressions.

@google-cla
Copy link

google-cla bot commented Dec 21, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Indrapal-70, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the jax.checkpoint and jax.remat functions by adding support for static_argnames. This change brings consistency with jax.jit's API, allowing users to explicitly mark function arguments as static using their names, which is crucial for preventing ConcretizationTypeErrors during tracing. The implementation involves a refactoring to enable a decorator factory pattern and leverages Python's inspect module for accurate argument binding and separation.

Highlights

  • Static Argument Naming: Introduced static_argnames to jax.checkpoint and jax.remat, allowing users to specify arguments to be treated as static by their name, similar to jax.jit.
  • Decorator Factory Pattern: Refactored jax.checkpoint and jax.remat to support being used as decorator factories (e.g., @jax.checkpoint(static_argnames=('foo',))), improving API flexibility.
  • Robust Argument Resolution: Implemented inspect.signature to correctly identify and separate static and dynamic parameters, whether they are passed positionally or as keywords.
  • Improved Error Messages: Updated error messages for ConcretizationTypeErrors to suggest using static_argnames as a potential solution.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@Indrapal-70 Indrapal-70 force-pushed the fix/checkpoint-static-argnames branch from 1d4ea80 to 3da5d66 Compare December 21, 2025 09:28
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for static_argnames to jax.checkpoint, aligning its interface with jax.jit. The use of inspect.signature for argument resolution is a good approach. However, the implementation of _remat_static_args and its helpers is very messy, with a lot of commented-out code and development notes that should be removed to improve maintainability. Additionally, the handling of static_argnames within **kwargs is incomplete and should be either implemented or documented as a limitation.

Comment on lines +418 to +594
def _remat_static_args(fun, static_argnums, static_argnames, args, kwargs):
if not static_argnums and not static_argnames:
return fun, args, kwargs

static_argnums = static_argnums or ()
static_argnames = static_argnames or ()

sig = inspect.signature(fun)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
arg_names = list(bound.arguments.keys())

# Normalize static_argnums to be a set of indices
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)
elif not (type(static_argnums) is tuple and
all(type(d) is int for d in static_argnums)):
raise TypeError("the `static_argnums` argument to `jax.checkpoint` / "
"`jax.remat` must be an int, tuple of ints or, bool, but "
f"got value {static_argnums}")

if not all(-len(args) <= d < len(args) for d in static_argnums):
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
"`jax.remat` can only take integer values greater than or "
"equal to `-len(args)` and less than `len(args)`, but got "
f"{static_argnums}, while `len(args)` = {len(args)}")

if not static_argnums:
return fun, args
nargs = len(args)
static_argnums_ = frozenset(d % len(args) for d in static_argnums)
dyn_args, static_args = [], []
for i, x in enumerate(args):
if i in static_argnums_: static_args.append(WrapHashably(x))
else: dyn_args.append(x)
new_fun = _dyn_args_fun(fun, static_argnums_, tuple(static_args), nargs)
return new_fun, dyn_args
static_argnums_set = set()
for i in static_argnums:
if not -len(arg_names) <= i < len(arg_names):
raise ValueError(f"static_argnums index {i} is out of bounds for arguments {arg_names}")
static_argnums_set.add(i % len(arg_names))

# Normalize static_argnames to be a set of indices
if isinstance(static_argnames, str):
static_argnames = (static_argnames,)
for name in static_argnames:
try:
idx = arg_names.index(name)
static_argnums_set.add(idx)
except ValueError:
# For **kwargs which are not in signature but passed, we can't easily handle them as static by name
# if they are part of **kwargs. JIT handles this by keeping them in separate dict?
# For now, let's assume static_argnames must refer to named arguments in signature OR
# if the function has **kwargs, we might need to handle it.
# Current api_util.resolve_argnums logic is more complex.
# Simplified approach: if name not in arg_names (explicit args), maybe its in kwargs?
if name in bound.kwargs.values(): # Wait, bound.arguments contains everything?
# bound.arguments contains mapping from arg name to value.
# if 'kwargs' is the name of **kwargs param, it holds a dict.
pass
raise ValueError(f"static_argname '{name}' not found in function signature or arguments.")

static_args = []
dyn_args = []
dyn_kwargs = {}

# Iterate over bound arguments and separate static/dynamic
# We basically reconstruct args/kwargs for the new function
# But the new function will receive dynamic args/kwargs.
# The issue is that we need to match the signature of the wrapper dynamic function
# to what tree_flatten expects?
# The wrapper effectively becomes:
# def new_fun(*dyn_args_list, **dyn_kwargs_dict): ...

# Let's collect static values to close over.
# And create a mapping for dynamic values.

static_refs = {} # index -> value
dyn_vals = []

for i, (name, val) in enumerate(bound.arguments.items()):
# Handle potentially var_positional (*args) or var_keyword (**kwargs)
param = sig.parameters.get(name)
if param and param.kind == inspect.Parameter.VAR_POSITIONAL:
# val is a tuple.
# We probably don't support partial static-ness inside *args yet with this simple logic,
# unless we explode it? jax.jit usually treats *args as a unit or explodes it?
# Let's assume standard args for now.
# If the user passed *args, bound.arguments has 'args': (val1, val2...)
# This complicates numerical indexing.
pass

if i in static_argnums_set:
static_refs[i] = WrapHashably(val)
else:
dyn_vals.append(val)

# Check: does checkpoint/trace_to_jaxpr support *args/**kwargs in the partial?
# trace_to_jaxpr calls flatten_fun(lu.wrap_init(fun, ...))
# flatten_fun uses tree_flatten((args, kwargs)).
# So we just need to return fun_, dyn_args, dyn_kwargs such that
# fun_(*dyn_args, **dyn_kwargs) calls fun(*args, **kwargs).

# BUT: trace_to_jaxpr expects the function it traces to look like it takes the pytree leaves.
# The `_dyn_args_fun` logic in unmodified code constructed a function that takes *dyn_args only.
# And `fun_remat` flattened `(args, kwargs)`.

# If we allow arbitrary args/kwargs structure, `tree_flatten((dyn_args, dyn_kwargs))`
# will produce leaves.

# A robust way is: identifying which args (by index in the flattened list or by name) are static. This is hard.

# Let's look at `_remat_static_argnums` logic again.
# It takes `args` (positional). It filters `static_argnums` out.
# It creates `new_fun` that takes `*dyn_args`.
# `fun_remat` then does `tree_flatten((dyn_args, kwargs))`? No.
# `fun_remat` does `fun_, args = _remat_static_argnums(fun, static_argnums, args)` (so args is now only dynamic positionals)
# Then `args_flat, in_tree = tree_flatten((args, kwargs))`.
# So all original kwargs are treated as dynamic.

# To support static_argnames (which might be in kwargs) or static_argnums referring to kwargs (if user passed them as kwargs):
# We need to process both args and kwargs, extract statics, and leave the rest as dynamic.

# Let's stick to the behavior: `fun_` will take `(*dyn_pos_args, **dyn_kwargs)`.
# `fun_remat` will call `tree_flatten((dyn_pos_args, dyn_kwargs))`.

# Re-implementing logic with inspect:

static_val_map = {} # name -> value (Hashable)
dyn_args_list = []
dyn_kwargs_dict = {}

# We need to map static_argnums to names if possible or just use indices in the bound arguments?
# `bound.arguments` is an OrderedDict.

arg_names = list(bound.arguments.keys())

# Determine which names are static
static_names = set(static_argnames)
for i in static_argnums:
if 0 <= i < len(arg_names):
static_names.add(arg_names[i])
elif -len(arg_names) <= i < 0:
static_names.add(arg_names[len(arg_names) + i])

for name, val in bound.arguments.items():
# Note: bound.arguments collapses *args and **kwargs into single entries if they exist.
# E.g. def f(a, *args): ... call f(1, 2, 3) -> bound.arguments={'a':1, 'args':(2,3)}
# If user said static_argnums=1, they probably meant the first element of *args?
# jax.jit doesn't support static_argnums inside *args easily unless you use explicit naming?
# Actually jax.jit documentation says "arguments that are not array-like ... must be marked as static".

if name in static_names:
static_val_map[name] = WrapHashably(val)
else:
# It's dynamic.
# We need to know if it was passed as pos or kwarg to reconstruct correctly?
# bound.arguments doesn't preserve passed-as-kwarg distinction fully if apply_defaults is used?
# `bound.arguments` reflects the values of the parameters.
# If we use `fun(**bound.arguments)`, it should work for most cases.
if name == 'kwargs' and sig.parameters[name].kind == inspect.Parameter.VAR_KEYWORD:
# This is the **kwargs dict.
# If static_argnames refer to keys inside here?
# For now, let's assume we don't peer inside **kwargs for staticness unless needed.
if isinstance(val, dict):
# We can allow capturing partial kwargs if we want, but simpler is:
# if 'kwargs' is not static, then all of it is dynamic.
dyn_kwargs_dict.update(val)
elif name == 'args' and sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL:
# This is *args tuple.
dyn_args_list.extend(val)
else:
# Regular argument.
# We can just put it in dyn_kwargs_dict for simplicity?
# Or keep it positionally if it was positional?
# Since bound.arguments is ordered, we can try to respect that?
# But we can just pass everything as kwargs to the wrapper if we are lazy,
# UNLESS there are positional-only args.

param = sig.parameters[name]
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
dyn_args_list.append(val)
else:
dyn_kwargs_dict[name] = val

# Wait, mixing dyn_args_list and dyn_kwargs_dict might be tricky if preserving order matters for some reason
# but binding should handle it.

# Let's refine the "reconstruct" part.
# The `new_fun` will take mixed args.
# Actually, `trace_to_jaxpr` flattens `(args, kwargs)`.
# So we just need `new_fun` to accept whatever `tree_unflatten` produces.
# The `in_tree` will be created from `(dyn_args_list, dyn_kwargs_dict)`.

# So `new_fun` will be called with `(*dyn_args_list, **dyn_kwargs_dict)`.
# Inside `new_fun`, we need to merge `static_val_map` and these dynamic args to call `fun`.

return _dyn_args_fun_inspect(fun, sig, tuple(sorted(static_val_map.items()))), dyn_args_list, dyn_kwargs_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function is very difficult to read due to a large amount of commented-out code and development notes. Please remove these to improve readability and maintainability. For example, lines 459-522 contain large blocks of commented-out logic and implementation sketches.

Comment on lines +597 to +669
@weakref_lru_cache
def _dyn_args_fun_inspect(fun, sig, static_val_map): # static_val_map is frozendict/tuple
def new_fun(*args, **kwargs):
# args are the dynamic positional arguments (corresponding to POSITIONAL_ONLY or VAR_POSITIONAL)
# kwargs are the dynamic keyword arguments

# Merging is tricky without exact mapping meta-data.
# But wait, we simplified: we put everything that CAN be kwarg into kwargs in `_remat_static_args`.
# So `args` only contains POSITIONAL_ONLY and exploded VAR_POSITIONAL?

# If we have `def f(a, /, b)`:
# If `a` is dynamic, it's in `args[0]`.
# If `b` is dynamic, it's in `kwargs['b']`.

# If `a` is static, `args` is empty (if no other pos only).

# We need to construct the call.
# We can build a dict of arguments and then ordered args list.

# final_args = []
# final_kwargs = {}

# Iterate sig parameters.
# If name in static_val_map: use that.
# Else if kind is POSITIONAL_ONLY: pop from args?
# Else: pop from kwargs?

# This requires tracking state of `args` consumption.

arg_iter = iter(args)
ba_args = []
ba_kwargs = {}

for name, param in sig.parameters.items():
if name in static_val_map:
val = static_val_map[name].val
if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
ba_args.append(val) # Wait, bind/apply_defaults puts everything in structure.
# But we are calling `fun`. `fun` expects args/kwargs matching signature.

# Actually, simpler: just prepare arguments for `fun(*ba_args, **ba_kwargs)`.

if param.kind == inspect.Parameter.POSITIONAL_ONLY:
ba_args.append(val)
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
ba_args.extend(val) # static *args?
elif param.kind == inspect.Parameter.VAR_KEYWORD:
ba_kwargs.update(val) # static **kwargs?
else:
# POSITIONAL_OR_KEYWORD or KEYWORD_ONLY
# We can pass as kwarg usually, unless it's already satisfied?
# If we pass as kwarg, it works.
ba_kwargs[name] = val

else:
# Dynamic. Look in args/kwargs.
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
ba_args.append(next(arg_iter))
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
# Consumes rest of args?
ba_args.extend(list(arg_iter))
elif param.kind == inspect.Parameter.VAR_KEYWORD:
ba_kwargs.update(kwargs) # The rest of kwargs
else:
# POSITIONAL_OR_KEYWORD or KEYWORD_ONLY
if name in kwargs:
ba_kwargs[name] = kwargs[name]
else:
# Should have been provided?
pass

return fun(*ba_args, **ba_kwargs)
return new_fun
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function definition appears to be an incomplete and buggy draft that is immediately overwritten on line 749. It should be removed. For instance, it incorrectly handles POSITIONAL_OR_KEYWORD static arguments, which could lead to them being passed twice. The actual implementation is in _dyn_args_fun_inspect_uncached.

Comment on lines +442 to +457
for name in static_argnames:
try:
idx = arg_names.index(name)
static_argnums_set.add(idx)
except ValueError:
# For **kwargs which are not in signature but passed, we can't easily handle them as static by name
# if they are part of **kwargs. JIT handles this by keeping them in separate dict?
# For now, let's assume static_argnames must refer to named arguments in signature OR
# if the function has **kwargs, we might need to handle it.
# Current api_util.resolve_argnums logic is more complex.
# Simplified approach: if name not in arg_names (explicit args), maybe its in kwargs?
if name in bound.kwargs.values(): # Wait, bound.arguments contains everything?
# bound.arguments contains mapping from arg name to value.
# if 'kwargs' is the name of **kwargs param, it holds a dict.
pass
raise ValueError(f"static_argname '{name}' not found in function signature or arguments.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The handling of static_argnames that are part of **kwargs seems incomplete. The except ValueError block is entered when an argument name is not found in the function's explicit signature, which is the case for arguments passed via **kwargs.

The current logic:

except ValueError:
   # ... (comments)
   if name in bound.kwargs.values(): # This check is incorrect and has no effect
       pass
   raise ValueError(...)

This will always raise a ValueError, meaning static_argnames in **kwargs are not supported. The check if name in bound.kwargs.values() is also incorrect as bound.kwargs is not an attribute of inspect.BoundArguments, and even if it were, it would likely contain values, not keys.

This seems to be a known limitation, as the test test_static_argnames_in_var_kwargs is empty.

To align with jax.jit's behavior, this should be implemented. If not, this limitation should be clearly documented in the checkpoint docstring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant