Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,3 @@ Experimental Modules
jax.experimental.pallas
jax.experimental.serialize_executable
jax.experimental.sparse

Experimental APIs
-----------------

.. autosummary::
:toctree: _autosummary

enable_x64
disable_x64
12 changes: 5 additions & 7 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,14 @@
del _ccache

_deprecations = {
# Added for v0.8.0
# Remove in v0.10.0
"array_ref": (
"jax.array_ref is deprecated; use jax.new_ref instead.",
new_ref
"jax.array_ref was removed in JAX v0.9.0; use jax.new_ref instead.",
None,
),
"ArrayRef": (
"jax.ArrayRef is deprecated; use jax.Ref instead.",
Ref
"jax.ArrayRef was removed in JAX v0.9.0; use jax.Ref instead.",
None
),
# Added for v0.8.1
"device_put_replicated": (
Expand All @@ -212,8 +212,6 @@

import typing as _typing
if _typing.TYPE_CHECKING:
array_ref = new_ref
ArrayRef = Ref
device_put_replicated = _deprecated_device_put_replicated
device_put_sharded = _deprecated_device_put_sharded
else:
Expand Down
44 changes: 16 additions & 28 deletions jax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,34 @@
from jax._src.earray import (
EArray as EArray
)
from jax._src import core as _src_core
from jax._src.core import (
cur_qdd as cur_qdd,
)
from jax.experimental import x64_context as _x64_context

_deprecations = {
# Added for v0.8.0
# Remove in v0.10.0
"disable_x64": (
("jax.experimental.disable_x64 is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.enable_x64(False) instead."),
_x64_context._disable_x64
("jax.experimental.disable_x64 was removed in JAX v0.9.0;"
" use jax.enable_x64(False) instead."),
None,
),
"enable_x64": (
("jax.experimental.enable_x64 is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.enable_x64(True) instead."),
_x64_context._enable_x64
("jax.experimental.enable_x64 was removed in JAX v0.9.0;"
" use jax.enable_x64(True) instead."),
None
),
"mutable_array": (
("jax.experimental.mutable_array is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.new_ref instead."),
_src_core.new_ref
("jax.experimental.mutable_array was removed in JAX v0.9.0;"
" use jax.new_ref instead."),
None,
),
"MutableArray": (
("jax.experimental.MutableArray is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.Ref instead."),
_src_core.Ref
("jax.experimental.MutableArray was removed in JAX v0.9.0;"
" use jax.Ref instead."),
None,
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
mutable_array = _src_core.new_ref
MutableArray = _src_core.Ref
enable_x64 = _x64_context._enable_x64
disable_x64 = _x64_context._disable_x64
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
del _src_core
del _x64_context
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
78 changes: 10 additions & 68 deletions jax/experimental/x64_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,78 +17,20 @@
**Deprecated: use :func:`jax.enable_x64` instead.**
"""

from contextlib import contextmanager
from jax._src import config

@contextmanager
def _enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.

.. warning::

This context manager is deprecated as of JAX v0.8.0, and will be removed in
JAX v0.9.0. Use :func:`jax.enable_x64` instead.

Usage::

>>> import jax
>>> x = np.arange(5, dtype='float64')
>>> with _enable_x64(True):
... print(jnp.asarray(x).dtype)
...
float64

See Also
--------
jax.experimental.disable_x64 : temporarily disable X64 mode.
"""
with config.enable_x64(new_val):
yield

@contextmanager
def _disable_x64():
"""Experimental context manager to temporarily disable X64 mode.

.. warning::

This context manager is deprecated as of JAX v0.8.0, and will be removed in
JAX v0.9.0. Use :func:`jax.enable_x64` instead.

Usage::

>>> x = np.arange(5, dtype='float64')
>>> with _disable_x64():
... print(jnp.asarray(x).dtype)
...
float32

See Also
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
with config.enable_x64(False):
yield

_deprecations = {
# Added for v0.8.0
# Remove in v0.10.0
"disable_x64": (
("jax.experimental.x64_context.disable_x64 is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.enable_x64(False) instead."),
_disable_x64
("jax.experimental.x64_context.disable_x64 was removed in JAX v0.9.0;"
" use jax.enable_x64(False) instead."),
None
),
"enable_x64": (
("jax.experimental.x64_context.enable_x64 is deprecated in JAX v0.8.0 and will be removed"
" in JAX v0.9.0; use jax.enable_x64(True) instead."),
_enable_x64
("jax.experimental.x64_context.enable_x64 was removed in JAX v0.9.0;"
" use jax.enable_x64(True) instead."),
None
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
enable_x64 = _enable_x64
disable_x64 = _disable_x64
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
130 changes: 50 additions & 80 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

from __future__ import annotations

from jax._src import ad_util as _src_ad_util
from jax._src.interpreters import ad as _src_ad

from jax._src.interpreters.ad import (
JVPTrace as JVPTrace,
JVPTracer as JVPTracer,
Expand All @@ -46,128 +43,101 @@


_deprecations = {
# Deprecated for JAX v0.7.1; finalize in JAX v0.9.0.
# Remove in v0.10.0
"zeros_like_p": (
"jax.interpreters.ad.zeros_like_p is deprecated in JAX v0.7.1. It has been unused since v0.4.24.",
_src_ad_util.zeros_like_p,
"jax.interpreters.ad.zeros_like_p was removed in JAX v0.9.0.",
None,
),
"bilinear_transpose": (
"jax.interpreters.ad.bilinear_transpose is deprecated.",
_src_ad.bilinear_transpose,
"jax.interpreters.ad.bilinear_transpose was removed in JAX v0.9.0.",
None,
),
"call_param_updaters": (
"jax.interpreters.ad.call_param_updaters is deprecated.",
_src_ad.call_param_updaters,
"jax.interpreters.ad.call_param_updaters was removed in JAX v0.9.0.",
None,
),
"call_transpose": (
"jax.interpreters.ad.call_transpose is deprecated.",
_src_ad.call_transpose,
"jax.interpreters.ad.call_transpose was removed in JAX v0.9.0.",
None,
),
"call_transpose_param_updaters": (
"jax.interpreters.ad.call_transpose_param_updaters is deprecated.",
_src_ad.call_transpose_param_updaters,
"jax.interpreters.ad.call_transpose_param_updaters was removed in JAX v0.9.0.",
None,
),
"custom_lin_p": (
"jax.interpreters.ad.custom_lin_p is deprecated.",
_src_ad.custom_lin_p,
"jax.interpreters.ad.custom_lin_p was removed in JAX v0.9.0.",
None,
),
"defjvp_zero": (
"jax.interpreters.ad.defjvp_zero is deprecated.",
_src_ad.defjvp_zero,
"jax.interpreters.ad.defjvp_zero was removed in JAX v0.9.0.",
None,
),
"f_jvp_traceable": (
"jax.interpreters.ad.f_jvp_traceable is deprecated.",
_src_ad.f_jvp_traceable,
"jax.interpreters.ad.f_jvp_traceable was removed in JAX v0.9.0.",
None,
),
"jvp_jaxpr": (
"jax.interpreters.ad.jvp_jaxpr is deprecated.",
_src_ad.jvp_jaxpr,
"jax.interpreters.ad.jvp_jaxpr was removed in JAX v0.9.0.",
None,
),
"jvp_subtrace": (
"jax.interpreters.ad.jvp_subtrace is deprecated.",
_src_ad.jvp_subtrace,
"jax.interpreters.ad.jvp_subtrace was removed in JAX v0.9.0.",
None,
),
"jvp_subtrace_aux": (
"jax.interpreters.ad.jvp_subtrace_aux is deprecated.",
_src_ad.jvp_subtrace_aux,
"jax.interpreters.ad.jvp_subtrace_aux was removed in JAX v0.9.0.",
None,
),
"jvpfun": (
"jax.interpreters.ad.jvpfun is deprecated.",
_src_ad.jvpfun,
"jax.interpreters.ad.jvpfun was removed in JAX v0.9.0.",
None,
),
"linear_jvp": (
"jax.interpreters.ad.linear_jvp is deprecated.",
_src_ad.linear_jvp,
"jax.interpreters.ad.linear_jvp was removed in JAX v0.9.0.",
None,
),
"linear_transpose": (
"jax.interpreters.ad.linear_transpose is deprecated.",
_src_ad.linear_transpose,
"jax.interpreters.ad.linear_transpose was removed in JAX v0.9.0.",
None,
),
"linear_transpose2": (
"jax.interpreters.ad.linear_transpose2 is deprecated.",
_src_ad.linear_transpose2,
"jax.interpreters.ad.linear_transpose2 was removed in JAX v0.9.0.",
None,
),
"map_transpose": (
"jax.interpreters.ad.map_transpose is deprecated.",
_src_ad.map_transpose,
"jax.interpreters.ad.map_transpose was removed in JAX v0.9.0.",
None,
),
"nonzero_outputs": (
"jax.interpreters.ad.nonzero_outputs is deprecated.",
_src_ad.nonzero_outputs,
"jax.interpreters.ad.nonzero_outputs was removed in JAX v0.9.0.",
None,
),
"nonzero_tangent_outputs": (
"jax.interpreters.ad.nonzero_tangent_outputs is deprecated.",
_src_ad.nonzero_tangent_outputs,
"jax.interpreters.ad.nonzero_tangent_outputs was removed in JAX v0.9.0.",
None,
),
"rearrange_binders": (
"jax.interpreters.ad.rearrange_binders is deprecated.",
_src_ad.rearrange_binders,
"jax.interpreters.ad.rearrange_binders was removed in JAX v0.9.0.",
None,
),
"standard_jvp": (
"jax.interpreters.ad.standard_jvp is deprecated.",
_src_ad.standard_jvp,
"jax.interpreters.ad.standard_jvp was removed in JAX v0.9.0.",
None,
),
"standard_jvp2": (
"jax.interpreters.ad.standard_jvp2 is deprecated.",
_src_ad.standard_jvp2,
"jax.interpreters.ad.standard_jvp2 was removed in JAX v0.9.0.",
None,
),
"traceable": (
"jax.interpreters.ad.traceable is deprecated.",
_src_ad.traceable,
"jax.interpreters.ad.traceable was removed in JAX v0.9.0.",
None,
),
"zero_jvp": (
"jax.interpreters.ad.zero_jvp is deprecated.",
_src_ad.zero_jvp,
"jax.interpreters.ad.zero_jvp was removed in JAX v0.9.0.",
None,
),
}

import typing
if typing.TYPE_CHECKING:
bilinear_transpose = _src_ad.bilinear_transpose
call_param_updaters = _src_ad.call_param_updaters
call_transpose = _src_ad.call_transpose
call_transpose_param_updaters = _src_ad.call_transpose_param_updaters
custom_lin_p = _src_ad.custom_lin_p
defjvp_zero = _src_ad.defjvp_zero
f_jvp_traceable = _src_ad.f_jvp_traceable
jvp_jaxpr = _src_ad.jvp_jaxpr
jvp_subtrace = _src_ad.jvp_subtrace
jvp_subtrace_aux = _src_ad.jvp_subtrace_aux
jvpfun = _src_ad.jvpfun
linear_jvp = _src_ad.linear_jvp
linear_transpose = _src_ad.linear_transpose
linear_transpose2 = _src_ad.linear_transpose2
map_transpose = _src_ad.map_transpose
nonzero_outputs = _src_ad.nonzero_outputs
nonzero_tangent_outputs = _src_ad.nonzero_tangent_outputs
rearrange_binders = _src_ad.rearrange_binders
standard_jvp = _src_ad.standard_jvp
standard_jvp2 = _src_ad.standard_jvp2
traceable = _src_ad.traceable
zero_jvp = _src_ad.zero_jvp
zeros_like_p = _src_ad_util.zeros_like_p
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
Loading
Loading