|
3 | 3 | #include "module_base/memory.h" |
4 | 4 | #include "module_base/tool_quit.h" |
5 | 5 | #include "module_hamilt_pw/hamilt_pwdft/global.h" |
| 6 | +#include "module_parameter/parameter.h" |
6 | 7 |
|
7 | 8 | namespace ModulePW |
8 | 9 | { |
@@ -808,7 +809,8 @@ void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<fl |
808 | 809 | #endif |
809 | 810 | } |
810 | 811 | template <> |
811 | | -void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<double>* in, |
| 812 | +void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, |
| 813 | + std::complex<double>* in, |
812 | 814 | std::complex<double>* out) const |
813 | 815 | { |
814 | 816 | #if defined(__CUDA) |
@@ -877,4 +879,189 @@ void FFT::set_precision(std::string precision_) |
877 | 879 | this->precision = std::move(precision_); |
878 | 880 | } |
879 | 881 |
|
| 882 | +#if defined(__CUDA) || defined(__ROCM) |
| 883 | +template <typename FPTYPE> |
| 884 | +BatchedFFT<FPTYPE>::BatchedFFT(int nx_, int ny_, int nz_): nx(nx_), ny(ny_), nz(nz_) |
| 885 | +{ |
| 886 | +} |
| 887 | + |
| 888 | +template <typename FPTYPE> |
| 889 | +BatchedFFT<FPTYPE>::BatchedFFT(): nx(0), ny(0), nz(0) |
| 890 | +{ |
| 891 | +} |
| 892 | + |
| 893 | +template <typename FPTYPE> |
| 894 | +void BatchedFFT<FPTYPE>::initFFT(int nx_, int ny_, int nz_) |
| 895 | +{ |
| 896 | + nx = nx_; |
| 897 | + ny = ny_; |
| 898 | + nz = nz_; |
| 899 | + this->cleanFFT(); |
| 900 | + this->clear_data(); |
| 901 | +} |
| 902 | + |
| 903 | +template <typename FPTYPE> |
| 904 | +BatchedFFT<FPTYPE>::~BatchedFFT() |
| 905 | +{ |
| 906 | + this->cleanFFT(); |
| 907 | + this->clear_data(); |
| 908 | +} |
| 909 | + |
| 910 | +template <typename FPTYPE> |
| 911 | +void BatchedFFT<FPTYPE>::cleanFFT() |
| 912 | +{ |
| 913 | + for (auto& pair : this->plans) { |
| 914 | +#if defined(__CUDA) |
| 915 | + CHECK_CUFFT(cufftDestroy(pair.second)); |
| 916 | +#elif defined(__ROCM) |
| 917 | + CHECK_CUFFT(hipfftDestroy(pair.second)); |
| 918 | +#endif |
| 919 | + } |
| 920 | +} |
| 921 | + |
| 922 | +template <typename FPTYPE> |
| 923 | +void BatchedFFT<FPTYPE>::clear_data() const |
| 924 | +{ |
| 925 | + if (this->auxr_3d){ |
| 926 | + base_device::memory::delete_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d); |
| 927 | + this->auxr_3d = nullptr; |
| 928 | + this->auxr_3d_size = 0; |
| 929 | + } |
| 930 | + |
| 931 | + if (this->sharedWorkArea){ |
| 932 | + base_device::memory::delete_memory_op<char, base_device::DEVICE_GPU>()(gpu_ctx, this->sharedWorkArea); |
| 933 | + this->sharedWorkArea = nullptr; |
| 934 | + this->sharedWorkAreaSize = 0; |
| 935 | + } |
| 936 | +} |
| 937 | + |
| 938 | +template <typename FPTYPE> |
| 939 | +std::complex<FPTYPE>* BatchedFFT<FPTYPE>::get_auxr_3d_data(int batchSize) const |
| 940 | +{ |
| 941 | + if(this->auxr_3d_size >= sizeof(std::complex<FPTYPE>) * batchSize * this->nx * this->ny * this->nz && this->auxr_3d != nullptr){ |
| 942 | + return this->auxr_3d; |
| 943 | + } |
| 944 | + |
| 945 | + base_device::memory::resize_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d, this->nx * this->ny * this->nz * batchSize); |
| 946 | + this->auxr_3d_size = sizeof(std::complex<FPTYPE>) * batchSize * this->nx * this->ny * this->nz; |
| 947 | + return this->auxr_3d; |
| 948 | +} |
| 949 | + |
| 950 | +template <typename FPTYPE> |
| 951 | +void BatchedFFT<FPTYPE>::fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const |
| 952 | +{ |
| 953 | +#if defined(__CUDA) |
| 954 | + cufftHandle plan = this->get_plan_from_cache(batchSize); |
| 955 | + if (this->fftType == CUFFT_C2C){ |
| 956 | + CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out), |
| 957 | + CUFFT_FORWARD)); |
| 958 | + }else{ |
| 959 | + CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out), |
| 960 | + CUFFT_FORWARD)); |
| 961 | + } |
| 962 | +#elif defined(__ROCM) |
| 963 | + hipfftHandle plan = this->get_plan_from_cache(batchSize); |
| 964 | + if (this->fftType == HIPFFT_C2C){ |
| 965 | + CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast<hipfftComplex*>(in), reinterpret_cast<hipfftComplex*>(out), |
| 966 | + HIPFFT_FORWARD)); |
| 967 | + }else{ |
| 968 | + CHECK_CUFFT(hipfftExecZ2Z(plan, reinterpret_cast<hipfftDoubleComplex*>(in), reinterpret_cast<hipfftDoubleComplex*>(out), |
| 969 | + HIPFFT_FORWARD)); |
| 970 | + } |
| 971 | +#endif |
| 972 | +} |
| 973 | + |
| 974 | +template <typename FPTYPE> |
| 975 | +void BatchedFFT<FPTYPE>::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int batchSize)const |
| 976 | +{ |
| 977 | +#if defined(__CUDA) |
| 978 | + cufftHandle plan = this->get_plan_from_cache(batchSize); |
| 979 | + if (this->fftType == CUFFT_C2C){ |
| 980 | + CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out), |
| 981 | + CUFFT_INVERSE)); |
| 982 | + }else{ |
| 983 | + CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<cufftDoubleComplex*>(in), reinterpret_cast<cufftDoubleComplex*>(out), |
| 984 | + CUFFT_INVERSE)); |
| 985 | + } |
| 986 | +#elif defined(__ROCM) |
| 987 | + hipfftHandle plan = this->get_plan_from_cache(batchSize); |
| 988 | + if (this->fftType == HIPFFT_C2C){ |
| 989 | + CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast<hipfftComplex*>(in), reinterpret_cast<hipfftComplex*>(out), |
| 990 | + HIPFFT_INVERSE)); |
| 991 | + }else{ |
| 992 | + CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast<hipfftDoubleComplex*>(in), reinterpret_cast<hipfftDoubleComplex*>(out), |
| 993 | + HIPFFT_INVERSE)); |
| 994 | + } |
| 995 | +#endif |
| 996 | + |
| 997 | +} |
| 998 | + |
| 999 | +template <typename FPTYPE> |
| 1000 | +int BatchedFFT<FPTYPE>::estimate_batch_size(size_t addtional_memory) |
| 1001 | +{ |
| 1002 | + int input_batchSize = PARAM.inp.fft_batch_size; |
| 1003 | + if (input_batchSize == 0){ |
| 1004 | + size_t free_mem, total_mem; |
| 1005 | +#if defined(__CUDA) |
| 1006 | + cudaMemGetInfo(&free_mem, &total_mem); |
| 1007 | +#elif defined(__ROCM) |
| 1008 | + hipMemGetInfo(&free_mem, &total_mem); |
| 1009 | +#endif |
| 1010 | + input_batchSize = free_mem * FREE_MEM_COEFF_FFT / addtional_memory; |
| 1011 | + } |
| 1012 | + return std::max(1, std::min(MAX_BATCH_SIZE_FFT, input_batchSize)); |
| 1013 | +} |
| 1014 | + |
| 1015 | +template <typename FPTYPE> |
| 1016 | +typename BatchedFFT<FPTYPE>::fftHandleType BatchedFFT<FPTYPE>::get_plan_from_cache(int batchSize) const |
| 1017 | +{ |
| 1018 | + auto it = this->plans.find(batchSize); |
| 1019 | + if (it != this->plans.end()) { |
| 1020 | + return it->second; |
| 1021 | + } |
| 1022 | + |
| 1023 | + fftHandleType plan; |
| 1024 | + int rank = 3; |
| 1025 | + int n[3] = {this->nx, this->ny, this->nz}; |
| 1026 | + int *inembed = nullptr; |
| 1027 | + int *onembed = nullptr; |
| 1028 | + int istride = 1, ostride = 1; |
| 1029 | + int idist = this->nx * this->ny * this->nz; |
| 1030 | + int odist = idist; |
| 1031 | + |
| 1032 | + size_t workAreaSize; |
| 1033 | + |
| 1034 | +#if defined(__CUDA) |
| 1035 | + CHECK_CUFFT(cufftPlanMany(&plan, rank, n, inembed, istride, idist, |
| 1036 | + onembed, ostride, odist, |
| 1037 | + fftType, batchSize)); |
| 1038 | + CHECK_CUFFT(cufftGetSize(plan, &workAreaSize)); |
| 1039 | +#elif defined(__ROCM) |
| 1040 | + CHECK_CUFFT(hipfftPlanMany(&plan, rank, n, inembed, istride, idist, |
| 1041 | + onembed, ostride, odist, |
| 1042 | + fftType, batchSize)); |
| 1043 | + CHECK_CUFFT(hipfftGetSize(plan, &workAreaSize)); |
| 1044 | +#endif |
| 1045 | + |
| 1046 | + if (workAreaSize >= this->sharedWorkAreaSize && workAreaSize > 0){ |
| 1047 | + base_device::memory::resize_memory_op<char, base_device::DEVICE_GPU>()(gpu_ctx, this->sharedWorkArea, workAreaSize); |
| 1048 | + this->sharedWorkAreaSize = workAreaSize; |
| 1049 | + } |
| 1050 | + |
| 1051 | + if (workAreaSize > 0){ |
| 1052 | +#if defined(__CUDA) |
| 1053 | + CHECK_CUFFT(cufftSetWorkArea(plan, this->sharedWorkArea)); |
| 1054 | +#elif defined(__ROCM) |
| 1055 | + CHECK_CUFFT(hipfftSetWorkArea(plan, this->sharedWorkArea)); |
| 1056 | +#endif |
| 1057 | + } |
| 1058 | + |
| 1059 | + this->plans[batchSize] = plan; |
| 1060 | + return plan; |
| 1061 | +} |
| 1062 | + |
| 1063 | +template class BatchedFFT<float>; |
| 1064 | +template class BatchedFFT<double>; |
| 1065 | + |
| 1066 | +#endif |
880 | 1067 | } // namespace ModulePW |
0 commit comments