Skip to content

Commit a584008

Browse files
[lang] Support ti.FieldsBuilder() (#2501)
* temp * dynamic wip (runtime get root) * dynamic snode * clean up comments * clean up commented code * hide root in GetRootStmt * resolve conversations * remove num_roots * add const * edit test * Auto Format * add default value * add default value to pass tests * Auto Format Co-authored-by: Taichi Gardener <[email protected]>
1 parent dedd976 commit a584008

File tree

13 files changed

+117
-100
lines changed

13 files changed

+117
-100
lines changed

python/taichi/lang/impl.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -442,15 +442,6 @@ def field(dtype, shape=None, offset=None, needs_grad=False):
442442

443443
assert (offset is not None and shape is None
444444
) == False, f'The shape cannot be None when offset is being set'
445-
'''
446-
if get_runtime().materialized:
447-
raise RuntimeError(
448-
"No new variables can be declared after materialization, i.e. kernel invocations "
449-
"or Python-scope field accesses. I.e., data layouts must be specified before "
450-
"any computation. Try appending ti.init() or ti.reset() "
451-
"right after 'import taichi as ti' if you are using Jupyter notebook or Blender."
452-
)
453-
'''
454445

455446
del _taichi_skip_traceback
456447

taichi/codegen/codegen_llvm.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,12 +1251,19 @@ llvm::Value *CodeGenLLVM::call(SNode *snode,
12511251
}
12521252

12531253
void CodeGenLLVM::visit(GetRootStmt *stmt) {
1254-
llvm_val[stmt] = builder->CreateBitCast(
1255-
get_root(),
1256-
llvm::PointerType::get(
1257-
StructCompilerLLVM::get_llvm_node_type(
1258-
module.get(), prog->get_snode_root(SNodeTree::kFirstID)),
1259-
0));
1254+
if (stmt->root() == nullptr)
1255+
llvm_val[stmt] = builder->CreateBitCast(
1256+
get_root(SNodeTree::kFirstID),
1257+
llvm::PointerType::get(
1258+
StructCompilerLLVM::get_llvm_node_type(
1259+
module.get(), prog->get_snode_root(SNodeTree::kFirstID)),
1260+
0));
1261+
else
1262+
llvm_val[stmt] = builder->CreateBitCast(
1263+
get_root(stmt->root()->get_snode_tree_id()),
1264+
llvm::PointerType::get(
1265+
StructCompilerLLVM::get_llvm_node_type(module.get(), stmt->root()),
1266+
0));
12601267
}
12611268

12621269
void CodeGenLLVM::visit(BitExtractStmt *stmt) {
@@ -2011,8 +2018,9 @@ llvm::Type *CodeGenLLVM::get_xlogue_function_type() {
20112018
get_xlogue_argument_types(), false);
20122019
}
20132020

2014-
llvm::Value *CodeGenLLVM::get_root() {
2015-
return create_call("LLVMRuntime_get_root", {get_runtime()});
2021+
llvm::Value *CodeGenLLVM::get_root(int snode_tree_id) {
2022+
return create_call("LLVMRuntime_get_roots",
2023+
{get_runtime(), tlctx->get_constant(snode_tree_id)});
20162024
}
20172025

20182026
llvm::Value *CodeGenLLVM::get_runtime() {

taichi/codegen/codegen_llvm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
101101

102102
llvm::Type *get_xlogue_function_type();
103103

104-
llvm::Value *get_root();
104+
llvm::Value *get_root(int snode_tree_id);
105105

106106
llvm::Value *get_runtime();
107107

taichi/inc/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
constexpr int taichi_max_num_indices = 8;
66
constexpr int taichi_max_num_args = 8;
77
constexpr int taichi_max_num_snodes = 1024;
8+
constexpr int taichi_max_num_snode_trees = 32;
89
constexpr int taichi_max_gpu_block_dim = 1024;
910
constexpr std::size_t taichi_global_tmp_buffer_size = 1024 * 1024;
1011
constexpr int taichi_max_num_mem_requests = 1024 * 64;

taichi/ir/snode.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,12 @@ SNode *SNode::get_grad() const {
200200
return grad_info->grad_snode();
201201
}
202202

203+
void SNode::set_snode_tree_id(int id) {
204+
snode_tree_id_ = id;
205+
}
206+
207+
int SNode::get_snode_tree_id() {
208+
return snode_tree_id_;
209+
}
210+
203211
TLANG_NAMESPACE_END

taichi/ir/snode.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,15 @@ class SNode {
293293
void begin_shared_exp_placement();
294294

295295
void end_shared_exp_placement();
296+
297+
// SNodeTree part
298+
299+
void set_snode_tree_id(int id);
300+
301+
int get_snode_tree_id();
302+
303+
private:
304+
int snode_tree_id_{0};
296305
};
297306

298307
} // namespace lang

taichi/ir/statements.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,16 +912,28 @@ class BitExtractStmt : public Stmt {
912912
*/
913913
class GetRootStmt : public Stmt {
914914
public:
915-
GetRootStmt() {
915+
GetRootStmt(SNode *root = nullptr) : root_(root) {
916+
if (this->root_ != nullptr) {
917+
while (this->root_->parent) {
918+
this->root_ = this->root_->parent;
919+
}
920+
}
916921
TI_STMT_REG_FIELDS;
917922
}
918923

919924
bool has_global_side_effect() const override {
920925
return false;
921926
}
922927

923-
TI_STMT_DEF_FIELDS(ret_type);
928+
TI_STMT_DEF_FIELDS(ret_type, root_);
924929
TI_DEFINE_ACCEPT_AND_CLONE
930+
931+
SNode *root() {
932+
return root_;
933+
}
934+
935+
private:
936+
SNode *root_;
925937
};
926938

927939
/**

taichi/program/program.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ void Program::initialize_llvm_runtime_snodes(const SNodeTree *tree,
402402
TI_TRACE("Allocating data structure of size {} bytes", scomp->root_size);
403403
runtime_jit->call<void *, std::size_t, int, int>(
404404
"runtime_initialize_snodes", llvm_runtime, scomp->root_size, root_id,
405-
(int)snodes.size());
405+
(int)snodes.size(), tree->id());
406406
for (int i = 0; i < (int)snodes.size(); i++) {
407407
if (is_gc_able(snodes[i]->type)) {
408408
std::size_t node_size;
@@ -430,6 +430,7 @@ void Program::initialize_llvm_runtime_snodes(const SNodeTree *tree,
430430
int Program::add_snode_tree(std::unique_ptr<SNode> root) {
431431
const int id = snode_trees_.size();
432432
auto tree = std::make_unique<SNodeTree>(id, std::move(root));
433+
tree->root()->set_snode_tree_id(id);
433434
materialize_snode_tree(tree.get());
434435
snode_trees_.push_back(std::move(tree));
435436
return id;
@@ -655,7 +656,9 @@ void Program::visualize_layout(const std::string &fn) {
655656
emit("]");
656657
};
657658

658-
visit(get_snode_root(SNodeTree::kFirstID));
659+
for (auto &a : snode_trees_) {
660+
visit(a->root());
661+
}
659662

660663
auto tail = R"(
661664
\end{tikzpicture}
@@ -891,7 +894,9 @@ void Program::print_memory_profiler_info() {
891894
}
892895
};
893896

894-
visit(get_snode_root(SNodeTree::kFirstID), /*depth=*/0);
897+
for (auto &a : snode_trees_) {
898+
visit(a->root(), /*depth=*/0);
899+
}
895900

896901
auto total_requested_memory = runtime_query<std::size_t>(
897902
"LLVMRuntime_get_total_requested_memory", llvm_runtime);

taichi/runtime/llvm/runtime.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,10 @@ struct LLVMRuntime {
525525
host_printf_type host_printf;
526526
host_vsnprintf_type host_vsnprintf;
527527
Ptr program;
528-
Ptr root;
529-
size_t root_mem_size;
528+
529+
Ptr roots[taichi_max_num_snode_trees];
530+
size_t root_mem_sizes[taichi_max_num_snode_trees];
531+
530532
Ptr thread_pool;
531533
parallel_for_type parallel_for;
532534
ListManager *element_lists[taichi_max_num_snodes];
@@ -573,8 +575,8 @@ struct LLVMRuntime {
573575
// TODO: are these necessary?
574576
STRUCT_FIELD_ARRAY(LLVMRuntime, element_lists);
575577
STRUCT_FIELD_ARRAY(LLVMRuntime, node_allocators);
576-
STRUCT_FIELD(LLVMRuntime, root);
577-
STRUCT_FIELD(LLVMRuntime, root_mem_size);
578+
STRUCT_FIELD_ARRAY(LLVMRuntime, roots);
579+
STRUCT_FIELD_ARRAY(LLVMRuntime, root_mem_sizes);
578580
STRUCT_FIELD(LLVMRuntime, temporaries);
579581
STRUCT_FIELD(LLVMRuntime, assert_failed);
580582
STRUCT_FIELD(LLVMRuntime, host_printf);
@@ -890,14 +892,15 @@ void runtime_initialize(
890892

891893
void runtime_initialize_snodes(LLVMRuntime *runtime,
892894
std::size_t root_size,
893-
int root_id,
894-
int num_snodes) {
895+
const int root_id,
896+
const int num_snodes,
897+
const int snode_tree_id) {
895898
// For Metal runtime, we have to make sure that both the beginning address
896899
// and the size of the root buffer memory are aligned to page size.
897-
runtime->root_mem_size =
900+
runtime->root_mem_sizes[snode_tree_id] =
898901
taichi::iroundup((size_t)root_size, taichi_page_size);
899-
runtime->root =
900-
runtime->allocate_aligned(runtime->root_mem_size, taichi_page_size);
902+
runtime->roots[snode_tree_id] = runtime->allocate_aligned(
903+
runtime->root_mem_sizes[snode_tree_id], taichi_page_size);
901904
// runtime->request_allocate_aligned ready to use
902905
// initialize the root node element list
903906
for (int i = root_id; i < root_id + num_snodes; i++) {
@@ -908,7 +911,7 @@ void runtime_initialize_snodes(LLVMRuntime *runtime,
908911
Element elem;
909912
elem.loop_bounds[0] = 0;
910913
elem.loop_bounds[1] = 1;
911-
elem.element = runtime->root;
914+
elem.element = runtime->roots[snode_tree_id];
912915
for (int i = 0; i < taichi_max_num_indices; i++) {
913916
elem.pcoord.val[i] = 0;
914917
}
@@ -1743,9 +1746,10 @@ i32 wasm_materialize(Context *context) {
17431746
(RandState *)((size_t)context->runtime + sizeof(LLVMRuntime));
17441747
// set random seed to (1, 0, 0, 0)
17451748
context->runtime->rand_states[0].x = 1;
1746-
context->runtime->root =
1749+
// TODO: remove hard coding on root id 0(SNodeTree::kFirstID)
1750+
context->runtime->roots[0] =
17471751
(Ptr)((size_t)context->runtime->rand_states + sizeof(RandState));
1748-
return (i32)(size_t)context->runtime->root;
1752+
return (i32)(size_t)context->runtime->roots[0];
17491753
}
17501754
}
17511755

taichi/transforms/ir_printer.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,12 @@ class IRPrinter : public IRVisitor {
466466
}
467467

468468
void visit(GetRootStmt *stmt) override {
469-
print("{}{} = get root", stmt->type_hint(), stmt->name());
469+
if (stmt->root() == nullptr)
470+
print("{}{} = get root nullptr", stmt->type_hint(), stmt->name());
471+
else
472+
print("{}{} = get root [{}][{}]", stmt->type_hint(), stmt->name(),
473+
stmt->root()->get_node_type_name_hinted(),
474+
stmt->root()->type_name());
470475
}
471476

472477
void visit(SNodeLookupStmt *stmt) override {

0 commit comments

Comments
 (0)