|
17 | 17 |
|
18 | 18 | from __future__ import annotations |
19 | 19 |
|
20 | | -from jax._src import ad_util as _src_ad_util |
21 | | -from jax._src.interpreters import ad as _src_ad |
22 | | - |
23 | 20 | from jax._src.interpreters.ad import ( |
24 | 21 | JVPTrace as JVPTrace, |
25 | 22 | JVPTracer as JVPTracer, |
|
46 | 43 |
|
47 | 44 |
|
48 | 45 | _deprecations = { |
49 | | - # Deprecated for JAX v0.7.1; finalize in JAX v0.9.0. |
| 46 | + # Remove in v0.10.0 |
50 | 47 | "zeros_like_p": ( |
51 | | - "jax.interpreters.ad.zeros_like_p is deprecated in JAX v0.7.1. It has been unused since v0.4.24.", |
52 | | - _src_ad_util.zeros_like_p, |
| 48 | + "jax.interpreters.ad.zeros_like_p was removed in JAX v0.9.0.", |
| 49 | + None, |
53 | 50 | ), |
54 | 51 | "bilinear_transpose": ( |
55 | | - "jax.interpreters.ad.bilinear_transpose is deprecated.", |
56 | | - _src_ad.bilinear_transpose, |
| 52 | + "jax.interpreters.ad.bilinear_transpose was removed in JAX v0.9.0.", |
| 53 | + None, |
57 | 54 | ), |
58 | 55 | "call_param_updaters": ( |
59 | | - "jax.interpreters.ad.call_param_updaters is deprecated.", |
60 | | - _src_ad.call_param_updaters, |
| 56 | + "jax.interpreters.ad.call_param_updaters was removed in JAX v0.9.0.", |
| 57 | + None, |
61 | 58 | ), |
62 | 59 | "call_transpose": ( |
63 | | - "jax.interpreters.ad.call_transpose is deprecated.", |
64 | | - _src_ad.call_transpose, |
| 60 | + "jax.interpreters.ad.call_transpose was removed in JAX v0.9.0.", |
| 61 | + None, |
65 | 62 | ), |
66 | 63 | "call_transpose_param_updaters": ( |
67 | | - "jax.interpreters.ad.call_transpose_param_updaters is deprecated.", |
68 | | - _src_ad.call_transpose_param_updaters, |
| 64 | + "jax.interpreters.ad.call_transpose_param_updaters was removed in JAX v0.9.0.", |
| 65 | + None, |
69 | 66 | ), |
70 | 67 | "custom_lin_p": ( |
71 | | - "jax.interpreters.ad.custom_lin_p is deprecated.", |
72 | | - _src_ad.custom_lin_p, |
| 68 | + "jax.interpreters.ad.custom_lin_p was removed in JAX v0.9.0.", |
| 69 | + None, |
73 | 70 | ), |
74 | 71 | "defjvp_zero": ( |
75 | | - "jax.interpreters.ad.defjvp_zero is deprecated.", |
76 | | - _src_ad.defjvp_zero, |
| 72 | + "jax.interpreters.ad.defjvp_zero was removed in JAX v0.9.0.", |
| 73 | + None, |
77 | 74 | ), |
78 | 75 | "f_jvp_traceable": ( |
79 | | - "jax.interpreters.ad.f_jvp_traceable is deprecated.", |
80 | | - _src_ad.f_jvp_traceable, |
| 76 | + "jax.interpreters.ad.f_jvp_traceable was removed in JAX v0.9.0.", |
| 77 | + None, |
81 | 78 | ), |
82 | 79 | "jvp_jaxpr": ( |
83 | | - "jax.interpreters.ad.jvp_jaxpr is deprecated.", |
84 | | - _src_ad.jvp_jaxpr, |
| 80 | + "jax.interpreters.ad.jvp_jaxpr was removed in JAX v0.9.0.", |
| 81 | + None, |
85 | 82 | ), |
86 | 83 | "jvp_subtrace": ( |
87 | | - "jax.interpreters.ad.jvp_subtrace is deprecated.", |
88 | | - _src_ad.jvp_subtrace, |
| 84 | + "jax.interpreters.ad.jvp_subtrace was removed in JAX v0.9.0.", |
| 85 | + None, |
89 | 86 | ), |
90 | 87 | "jvp_subtrace_aux": ( |
91 | | - "jax.interpreters.ad.jvp_subtrace_aux is deprecated.", |
92 | | - _src_ad.jvp_subtrace_aux, |
| 88 | + "jax.interpreters.ad.jvp_subtrace_aux was removed in JAX v0.9.0.", |
| 89 | + None, |
93 | 90 | ), |
94 | 91 | "jvpfun": ( |
95 | | - "jax.interpreters.ad.jvpfun is deprecated.", |
96 | | - _src_ad.jvpfun, |
| 92 | + "jax.interpreters.ad.jvpfun was removed in JAX v0.9.0.", |
| 93 | + None, |
97 | 94 | ), |
98 | 95 | "linear_jvp": ( |
99 | | - "jax.interpreters.ad.linear_jvp is deprecated.", |
100 | | - _src_ad.linear_jvp, |
| 96 | + "jax.interpreters.ad.linear_jvp was removed in JAX v0.9.0.", |
| 97 | + None, |
101 | 98 | ), |
102 | 99 | "linear_transpose": ( |
103 | | - "jax.interpreters.ad.linear_transpose is deprecated.", |
104 | | - _src_ad.linear_transpose, |
| 100 | + "jax.interpreters.ad.linear_transpose was removed in JAX v0.9.0.", |
| 101 | + None, |
105 | 102 | ), |
106 | 103 | "linear_transpose2": ( |
107 | | - "jax.interpreters.ad.linear_transpose2 is deprecated.", |
108 | | - _src_ad.linear_transpose2, |
| 104 | + "jax.interpreters.ad.linear_transpose2 was removed in JAX v0.9.0.", |
| 105 | + None, |
109 | 106 | ), |
110 | 107 | "map_transpose": ( |
111 | | - "jax.interpreters.ad.map_transpose is deprecated.", |
112 | | - _src_ad.map_transpose, |
| 108 | + "jax.interpreters.ad.map_transpose was removed in JAX v0.9.0.", |
| 109 | + None, |
113 | 110 | ), |
114 | 111 | "nonzero_outputs": ( |
115 | | - "jax.interpreters.ad.nonzero_outputs is deprecated.", |
116 | | - _src_ad.nonzero_outputs, |
| 112 | + "jax.interpreters.ad.nonzero_outputs was removed in JAX v0.9.0.", |
| 113 | + None, |
117 | 114 | ), |
118 | 115 | "nonzero_tangent_outputs": ( |
119 | | - "jax.interpreters.ad.nonzero_tangent_outputs is deprecated.", |
120 | | - _src_ad.nonzero_tangent_outputs, |
| 116 | + "jax.interpreters.ad.nonzero_tangent_outputs was removed in JAX v0.9.0.", |
| 117 | + None, |
121 | 118 | ), |
122 | 119 | "rearrange_binders": ( |
123 | | - "jax.interpreters.ad.rearrange_binders is deprecated.", |
124 | | - _src_ad.rearrange_binders, |
| 120 | + "jax.interpreters.ad.rearrange_binders was removed in JAX v0.9.0.", |
| 121 | + None, |
125 | 122 | ), |
126 | 123 | "standard_jvp": ( |
127 | | - "jax.interpreters.ad.standard_jvp is deprecated.", |
128 | | - _src_ad.standard_jvp, |
| 124 | + "jax.interpreters.ad.standard_jvp was removed in JAX v0.9.0.", |
| 125 | + None, |
129 | 126 | ), |
130 | 127 | "standard_jvp2": ( |
131 | | - "jax.interpreters.ad.standard_jvp2 is deprecated.", |
132 | | - _src_ad.standard_jvp2, |
| 128 | + "jax.interpreters.ad.standard_jvp2 was removed in JAX v0.9.0.", |
| 129 | + None, |
133 | 130 | ), |
134 | 131 | "traceable": ( |
135 | | - "jax.interpreters.ad.traceable is deprecated.", |
136 | | - _src_ad.traceable, |
| 132 | + "jax.interpreters.ad.traceable was removed in JAX v0.9.0.", |
| 133 | + None, |
137 | 134 | ), |
138 | 135 | "zero_jvp": ( |
139 | | - "jax.interpreters.ad.zero_jvp is deprecated.", |
140 | | - _src_ad.zero_jvp, |
| 136 | + "jax.interpreters.ad.zero_jvp was removed in JAX v0.9.0.", |
| 137 | + None, |
141 | 138 | ), |
142 | 139 | } |
143 | 140 |
|
144 | | -import typing |
145 | | -if typing.TYPE_CHECKING: |
146 | | - bilinear_transpose = _src_ad.bilinear_transpose |
147 | | - call_param_updaters = _src_ad.call_param_updaters |
148 | | - call_transpose = _src_ad.call_transpose |
149 | | - call_transpose_param_updaters = _src_ad.call_transpose_param_updaters |
150 | | - custom_lin_p = _src_ad.custom_lin_p |
151 | | - defjvp_zero = _src_ad.defjvp_zero |
152 | | - f_jvp_traceable = _src_ad.f_jvp_traceable |
153 | | - jvp_jaxpr = _src_ad.jvp_jaxpr |
154 | | - jvp_subtrace = _src_ad.jvp_subtrace |
155 | | - jvp_subtrace_aux = _src_ad.jvp_subtrace_aux |
156 | | - jvpfun = _src_ad.jvpfun |
157 | | - linear_jvp = _src_ad.linear_jvp |
158 | | - linear_transpose = _src_ad.linear_transpose |
159 | | - linear_transpose2 = _src_ad.linear_transpose2 |
160 | | - map_transpose = _src_ad.map_transpose |
161 | | - nonzero_outputs = _src_ad.nonzero_outputs |
162 | | - nonzero_tangent_outputs = _src_ad.nonzero_tangent_outputs |
163 | | - rearrange_binders = _src_ad.rearrange_binders |
164 | | - standard_jvp = _src_ad.standard_jvp |
165 | | - standard_jvp2 = _src_ad.standard_jvp2 |
166 | | - traceable = _src_ad.traceable |
167 | | - zero_jvp = _src_ad.zero_jvp |
168 | | - zeros_like_p = _src_ad_util.zeros_like_p |
169 | | -else: |
170 | | - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr |
171 | | - __getattr__ = _deprecation_getattr(__name__, _deprecations) |
172 | | - del _deprecation_getattr |
173 | | -del typing |
| 141 | +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr |
| 142 | +__getattr__ = _deprecation_getattr(__name__, _deprecations) |
| 143 | +del _deprecation_getattr |
0 commit comments