diff --git a/BUILD.bazel b/BUILD.bazel index f7dee0cb7bb0..38351eb6dca1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -137,14 +137,17 @@ jax_source_package( genrule( name = "wheel_additives", + testonly = True, srcs = [ "//jax/_src:internal_test_harnesses", "//jax/_src:internal_test_util", "//jax/_src:internal_export_back_compat_test_util", "//jax/_src:internal_export_back_compat_test_data", + "//jax/experimental:mosaic_gpu_test_util", "//jax/experimental/jax2tf/tests:tf_test_util", "//jax/experimental/mosaic/gpu/examples:flash_attention.py", "//jax/experimental/mosaic/gpu/examples:matmul.py", + "//jax/experimental/mosaic/gpu/examples:matmul_blackwell", "//jax/_src:test_multiprocess", "//jax/_src/pallas:pallas_test_util", ], @@ -155,6 +158,7 @@ genrule( py_import( name = "jax_py_import", + testonly = True, wheel = ":jax_wheel", zip_deps = [":wheel_additives"], )