2626#include < shared_mutex>
2727#include < type_traits>
2828#include " merlin/core_kernels.cuh"
29+ #include " merlin/external_storage.cuh"
2930#include " merlin/flexible_buffer.cuh"
3031#include " merlin/memory_pool.cuh"
3132#include " merlin/types.cuh"
@@ -152,6 +153,8 @@ class HashTable {
152153 using DeviceMemoryPool = MemoryPool<DeviceAllocator<char >>;
153154 using HostMemoryPool = MemoryPool<HostAllocator<char >>;
154155
156+ using external_storage_type = ExternalStorage<K, V>;
157+
155158#if THRUST_VERSION >= 101600
156159 static constexpr auto thrust_par = thrust::cuda::par_nosync;
157160#else
@@ -169,6 +172,8 @@ class HashTable {
169172 * table object.
170173 */
171174 ~HashTable () {
175+ unlink_external_storage ();
176+
172177 if (initialized_) {
173178 CUDA_CHECK (cudaDeviceSynchronize ());
174179
@@ -295,25 +300,34 @@ class HashTable {
295300 load_factor = fast_load_factor ();
296301 }
297302
298- Selector::execute_kernel (
299- load_factor, options_.block_size , stream, n, c_table_index_, d_table_,
300- keys, reinterpret_cast <const value_type*>(values), metas);
303+ Selector::execute_kernel (load_factor, options_.block_size , stream, n,
304+ c_table_index_, d_table_, keys, values, metas);
301305 } else {
302- const size_type dev_ws_size{n * (sizeof (value_type*) + sizeof (int ))};
306+ const size_type dev_ws_base_size{n * (sizeof (value_type*) + sizeof (int ))};
307+ const size_type dev_ws_size{dev_ws_base_size +
308+ (ext_store_ ? n : 0 ) * sizeof (key_type)};
303309 auto dev_ws{dev_mem_pool_->get_workspace <1 >(dev_ws_size, stream)};
304310 auto d_dst{dev_ws.get <value_type**>(0 )};
305311 auto d_src_offset{reinterpret_cast <int *>(d_dst + n)};
312+ auto d_evicted_keys{reinterpret_cast <key_type*>(d_src_offset + n)};
306313
307- CUDA_CHECK (cudaMemsetAsync (d_dst, 0 , dev_ws_size , stream));
314+ CUDA_CHECK (cudaMemsetAsync (d_dst, 0 , dev_ws_base_size , stream));
308315
309316 {
310317 const size_t block_size = options_.block_size ;
311318 const size_t N = n * TILE_SIZE;
312319 const size_t grid_size = SAFE_GET_GRID_SIZE (N, block_size);
313320
314321 upsert_kernel<key_type, value_type, meta_type, TILE_SIZE>
315- <<<grid_size, block_size, 0 , stream>>> (d_table_, keys, d_dst, metas,
316- d_src_offset, N);
322+ <<<grid_size, block_size, 0 , stream>>> (
323+ d_table_, keys, ext_store_ ? d_evicted_keys : nullptr , d_dst,
324+ metas, d_src_offset, N);
325+ }
326+
327+ if (ext_store_) {
328+ ext_store_->insert_or_assign (
329+ *dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm , n,
330+ d_evicted_keys, reinterpret_cast <value_type**>(d_dst), stream);
317331 }
318332
319333 {
@@ -326,16 +340,17 @@ class HashTable {
326340 }
327341
328342 if (options_.io_by_cpu ) {
329- const size_type host_ws_size{dev_ws_size +
343+ const size_type host_ws_size{dev_ws_base_size +
330344 n * sizeof (value_type) * dim ()};
331345 auto host_ws{host_mem_pool_->get_workspace <1 >(host_ws_size, stream)};
332346 auto h_dst{host_ws.get <value_type**>(0 )};
333347 auto h_src_offset{reinterpret_cast <int *>(h_dst + n)};
334348 auto h_values{reinterpret_cast <value_type*>(h_src_offset + n)};
335349
336- CUDA_CHECK (cudaMemcpyAsync (h_dst, d_dst, dev_ws_size ,
350+ CUDA_CHECK (cudaMemcpyAsync (h_dst, d_dst, dev_ws_base_size ,
337351 cudaMemcpyDeviceToHost, stream));
338- CUDA_CHECK (cudaMemcpyAsync (h_values, values, host_ws_size - dev_ws_size,
352+ CUDA_CHECK (cudaMemcpyAsync (h_values, values,
353+ host_ws_size - dev_ws_base_size,
339354 cudaMemcpyDeviceToHost, stream));
340355 CUDA_CHECK (cudaStreamSynchronize (stream));
341356
@@ -547,6 +562,11 @@ class HashTable {
547562 }
548563 }
549564
565+ if (ext_store_) {
566+ ext_store_->find (*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds,
567+ stream);
568+ }
569+
550570 CudaCheckError ();
551571 }
552572
@@ -1097,6 +1117,21 @@ class HashTable {
10971117 return total_count;
10981118 }
10991119
1120+ void link_external_storage (
1121+ std::shared_ptr<external_storage_type>& ext_store) {
1122+ MERLIN_CHECK (
1123+ ext_store->value_dim == dim (),
1124+ " Provided external storage value dimension is not incompatible!" );
1125+
1126+ std::unique_lock<std::shared_timed_mutex> lock (mutex_);
1127+ ext_store_ = ext_store;
1128+ }
1129+
1130+ void unlink_external_storage () {
1131+ std::unique_lock<std::shared_timed_mutex> lock (mutex_);
1132+ ext_store_.reset ();
1133+ }
1134+
11001135 private:
11011136 inline bool is_fast_mode () const noexcept { return table_->is_pure_hbm ; }
11021137
@@ -1173,6 +1208,8 @@ class HashTable {
11731208 int c_table_index_ = -1 ;
11741209 std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
11751210 std::unique_ptr<HostMemoryPool> host_mem_pool_;
1211+
1212+ std::shared_ptr<external_storage_type> ext_store_;
11761213};
11771214
11781215} // namespace merlin
0 commit comments