-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Using EasyDel on JAX >=0.7.0 causes an error
RuntimeError: Failed to import easydel.modules.auto because of the following error (look up to see its traceback):
module 'jax._src.pjit' has no attribute 'pjit_p'
In the changelog for JAX 0.7.0
The jax.extend.core.primitives.pjit_p primitive has been renamed to jit_p, and its name attribute has changed from "pjit" to "jit".
This affects the string representations of jaxprs.
The same primitive is no longer exported from the jax.experimental.pjit module.
Renaming it does fix the issue
Metadata
Metadata
Assignees
Labels
No labels