@@ -525,8 +525,10 @@ struct LLVMRuntime {
525525 host_printf_type host_printf;
526526 host_vsnprintf_type host_vsnprintf;
527527 Ptr program;
528- Ptr root;
529- size_t root_mem_size;
528+
529+ Ptr roots[taichi_max_num_snode_trees];
530+ size_t root_mem_sizes[taichi_max_num_snode_trees];
531+
530532 Ptr thread_pool;
531533 parallel_for_type parallel_for;
532534 ListManager *element_lists[taichi_max_num_snodes];
@@ -573,8 +575,8 @@ struct LLVMRuntime {
573575// TODO: are these necessary?
574576STRUCT_FIELD_ARRAY (LLVMRuntime, element_lists);
575577STRUCT_FIELD_ARRAY (LLVMRuntime, node_allocators);
576- STRUCT_FIELD (LLVMRuntime, root );
577- STRUCT_FIELD (LLVMRuntime, root_mem_size );
578+ STRUCT_FIELD_ARRAY (LLVMRuntime, roots );
579+ STRUCT_FIELD_ARRAY (LLVMRuntime, root_mem_sizes );
578580STRUCT_FIELD (LLVMRuntime, temporaries);
579581STRUCT_FIELD (LLVMRuntime, assert_failed);
580582STRUCT_FIELD (LLVMRuntime, host_printf);
@@ -890,14 +892,15 @@ void runtime_initialize(
890892
891893void runtime_initialize_snodes (LLVMRuntime *runtime,
892894 std::size_t root_size,
893- int root_id,
894- int num_snodes) {
895+ const int root_id,
896+ const int num_snodes,
897+ const int snode_tree_id) {
895898 // For Metal runtime, we have to make sure that both the beginning address
896899 // and the size of the root buffer memory are aligned to page size.
897- runtime->root_mem_size =
900+ runtime->root_mem_sizes [snode_tree_id] =
898901 taichi::iroundup ((size_t )root_size, taichi_page_size);
899- runtime->root =
900- runtime->allocate_aligned (runtime-> root_mem_size , taichi_page_size);
902+ runtime->roots [snode_tree_id] = runtime-> allocate_aligned (
903+ runtime->root_mem_sizes [snode_tree_id] , taichi_page_size);
901904 // runtime->request_allocate_aligned ready to use
902905 // initialize the root node element list
903906 for (int i = root_id; i < root_id + num_snodes; i++) {
@@ -908,7 +911,7 @@ void runtime_initialize_snodes(LLVMRuntime *runtime,
908911 Element elem;
909912 elem.loop_bounds [0 ] = 0 ;
910913 elem.loop_bounds [1 ] = 1 ;
911- elem.element = runtime->root ;
914+ elem.element = runtime->roots [snode_tree_id] ;
912915 for (int i = 0 ; i < taichi_max_num_indices; i++) {
913916 elem.pcoord .val [i] = 0 ;
914917 }
@@ -1743,9 +1746,10 @@ i32 wasm_materialize(Context *context) {
17431746 (RandState *)((size_t )context->runtime + sizeof (LLVMRuntime));
17441747 // set random seed to (1, 0, 0, 0)
17451748 context->runtime ->rand_states [0 ].x = 1 ;
1746- context->runtime ->root =
1749+ // TODO: remove hard coding on root id 0(SNodeTree::kFirstID)
1750+ context->runtime ->roots [0 ] =
17471751 (Ptr)((size_t )context->runtime ->rand_states + sizeof (RandState));
1748- return (i32 )(size_t )context->runtime ->root ;
1752+ return (i32 )(size_t )context->runtime ->roots [ 0 ] ;
17491753}
17501754}
17511755
0 commit comments