Skip to content

Commit e79efa3

Browse files
committed
GDS interface design implementation.
1 parent eb68bd0 commit e79efa3

File tree

3 files changed

+168
-13
lines changed

3 files changed

+168
-13
lines changed

include/merlin/core_kernels.cuh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -942,8 +942,9 @@ struct SelectUpsertKernelWithIO {
942942
*/
943943
template <class K, class V, class M, uint32_t TILE_SIZE = 4>
944944
__global__ void upsert_kernel(const Table<K, V, M>* __restrict table,
945-
const K* __restrict keys, V** __restrict vectors,
946-
const M* __restrict metas,
945+
const K* __restrict keys,
946+
K* __restrict evicted_keys,
947+
V** __restrict vectors, const M* __restrict metas,
947948
int* __restrict src_offset, size_t N) {
948949
Bucket<K, V, M>* buckets = table->buckets;
949950
int* buckets_size = table->buckets_size;
@@ -959,7 +960,7 @@ __global__ void upsert_kernel(const Table<K, V, M>* __restrict table,
959960

960961
for (size_t t = tid; t < N; t += blockDim.x * gridDim.x) {
961962
int key_pos = -1;
962-
size_t key_idx = t / TILE_SIZE;
963+
const size_t key_idx{t / TILE_SIZE};
963964
int local_size = 0;
964965

965966
const K insert_key = keys[key_idx];
@@ -1063,6 +1064,10 @@ __global__ void upsert_kernel(const Table<K, V, M>* __restrict table,
10631064
key_pos = bucket->min_pos;
10641065

10651066
if (rank == src_lane) {
1067+
if (evicted_keys) {
1068+
evicted_keys[key_idx] =
1069+
bucket->keys[key_pos].load(cuda::std::memory_order_relaxed);
1070+
}
10661071
bucket->keys[key_pos].store(insert_key,
10671072
cuda::std::memory_order_relaxed);
10681073
*(vectors + key_idx) = (bucket->vectors + key_pos * dim);
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) 2022, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <cstdint>
19+
#include <type_traits>
20+
#include "merlin/memory_pool.cuh"
21+
22+
namespace nv {
23+
namespace merlin {
24+
25+
template <class Key, class Value>
26+
class ExternalStorage {
27+
public:
28+
using size_type = size_t;
29+
using key_type = Key;
30+
using value_type = Value;
31+
32+
using dev_mem_pool_type = MemoryPool<DeviceAllocator<char>>;
33+
using host_mem_pool_type = MemoryPool<HostAllocator<char>>;
34+
35+
const size_type value_dim;
36+
37+
ExternalStorage() = delete;
38+
39+
/**
40+
* Constructs external storage object.
41+
*
42+
* @param value_dim The dimensionality of the values. In other words, each
43+
* value stored is exactly `value_dim * sizeof(value_type)` bytes large.
44+
*/
45+
ExternalStorage(const size_type value_dim) : value_dim{value_dim} {}
46+
47+
/**
48+
* @brief Inserts key/value pairs into the external storage that are about to
49+
* be evicted from the Merlin hashtable. If a key/value pair already exists,
50+
* overwrites the current value.
51+
*
52+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
53+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
54+
* @param hkvs_is_pure_hbm True if the Merlin hashtable store is currently
55+
* operating in pure HBM mode, false otherwise. In pure HBM mode, all `values`
56+
* pointers are GUARANTEED to point to device memory.
57+
* @param n Number of key/value slots provided in other arguments.
58+
* @param d_masked_keys Device pointer to an (n)-sized array of keys.
59+
* Key-Value slots that should be ignored have the key set tO `EMPTY_KEY`.
60+
* @param d_values Device pointer to an (n)-sized array containing pointers to
61+
* respectively a memory location where the current values for a key are
62+
* stored. Each pointer points to a vector of length `value_dim`. Pointers
63+
* *can* be set to `nullptr` for slots where the corresponding key equated to
64+
* the `EMPTY_KEY`. The memory locations can be device or host memory (see
65+
* also `hkvs_is_pure_hbm`).
66+
* @param stream Stream that MUST be used for queuing asynchronous CUDA
67+
* operations. If only the input arguments or resources obtained from
68+
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
69+
* operations, it is not necessary to synchronize the stream prior to
70+
* returning from the function.
71+
*/
72+
virtual void insert_or_assign(dev_mem_pool_type& dev_mem_pool,
73+
host_mem_pool_type& host_mem_pool,
74+
bool hkvs_is_pure_hbm, size_type n,
75+
const key_type* d_masked_keys, // (n)
76+
const value_type* const* d_values, // (n)
77+
cudaStream_t stream) = 0;
78+
79+
/**
80+
* @brief Attempts to find the supplied `d_keys` if the corresponding
81+
* `d_founds`-flag is `false` and fills the stored into the supplied memory
82+
* locations (i.e. in `d_values`).
83+
*
84+
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
85+
* @param host_mem_pool Memory pool for temporarily allocating host memory.
86+
* @param n Number of key/value slots provided in other arguments.
87+
* @param d_keys Device pointer to an (n)-sized array of keys.
88+
* @param d_values Device pointer to an (n * value_dim)-sized array to store
89+
* the retrieved `d_values`. For slots where the corresponding `d_founds`-flag
90+
* is not `false`, the value may already have been assigned and, thus, MUST
91+
* not be altered.
92+
* @param d_founds Device pointer to an (n)-sized array which indicates
93+
* whether the corresponding `d_values` slot is already filled or not. So, if
94+
* and only if `d_founds` is still false, the implementation shall attempt to
95+
* retrieve and fill in the value for the corresponding key. If a key/value
96+
* was retrieved successfully from external storage, the implementation MUST
97+
* also set `d_founds` to `true`.
98+
* @param stream Stream that MUST be used for queuing asynchronous CUDA
99+
* operations. If only the input arguments or resources obtained from
100+
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
101+
* operations, it is not necessary to synchronize the stream prior to
102+
* returning from the function.
103+
*/
104+
virtual void find(dev_mem_pool_type& dev_mem_pool,
105+
host_mem_pool_type& host_mem_pool, size_type n,
106+
const key_type* d_keys, // (n)
107+
value_type* d_values, // (n * value_dim)
108+
bool* d_founds, // (n)
109+
cudaStream_t stream) = 0;
110+
};
111+
112+
} // namespace merlin
113+
} // namespace nv

include/merlin_hashtable.cuh

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <shared_mutex>
2727
#include <type_traits>
2828
#include "merlin/core_kernels.cuh"
29+
#include "merlin/external_storage.cuh"
2930
#include "merlin/flexible_buffer.cuh"
3031
#include "merlin/memory_pool.cuh"
3132
#include "merlin/types.cuh"
@@ -152,6 +153,8 @@ class HashTable {
152153
using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
153154
using HostMemoryPool = MemoryPool<HostAllocator<char>>;
154155

156+
using external_storage_type = ExternalStorage<K, V>;
157+
155158
#if THRUST_VERSION >= 101600
156159
static constexpr auto thrust_par = thrust::cuda::par_nosync;
157160
#else
@@ -169,6 +172,8 @@ class HashTable {
169172
* table object.
170173
*/
171174
~HashTable() {
175+
unlink_external_storage();
176+
172177
if (initialized_) {
173178
CUDA_CHECK(cudaDeviceSynchronize());
174179

@@ -295,25 +300,34 @@ class HashTable {
295300
load_factor = fast_load_factor();
296301
}
297302

298-
Selector::execute_kernel(
299-
load_factor, options_.block_size, stream, n, c_table_index_, d_table_,
300-
keys, reinterpret_cast<const value_type*>(values), metas);
303+
Selector::execute_kernel(load_factor, options_.block_size, stream, n,
304+
c_table_index_, d_table_, keys, values, metas);
301305
} else {
302-
const size_type dev_ws_size{n * (sizeof(value_type*) + sizeof(int))};
306+
const size_type dev_ws_base_size{n * (sizeof(value_type*) + sizeof(int))};
307+
const size_type dev_ws_size{dev_ws_base_size +
308+
(ext_store_ ? n : 0) * sizeof(key_type)};
303309
auto dev_ws{dev_mem_pool_->get_workspace<1>(dev_ws_size, stream)};
304310
auto d_dst{dev_ws.get<value_type**>(0)};
305311
auto d_src_offset{reinterpret_cast<int*>(d_dst + n)};
312+
auto d_evicted_keys{reinterpret_cast<key_type*>(d_src_offset + n)};
306313

307-
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));
314+
CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_base_size, stream));
308315

309316
{
310317
const size_t block_size = options_.block_size;
311318
const size_t N = n * TILE_SIZE;
312319
const size_t grid_size = SAFE_GET_GRID_SIZE(N, block_size);
313320

314321
upsert_kernel<key_type, value_type, meta_type, TILE_SIZE>
315-
<<<grid_size, block_size, 0, stream>>>(d_table_, keys, d_dst, metas,
316-
d_src_offset, N);
322+
<<<grid_size, block_size, 0, stream>>>(
323+
d_table_, keys, ext_store_ ? d_evicted_keys : nullptr, d_dst,
324+
metas, d_src_offset, N);
325+
}
326+
327+
if (ext_store_) {
328+
ext_store_->insert_or_assign(
329+
*dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm, n,
330+
d_evicted_keys, reinterpret_cast<value_type**>(d_dst), stream);
317331
}
318332

319333
{
@@ -326,16 +340,17 @@ class HashTable {
326340
}
327341

328342
if (options_.io_by_cpu) {
329-
const size_type host_ws_size{dev_ws_size +
343+
const size_type host_ws_size{dev_ws_base_size +
330344
n * sizeof(value_type) * dim()};
331345
auto host_ws{host_mem_pool_->get_workspace<1>(host_ws_size, stream)};
332346
auto h_dst{host_ws.get<value_type**>(0)};
333347
auto h_src_offset{reinterpret_cast<int*>(h_dst + n)};
334348
auto h_values{reinterpret_cast<value_type*>(h_src_offset + n)};
335349

336-
CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_size,
350+
CUDA_CHECK(cudaMemcpyAsync(h_dst, d_dst, dev_ws_base_size,
337351
cudaMemcpyDeviceToHost, stream));
338-
CUDA_CHECK(cudaMemcpyAsync(h_values, values, host_ws_size - dev_ws_size,
352+
CUDA_CHECK(cudaMemcpyAsync(h_values, values,
353+
host_ws_size - dev_ws_base_size,
339354
cudaMemcpyDeviceToHost, stream));
340355
CUDA_CHECK(cudaStreamSynchronize(stream));
341356

@@ -547,6 +562,11 @@ class HashTable {
547562
}
548563
}
549564

565+
if (ext_store_) {
566+
ext_store_->find(*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds,
567+
stream);
568+
}
569+
550570
CudaCheckError();
551571
}
552572

@@ -1097,6 +1117,21 @@ class HashTable {
10971117
return total_count;
10981118
}
10991119

1120+
void link_external_storage(
1121+
std::shared_ptr<external_storage_type>& ext_store) {
1122+
MERLIN_CHECK(
1123+
ext_store->value_dim == dim(),
1124+
"Provided external storage value dimension is not incompatible!");
1125+
1126+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
1127+
ext_store_ = ext_store;
1128+
}
1129+
1130+
void unlink_external_storage() {
1131+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
1132+
ext_store_.reset();
1133+
}
1134+
11001135
private:
11011136
inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; }
11021137

@@ -1173,6 +1208,8 @@ class HashTable {
11731208
int c_table_index_ = -1;
11741209
std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
11751210
std::unique_ptr<HostMemoryPool> host_mem_pool_;
1211+
1212+
std::shared_ptr<external_storage_type> ext_store_;
11761213
};
11771214

11781215
} // namespace merlin

0 commit comments

Comments
 (0)