/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/** \file
 * FFT based Poisson solver.
 *
 * \author Ibai Gurrutxaga
 */

/********************************************
 * Includes                                 *
 ********************************************/
#include <cufft.h>
#include <mpi.h>

#include "core/globals.h"
#include "cuPoisson.h"
#include "core/grid.h"
#include "solvers/solver.h"
#include "utils/alg.h"
#include "utils/utils.h"
#include "solvers/fft_utils.h"
#include "solvers/fft_utils_mpi.h"
#include "fft_mpi.h"

/********************************************
 * Private function prototypes              *
 ********************************************/
static cup_error_t exec_fft2d_stage( cup_mpi_solver* mpi_solver, int direction );
static cup_error_t exec_fft1d_stage( cup_mpi_solver*     mpi_solver,
                                     cufftDoubleComplex* data );
static cup_error_t exec_fft2d_pipeline( cup_mpi_solver* mpi_solver, int direction );

/********************************************
 * Public functions                         *
 ********************************************/
cup_error_t mpi_exec_fft_solver( double*         output,
                                 const double*   input,
                                 cup_mpi_solver* mpi_solver )
{
	size_t main_bytes;
	struct mpi_fft_solver_data* mpi_sd;
	struct cup_solver* solver;

	assert( mpi_solver != NULL );

	cudaSetDevice( global_info.devices[0] );

	solver = &mpi_solver->base;

	mpi_sd = &mpi_solver->sd.fft;

	main_bytes = solver->grid->nps[0] *
	             solver->grid->nps[1] *
	             mpi_sd->size_2d *
	             sizeof( double );

	if( input != NULL )
	{
		CUDA( cudaMemcpyAsync( solver->dev_data[0],
		                       input,
		                       main_bytes,
		                       cudaMemcpyHostToDevice ) );
	}

	if( mpi_sd->seg_size == 0 )
	{
		THROW_CUP_ERROR( exec_fft2d_stage( mpi_solver, CUFFT_FORWARD ) );
		THROW_CUP_ERROR( exec_fft1d_stage( mpi_solver,
		                                  (cufftDoubleComplex*) solver->dev_data[0] ) );
		THROW_CUP_ERROR( exec_fft2d_stage( mpi_solver, CUFFT_INVERSE ) );
	}
	else
	{
		THROW_CUP_ERROR( exec_fft2d_pipeline( mpi_solver, CUFFT_FORWARD ) );
		THROW_CUP_ERROR( exec_fft1d_stage( mpi_solver, mpi_sd->dev_data_1d ) );
		THROW_CUP_ERROR( exec_fft2d_pipeline( mpi_solver, CUFFT_INVERSE ) );
	}

	if( output != NULL )
	{
		CUDA( cudaMemcpy( output,
		                  solver->dev_data[0],
		                  main_bytes,
		                  cudaMemcpyDeviceToHost ) );
	}
	else
		CUDA( cudaDeviceSynchronize() );

	return CUP_SUCCESS;
}

cup_error_t mpi_create_fft_solver( struct cup_mpi_solver* mpi_solver )
{
	struct cup_solver* solver;
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;
	size_t buf2d_bytes,
	       buf1d_bytes;

	assert( mpi_solver != NULL );

	solver = &mpi_solver->base;
	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	assert( global_info.num_devices == 1 );

	MALLOC( solver->dev_data, 1, double* );

	THROW_CUP_ERROR( mpi_set_fft_parms( mpi_solver ) );
	print_msg( INFO, "Characteristics of the FFT solver:" );
	print_msg( INFO,
	           "\t2D -> Block size: %d Offset: %d.",
	           mpi_sd->size_2d,
	           mpi_sd->offset_2d );
	print_msg( INFO,
	           "\t1D -> Block size: %d Offset: %d.",
	           mpi_sd->size_1d,
	           mpi_sd->offset_1d );
	if( mpi_sd->seg_size > 0 )
	{
		print_msg( INFO, "\tPipeline segment size: %d.", mpi_sd->seg_size );
		print_msg( INFO, "\tPipeline segments: %d.", mpi_sd->num_segs );
	}
	else
		print_msg( INFO, "\tNot pipelined solver." );

	THROW_CUP_ERROR( mpi_set_fft_plans( mpi_solver ) );

	buf2d_bytes = sd->parms.np_sym_dimension
	              * solver->grid->nps[1]
	              * mpi_sd->size_2d
	              * sizeof( cufftDoubleComplex );
	buf1d_bytes = mpi_sd->size_1d
	              * solver->grid->nps[1]
	              * mpi_sd->padded_np2
	              * sizeof( cufftDoubleComplex );

	if( mpi_sd->seg_size > 0 )
	{
		CUDA( cudaMalloc( solver->dev_data, buf2d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_data_1d, buf1d_bytes ) );

		buf2d_bytes = sd->parms.np_sym_dimension
		              * solver->grid->nps[1]
		              * mpi_sd->seg_size
		              * sizeof( cufftDoubleComplex );
		buf1d_bytes = mpi_sd->size_1d
		              * solver->grid->nps[1]
		              * mpi_sd->seg_size
		              * mpi_solver->size
		              * sizeof( cufftDoubleComplex );

		CUDA( cudaMallocHost( &mpi_sd->buffer2d[0], buf2d_bytes ) );
		CUDA( cudaMallocHost( &mpi_sd->buffer2d[1], buf2d_bytes ) );
		CUDA( cudaMallocHost( &mpi_sd->buffer1d[0], buf1d_bytes ) );
		CUDA( cudaMallocHost( &mpi_sd->buffer1d[1], buf1d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_transposed[0], buf2d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_transposed[1], buf2d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_transposed[2], buf1d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_transposed[3], buf1d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_aux[0], buf2d_bytes) );
		CUDA( cudaMalloc( &mpi_sd->dev_aux[1], buf2d_bytes) );
		mpi_sd->streams[0] = solver->streams[0];
		CUDA( cudaStreamCreate( &mpi_sd->streams[1] ) );
		CUDA( cudaStreamCreate( &mpi_sd->streams[2] ) );
		CUDA( cudaStreamCreate( &mpi_sd->streams[3] ) );
		CUDA( cudaEventCreate( &mpi_sd->events[0] ) );
		CUDA( cudaEventCreate( &mpi_sd->events[1] ) );
		CUDA( cudaEventCreate( &mpi_sd->events[2] ) );
		CUDA( cudaEventCreate( &mpi_sd->events[3] ) );
	}
	else
	{
		CUDA( cudaMalloc( solver->dev_data, MAX( buf2d_bytes, buf1d_bytes) ) );
		CUDA( cudaMallocHost( &mpi_sd->buffer2d[0], buf2d_bytes ) );
		CUDA( cudaMalloc( &mpi_sd->dev_transposed[0],
		                   MAX( buf2d_bytes, buf1d_bytes ) ) );
		CUDA( cudaMallocHost( &mpi_sd->buffer1d[0], buf1d_bytes ) );
		mpi_sd->streams[0] = solver->streams[0];
	}

	return CUP_SUCCESS;
}

cup_error_t mpi_destroy_fft_solver( struct cup_mpi_solver* mpi_solver )
{
	assert( mpi_solver != NULL );

	struct cup_solver* solver = &mpi_solver->base;
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;

	int i;

	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	CUDA( cudaFree( solver->dev_data[0] ) );
	CUDA( cudaFree( mpi_sd->dev_transposed[0] ) );
	if( mpi_sd->seg_size > 0 )
	{
		CUDA( cudaFree( mpi_sd->dev_data_1d ) );
		CUDA( cudaFree( mpi_sd->dev_transposed[1] ) );
		CUDA( cudaFree( mpi_sd->dev_transposed[2] ) );
		CUDA( cudaFree( mpi_sd->dev_transposed[3] ) );
		CUDA( cudaFree( mpi_sd->dev_aux[0] ) );
		CUDA( cudaFree( mpi_sd->dev_aux[1] ) );
		CUDA( cudaStreamDestroy( mpi_sd->streams[1] ) );
		CUDA( cudaStreamDestroy( mpi_sd->streams[2] ) );
		CUDA( cudaStreamDestroy( mpi_sd->streams[3] ) );
		CUDA( cudaEventDestroy( mpi_sd->events[0] ) );
		CUDA( cudaEventDestroy( mpi_sd->events[1] ) );
	}
	if( mpi_sd->size_2d > 0)
	{
		CUDA( cudaFreeHost( mpi_sd->buffer2d[0] ) );
		CUFFT( cufftDestroy( sd->plan_main_fw[0] ) );
		CUFFT( cufftDestroy( sd->plan_main_inv[0] ) );
		if( mpi_sd->seg_size > 0 )
		{
			CUDA( cudaFreeHost( mpi_sd->buffer2d[1] ) );
			CUFFT( cufftDestroy( mpi_sd->plan_2d_fw_last ) );
			CUFFT( cufftDestroy( mpi_sd->plan_2d_inv_last ) );
		}
		free( sd->plan_main_fw );
		free( sd->plan_main_inv );
	}
	if( mpi_sd->size_1d > 0)
	{
		CUDA( cudaFreeHost( mpi_sd->buffer1d[0] ) );
		CUFFT( cufftDestroy( sd->plan_sec[0] ) );
		free( sd->plan_sec );
		if( mpi_sd->seg_size > 0 )
			CUDA( cudaFreeHost( mpi_sd->buffer1d[1] ) );
	}

	free( solver->dev_data );

	for( i = 0; i < (mpi_sd->seg_size > 0 ? 2 : 1); i++ )
	{
		free( mpi_sd->counts_2d[i] );
		free( mpi_sd->displs_2d[i] );
	}

	for( i = 0; i < (mpi_sd->seg_size > 0 ? 2 : 1); i++ )
	{
		free( mpi_sd->counts_1d[i] );
		free( mpi_sd->displs_1d[i] );
	}

	return CUP_SUCCESS;
}

/********************************************
 * Private functions                        *
 ********************************************/
static cup_error_t exec_fft2d_stage( cup_mpi_solver* mpi_solver, int direction )
{
	assert( mpi_solver != NULL );

	struct cup_solver* solver = &mpi_solver->base;
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;
	int np_main;

	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	np_main = solver->grid->nps[0] * solver->grid->nps[1]* mpi_sd->size_2d;

	if( direction == CUFFT_FORWARD )
		if( mpi_sd->size_2d > 0 )
		{
			CUFFT( cufftExecD2Z( sd->plan_main_fw[0],
								 solver->dev_data[0],
								 (cufftDoubleComplex*) solver->dev_data[0] ) );
		}

	THROW_CUP_ERROR( mpi_exchange_fft_data( mpi_solver, direction ) );

	if( direction == CUFFT_INVERSE)
		if( mpi_sd->size_2d > 0 )
		{
			CUFFT( cufftExecZ2D( sd->plan_main_inv[0],
			                     (cufftDoubleComplex*)
			                     solver->dev_data[0],
			                     solver->dev_data[0] ) );
			dscalv( 1.0/solver->grid->np,
			        solver->dev_data[0],
			        solver->dev_data[0],
			        np_main,
			        0,
			        solver->streams[0] );
		}

	return CUP_SUCCESS;
}

static cup_error_t exec_fft1d_stage( cup_mpi_solver*     mpi_solver,
                                     cufftDoubleComplex* data )
{
	assert( mpi_solver != NULL );

	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;

	sd = &mpi_solver->base.sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	if( mpi_sd->size_1d > 0 )
	{
		CUFFT( cufftExecZ2Z( sd->plan_sec[0], data, data, CUFFT_FORWARD ) );
		mpi_solve_poisson_in_GPU( mpi_solver, data );
		CUFFT( cufftExecZ2Z( sd->plan_sec[0], data, data, CUFFT_INVERSE ) );
	}

	CUDA( cudaStreamSynchronize( mpi_sd->streams[0] ) );

	return CUP_SUCCESS;
}

static cup_error_t exec_fft2d_pipeline( cup_mpi_solver* mpi_solver, int direction )
{
	assert( mpi_solver != NULL );

	struct mpi_fft_solver_data* mpi_sd;
	int i_seg;

	mpi_sd = &mpi_solver->sd.fft;

	if( direction == CUFFT_FORWARD )
	{
		THROW_CUP_ERROR( mpi_pipeline_fft2d_fw( mpi_solver, 0 ) );
		THROW_CUP_ERROR( mpi_pipeline_fft2d_fw( mpi_solver, 1 ) );
		for( i_seg = 0; i_seg < mpi_sd->num_segs; i_seg++ )
		{
			THROW_CUP_ERROR( mpi_pipeline_alltoall_fw( mpi_solver, i_seg ) );
			if( i_seg < mpi_sd->num_segs - 2 )
				THROW_CUP_ERROR( mpi_pipeline_fft2d_fw( mpi_solver, i_seg + 2 ) );
			THROW_CUP_ERROR( mpi_pipeline_h2d_fw( mpi_solver, i_seg ) );
		}
		CUDA( cudaStreamSynchronize( mpi_sd->streams[2] ) );
		CUDA( cudaStreamSynchronize( mpi_sd->streams[3] ) );
	}

	if( direction == CUFFT_INVERSE)
	{
		THROW_CUP_ERROR( mpi_pipeline_d2h_inv( mpi_solver, 0 ) );
		THROW_CUP_ERROR( mpi_pipeline_d2h_inv( mpi_solver, 1 ) );
		for( i_seg = 0; i_seg < mpi_sd->num_segs; i_seg++ )
		{
			THROW_CUP_ERROR( mpi_pipeline_alltoall_inv( mpi_solver, i_seg ) );
			if( i_seg < mpi_sd->num_segs - 2 )
				THROW_CUP_ERROR( mpi_pipeline_d2h_inv( mpi_solver, i_seg + 2) );
			THROW_CUP_ERROR( mpi_pipeline_fft2d_inv( mpi_solver,
			                                         i_seg ) );
		}
	}

	return CUP_SUCCESS;
}
