Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion source/module_basis/module_pw/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "module_base/memory.h"
#include "module_base/tool_quit.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_parameter/parameter.h"

namespace ModulePW
{
Expand Down Expand Up @@ -808,7 +809,8 @@ void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<fl
#endif
}
template <>
void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<double>* in,
void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/,
std::complex<double>* in,
std::complex<double>* out) const
{
#if defined(__CUDA)
Expand Down Expand Up @@ -877,4 +879,189 @@ void FFT::set_precision(std::string precision_)
this->precision = std::move(precision_);
}

#if defined(__CUDA) || defined(__ROCM)
template <typename FPTYPE>
BatchedFFT<FPTYPE>::BatchedFFT(int nx_, int ny_, int nz_): nx(nx_), ny(ny_), nz(nz_)
{
}

template <typename FPTYPE>
BatchedFFT<FPTYPE>::BatchedFFT(): nx(0), ny(0), nz(0)
{
}

template <typename FPTYPE>
void BatchedFFT<FPTYPE>::initFFT(int nx_, int ny_, int nz_)
{
nx = nx_;
ny = ny_;
nz = nz_;
this->cleanFFT();
this->clear_data();
}

template <typename FPTYPE>
BatchedFFT<FPTYPE>::~BatchedFFT()
{
this->cleanFFT();
this->clear_data();
}

template <typename FPTYPE>
void BatchedFFT<FPTYPE>::cleanFFT()
{
for (auto& pair : this->plans) {
#if defined(__CUDA)
CHECK_CUFFT(cufftDestroy(pair.second));
#elif defined(__ROCM)
CHECK_CUFFT(hipfftDestroy(pair.second));
#endif
}
}

template <typename FPTYPE>
void BatchedFFT<FPTYPE>::clear_data() const
{
if (this->auxr_3d){
base_device::memory::delete_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d);
this->auxr_3d = nullptr;
this->auxr_3d_size = 0;
}

if (this->sharedWorkArea){
base_device::memory::delete_memory_op<char, base_device::DEVICE_GPU>()(gpu_ctx, this->sharedWorkArea);
this->sharedWorkArea = nullptr;
this->sharedWorkAreaSize = 0;
}
}

template <typename FPTYPE>
std::complex<FPTYPE>* BatchedFFT<FPTYPE>::get_auxr_3d_data(int batchSize) const
{
if(this->auxr_3d_size >= sizeof(std::complex<FPTYPE>) * batchSize * this->nx * this->ny * this->nz && this->auxr_3d != nullptr){
return this->auxr_3d;
}

base_device::memory::resize_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d, this->nx * this->ny * this->nz * batchSize);
this->auxr_3d_size = sizeof(std::complex<FPTYPE>) * batchSize * this->nx * this->ny * this->nz;
return this->auxr_3d;
}

template <typename FPTYPE>
void BatchedFFT<FPTYPE>::fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const
{
#if defined(__CUDA)
cufftHandle plan = this->get_plan_from_cache(batchSize);
if (this->fftType == CUFFT_C2C){
CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out),
CUFFT_FORWARD));
}else{
CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out),
CUFFT_FORWARD));
}
#elif defined(__ROCM)
hipfftHandle plan = this->get_plan_from_cache(batchSize);
if (this->fftType == HIPFFT_C2C){
CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast<hipfftComplex*>(in), reinterpret_cast<hipfftComplex*>(out),
HIPFFT_FORWARD));
}else{
CHECK_CUFFT(hipfftExecZ2Z(plan, reinterpret_cast<hipfftDoubleComplex*>(in), reinterpret_cast<hipfftDoubleComplex*>(out),
HIPFFT_FORWARD));
}
#endif
}

template <typename FPTYPE>
void BatchedFFT<FPTYPE>::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const
{
#if defined(__CUDA)
cufftHandle plan = this->get_plan_from_cache(batchSize);
if (this->fftType == CUFFT_C2C){
CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out),
CUFFT_INVERSE));
}else{
CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out),
CUFFT_INVERSE));
}
#elif defined(__ROCM)
hipfftHandle plan = this->get_plan_from_cache(batchSize);
if (this->fftType == HIPFFT_C2C){
CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast<hipfftComplex*>(in), reinterpret_cast<hipfftComplex*>(out),
HIPFFT_INVERSE));
}else{
CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<hipfftDoubleComplex*>(in), reinterpret_cast<hipfftDoubleComplex*>(out),
HIPFFT_INVERSE));
}
#endif

}

template <typename FPTYPE>
int BatchedFFT<FPTYPE>::estimate_batch_size(size_t addtional_memory)
{
int input_batchSize = PARAM.inp.fft_batch_size;
if (input_batchSize == 0){
size_t free_mem, total_mem;
#if defined(__CUDA)
cudaMemGetInfo(&free_mem, &total_mem);
#elif defined(__ROCM)
hipMemGetInfo(&free_mem, &total_mem);
#endif
input_batchSize = free_mem * FREE_MEM_COEFF_FFT / addtional_memory;
}
return std::max(1, std::min(MAX_BATCH_SIZE_FFT, input_batchSize));
}

template <typename FPTYPE>
typename BatchedFFT<FPTYPE>::fftHandleType BatchedFFT<FPTYPE>::get_plan_from_cache(int batchSize) const
{
auto it = this->plans.find(batchSize);
if (it != this->plans.end()) {
return it->second;
}

fftHandleType plan;
int rank = 3;
int n[3] = {this->nx, this->ny, this->nz};
int *inembed = nullptr;
int *onembed = nullptr;
int istride = 1, ostride = 1;
int idist = this->nx * this->ny * this->nz;
int odist = idist;

size_t workAreaSize;

#if defined(__CUDA)
CHECK_CUFFT(cufftPlanMany(&plan, rank, n, inembed, istride, idist,
onembed, ostride, odist,
fftType, batchSize));
CHECK_CUFFT(cufftGetSize(plan, &workAreaSize));
#elif defined(__ROCM)
CHECK_CUFFT(hipfftPlanMany(&plan, rank, n, inembed, istride, idist,
onembed, ostride, odist,
fftType, batchSize));
CHECK_CUFFT(hipfftGetSize(plan, &workAreaSize));
#endif

if (workAreaSize >= this->sharedWorkAreaSize && workAreaSize > 0){
base_device::memory::resize_memory_op<char, base_device::DEVICE_GPU>()(gpu_ctx, this->sharedWorkArea, workAreaSize);
this->sharedWorkAreaSize = workAreaSize;
}

if (workAreaSize > 0){
#if defined(__CUDA)
CHECK_CUFFT(cufftSetWorkArea(plan, this->sharedWorkArea));
#elif defined(__ROCM)
CHECK_CUFFT(hipfftSetWorkArea(plan, this->sharedWorkArea));
#endif
}

this->plans[batchSize] = plan;
return plan;
}

template class BatchedFFT<float>;
template class BatchedFFT<double>;

#endif
} // namespace ModulePW
81 changes: 77 additions & 4 deletions source/module_basis/module_pw/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include <complex>
#include <string>
#include <type_traits>
#include <unordered_map>

#include "fftw3.h"
#if defined(__FFTW3_MPI) && defined(__MPI)
Expand Down Expand Up @@ -40,13 +42,13 @@ class FFT
FFT();
~FFT();
void clear(); //reset fft

// init parameters of fft
void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false);

//init fftw_plans
void setupFFT();
void setupFFT();

//destroy fftw_plans
void cleanFFT();
Expand Down Expand Up @@ -106,7 +108,7 @@ public :
template <typename FPTYPE>
std::complex<FPTYPE>* get_auxr_3d_data() const;

int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive
int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive

private:
bool gamma_only = false;
Expand Down Expand Up @@ -167,6 +169,77 @@ public :
void set_precision(std::string precision_);

};

#if defined(__CUDA) || defined(__ROCM)
template <typename T>
struct FFTTypeTraits;

template <>
struct FFTTypeTraits<float> {
#if defined(__CUDA)
using cuComplexType = cufftComplex;
static constexpr cufftType Type = CUFFT_C2C;
#elif defined(__ROCM)
using hipComplexType = hipfftComplex;
static constexpr hipfftType Type = HIPFFT_C2C;
#endif
};

template <>
struct FFTTypeTraits<double> {
#if defined(__CUDA)
using cuComplexType = cufftDoubleComplex;
static constexpr cufftType Type = CUFFT_Z2Z;
#elif defined(__ROCM)
using hipComplexType = hipfftDoubleComplex;
static constexpr hipfftType Type = HIPFFT_Z2Z;
#endif
};

constexpr float FREE_MEM_COEFF_FFT = 0.8;
constexpr int MAX_BATCH_SIZE_FFT = 32;

template<typename FPTYPE> // float or double
class BatchedFFT
{
public:
#if defined(__CUDA)
using cuComplexType = typename FFTTypeTraits<FPTYPE>::cuComplexType;
static constexpr cufftType fftType = FFTTypeTraits<FPTYPE>::Type;
using fftHandleType = cufftHandle;
#elif defined(__ROCM)
using hipComplexType = typename FFTTypeTraits<FPTYPE>::hipComplexType;
static constexpr hipfftType fftType = FFTTypeTraits<FPTYPE>::Type;
using fftHandleType = hipfftHandle;
#endif

BatchedFFT(int nx_, int ny_, int nz_);
BatchedFFT();
~BatchedFFT();
void initFFT(int nx_, int ny_, int nz_);
void cleanFFT();
void clear_data() const;
public:
std::complex<FPTYPE>* get_auxr_3d_data(int batchSize)const;
void fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const;
void fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const;
static int estimate_batch_size(size_t addtional_memory);

private:
int nx, ny, nz;
mutable std::unordered_map<int, fftHandleType> plans;

mutable std::complex<FPTYPE> *auxr_3d = nullptr; // fft space
mutable ::size_t auxr_3d_size = 0;
mutable char *sharedWorkArea = nullptr;
mutable size_t sharedWorkAreaSize = 0;

fftHandleType get_plan_from_cache(int batchSize)const;

};

#endif

}

#endif
Expand Down
Loading
Loading