/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/********************************************
 * Includes                                 *
 ********************************************/
#include <cufft.h>
#include <cuda_runtime.h>

#include "cuPoisson.h"
#include "core/globals.h"
#include "solvers/solver.h"
#include "fft_utils.h"

/********************************************
 * Private function prototypes              *
 ********************************************/
static __global__
void dev_transpose_210( cufftDoubleComplex*       out,
                        const cufftDoubleComplex* in,
                        int                       np0,
                        int                       np1,
                        int                       np2 );
static __global__
void dev_transpose2d( cufftDoubleComplex*       odata,
                      const cufftDoubleComplex* idata,
                      int                       np0,
                      int                       np1 );

/********************************************
 * Public functions                         *
 ********************************************/
/**
 * Get a solver's device data pointers and its distribution in devices.
 * \param[in] solver The solver.
 * \param[out] data The pointer array. One pointer per device.
 * \param[out] sizes The number of planes stored in each device stored in an
 *                    array. One size per device. The array must be allocated
 *                    by the caller.
 * \return CUP_SUCCES
 */
cup_error_t get_solver_data( const cup_solver* solver,
                             double***         data,
                             int*              sizes )
{
	int i_dev;

	assert( solver != NULL );
	assert( data != NULL );
	assert( sizes != NULL );

	*data = solver->dev_data;

	if( global_info.num_devices == 1 )
	{
		sizes[0] = solver->grid->nps[2];
		return CUP_SUCCESS;
	}

	for( i_dev = 0; i_dev < solver->sd.fft.parms.fft2d.num_full_blocks; i_dev++ )
		sizes[i_dev] = solver->sd.fft.parms.fft2d.block_size;
	for( ; i_dev < solver->sd.fft.parms.fft2d.num_blocks; i_dev++ )
		sizes[i_dev] = solver->sd.fft.parms.fft2d.rest_block_size;
	for( ; i_dev < global_info.num_devices; i_dev++ )
		sizes[i_dev] = 0;

	return CUP_SUCCESS;
}

/** \ingroup utils
 * Get a descriptive string from a cufft library error code.
 *  \param[in] error An error code from the cufft library.
 *  \return A pointer to a string describing the error.
 */
const char *cufft_error_string( cufftResult_t error )
{
	static const char* error_strings[] =
	{
		"Success.",
		"Invalid plan.",
		"CPU or GPU memory allocation failed.",
		"Invalid type.",
		"Invalid user specified value.",
		"Internal error.",
		"FFT execution error.",
		"Initialization error.",
		"Invalid transform size.",
		"Unaligned data.",
		"Missing parameters.",
		"Execution of a plan was on different GPU than plan creation.",
		"Internal plan database error.",
		"No workspace has been provided prior to plan execution.",
		"Unknown error."
	};

	return error_strings[error];
}

/**
 *  Solve the Poisson's equation of a 3D grid transformed to frequency space.
 *  The data passed as input can be a part of a larger grid. In this case,
 *  \p sym_dim_offset will specify the offset of the parameter in the whole grid.
 *
 * The original input is supposed to be real valued, so \p dev_sym_grid must be
 * symmetric and just the non-symmetric \p nps[\p sym_dim] points are stored for
 * dimension \p sym_dim.
 * The function divides each point by -(2*pi*k)**2.
 * \param[in,out] dev_sym_grid Input and output grid (in-place computation).
 * \param[in] nps The number of points in three dimensions. Dimension 0 is
 * the innermost dimension.
 * \param[in] padded_np0 The data in the innermost dimension can be padded. This
 * value specifies the actual storage data, which can be larger thab nps[0].
 * \param[in] sym_dim The index of the symmetric dimension.
 * \param[in] sym_dim_offset The index of the first slice in dimension
 * \p sym_grid of \p dev_sym_grid regarding to the whole grid.
 * \param[in] size2 The square of the size of the grid in each dimension.
 */
__global__ void __launch_bounds__((MAX_NUMBER_OF_THREADS))
           dev_solve_poisson( cufftDoubleComplex* dev_sym_grid,
                              const int         * nps,
                              int                 padded_np0,
                              int                 sym_dim,
                              int                 sym_dim_offset,
                              const double      * size2
                            )
{
	int ind  = blockIdx.x * blockDim.x + threadIdx.x;
	int stride = gridDim.x  * blockDim.x;

	int ijk[3],
	    i,
	    slice_size,
	    limit;
	double ksum;

 	slice_size = padded_np0 * nps[1];
	limit = slice_size * nps[2];

	for( ; ind < limit; ind += stride )
	{

		ijk[2] = ind / slice_size;
		ijk[1] = (ind - ijk[2] * slice_size) / padded_np0;
		ijk[0] = ind - ijk[2] * slice_size - ijk[1] * padded_np0;
		ijk[sym_dim] += sym_dim_offset;

		if( ijk[0] == 0 && ijk[1] == 0 && ijk[2] == 0 )
		{
			ksum = 4 * M_PI * M_PI/size2[0] +
			       4 * M_PI * M_PI/size2[1] +
			       4 * M_PI * M_PI/size2[2];
		}
		else
		{
			ksum = 0.0;
			for( i = 0; i < 3; i++ )
			{
				if( sym_dim == i || ijk[i] <= nps[i]/2 )
					ksum += ijk[i] * ijk[i] / size2[i];
				else
					ksum += (ijk[i] - nps[i] ) * (ijk[i] - nps[i] ) / size2[i];
			}

			ksum *= 4*M_PI*M_PI;
		}
		dev_sym_grid[ind].x = -dev_sym_grid[ind].x / ksum;
		dev_sym_grid[ind].y = -dev_sym_grid[ind].y / ksum;
	}

}

/**
 * Perform a 2,1,0 transpose (order inversion) out of place.
 *
 * The 0 dimension must match the innermost dimension of the input grid, and
 * the 2 dimension must match the innermost dimension of the transposed grid.
 * \param[out] out The output array.
 * \param[in] in The input array.
 * \param[in] np0 The size of the innermost dimension of the input array.
 * \param[in] np1 The size of the middle dimension of the input array.
 * \param[in] np2 The size of the outermost dimension of the input array.
 * \param[in] stream The CUDA stream to run the kernel in.
 */
void transpose_210( cufftDoubleComplex*       out,
                    const cufftDoubleComplex* in,
                    int                       np0,
                    int                       np1,
                    int                       np2,
                    cudaStream_t              stream )
{
	dim3 block_size,
	     num_blocks;

	assert( out != NULL && in != NULL );
	assert( np0 > 0 && np1 > 0 && np2 > 0 );

	if( np0 > 1 && np2 > 1 )
	{
		block_size.x = TILE_SIZE;
		block_size.y = TILE_SIZE / 4;
		block_size.z = 1;
		num_blocks.x = np0 / TILE_SIZE;
		if( np0 % TILE_SIZE != 0 )
			num_blocks.x++;

		num_blocks.y = np2 / TILE_SIZE;
		if( np2 % TILE_SIZE != 0 )
			num_blocks.y++;
		num_blocks.z = np1;

		dev_transpose_210<<< num_blocks, block_size, 0, stream >>>( out, in , np0, np1, np2 );
	}
	else
	{
		int np[2];
		
		if( np0 == 1 )
		{
			np[0] = np1;
			np[1] = np2;
		}
		else
		{
			np[0] = np0;
			np[1] = np1;
		}

		block_size.x = TILE_SIZE;
		block_size.y = 4;
		block_size.z = 1;
		num_blocks.x = np[0] / TILE_SIZE;
		if( np[0] % TILE_SIZE != 0 )
			num_blocks.x++;

		num_blocks.y = np[1] / TILE_SIZE;
		if( np[1] % TILE_SIZE != 0 )
			num_blocks.y++;
		num_blocks.z = 1;

		dev_transpose2d<<< num_blocks, block_size, 0, stream >>>( out, in, np[0], np[1] );
	}
}

/********************************************
 * Private functions                        *
 ********************************************/
/**
 * Kernel that performs a 2,1,0 transpose (order inversion) out of place.
 *
 * The grid of threads must be a 3-dimensional grid of 2-dimensional blocks.
 * The 0 dimension must match the innermost dimension of the input grid, and
 * the 2 dimension must match the innermost dimension of the transposed grid.
 * \param[out] out The output array.
 * \param[in] in The input array.
 * \param[in] np0 The size of the innermost dimension of the input array.
 * \param[in] np1 The size of the middle dimension of the input array.
 * \param[in] np2 The size of the outermost dimension of the input array.
 */
static __global__
void dev_transpose_210( cufftDoubleComplex*       out,
                        const cufftDoubleComplex* in,
                        int                       np0,
                        int                       np1,
                        int                       np2 )
{

	__shared__ cufftDoubleComplex tile[TILE_SIZE][TILE_SIZE + 1];

	int x_in, y, z_in,
	    x_out, z_out,
	    ind_in,
	    ind_out;

	int lx = threadIdx.x,
	    ly = threadIdx.y,
	    bx = blockIdx.x,
	    by = blockIdx.y;

	x_in = lx + TILE_SIZE * bx;
	z_in = ly + TILE_SIZE * by;

	y = blockIdx.z;

	x_out = ly + TILE_SIZE * bx;
	z_out = lx + TILE_SIZE * by;

	ind_in = x_in + (y + z_in * np1) * np0;
	ind_out = z_out + (y + x_out * np1) * np2;

	if( x_in < np0 && z_in < np2 )
	{
		tile[lx][ly] = in[ind_in];
		if( z_in + 4 < np2 )
		{
			tile[lx][ly +  4] = in[ind_in +  4*np0*np1];
			if( z_in + 8 < np2 )
			{
				tile[lx][ly +  8] = in[ind_in +  8*np0*np1];
				if( z_in + 12 < np2 )
				{
					tile[lx][ly + 12] = in[ind_in + 12*np0*np1];
				}
			}
		}
	}

	__syncthreads();

	if( z_out < np2 && x_out < np0 )
	{
		out[ind_out] = tile[ly][lx];
		if( x_out + 4 < np0 )
		{
			out[ind_out +  4*np2*np1] = tile[ly +  4][lx];
			if( x_out + 8 < np0 )
			{
				out[ind_out +  8*np2*np1] = tile[ly +  8][lx];
				if( x_out + 12 < np0 )
				{
					out[ind_out + 12*np2*np1] = tile[ly + 12][lx];
				}
			}
		}
	}
}

/**
 * Kernel that performs a 2-dimensional transpose out of place.
 * \param[out] odata The output array.
 * \param[in] idata The input array.
 * \param[in] np0 The size of the innermost dimension of the input array.
 * \param[in] np1 The size of the outermost dimension of the input array.
 */
static __global__
void dev_transpose2d( cufftDoubleComplex*       odata,
                      const cufftDoubleComplex* idata,
                      int                       np0,
                      int                       np1 )
{
	 __shared__ cufftDoubleComplex tile[TILE_SIZE][TILE_SIZE + 1];
	int i,
	    xIndex, yIndex,
	    index_in, index_out;

	xIndex = blockIdx.x*TILE_SIZE + threadIdx.x;
	yIndex = blockIdx.y*TILE_SIZE + threadIdx.y;
	index_in = xIndex + yIndex*np0;

	for ( int i = 0; i < TILE_SIZE; i += 4 )
	{
		if( (xIndex < np0) && ( (yIndex + i) < np1) )
		{
			tile[threadIdx.y + i][threadIdx.x] =
			idata[index_in + i*np0];
		}
	}

	__syncthreads();

	xIndex = blockIdx.y*TILE_SIZE + threadIdx.x;
	yIndex = blockIdx.x*TILE_SIZE + threadIdx.y;
	index_out = xIndex + yIndex*np1;

	for( i = 0; i < TILE_SIZE; i+=4 )
	{
		if( (xIndex < np1) && ( (yIndex + i) < np0 ) )
		{
			odata[index_out + i*np1] =
			tile[threadIdx.x][threadIdx.y+i];
		}
	}
}

