-
Notifications
You must be signed in to change notification settings - Fork 125
[ptx] Support for CUDA JIT compiler flags #713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
… flags can be passed to the CUDA JIT compiler, implement the JNI of cuModuleLoad so TornadoVM prebuilt can read .cubin files, currently both compiler flags and the path of the .cubin are hardcoded, should pass them from API
… flags can be passed to the CUDA JIT compiler, implement the JNI of cuModuleLoad so TornadoVM prebuilt can read .cubin files, currently both compiler flags and the path of the .cubin are hardcoded, should pass them from API
|
@yrq0208 what it needs to accept compiler flags as in opencl backend? |
…ornadoVM into PTX_cuModuleLoadDataEx
…to process the CUDA JIT flags passed from TornadoOption.java to the relevant CUDA function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements support for passing CUDA JIT compiler flags through the cuModuleLoadDataEx() method, enabling explicit control over PTX compilation optimization levels and other performance-related flags. Previously, the implementation only used cuModuleLoadData() which didn't allow passing compiler flags.
Key Changes:
- Added new JNI method
cuModuleLoadDataExto handle CUDA JIT compiler flags - Updated default PTX compiler flags to include
CU_JIT_OPTIMIZATION_LEVEL 4for optimal performance - Modified PTX module loading pipeline to pass
TaskDataContextmetadata through the compilation chain
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java | Updated default PTX compiler flags and added documentation on flag format |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXTornadoCompiler.java | Updated installSource call to pass metadata parameter |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXJITCompiler.java | Updated installCode call to pass task metadata |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java | Modified compilation methods to pass task metadata for compiler flags |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXModule.java | Added new constructor parameter for compiler flags and native method declaration |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java | Updated installCode methods to accept and forward task metadata |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java | Modified to extract compiler flags from task metadata and pass to PTXModule |
| tornado-drivers/ptx-jni/src/main/cpp/source/PTXModule.h | Added JNI header declaration for cuModuleLoadDataEx |
| tornado-drivers/ptx-jni/src/main/cpp/source/PTXModule.cpp | Implemented cuModuleLoadDataEx with CUDA JIT flag parsing and application |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java
Outdated
Show resolved
Hide resolved
tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java
Outdated
Show resolved
Hide resolved
| char ptx[ptx_length + 1]; | ||
| #endif | ||
| env->GetByteArrayRegion(source, 0, ptx_length, reinterpret_cast<jbyte *>(ptx)); | ||
| ptx[ptx_length] = 0; // Make sure string terminates with a 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary? is it a cuda-function requirement?
as far as I see, it is not used again after line 95. Are you trying to reset it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is necessary, otherwise if removed cuModuleLoadDataEx() would throw error 218 which means invalid PTX code input.
tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXModule.java
Outdated
Show resolved
Hide resolved
|
should be ready now. |
| auto it = CUDAJITFlagsMap.find(flagName); | ||
| if (it == CUDAJITFlagsMap.end()) { | ||
| std::cerr << "Unsupported CUDA JIT flag: " << flagName << "\n"; | ||
| continue; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have both a map in C++ and one in Java?
See in PTXCodeCache.java:
if (!SUPPORTED_PTX_JIT_FLAGS.contains(flag)) {
throw new TornadoBailoutRuntimeException("Unsupported PTX JIT compiler flag: " + flag + ". Supported flags are: " + SUPPORTED_PTX_JIT_FLAGS);
}It seems redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, this redundancy is necessary since @mikepapadim wants to check whether the flags are valid in Java instead of in JNI, since that will be too late, as the PTX code will be compiled already. Now you might ask, why don't you process the string in Java when you check the flags? I can indeed process them into a Jstring and a Jint array (since cuModuleLoadDataEx needs 2 separate parameters for the flags, i.e., the flags themselves and their values). The problem is with the flags parameter cuModuleLoadDataEx; they need to be in the type of CUjit_option*, I don't think you can pass parameters in the type of CUjit_option* to JNI. Thus, another map is needed in JNI to look up which Jstring corresponds to which CUjit_option*.
stratika
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested your changes, they work in my machine. I added only a comment regarding potential redundancy, please have a look.
and please sync with latest develop. We should test it to ensure they work now, since we have a new release and github-actions.
works with the latest Tornado on my side as well. However, I noticed that in CI, there is no test for PTX? Currently, the tests are for OpenCL, Apple OpenCL, and SPIRV. |
here it is. |
looks like all CI checks have passed |
Description
This patch enables the explicit passing of CUDA JIT compiler flags by using the cuModuleLoadDataEx() method and TornadoOptions.java
Problem description
The current Java_uk_ac_manchester_tornado_drivers_ptx_PTXModule_cuModuleLoadData implementation does not allow the explicit passing of CUDA JIT compiler flags.
Backend/s tested
Mark the backends affected by this PR.
OS tested
Mark the OS where this PR is tested.
Did you check on FPGAs?
This patch is not applicable to FPGAs
How to test the new patch?
The CUDA JIT flags are passed via the TornadoVM CLI in the form of a string. Please refer to the document for a list of currently supported CUDA JIT flags. Feel free to try other flags. By default, the TestCompilerFlagsAPI unit test for PTX using optimization level 0, by passing opt level 4 via the CLI, I can observe a speedup of around 3.8x in terms of TASK_KERNEL_TIME on my machine, which suggests the passing is successful and the opt level is overwritten from 0 to 4
make BACKEND=ptx tornado-test -V --jvm="-Ds0.t0.device=0:0 -Dtornado.ptx.compiler.flags=CU_JIT_OPTIMIZATION_LEVEL\ 4\ CU_JIT_CACHE_MODE\ 0" --enableProfiler console uk.ac.manchester.tornado.unittests.compiler.TestCompilerFlagsAPI#testPTXHere is another example command with all the flags (set your own CU_JIT_TARGET accordingly! Older versions of CUDA might not support GPU with 12.0 computer capability.):
tornado-test -V --jvm="-Ds0.t0.device=0:0 -Dtornado.ptx.compiler.flags=CU_JIT_OPTIMIZATION_LEVEL\ 4\ CU_JIT_CACHE_MODE\ 0\ CU_JIT_MAX_REGISTERS\ 255\ CU_JIT_TARGET\ 120\ CU_JIT_GENERATE_DEBUG_INFO\ 0\ CU_JIT_LOG_VERBOSE\ 0\ CU_JIT_GENERATE_LINE_INFO\ 0" --enableProfiler console uk.ac.manchester.tornado.unittests.compiler.TestCompilerFlagsAPI#testPTX