/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/********************************************
 * Includes                                 *
 ********************************************/
#include <cufft.h>
#include <cuda_runtime.h>

#include "cuPoisson.h"
#include "core/globals.h"
#include "core/grid.h"
#include "solvers/solver.h"
#include "fft_utils.h"
#include "fft_utils_ser.h"

/********************************************
 * Public functions                         *
 ********************************************/
cup_error_t ser_set_fft_parms( cup_solver* solver )
{
	int nd,
	    i_dev;

	assert( solver != NULL );

	nd = global_info.num_devices;

	solver->sd.fft.parms.np_sym_dimension = ( solver->grid->nps[0] / 2 ) + 1;
	distribute_blocks( solver->grid->nps[2], nd, &solver->sd.fft.parms.fft2d );
	distribute_blocks( solver->sd.fft.parms.np_sym_dimension, nd, &solver->sd.fft.parms.fft1d );

	MALLOC( solver->sd.fft.sync2d,
	        solver->sd.fft.parms.fft2d.num_blocks,
	        struct thread_sync );
	MALLOC( solver->sd.fft.sync1d,
	        solver->sd.fft.parms.fft1d.num_blocks,
	        struct thread_sync );

	for( i_dev = 0; i_dev < solver->sd.fft.parms.fft2d.num_blocks; i_dev++ )
	{
		PTHREAD( pthread_mutex_init( &solver->sd.fft.sync2d[i_dev].mutex, 0 ) );
		PTHREAD( pthread_cond_init(&solver->sd.fft.sync2d[i_dev].cond, 0 ) );
	}
	for( i_dev = 0; i_dev < solver->sd.fft.parms.fft1d.num_blocks; i_dev++ )
	{
		CUDA( cudaSetDevice( global_info.devices[i_dev] ) );
		PTHREAD( pthread_mutex_init( &solver->sd.fft.sync1d[i_dev].mutex, 0 ) );
		PTHREAD( pthread_cond_init( &solver->sd.fft.sync1d[i_dev].cond, 0 ) );
	}

	return CUP_SUCCESS;
}

cup_error_t ser_set_fft_plans( cup_solver* solver )
{
	struct fft_solver_data* sd;
	int  i_dev,
	     nd,
	   * devices,
	     nps_rev[3];

	assert( solver != NULL );

	nd = global_info.num_devices;
	devices = global_info.devices;
	sd = &solver->sd.fft;
	nps_rev[0] = solver->grid->nps[2];
	nps_rev[1] = solver->grid->nps[1];
	nps_rev[2] = solver->grid->nps[0];

	if( nd == 1 )
	{
		MALLOC( sd->plan_main_fw, 1, cufftHandle );

		MALLOC( sd->plan_main_inv, 1, cufftHandle );

		CUFFT( cufftPlan3d( &sd->plan_main_fw[0],
		                    solver->grid->nps[2],
		                    solver->grid->nps[1],
		                    solver->grid->nps[0],
		                    CUFFT_D2Z ) )
		CUFFT( cufftPlan3d( &sd->plan_main_inv[0],
		                    solver->grid->nps[2],
		                    solver->grid->nps[1],
		                    solver->grid->nps[0],
		                    CUFFT_Z2D ) )
		CUFFT( cufftSetCompatibilityMode( sd->plan_main_fw[0],
		                                  CUFFT_COMPATIBILITY_NATIVE ) );
		CUFFT( cufftSetCompatibilityMode( sd->plan_main_inv[0],
		                                  CUFFT_COMPATIBILITY_NATIVE ) );
	}
	else
	{
		MALLOC( sd->plan_main_fw,
		        solver->sd.fft.parms.fft2d.num_blocks,
		        cufftHandle );

		MALLOC( sd->plan_main_inv,
		        solver->sd.fft.parms.fft2d.num_blocks,
		        cufftHandle );

		MALLOC( sd->plan_sec,
		        solver->sd.fft.parms.fft1d.num_blocks,
		        cufftHandle );

		for( i_dev = 0; i_dev < sd->parms.fft2d.num_blocks; i_dev++ )
		{
			CUDA( cudaSetDevice( devices[i_dev] ) );
			CUFFT( cufftPlanMany( sd->plan_main_fw + i_dev,
			                      2,
			                      nps_rev + 1,
			                      NULL,
			                      0,
			                      0,
			                      NULL,
			                      0,
			                      0,
			                      CUFFT_D2Z,
			                      sd->parms.fft2d.block_size ) );
			CUFFT( cufftSetCompatibilityMode( sd->plan_main_fw[i_dev],
			                                  CUFFT_COMPATIBILITY_NATIVE ) );
			CUFFT( cufftSetStream( sd->plan_main_fw[i_dev],
			                       solver->streams[i_dev] ) );
			CUFFT( cufftPlanMany( sd->plan_main_inv + i_dev,
			                      2,
			                      nps_rev + 1,
			                      NULL,
			                      0,
			                      0,
			                      NULL,
			                      0,
			                      0,
			                      CUFFT_Z2D,
			                      sd->parms.fft2d.block_size ) )
			CUFFT( cufftSetCompatibilityMode( sd->plan_main_inv[i_dev],
			                                  CUFFT_COMPATIBILITY_NATIVE ) );
			CUFFT( cufftSetStream( sd->plan_main_inv[i_dev],
			                       solver->streams[i_dev] ) );
		}

		for( i_dev = 0; i_dev < sd->parms.fft1d.num_blocks; i_dev++ )
		{
			CUDA( cudaSetDevice( devices[i_dev] ) );
			CUFFT( cufftPlanMany( sd->plan_sec + i_dev,
			                      1,
			                      nps_rev,
			                      nps_rev,
			                      nps_rev[1] * sd->parms.fft1d.block_size,
			                      1,
			                      nps_rev,
			                      sd->parms.fft1d.block_size
			                      * solver->grid->nps[1],
			                      1,
			                      CUFFT_Z2Z,
			                      nps_rev[1] * sd->parms.fft1d.block_size ) )
			CUFFT( cufftSetStream( sd->plan_sec[i_dev],
			                       solver->streams[i_dev] ) );
		}
	}

	return CUP_SUCCESS;
}

cup_error_t ser_send_fft_data( struct cup_solver* solver, int src_dev, int direction )
{
	struct fft_solver_data* sd;
	int dst_dev;
	cudaMemcpy3DPeerParms exchange_parms = {0};

	assert( solver != NULL );
	assert( src_dev >= 0 && src_dev < global_info.num_devices );
	assert( direction == CUFFT_FORWARD || direction == CUFFT_INVERSE );

	sd = &solver->sd.fft;

	exchange_parms.srcDevice = global_info.devices[src_dev];
	if( direction == CUFFT_FORWARD )
	{
		for( dst_dev = 0; dst_dev < sd->parms.fft1d.num_blocks; dst_dev++ )
		{
			exchange_parms.dstDevice = global_info.devices[dst_dev];
			exchange_parms.srcPtr = make_cudaPitchedPtr(
			                        solver->dev_data[src_dev],
			                        sd->parms.np_sym_dimension
			                        * sizeof( cufftDoubleComplex ),
			                        sd->parms.np_sym_dimension
			                        * sizeof( cufftDoubleComplex ),
			                        solver->grid->nps[1]
			                        );
			exchange_parms.dstPtr = make_cudaPitchedPtr( sd->dev_buffer[dst_dev],
			                                             sd->parms.fft1d.block_size
			                                             * sizeof( cufftDoubleComplex ),
			                                             sd->parms.fft1d.block_size
			                                             * sizeof( cufftDoubleComplex ),
			                                             solver->grid->nps[1] );

			exchange_parms.srcPos = make_cudaPos( dst_dev *
			                                      sd->parms.fft1d.block_size *
			                                      sizeof( cufftDoubleComplex ),
			                                      0,
			                                      0 );
			exchange_parms.dstPos = make_cudaPos( 0,
			                                      0,
			                                      src_dev
			                                      * sd->parms.fft2d.block_size );

			/*
			 * The following if is needed because otherwise the specified
			 * extent overflows the source data array.
			 */
			if( dst_dev < sd->parms.fft1d.num_full_blocks &&
			    src_dev < sd->parms.fft2d.num_full_blocks )
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.block_size );
			else if( dst_dev < sd->parms.fft1d.num_full_blocks &&
			         src_dev >= sd->parms.fft2d.num_full_blocks )
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.rest_block_size );
			else if( dst_dev >= sd->parms.fft1d.num_full_blocks &&
			         src_dev < sd->parms.fft2d.num_full_blocks )
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.rest_block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.block_size );
			else
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.rest_block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.rest_block_size );

			CUDA( cudaMemcpy3DPeerAsync( &exchange_parms,
			                             solver->streams[src_dev] ) );
			if( src_dev != dst_dev )
			{
				CUDA( cudaStreamSynchronize( solver->streams[src_dev] ) );
				PTHREAD( pthread_mutex_lock( &sd->sync1d[dst_dev].mutex ) );
				sd->sync1d[dst_dev].counter++;
				PTHREAD( pthread_cond_signal( &sd->sync1d[dst_dev].cond ) );
				PTHREAD( pthread_mutex_unlock( &sd->sync1d[dst_dev].mutex ) );
			}
		}
	}
	else // direction == CUFFT_INVERSE
	{
		for( dst_dev = 0; dst_dev < sd->parms.fft2d.num_blocks; dst_dev++ )
		{
			exchange_parms.dstDevice = global_info.devices[dst_dev];
			exchange_parms.dstPtr = make_cudaPitchedPtr(
			                        solver->dev_data[dst_dev],
			                        sd->parms.np_sym_dimension
			                        * sizeof( cufftDoubleComplex ),
			                        sd->parms.np_sym_dimension
			                        * sizeof( cufftDoubleComplex ),
			                        solver->grid->nps[1]
			                        );
			exchange_parms.srcPtr = make_cudaPitchedPtr( sd->dev_buffer[src_dev],
			                                             sd->parms.fft1d.block_size
			                                             * sizeof( cufftDoubleComplex ),
			                                             sd->parms.fft1d.block_size
			                                             * sizeof( cufftDoubleComplex ),
			                                             solver->grid->nps[1] );

			exchange_parms.dstPos = make_cudaPos( src_dev *
			                                      sd->parms.fft1d.block_size *
			                                      sizeof( cufftDoubleComplex ),
			                                      0,
			                                      0 );
			exchange_parms.srcPos = make_cudaPos( 0,
			                                      0,
			                                      dst_dev
			                                      * sd->parms.fft2d.block_size );

			/*
			 * The following if is needed because otherwise the specified
			 * extent overflows the source data array.
			 */
			if( src_dev < sd->parms.fft1d.num_full_blocks &&
			    dst_dev < sd->parms.fft2d.num_full_blocks )
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.block_size );
			else if( src_dev < sd->parms.fft1d.num_full_blocks &&
			         dst_dev >= sd->parms.fft2d.num_full_blocks )
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.rest_block_size );
			else if( src_dev >= sd->parms.fft1d.num_full_blocks &&
			         dst_dev < sd->parms.fft2d.num_full_blocks )
				exchange_parms.extent = make_cudaExtent(
				                        sd->parms.fft1d.rest_block_size *
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.block_size );
			else
				exchange_parms.extent = make_cudaExtent(
				                        sizeof( cufftDoubleComplex ),
				                        solver->grid->nps[1],
				                        sd->parms.fft2d.rest_block_size );

			CUDA( cudaMemcpy3DPeerAsync( &exchange_parms,
			                             solver->streams[src_dev] ) );
			if( src_dev != dst_dev )
			{
				CUDA( cudaStreamSynchronize( solver->streams[src_dev] ) );
				PTHREAD( pthread_mutex_lock( &sd->sync2d[dst_dev].mutex ) );
				sd->sync2d[dst_dev].counter++;
				PTHREAD( pthread_cond_signal( &sd->sync2d[dst_dev].cond ) );
				PTHREAD( pthread_mutex_unlock( &sd->sync2d[dst_dev].mutex ) );
			}
		}
	}
	return CUP_SUCCESS;
}

/**
 * Solve the Poisson's equation of a 3D grid transformed to frequency space.
 *
 * The original input is supposed to be real valued, so the grid must be
 * symmetric. Therefore, just \p nps[sym_dim]/2 + 1 points are stored in the
 * \p sym_dim dimension.
 * The function divides each point by -(2*pi*k)**2.
 *  \param[in,out] dev_grid The grid of complex values.
 *  \param[in] nps The number of points in each of the 3 dimensions.
 */
void solve_poisson_in_GPU( cup_solver* solver, int i_dev )
{
	int  num_blocks,
	     block_size,
	     num_threads,
	     nps[3],
	   * dev_nps;
	double* dev_size2,
	        size2[3];

	assert( solver != NULL );
	assert( i_dev >= 0 && i_dev < global_info.num_devices );

	cudaMalloc( &dev_nps, 3 * sizeof( int ) );
	cudaMalloc( &dev_size2, 3 * sizeof( double ) );
	size2[0] = solver->grid->size[0] * solver->grid->size[0];
	size2[1] = solver->grid->size[1] * solver->grid->size[1];
	size2[2] = solver->grid->size[2] * solver->grid->size[2];
	cudaMemcpy( dev_size2, size2, 3 * sizeof( double ), cudaMemcpyHostToDevice );

	nps[1] = solver->grid->nps[1];
	nps[2] = solver->grid->nps[2];
	if( global_info.num_devices == 1 )
	{
		nps[0] = solver->grid->nps[0] / 2 + 1;
		num_threads = solver->grid->np;
		set_grid_dims( num_threads, &num_blocks, &block_size );
		cudaMemcpy( dev_nps, nps, 3 * sizeof( int ), cudaMemcpyHostToDevice );
		dev_solve_poisson<<< num_blocks, block_size >>>
		                 ( (cufftDoubleComplex *) solver->dev_data[0],
		                   dev_nps,
		                   nps[0],
		                   0,
		                   0,
		                   dev_size2 );
	}
	else
	{
		nps[0] = solver->sd.fft.parms.fft1d.block_size;
		num_threads = solver->sd.fft.parms.fft1d.block_size
		              * solver->grid->nps[1]
		              * solver->grid->nps[2];
		set_grid_dims( num_threads, &num_blocks, &block_size );
		cudaMemcpy( dev_nps, nps, 3 * sizeof( int ), cudaMemcpyHostToDevice );
		dev_solve_poisson<<< num_blocks, block_size, 0, solver->streams[i_dev] >>>
		                 ( solver->sd.fft.dev_buffer[i_dev],
		                   dev_nps,
		                   nps[0],
		                   0,
		                   i_dev * solver->sd.fft.parms.fft1d.block_size,
		                   dev_size2 );
	}
}
