Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion source/source_base/module_fft/fft_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class FFT_BASE
* @brief Initialize the fft parameters as virtual function.
*
* The function is used to initialize the fft parameters.
* Only FFT on GPU supports batch FFT. So only the second function has the batch_size parameter.
*/
virtual __attribute__((weak)) void initfft(int nx_in,
int ny_in,
Expand All @@ -27,7 +28,7 @@ class FFT_BASE
bool gamma_only_in,
bool xprime_in = true);

virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in);
virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in, int batch_size = 0);

/**
* @brief Setup the fft plan and data as pure virtual function.
Expand Down
1 change: 1 addition & 0 deletions source/source_base/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void FFT_Bundle::initfft(int nx_in,
int nproc_in,
bool gamma_only_in,
bool xprime_in,
int batch_size,
bool mpifft_in)
{
assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp");
Expand Down
1 change: 1 addition & 0 deletions source/source_base/module_fft/fft_bundle.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class FFT_Bundle
int nproc_in,
bool gamma_only_in,
bool xprime_in = true,
int batch_size = 0,
bool mpifft_in = false);

/**
Expand Down
2 changes: 1 addition & 1 deletion source/source_base/module_fft/fft_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void FFT_CPU<FPTYPE>::initfft(int nx_in,
int ns_in,
int nplane_in,
int nproc_in,
bool gamma_only_in,
bool gamma_only_in,
bool xprime_in)
{
this->gamma_only = gamma_only_in;
Expand Down
2 changes: 1 addition & 1 deletion source/source_base/module_fft/fft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FFT_CPU : public FFT_BASE<FPTYPE>
int ns_in,
int nplane_in,
int nproc_in,
bool gamma_only_in,
bool gamma_only_in,
bool xprime_in = true) override;

__attribute__((weak))
Expand Down
44 changes: 39 additions & 5 deletions source/source_base/module_fft/fft_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,57 @@
namespace ModuleBase
{
template <typename FPTYPE>
void FFT_CUDA<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in)
void FFT_CUDA<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in, int batch_size)
{
this->nx = nx_in;
this->ny = ny_in;
this->nz = nz_in;
this->batch_size = batch_size;
}
template <>
void FFT_CUDA<float>::setupFFT()
{
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
if (this->batch_size){
int rank = 3; // this means the dimension is 3
int n[3] = {this->nx, this->ny, this->nz};
int inembed[3] = {this->nx, this->ny, this->nz};
int onembed[3] = {this->nx, this->ny, this->nz};
int istride = 1, ostride = 1;
size_t N = static_cast<size_t>(this->nx) * this->ny * this->nz;
int idist = N;
int odist = N;
cufftPlanMany(&c_handle, rank, n,
inembed, istride, idist,
onembed, ostride, odist,
CUFFT_C2C, this->batch_size);
}
else{
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
}

}
template <>
void FFT_CUDA<double>::setupFFT()
{
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
if (this->batch_size){
int rank = 3; // this means the dimension is 3
int n[3] = {this->nx, this->ny, this->nz};
int inembed[3] = {this->nx, this->ny, this->nz};
int onembed[3] = {this->nx, this->ny, this->nz};
int istride = 1, ostride = 1;
size_t N = static_cast<size_t>(this->nx) * this->ny * this->nz;
int idist = N;
int odist = N;
cufftPlanMany(&z_handle, rank, n,
inembed, istride, idist,
onembed, ostride, odist,
CUFFT_Z2Z, this->batch_size);
}
else{
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
}
}
template <>
void FFT_CUDA<float>::cleanFFT()
Expand Down
6 changes: 5 additions & 1 deletion source/source_base/module_fft/fft_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class FFT_CUDA : public FFT_BASE<FPTYPE>
* @param nx_in number of grid points in x direction
* @param ny_in number of grid points in y direction
* @param nz_in number of grid points in z direction
* @param batch_size number of batches. Please set to zero if batch FFT is not needed.
*
*/
void initfft(int nx_in,
int ny_in,
int nz_in) override;
int nz_in,
int batch_size) override;

/**
* @brief Get the real space data
Expand Down Expand Up @@ -61,6 +63,8 @@ class FFT_CUDA : public FFT_BASE<FPTYPE>
std::complex<float>* c_auxr_3d = nullptr; // fft space
std::complex<double>* z_auxr_3d = nullptr; // fft space

int batch_size = 0;

};

} // namespace ModuleBase
Expand Down
1 change: 1 addition & 0 deletions source/source_io/module_parameter/input_parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct Input_para
double ecutrho = 0; ///< energy cutoff for charge/potential

int nx = 0, ny = 0, nz = 0; ///< three dimension of FFT wavefunc
int fft_batch = 0; ///< the batch size of FFT on GPU. Set to zero if don't need to use.
int ndx = 0, ndy = 0, ndz = 0; ///< three dimension of FFT smooth charge density

double cell_factor = 1.2; ///< LiuXh add 20180619
Expand Down
15 changes: 15 additions & 0 deletions source/source_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,21 @@ void ReadInput::item_system()
sync_int(input.nz);
this->add_item(item);
}
{
Input_Item item("fft_batch");
item.annotation = "the batch size of FFT on GPU, probably makes cuFFT faster";
item.read_value = [](const Input_Item& item, Parameter& para) {
para.input.fft_batch = intvalue;
};
item.check_value = [](const Input_Item& item, const Parameter& para) {
if (para.input.fft_batch < 0)
{
ModuleBase::WARNING_QUIT("ReadInput", "fft_batch should be set to no less than zero");
}
};
sync_int(input.fft_batch);
this->add_item(item);
}
{
Input_Item item("ndx");
item.annotation = "number of points along x axis for FFT smooth grid";
Expand Down
Loading