Skip to content

Commit 67a038f

Browse files
[aot] [vulkan] Add AotKernel and its Vulkan impl (#4387)
* [aot] [vulkan] Add AotKernel and its Vulkan impl * rename to launch * Auto Format * fix win * fix win * im damned * weird... * fix * fix * maybe fix * Auto Format * horrible c++ * the final fix Co-authored-by: Taichi Gardener <[email protected]>
1 parent 9244144 commit 67a038f

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

taichi/aot/module_loader.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "taichi/aot/module_loader.h"
2+
3+
namespace taichi {
4+
namespace lang {
5+
6+
AotKernel *AotModuleLoader::get_kernel(const std::string &name) {
7+
auto itr = loaded_kernels_.find(name);
8+
if (itr != loaded_kernels_.end()) {
9+
return itr->second.get();
10+
}
11+
auto k = make_new_kernel(name);
12+
auto *kptr = k.get();
13+
loaded_kernels_[name] = std::move(k);
14+
return kptr;
15+
}
16+
17+
} // namespace lang
18+
} // namespace taichi

taichi/aot/module_loader.h

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3+
#include <memory>
34
#include <string>
5+
#include <unordered_map>
46
#include <vector>
57

68
#include "taichi/aot/module_data.h"
@@ -11,17 +13,62 @@
1113
namespace taichi {
1214
namespace lang {
1315

14-
class AotModuleLoader {
16+
class RuntimeContext;
17+
18+
// TODO: Instead of prefixing all these classes with "Aot", just put them into
19+
// the `aot` namespace.
20+
class TI_DLL_EXPORT AotKernel {
21+
public:
22+
// Rule of 5 to make MSVC happy
23+
AotKernel() = default;
24+
virtual ~AotKernel() = default;
25+
AotKernel(const AotKernel &) = delete;
26+
AotKernel &operator=(const AotKernel &) = delete;
27+
AotKernel(AotKernel &&) = default;
28+
AotKernel &operator=(AotKernel &&) = default;
29+
30+
/**
31+
* @brief Launches the kernel to the device
32+
*
33+
* This does not manage the device to host synchronization.
34+
*
35+
* @param ctx Host context
36+
*/
37+
virtual void launch(RuntimeContext *ctx) = 0;
38+
};
39+
40+
class TI_DLL_EXPORT AotModuleLoader {
1541
public:
42+
// Rule of 5 to make MSVC happy
43+
AotModuleLoader() = default;
1644
virtual ~AotModuleLoader() = default;
45+
AotModuleLoader(const AotModuleLoader &) = delete;
46+
AotModuleLoader &operator=(const AotModuleLoader &) = delete;
47+
AotModuleLoader(AotModuleLoader &&) = default;
48+
AotModuleLoader &operator=(AotModuleLoader &&) = default;
1749

1850
// TODO: Add method get_kernel(...) once the kernel field data will be
1951
// generic/common across all backends.
2052

2153
virtual bool get_field(const std::string &name,
2254
aot::CompiledFieldData &field) = 0;
2355

56+
/**
57+
* @brief Get the kernel object
58+
*
59+
* @param name Name of the kernel
60+
* @return AotKernel*
61+
*/
62+
AotKernel *get_kernel(const std::string &name);
63+
2464
virtual size_t get_root_size() const = 0;
65+
66+
protected:
67+
virtual std::unique_ptr<AotKernel> make_new_kernel(
68+
const std::string &name) = 0;
69+
70+
private:
71+
std::unordered_map<std::string, std::unique_ptr<AotKernel>> loaded_kernels_;
2572
};
2673

2774
// Only responsible for reporting device capabilities

taichi/backends/vulkan/aot_module_loader_impl.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,30 @@
33
#include <fstream>
44
#include <type_traits>
55

6+
#include "taichi/backends/vulkan/runtime.h"
7+
68
namespace taichi {
79
namespace lang {
810
namespace vulkan {
11+
namespace {
12+
13+
using KernelHandle = VkRuntime::KernelHandle;
14+
15+
class KernelImpl : public AotKernel {
16+
public:
17+
explicit KernelImpl(VkRuntime *runtime, KernelHandle handle)
18+
: runtime_(runtime), handle_(handle) {
19+
}
20+
21+
void launch(RuntimeContext *ctx) override {
22+
runtime_->launch_kernel(handle_, ctx);
23+
}
24+
25+
private:
26+
VkRuntime *const runtime_;
27+
const KernelHandle handle_;
28+
};
29+
} // namespace
930

1031
AotModuleLoaderImpl::AotModuleLoaderImpl(const std::string &output_dir) {
1132
const std::string bin_path = fmt::format("{}/metadata.tcb", output_dir);
@@ -55,6 +76,17 @@ bool AotModuleLoaderImpl::get_kernel(const std::string &name,
5576
return false;
5677
}
5778

79+
std::unique_ptr<AotKernel> AotModuleLoaderImpl::make_new_kernel(
80+
const std::string &name) {
81+
VkRuntime::RegisterParams kparams;
82+
if (!get_kernel(name, kparams)) {
83+
TI_DEBUG("Failed to load kernel {}", name);
84+
return nullptr;
85+
}
86+
auto handle = runtime_->register_taichi_kernel(kparams);
87+
return std::make_unique<KernelImpl>(runtime_, handle);
88+
}
89+
5890
bool AotModuleLoaderImpl::get_field(const std::string &name,
5991
aot::CompiledFieldData &field) {
6092
TI_ERROR("AOT: get_field for Vulkan not implemented yet");

taichi/backends/vulkan/aot_module_loader_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace taichi {
1313
namespace lang {
1414
namespace vulkan {
1515

16+
class VkRuntime;
17+
1618
class TI_DLL_EXPORT AotModuleLoaderImpl : public AotModuleLoader {
1719
public:
1820
explicit AotModuleLoaderImpl(const std::string &output_dir);
@@ -25,10 +27,12 @@ class TI_DLL_EXPORT AotModuleLoaderImpl : public AotModuleLoader {
2527
size_t get_root_size() const override;
2628

2729
private:
30+
std::unique_ptr<AotKernel> make_new_kernel(const std::string &name) override;
2831
std::vector<uint32_t> read_spv_file(const std::string &output_dir,
2932
const TaskAttributes &k);
3033

3134
TaichiAotData ti_aot_data_;
35+
VkRuntime *runtime_{nullptr};
3236
};
3337

3438
} // namespace vulkan

0 commit comments

Comments
 (0)