Skip to content

Commit f281864

Browse files
jakeharmon8TF2JAXDev
authored andcommitted
Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886640
1 parent 171d7e3 commit f281864

File tree

6 files changed

+13
-13
lines changed

6 files changed

+13
-13
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ performance.
325325

326326
### Platform Specificity
327327

328-
Natively serialized JAX programs are platform specific ([link](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#natively-serialized-jax-modules-are-platform-specific)). Executing a natively
328+
Natively serialized JAX programs are platform specific ([link](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#natively-serialized-jax-modules-are-platform-specific)). Executing a natively
329329
serialized program on platforms other than the one for which it was lowered,
330330
would raise a ValueError, e.g.:
331331

@@ -399,8 +399,8 @@ ops.
399399

400400
[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
401401
[DeepMind JAX Ecosystem citation]: https://github.com/google-deepmind/jax/blob/main/deepmind2020jax.txt "Citation"
402-
[JAX]: https://github.com/google/jax "JAX on GitHub"
402+
[JAX]: https://github.com/jax-ml/jax "JAX on GitHub"
403403
[TensorFlow]: https://github.com/tensorflow/tensorflow "TensorFlow on GitHub"
404-
[jax2tf documentation]: https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax "jax2tf documentation"
405-
[jax2tf_cumulative_reduction]: https://github.com/google/jax/blob/main/jax/experimental/jax2tf/jax2tf.py#L2172
404+
[jax2tf documentation]: https://github.com/jax-ml/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax "jax2tf documentation"
405+
[jax2tf_cumulative_reduction]: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/jax2tf.py#L2172
406406
[StableHLO]: https://github.com/openxla/stablehlo

test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax
7070
# Native lowering is in active development so we test against nightly and github head.
7171
pip uninstall --yes tensorflow
7272
pip install tf-nightly
73-
pip install git+https://github.com/google/jax.git
73+
pip install git+https://github.com/jax-ml/jax.git
7474
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
7575
CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test
7676
cd ..

tf2jax/_src/numpy_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def is_poly_dim(x) -> bool:
7676
return export.is_symbolic_dim(x)
7777

7878
# This should reflect is_poly_dim() at
79-
# https://github.com/google/jax/blob/main/jax/experimental/jax2tf/shape_poly.py#L676
79+
# https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/shape_poly.py#L676
8080
# Array types.
8181
if isinstance(x, (np.ndarray, jax.core.Tracer, xc.ArrayImpl)): # pylint: disable=isinstance-second-argument-not-valid-type
8282
return False

tf2jax/_src/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def _func(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
844844
evals, evecs = jnp.linalg.eigh(x, symmetrize_input=False)
845845
else:
846846
# symmetrize_input does not exist for eigvalsh.
847-
# See https://github.com/google/jax/issues/9473
847+
# See https://github.com/jax-ml/jax/issues/9473
848848
evals, evecs = jnp.linalg.eigvalsh(symmetrize(x)), None
849849

850850
# Sorting by eigenvalues to tf.raw_ops.Eig better.
@@ -2655,7 +2655,7 @@ def _xla_variadic_sort(proto):
26552655
return _XlaVariadicSort(dict(comparator=comparator), is_stable=is_stable)
26562656

26572657

2658-
# Taken from https://github.com/google/jax/blob/main/jax/_src/lax/lax.py#L1056
2658+
# Taken from https://github.com/jax-ml/jax/blob/main/jax/_src/lax/lax.py#L1056
26592659
def _get_max_identity(dtype):
26602660
if jax.dtypes.issubdtype(dtype, np.inexact):
26612661
return np.array(-np.inf, dtype)
@@ -2665,7 +2665,7 @@ def _get_max_identity(dtype):
26652665
return np.array(False, np.bool_)
26662666

26672667

2668-
# Taken from https://github.com/google/jax/blob/main/jax/_src/lax/lax.py#L1064
2668+
# Taken from https://github.com/jax-ml/jax/blob/main/jax/_src/lax/lax.py#L1064
26692669
def _get_min_identity(dtype):
26702670
if jax.dtypes.issubdtype(dtype, np.inexact):
26712671
return np.array(np.inf, dtype)

tf2jax/experimental/mhlo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def mhlo_apply_impl(*args, module: MhloModule):
6161
mhlo_apply_p.def_impl(mhlo_apply_impl)
6262

6363

64-
# See https://github.com/google/jax/blob/main/jax/_src/interpreters/mlir.py#L115
64+
# See https://github.com/jax-ml/jax/blob/main/jax/_src/interpreters/mlir.py#L115
6565
# for reference
6666
def ir_type_to_dtype(ir_type: ir.Type) -> jnp.dtype:
6767
"""Converts MLIR type to JAX dtype."""
@@ -154,7 +154,7 @@ def mhlo_apply_abstract_eval(
154154

155155

156156
# Taken from
157-
# github.com/google/jax/blob/main/jax/experimental/jax2tf/jax_export.py#L859
157+
# github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/jax_export.py#L859
158158
def refine_polymorphic_shapes(
159159
module: ir.Module, validate_static_shapes: bool
160160
) -> ir.Module:

tf2jax/experimental/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
# See canonicalize_platform for reference
34-
# https://github.com/google/jax/blob/main/jax/_src/xla_bridge.py#L344
34+
# https://github.com/jax-ml/jax/blob/main/jax/_src/xla_bridge.py#L344
3535
def _platform_to_alias(platform: str) -> str:
3636
aliases = {
3737
"cuda": "gpu",
@@ -41,7 +41,7 @@ def _platform_to_alias(platform: str) -> str:
4141

4242

4343
# Adapted from
44-
# https://github.com/google/jax/commit/ec8b855fa16962b1394716622c8cbc006ce76b1c
44+
# https://github.com/jax-ml/jax/commit/ec8b855fa16962b1394716622c8cbc006ce76b1c
4545
@functools.lru_cache(None)
4646
def _refine_with_static_input_shapes(
4747
module_text: str, operands: Tuple[jax.core.ShapedArray, ...]

0 commit comments

Comments
 (0)