Skip to content

Commit ad35b30

Browse files
Jake VanderPlasGoogle-ML-Automation
authored andcommitted
[dep] finalize a number of deprecations for JAX v0.9.0
PiperOrigin-RevId: 846831674
1 parent 69518c5 commit ad35b30

File tree

8 files changed

+180
-391
lines changed

8 files changed

+180
-391
lines changed

docs/jax.experimental.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,3 @@ Experimental Modules
2525
jax.experimental.pallas
2626
jax.experimental.serialize_executable
2727
jax.experimental.sparse
28-
29-
Experimental APIs
30-
-----------------
31-
32-
.. autosummary::
33-
:toctree: _autosummary
34-
35-
enable_x64
36-
disable_x64

jax/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@
189189
del _ccache
190190

191191
_deprecations = {
192-
# Added for v0.8.0
192+
# Remove in v0.10.0
193193
"array_ref": (
194-
"jax.array_ref is deprecated; use jax.new_ref instead.",
195-
new_ref
194+
"jax.array_ref was removed in JAX v0.9.0; use jax.new_ref instead.",
195+
None,
196196
),
197197
"ArrayRef": (
198-
"jax.ArrayRef is deprecated; use jax.Ref instead.",
199-
Ref
198+
"jax.ArrayRef was removed in JAX v0.9.0; use jax.Ref instead.",
199+
None
200200
),
201201
# Added for v0.8.1
202202
"device_put_replicated": (
@@ -212,8 +212,6 @@
212212

213213
import typing as _typing
214214
if _typing.TYPE_CHECKING:
215-
array_ref = new_ref
216-
ArrayRef = Ref
217215
device_put_replicated = _deprecated_device_put_replicated
218216
device_put_sharded = _deprecated_device_put_sharded
219217
else:

jax/experimental/__init__.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,46 +30,34 @@
3030
from jax._src.earray import (
3131
EArray as EArray
3232
)
33-
from jax._src import core as _src_core
3433
from jax._src.core import (
3534
cur_qdd as cur_qdd,
3635
)
37-
from jax.experimental import x64_context as _x64_context
3836

3937
_deprecations = {
40-
# Added for v0.8.0
38+
# Remove in v0.10.0
4139
"disable_x64": (
42-
("jax.experimental.disable_x64 is deprecated in JAX v0.8.0 and will be removed"
43-
" in JAX v0.9.0; use jax.enable_x64(False) instead."),
44-
_x64_context._disable_x64
40+
("jax.experimental.disable_x64 was removed in JAX v0.9.0;"
41+
" use jax.enable_x64(False) instead."),
42+
None,
4543
),
4644
"enable_x64": (
47-
("jax.experimental.enable_x64 is deprecated in JAX v0.8.0 and will be removed"
48-
" in JAX v0.9.0; use jax.enable_x64(True) instead."),
49-
_x64_context._enable_x64
45+
("jax.experimental.enable_x64 was removed in JAX v0.9.0;"
46+
" use jax.enable_x64(True) instead."),
47+
None
5048
),
5149
"mutable_array": (
52-
("jax.experimental.mutable_array is deprecated in JAX v0.8.0 and will be removed"
53-
" in JAX v0.9.0; use jax.new_ref instead."),
54-
_src_core.new_ref
50+
("jax.experimental.mutable_array was removed in JAX v0.9.0;"
51+
" use jax.new_ref instead."),
52+
None,
5553
),
5654
"MutableArray": (
57-
("jax.experimental.MutableArray is deprecated in JAX v0.8.0 and will be removed"
58-
" in JAX v0.9.0; use jax.Ref instead."),
59-
_src_core.Ref
55+
("jax.experimental.MutableArray was removed in JAX v0.9.0;"
56+
" use jax.Ref instead."),
57+
None,
6058
),
6159
}
6260

63-
import typing as _typing
64-
if _typing.TYPE_CHECKING:
65-
mutable_array = _src_core.new_ref
66-
MutableArray = _src_core.Ref
67-
enable_x64 = _x64_context._enable_x64
68-
disable_x64 = _x64_context._disable_x64
69-
else:
70-
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
71-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
72-
del _deprecation_getattr
73-
del _typing
74-
del _src_core
75-
del _x64_context
61+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
62+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
63+
del _deprecation_getattr

jax/experimental/x64_context.py

Lines changed: 10 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,78 +17,20 @@
1717
**Deprecated: use :func:`jax.enable_x64` instead.**
1818
"""
1919

20-
from contextlib import contextmanager
21-
from jax._src import config
22-
23-
@contextmanager
24-
def _enable_x64(new_val: bool = True):
25-
"""Experimental context manager to temporarily enable X64 mode.
26-
27-
.. warning::
28-
29-
This context manager is deprecated as of JAX v0.8.0, and will be removed in
30-
JAX v0.9.0. Use :func:`jax.enable_x64` instead.
31-
32-
Usage::
33-
34-
>>> import jax
35-
>>> x = np.arange(5, dtype='float64')
36-
>>> with _enable_x64(True):
37-
... print(jnp.asarray(x).dtype)
38-
...
39-
float64
40-
41-
See Also
42-
--------
43-
jax.experimental.disable_x64 : temporarily disable X64 mode.
44-
"""
45-
with config.enable_x64(new_val):
46-
yield
47-
48-
@contextmanager
49-
def _disable_x64():
50-
"""Experimental context manager to temporarily disable X64 mode.
51-
52-
.. warning::
53-
54-
This context manager is deprecated as of JAX v0.8.0, and will be removed in
55-
JAX v0.9.0. Use :func:`jax.enable_x64` instead.
56-
57-
Usage::
58-
59-
>>> x = np.arange(5, dtype='float64')
60-
>>> with _disable_x64():
61-
... print(jnp.asarray(x).dtype)
62-
...
63-
float32
64-
65-
See Also
66-
--------
67-
jax.experimental.enable_x64 : temporarily enable X64 mode.
68-
"""
69-
with config.enable_x64(False):
70-
yield
71-
7220
_deprecations = {
73-
# Added for v0.8.0
21+
# Remove in v0.10.0
7422
"disable_x64": (
75-
("jax.experimental.x64_context.disable_x64 is deprecated in JAX v0.8.0 and will be removed"
76-
" in JAX v0.9.0; use jax.enable_x64(False) instead."),
77-
_disable_x64
23+
("jax.experimental.x64_context.disable_x64 was removed in JAX v0.9.0;"
24+
" use jax.enable_x64(False) instead."),
25+
None
7826
),
7927
"enable_x64": (
80-
("jax.experimental.x64_context.enable_x64 is deprecated in JAX v0.8.0 and will be removed"
81-
" in JAX v0.9.0; use jax.enable_x64(True) instead."),
82-
_enable_x64
28+
("jax.experimental.x64_context.enable_x64 was removed in JAX v0.9.0;"
29+
" use jax.enable_x64(True) instead."),
30+
None
8331
),
8432
}
8533

86-
import typing as _typing
87-
if _typing.TYPE_CHECKING:
88-
enable_x64 = _enable_x64
89-
disable_x64 = _disable_x64
90-
else:
91-
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
92-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
93-
del _deprecation_getattr
94-
del _typing
34+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
35+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
36+
del _deprecation_getattr

jax/interpreters/ad.py

Lines changed: 50 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717

1818
from __future__ import annotations
1919

20-
from jax._src import ad_util as _src_ad_util
21-
from jax._src.interpreters import ad as _src_ad
22-
2320
from jax._src.interpreters.ad import (
2421
JVPTrace as JVPTrace,
2522
JVPTracer as JVPTracer,
@@ -46,128 +43,101 @@
4643

4744

4845
_deprecations = {
49-
# Deprecated for JAX v0.7.1; finalize in JAX v0.9.0.
46+
# Remove in v0.10.0
5047
"zeros_like_p": (
51-
"jax.interpreters.ad.zeros_like_p is deprecated in JAX v0.7.1. It has been unused since v0.4.24.",
52-
_src_ad_util.zeros_like_p,
48+
"jax.interpreters.ad.zeros_like_p was removed in JAX v0.9.0.",
49+
None,
5350
),
5451
"bilinear_transpose": (
55-
"jax.interpreters.ad.bilinear_transpose is deprecated.",
56-
_src_ad.bilinear_transpose,
52+
"jax.interpreters.ad.bilinear_transpose was removed in JAX v0.9.0.",
53+
None,
5754
),
5855
"call_param_updaters": (
59-
"jax.interpreters.ad.call_param_updaters is deprecated.",
60-
_src_ad.call_param_updaters,
56+
"jax.interpreters.ad.call_param_updaters was removed in JAX v0.9.0.",
57+
None,
6158
),
6259
"call_transpose": (
63-
"jax.interpreters.ad.call_transpose is deprecated.",
64-
_src_ad.call_transpose,
60+
"jax.interpreters.ad.call_transpose was removed in JAX v0.9.0.",
61+
None,
6562
),
6663
"call_transpose_param_updaters": (
67-
"jax.interpreters.ad.call_transpose_param_updaters is deprecated.",
68-
_src_ad.call_transpose_param_updaters,
64+
"jax.interpreters.ad.call_transpose_param_updaters was removed in JAX v0.9.0.",
65+
None,
6966
),
7067
"custom_lin_p": (
71-
"jax.interpreters.ad.custom_lin_p is deprecated.",
72-
_src_ad.custom_lin_p,
68+
"jax.interpreters.ad.custom_lin_p was removed in JAX v0.9.0.",
69+
None,
7370
),
7471
"defjvp_zero": (
75-
"jax.interpreters.ad.defjvp_zero is deprecated.",
76-
_src_ad.defjvp_zero,
72+
"jax.interpreters.ad.defjvp_zero was removed in JAX v0.9.0.",
73+
None,
7774
),
7875
"f_jvp_traceable": (
79-
"jax.interpreters.ad.f_jvp_traceable is deprecated.",
80-
_src_ad.f_jvp_traceable,
76+
"jax.interpreters.ad.f_jvp_traceable was removed in JAX v0.9.0.",
77+
None,
8178
),
8279
"jvp_jaxpr": (
83-
"jax.interpreters.ad.jvp_jaxpr is deprecated.",
84-
_src_ad.jvp_jaxpr,
80+
"jax.interpreters.ad.jvp_jaxpr was removed in JAX v0.9.0.",
81+
None,
8582
),
8683
"jvp_subtrace": (
87-
"jax.interpreters.ad.jvp_subtrace is deprecated.",
88-
_src_ad.jvp_subtrace,
84+
"jax.interpreters.ad.jvp_subtrace was removed in JAX v0.9.0.",
85+
None,
8986
),
9087
"jvp_subtrace_aux": (
91-
"jax.interpreters.ad.jvp_subtrace_aux is deprecated.",
92-
_src_ad.jvp_subtrace_aux,
88+
"jax.interpreters.ad.jvp_subtrace_aux was removed in JAX v0.9.0.",
89+
None,
9390
),
9491
"jvpfun": (
95-
"jax.interpreters.ad.jvpfun is deprecated.",
96-
_src_ad.jvpfun,
92+
"jax.interpreters.ad.jvpfun was removed in JAX v0.9.0.",
93+
None,
9794
),
9895
"linear_jvp": (
99-
"jax.interpreters.ad.linear_jvp is deprecated.",
100-
_src_ad.linear_jvp,
96+
"jax.interpreters.ad.linear_jvp was removed in JAX v0.9.0.",
97+
None,
10198
),
10299
"linear_transpose": (
103-
"jax.interpreters.ad.linear_transpose is deprecated.",
104-
_src_ad.linear_transpose,
100+
"jax.interpreters.ad.linear_transpose was removed in JAX v0.9.0.",
101+
None,
105102
),
106103
"linear_transpose2": (
107-
"jax.interpreters.ad.linear_transpose2 is deprecated.",
108-
_src_ad.linear_transpose2,
104+
"jax.interpreters.ad.linear_transpose2 was removed in JAX v0.9.0.",
105+
None,
109106
),
110107
"map_transpose": (
111-
"jax.interpreters.ad.map_transpose is deprecated.",
112-
_src_ad.map_transpose,
108+
"jax.interpreters.ad.map_transpose was removed in JAX v0.9.0.",
109+
None,
113110
),
114111
"nonzero_outputs": (
115-
"jax.interpreters.ad.nonzero_outputs is deprecated.",
116-
_src_ad.nonzero_outputs,
112+
"jax.interpreters.ad.nonzero_outputs was removed in JAX v0.9.0.",
113+
None,
117114
),
118115
"nonzero_tangent_outputs": (
119-
"jax.interpreters.ad.nonzero_tangent_outputs is deprecated.",
120-
_src_ad.nonzero_tangent_outputs,
116+
"jax.interpreters.ad.nonzero_tangent_outputs was removed in JAX v0.9.0.",
117+
None,
121118
),
122119
"rearrange_binders": (
123-
"jax.interpreters.ad.rearrange_binders is deprecated.",
124-
_src_ad.rearrange_binders,
120+
"jax.interpreters.ad.rearrange_binders was removed in JAX v0.9.0.",
121+
None,
125122
),
126123
"standard_jvp": (
127-
"jax.interpreters.ad.standard_jvp is deprecated.",
128-
_src_ad.standard_jvp,
124+
"jax.interpreters.ad.standard_jvp was removed in JAX v0.9.0.",
125+
None,
129126
),
130127
"standard_jvp2": (
131-
"jax.interpreters.ad.standard_jvp2 is deprecated.",
132-
_src_ad.standard_jvp2,
128+
"jax.interpreters.ad.standard_jvp2 was removed in JAX v0.9.0.",
129+
None,
133130
),
134131
"traceable": (
135-
"jax.interpreters.ad.traceable is deprecated.",
136-
_src_ad.traceable,
132+
"jax.interpreters.ad.traceable was removed in JAX v0.9.0.",
133+
None,
137134
),
138135
"zero_jvp": (
139-
"jax.interpreters.ad.zero_jvp is deprecated.",
140-
_src_ad.zero_jvp,
136+
"jax.interpreters.ad.zero_jvp was removed in JAX v0.9.0.",
137+
None,
141138
),
142139
}
143140

144-
import typing
145-
if typing.TYPE_CHECKING:
146-
bilinear_transpose = _src_ad.bilinear_transpose
147-
call_param_updaters = _src_ad.call_param_updaters
148-
call_transpose = _src_ad.call_transpose
149-
call_transpose_param_updaters = _src_ad.call_transpose_param_updaters
150-
custom_lin_p = _src_ad.custom_lin_p
151-
defjvp_zero = _src_ad.defjvp_zero
152-
f_jvp_traceable = _src_ad.f_jvp_traceable
153-
jvp_jaxpr = _src_ad.jvp_jaxpr
154-
jvp_subtrace = _src_ad.jvp_subtrace
155-
jvp_subtrace_aux = _src_ad.jvp_subtrace_aux
156-
jvpfun = _src_ad.jvpfun
157-
linear_jvp = _src_ad.linear_jvp
158-
linear_transpose = _src_ad.linear_transpose
159-
linear_transpose2 = _src_ad.linear_transpose2
160-
map_transpose = _src_ad.map_transpose
161-
nonzero_outputs = _src_ad.nonzero_outputs
162-
nonzero_tangent_outputs = _src_ad.nonzero_tangent_outputs
163-
rearrange_binders = _src_ad.rearrange_binders
164-
standard_jvp = _src_ad.standard_jvp
165-
standard_jvp2 = _src_ad.standard_jvp2
166-
traceable = _src_ad.traceable
167-
zero_jvp = _src_ad.zero_jvp
168-
zeros_like_p = _src_ad_util.zeros_like_p
169-
else:
170-
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
171-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
172-
del _deprecation_getattr
173-
del typing
141+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
142+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
143+
del _deprecation_getattr

0 commit comments

Comments
 (0)