Skip to content

Commit b76b1a0

Browse files
committed
Perf: batch FFT and surrounding operators for performance
Signed-off-by:Tianxiang Wang<[email protected]>,Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd.
1 parent c297263 commit b76b1a0

File tree

24 files changed

+1734
-83
lines changed

24 files changed

+1734
-83
lines changed

source/module_basis/module_pw/fft.cpp

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "module_base/memory.h"
44
#include "module_base/tool_quit.h"
55
#include "module_hamilt_pw/hamilt_pwdft/global.h"
6+
#include "module_parameter/parameter.h"
67

78
namespace ModulePW
89
{
@@ -808,7 +809,8 @@ void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<fl
808809
#endif
809810
}
810811
template <>
811-
void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<double>* in,
812+
void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/,
813+
std::complex<double>* in,
812814
std::complex<double>* out) const
813815
{
814816
#if defined(__CUDA)
@@ -877,4 +879,189 @@ void FFT::set_precision(std::string precision_)
877879
this->precision = std::move(precision_);
878880
}
879881

882+
#if defined(__CUDA) || defined(__ROCM)
883+
template <typename FPTYPE>
884+
BatchedFFT<FPTYPE>::BatchedFFT(int nx_, int ny_, int nz_): nx(nx_), ny(ny_), nz(nz_)
885+
{
886+
}
887+
888+
template <typename FPTYPE>
889+
BatchedFFT<FPTYPE>::BatchedFFT(): nx(0), ny(0), nz(0)
890+
{
891+
}
892+
893+
template <typename FPTYPE>
894+
void BatchedFFT<FPTYPE>::initFFT(int nx_, int ny_, int nz_)
895+
{
896+
nx = nx_;
897+
ny = ny_;
898+
nz = nz_;
899+
this->cleanFFT();
900+
this->clear_data();
901+
}
902+
903+
template <typename FPTYPE>
904+
BatchedFFT<FPTYPE>::~BatchedFFT()
905+
{
906+
this->cleanFFT();
907+
this->clear_data();
908+
}
909+
910+
template <typename FPTYPE>
911+
void BatchedFFT<FPTYPE>::cleanFFT()
912+
{
913+
for (auto& pair : this->plans) {
914+
#if defined(__CUDA)
915+
CHECK_CUFFT(cufftDestroy(pair.second));
916+
#elif defined(__ROCM)
917+
CHECK_CUFFT(hipfftDestroy(pair.second));
918+
#endif
919+
}
920+
}
921+
922+
template <typename FPTYPE>
923+
void BatchedFFT<FPTYPE>::clear_data() const
924+
{
925+
if (this->auxr_3d){
926+
base_device::memory::delete_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d);
927+
this->auxr_3d = nullptr;
928+
this->auxr_3d_size = 0;
929+
}
930+
931+
if (this->sharedWorkArea){
932+
base_device::memory::delete_memory_op<char, base_device::DEVICE_GPU>()(gpu_ctx, this->sharedWorkArea);
933+
this->sharedWorkArea = nullptr;
934+
this->sharedWorkAreaSize = 0;
935+
}
936+
}
937+
938+
template <typename FPTYPE>
939+
std::complex<FPTYPE>* BatchedFFT<FPTYPE>::get_auxr_3d_data(int batchSize) const
940+
{
941+
if(this->auxr_3d_size >= sizeof(std::complex<FPTYPE>) * batchSize * this->nx * this->ny * this->nz && this->auxr_3d != nullptr){
942+
return this->auxr_3d;
943+
}
944+
945+
base_device::memory::resize_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d, this->nx * this->ny * this->nz * batchSize);
946+
this->auxr_3d_size = sizeof(std::complex<FPTYPE>) * batchSize * this->nx * this->ny * this->nz;
947+
return this->auxr_3d;
948+
}
949+
950+
template <typename FPTYPE>
951+
void BatchedFFT<FPTYPE>::fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const
952+
{
953+
#if defined(__CUDA)
954+
cufftHandle plan = this->get_plan_from_cache(batchSize);
955+
if (this->fftType == CUFFT_C2C){
956+
CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out),
957+
CUFFT_FORWARD));
958+
}else{
959+
CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out),
960+
CUFFT_FORWARD));
961+
}
962+
#elif defined(__ROCM)
963+
hipfftHandle plan = this->get_plan_from_cache(batchSize);
964+
if (this->fftType == HIPFFT_C2C){
965+
CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast<hipfftComplex*>(in), reinterpret_cast<hipfftComplex*>(out),
966+
HIPFFT_FORWARD));
967+
}else{
968+
CHECK_CUFFT(hipfftExecZ2Z(plan, reinterpret_cast<hipfftDoubleComplex*>(in), reinterpret_cast<hipfftDoubleComplex*>(out),
969+
HIPFFT_FORWARD));
970+
}
971+
#endif
972+
}
973+
974+
template <typename FPTYPE>
975+
void BatchedFFT<FPTYPE>::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const
976+
{
977+
#if defined(__CUDA)
978+
cufftHandle plan = this->get_plan_from_cache(batchSize);
979+
if (this->fftType == CUFFT_C2C){
980+
CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out),
981+
CUFFT_INVERSE));
982+
}else{
983+
CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out),
984+
CUFFT_INVERSE));
985+
}
986+
#elif defined(__ROCM)
987+
hipfftHandle plan = this->get_plan_from_cache(batchSize);
988+
if (this->fftType == HIPFFT_C2C){
989+
CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast<hipfftComplex*>(in), reinterpret_cast<hipfftComplex*>(out),
990+
HIPFFT_INVERSE));
991+
}else{
992+
CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<hipfftDoubleComplex*>(in), reinterpret_cast<hipfftDoubleComplex*>(out),
993+
HIPFFT_INVERSE));
994+
}
995+
#endif
996+
997+
}
998+
999+
template <typename FPTYPE>
1000+
int BatchedFFT<FPTYPE>::estimate_batch_size(size_t addtional_memory)
1001+
{
1002+
int input_batchSize = PARAM.inp.fft_batch_size;
1003+
if (input_batchSize == 0){
1004+
size_t free_mem, total_mem;
1005+
#if defined(__CUDA)
1006+
cudaMemGetInfo(&free_mem, &total_mem);
1007+
#elif defined(__ROCM)
1008+
hipMemGetInfo(&free_mem, &total_mem);
1009+
#endif
1010+
input_batchSize = free_mem * FREE_MEM_COEFF_FFT / addtional_memory;
1011+
}
1012+
return std::max(1, std::min(MAX_BATCH_SIZE_FFT, input_batchSize));
1013+
}
1014+
1015+
template <typename FPTYPE>
1016+
typename BatchedFFT<FPTYPE>::fftHandleType BatchedFFT<FPTYPE>::get_plan_from_cache(int batchSize) const
1017+
{
1018+
auto it = this->plans.find(batchSize);
1019+
if (it != this->plans.end()) {
1020+
return it->second;
1021+
}
1022+
1023+
fftHandleType plan;
1024+
int rank = 3;
1025+
int n[3] = {this->nx, this->ny, this->nz};
1026+
int *inembed = nullptr;
1027+
int *onembed = nullptr;
1028+
int istride = 1, ostride = 1;
1029+
int idist = this->nx * this->ny * this->nz;
1030+
int odist = idist;
1031+
1032+
size_t workAreaSize;
1033+
1034+
#if defined(__CUDA)
1035+
CHECK_CUFFT(cufftPlanMany(&plan, rank, n, inembed, istride, idist,
1036+
onembed, ostride, odist,
1037+
fftType, batchSize));
1038+
CHECK_CUFFT(cufftGetSize(plan, &workAreaSize));
1039+
#elif defined(__ROCM)
1040+
CHECK_CUFFT(hipfftPlanMany(&plan, rank, n, inembed, istride, idist,
1041+
onembed, ostride, odist,
1042+
fftType, batchSize));
1043+
CHECK_CUFFT(hipfftGetSize(plan, &workAreaSize));
1044+
#endif
1045+
1046+
if (workAreaSize >= this->sharedWorkAreaSize && workAreaSize > 0){
1047+
base_device::memory::resize_memory_op<char, base_device::DEVICE_GPU>()(gpu_ctx, this->sharedWorkArea, workAreaSize);
1048+
this->sharedWorkAreaSize = workAreaSize;
1049+
}
1050+
1051+
if (workAreaSize > 0){
1052+
#if defined(__CUDA)
1053+
CHECK_CUFFT(cufftSetWorkArea(plan, this->sharedWorkArea));
1054+
#elif defined(__ROCM)
1055+
CHECK_CUFFT(hipfftSetWorkArea(plan, this->sharedWorkArea));
1056+
#endif
1057+
}
1058+
1059+
this->plans[batchSize] = plan;
1060+
return plan;
1061+
}
1062+
1063+
template class BatchedFFT<float>;
1064+
template class BatchedFFT<double>;
1065+
1066+
#endif
8801067
} // namespace ModulePW

source/module_basis/module_pw/fft.h

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include <complex>
55
#include <string>
6+
#include <type_traits>
7+
#include <unordered_map>
68

79
#include "fftw3.h"
810
#if defined(__FFTW3_MPI) && defined(__MPI)
@@ -40,13 +42,13 @@ class FFT
4042
FFT();
4143
~FFT();
4244
void clear(); //reset fft
43-
45+
4446
// init parameters of fft
45-
void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
47+
void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
4648
int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false);
4749

4850
//init fftw_plans
49-
void setupFFT();
51+
void setupFFT();
5052

5153
//destroy fftw_plans
5254
void cleanFFT();
@@ -106,7 +108,7 @@ public :
106108
template <typename FPTYPE>
107109
std::complex<FPTYPE>* get_auxr_3d_data() const;
108110

109-
int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive
111+
int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive
110112

111113
private:
112114
bool gamma_only = false;
@@ -167,6 +169,77 @@ public :
167169
void set_precision(std::string precision_);
168170

169171
};
172+
173+
#if defined(__CUDA) || defined(__ROCM)
174+
template <typename T>
175+
struct FFTTypeTraits;
176+
177+
template <>
178+
struct FFTTypeTraits<float> {
179+
#if defined(__CUDA)
180+
using cuComplexType = cufftComplex;
181+
static constexpr cufftType Type = CUFFT_C2C;
182+
#elif defined(__ROCM)
183+
using hipComplexType = hipfftComplex;
184+
static constexpr hipfftType Type = HIPFFT_C2C;
185+
#endif
186+
};
187+
188+
template <>
189+
struct FFTTypeTraits<double> {
190+
#if defined(__CUDA)
191+
using cuComplexType = cufftDoubleComplex;
192+
static constexpr cufftType Type = CUFFT_Z2Z;
193+
#elif defined(__ROCM)
194+
using hipComplexType = hipfftDoubleComplex;
195+
static constexpr hipfftType Type = HIPFFT_Z2Z;
196+
#endif
197+
};
198+
199+
constexpr float FREE_MEM_COEFF_FFT = 0.8;
200+
constexpr int MAX_BATCH_SIZE_FFT = 32;
201+
202+
template<typename FPTYPE> // float or double
203+
class BatchedFFT
204+
{
205+
public:
206+
#if defined(__CUDA)
207+
using cuComplexType = typename FFTTypeTraits<FPTYPE>::cuComplexType;
208+
static constexpr cufftType fftType = FFTTypeTraits<FPTYPE>::Type;
209+
using fftHandleType = cufftHandle;
210+
#elif defined(__ROCM)
211+
using hipComplexType = typename FFTTypeTraits<FPTYPE>::hipComplexType;
212+
static constexpr hipfftType fftType = FFTTypeTraits<FPTYPE>::Type;
213+
using fftHandleType = hipfftHandle;
214+
#endif
215+
216+
BatchedFFT(int nx_, int ny_, int nz_);
217+
BatchedFFT();
218+
~BatchedFFT();
219+
void initFFT(int nx_, int ny_, int nz_);
220+
void cleanFFT();
221+
void clear_data() const;
222+
public:
223+
std::complex<FPTYPE>* get_auxr_3d_data(int batchSize)const;
224+
void fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const;
225+
void fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const;
226+
static int estimate_batch_size(size_t addtional_memory);
227+
228+
private:
229+
int nx, ny, nz;
230+
mutable std::unordered_map<int, fftHandleType> plans;
231+
232+
mutable std::complex<FPTYPE> *auxr_3d = nullptr; // fft space
233+
mutable ::size_t auxr_3d_size = 0;
234+
mutable char *sharedWorkArea = nullptr;
235+
mutable size_t sharedWorkAreaSize = 0;
236+
237+
fftHandleType get_plan_from_cache(int batchSize)const;
238+
239+
};
240+
241+
#endif
242+
170243
}
171244

172245
#endif

0 commit comments

Comments
 (0)