/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/** \file
 * FFT based Poisson solver.
 *
 * \author Ibai Gurrutxaga
 * \author Javier Mediavilla
 */

/********************************************
 * Includes                                 *
 ********************************************/
#include <cufft.h>

#include "core/globals.h"
#include "cuPoisson.h"
#include "core/grid.h"
#include "solvers/solver.h"
#include "utils/alg.h"
#include "utils/utils.h"
#include "solvers/fft_utils.h"
#include "solvers/fft_utils_ser.h"
#include "fft_ser.h"

/********************************************
 * Private function prototypes              *
 ********************************************/
static cup_error_t exec_fft3d_solver( cup_solver* solver );
static cup_error_t exec_fft2d1d_solver( cup_solver* solver, int i_dev );

/********************************************
 * Public functions                         *
 ********************************************/
cup_error_t init_fft_solver( cup_solver* solver )
{
	int i_dev;

	assert( solver != NULL );

	if( global_info.num_devices > 1 )
	{
		for (i_dev = 0; i_dev < solver->sd.fft.parms.fft1d.num_blocks; ++i_dev)
			solver->sd.fft.sync1d[i_dev].counter = 0;
		for (i_dev = 0; i_dev < solver->sd.fft.parms.fft2d.num_blocks; ++i_dev)
			solver->sd.fft.sync2d[i_dev].counter = 0;
	}

	return CUP_SUCCESS;
}

cup_error_t exec_fft_solver( double*       output,
                             const double* input,
                             cup_solver*   solver,
                             int           i_dev )
{
	int  np_main,
	   * devices;
	size_t main_bytes;
	struct fft_solver_data* sd;

	assert( solver != NULL );
	assert( i_dev >= 0 && i_dev < global_info.num_devices );

	devices = global_info.devices;
	sd = &solver->sd.fft;

	if( global_info.num_devices == 1)
	{
		CUDA( cudaSetDevice( devices[0] ) );

		if( input != NULL )
			CUDA( cudaMemcpyAsync( solver->dev_data[0],
			                       input,
			                       solver->grid->np * sizeof( double ),
			                       cudaMemcpyHostToDevice ) );
		THROW_CUP_ERROR( exec_fft3d_solver( solver ) );
		if( output != NULL )
		{
			CUDA( cudaMemcpy( output,
			                  solver->dev_data[0],
			                  solver->grid->np * sizeof( double ),
			                  cudaMemcpyDeviceToHost ) );
		}
		else
			CUDA( cudaDeviceSynchronize() );
	}
	else
	{
		CUDA( cudaSetDevice( devices[i_dev] ) );

		np_main = solver->grid->nps[0]
		          * solver->grid->nps[1]
		          * sd->parms.fft2d.block_size;
		if( i_dev < sd->parms.fft2d.num_full_blocks)
			main_bytes = np_main * sizeof( double );
		else if( i_dev < sd->parms.fft2d.num_blocks )
			main_bytes = solver->grid->nps[0]
			             * solver->grid->nps[1]
			             * sd->parms.fft2d.rest_block_size
			             * sizeof( double );

		if( input != NULL && i_dev < sd->parms.fft2d.num_blocks)
		{
			CUDA( cudaMemcpyAsync( solver->dev_data[i_dev],
			                       input +
			                       i_dev * np_main,
			                       main_bytes,
			                       cudaMemcpyHostToDevice,
			                       solver->streams[i_dev]) );
		}

		THROW_CUP_ERROR( exec_fft2d1d_solver( solver, i_dev ) );

		if( output != NULL && i_dev < sd->parms.fft2d.num_blocks )
		{
			CUDA( cudaMemcpyAsync( output +
			                       i_dev * np_main,
			                       solver->dev_data[i_dev],
			                       main_bytes,
			                       cudaMemcpyDeviceToHost,
			                       solver->streams[i_dev]) );
		}
		else
			CUDA( cudaDeviceSynchronize() );
	}

	return CUP_SUCCESS;
}

cup_error_t create_fft_solver( struct cup_solver* solver )
{
	struct fft_solver_data* sd;
	int nd,
	    i_dev;
	cup_grid* grid;
	size_t freq_size_in_bytes,
	       interm_size_in_bytes;

	assert( solver != NULL );

	nd = global_info.num_devices;
	sd = &solver->sd.fft;
	grid = solver->grid;

	if( nd == 1 )
	{
		MALLOC( solver->dev_data, 1, double* );

		THROW_CUP_ERROR( ser_set_fft_plans( solver ) );

		sd->dev_buffer = NULL;
		sd->plan_sec = NULL;
		freq_size_in_bytes = (grid->nps[0] / 2 + 1) * grid->nps[1] * grid->nps[2];
		freq_size_in_bytes *= sizeof( cufftDoubleComplex );
		CUDA( cudaSetDevice( global_info.devices[0] ) );
		CUDA( cudaMalloc( solver->dev_data,
		                  freq_size_in_bytes ) )
	}
	else
	{
		THROW_CUP_ERROR( ser_set_fft_parms( solver ) );
		THROW_CUP_ERROR( ser_set_fft_plans( solver ) );

		MALLOC( solver->dev_data,
		        solver->sd.fft.parms.fft2d.num_blocks,
		        double* );

		MALLOC( sd->dev_buffer,
		        solver->sd.fft.parms.fft1d.num_blocks,
		        cufftDoubleComplex* );

		freq_size_in_bytes = sd->parms.np_sym_dimension
		                     * solver->grid->nps[1]
		                     * sd->parms.fft2d.block_size
		                     * sizeof( cufftDoubleComplex );
		for( i_dev = 0; i_dev < sd->parms.fft2d.num_blocks; i_dev++ )
		{
			CUDA( cudaSetDevice( global_info.devices[i_dev] ) );
			CUDA( cudaMalloc( solver->dev_data + i_dev,
			                  freq_size_in_bytes ) );
		}

		interm_size_in_bytes = sd->parms.fft1d.block_size
		                       * solver->grid->nps[1]
		                       * solver->grid->nps[2]
		                       * sizeof( cufftDoubleComplex );
		for( i_dev = 0; i_dev < sd->parms.fft1d.num_blocks; i_dev++ )
		{
			CUDA( cudaSetDevice( global_info.devices[i_dev] ) );
			CUDA( cudaMalloc( sd->dev_buffer + i_dev,
			                  interm_size_in_bytes ) );
		}
	}

	return CUP_SUCCESS;
}

cup_error_t destroy_fft_solver( struct cup_solver* solver )
{
	struct fft_solver_data* sd;
	int nd,
	    i_dev;

	assert( solver != NULL );

	nd = global_info.num_devices;
	sd = &solver->sd.fft;

	if( nd == 1 )
	{
		CUDA( cudaFree( solver->dev_data[0] ) );
		CUFFT( cufftDestroy( sd->plan_main_fw[0] ) );
		CUFFT( cufftDestroy( sd->plan_main_inv[0] ) );
	}
	else
	{
		for( i_dev = 0; i_dev < sd->parms.fft2d.num_blocks; i_dev++ )
		{
			CUDA( cudaSetDevice( global_info.devices[i_dev] ) );
			CUDA( cudaFree( solver->dev_data[i_dev] ) );
			CUFFT( cufftDestroy( sd->plan_main_fw[i_dev] ) );
			CUFFT( cufftDestroy( sd->plan_main_inv[i_dev] ) );
			PTHREAD( pthread_mutex_destroy( &sd->sync2d[i_dev].mutex ) );
			PTHREAD( pthread_cond_destroy( &sd->sync2d[i_dev].cond ) );
		}
		for( i_dev = 0; i_dev < sd->parms.fft1d.num_blocks; i_dev++ )
		{
			CUDA( cudaSetDevice( global_info.devices[i_dev] ) );
			CUDA( cudaFree( sd->dev_buffer[i_dev] ) );
			CUFFT( cufftDestroy( sd->plan_sec[i_dev] ) );
			PTHREAD( pthread_mutex_destroy( &sd->sync1d[i_dev].mutex ) );
			PTHREAD( pthread_cond_destroy( &sd->sync1d[i_dev].cond ) );
		}
		free( sd->dev_buffer );
		free( sd->plan_sec );
		free( sd->sync2d );
		free( sd->sync1d );
	}

	free( solver->dev_data );
	free( sd->plan_main_fw );
	free( sd->plan_main_inv );

	return CUP_SUCCESS;
}

/********************************************
 * Private functions                        *
 ********************************************/
// Execute the solver in a single GPU, directly computing a 3D FFT.
static cup_error_t exec_fft3d_solver( cup_solver* solver )
{
	assert( solver != NULL );

	CUFFT( cufftExecD2Z( solver->sd.fft.plan_main_fw[0],
								solver->dev_data[0],
								(cufftDoubleComplex*) solver->dev_data[0] ) );

	solve_poisson_in_GPU( solver, 0 );

	CUFFT( cufftExecZ2D( solver->sd.fft.plan_main_inv[0],
								(cufftDoubleComplex*) solver->dev_data[0],
								solver->dev_data[0] ) );
	dscalv( 1.0/solver->grid->np,
			solver->dev_data[0],
			solver->dev_data[0],
			solver->grid->np,
			0 );

	return CUP_SUCCESS;
}

// Execute the solver in several GPUs, computing 2D FFTs followed by a data
// communication a and 1D FFTs.
static cup_error_t exec_fft2d1d_solver( cup_solver* solver, int i_dev )
{
	int np_main;
	struct fft_solver_data* sd;

	assert( solver != NULL );
	assert( i_dev >= 0 && i_dev < global_info.num_devices );

	sd = &solver->sd.fft;
	np_main = solver->grid->nps[0] * solver->grid->nps[1]
	          * sd->parms.fft2d.block_size;

	if( i_dev < sd->parms.fft2d.num_blocks )
	{
		CUFFT( cufftExecD2Z( sd->plan_main_fw[i_dev],
		                            solver->dev_data[i_dev],
		                            (cufftDoubleComplex*)
		                            solver->dev_data[i_dev] ) );
		THROW_CUP_ERROR( ser_send_fft_data( solver, i_dev, CUFFT_FORWARD ) );
	}

	if( i_dev < sd->parms.fft1d.num_blocks )
	{
		PTHREAD( pthread_mutex_lock( &sd->sync1d[i_dev].mutex ) );
		while( sd->sync1d[i_dev].counter < sd->parms.fft2d.num_blocks - 1 )
		{
			PTHREAD( pthread_cond_wait( &sd->sync1d[i_dev].cond,
			                            &sd->sync1d[i_dev].mutex ) );

		}
		PTHREAD( pthread_mutex_unlock( &sd->sync1d[i_dev].mutex ) );

		CUFFT( cufftExecZ2Z( sd->plan_sec[i_dev],
		                            sd->dev_buffer[i_dev],
		                            sd->dev_buffer[i_dev],
		                            CUFFT_FORWARD ) );
		solve_poisson_in_GPU( solver, i_dev );
		CUFFT( cufftExecZ2Z( sd->plan_sec[i_dev],
		                            sd->dev_buffer[i_dev],
		                            sd->dev_buffer[i_dev],
		                            CUFFT_INVERSE ) );
		THROW_CUP_ERROR( ser_send_fft_data( solver, i_dev, CUFFT_INVERSE ) );
	}

	if( i_dev < sd->parms.fft2d.num_blocks )
	{
		PTHREAD( pthread_mutex_lock( &sd->sync2d[i_dev].mutex ) );
		while( sd->sync2d[i_dev].counter < sd->parms.fft1d.num_blocks - 1 )
		{
			PTHREAD( pthread_cond_wait( &sd->sync2d[i_dev].cond,
			                            &sd->sync2d[i_dev].mutex ) );

		}
		PTHREAD( pthread_mutex_unlock( &sd->sync2d[i_dev].mutex ) );

		CUFFT( cufftExecZ2D( sd->plan_main_inv[i_dev],
		                            (cufftDoubleComplex*)
		                            solver->dev_data[i_dev],
		                            solver->dev_data[i_dev] ) );
		dscalv( 1.0/solver->grid->np,
		        solver->dev_data[i_dev],
		        solver->dev_data[i_dev],
		        np_main,
		        i_dev,
		        solver->streams[i_dev] );
	}

	return CUP_SUCCESS;
}
