-
-
Couldn't load subscription status.
- Fork 118
Description
Hi, I'm currently working on a jax binding of VkFFT via the XLA FFI interface.
It's almost done, but I encountered an ordering issue owing to the inapproriate use of CUDA stream (see jax-ml/jax#28005 for discussions with jax developer)
The viable solution to the problem is to pass the CUDA stream at runtime (i.e., after the app initialization, see cuda binding example https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/cuda_examples.cu)
However, VkFFTLaunchParams currently doesn't hold stream, and I am not sure if changing the config's stream (after app initialization) would take desirable effect.
Of course, we can always reconstruct an app at each invocation of the fft/ifft, but I don't think it's an efficient way to do so.
Could you please help me sort out this problem? thanks!