/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/********************************************
 * Includes                                 *
 ********************************************/
#include <cufft.h>
#include <cuda_runtime.h>
#include <mpi.h>

#include "cuPoisson.h"
#include "core/globals.h"
#include "core/grid.h"
#include "solvers/solver.h"
#include "utils/alg.h"
#include "utils/utils.h"
#include "fft_utils_mpi.h"

/********************************************
 * Private function prototypes              *
 ********************************************/
static cup_error_t set_all2all_parms( struct cup_mpi_solver* mpi_solver );
static cup_error_t shuffle_data( struct cup_mpi_solver* mpi_solver,
                                 int                    direction );
static __global__
void dev_shuffle_data_fw( cufftDoubleComplex*       output,
                          const cufftDoubleComplex* input,
                          int                       np0,
                          int                       np1,
                          int                       np2,
                          int                       num_full_blocks,
                          int                       block_size,
                          int                       rest_block_size );
static __global__
void dev_shuffle_data_inv( cufftDoubleComplex*       output,
                           const cufftDoubleComplex* input,
                           int                       np0,
                           int                       np1,
                           int                       np2,
                           int                       num_full_blocks,
                           int                       block_size,
                           int                       rest_block_size );
static cup_error_t shuffle_seg_data( struct cup_mpi_solver* mpi_solver,
                                     int                    i_seg,
                                     int                    direction );
static __global__
void dev_shuffle_seg_data_fw( cufftDoubleComplex*       output,
                              const cufftDoubleComplex* input,
                              int                       np0,
                              int                       np1,
                              int                       np2,
                              int                       block_size,
                              int                       seg_offset,
                              int                       seg_size );
static __global__
void dev_shuffle_seg_data_inv( cufftDoubleComplex*       output,
                               const cufftDoubleComplex* input,
                               int                       np0,
                               int                       np1,
                               int                       np2,
                               int                       block_size,
                               int                       seg_offset,
                               int                       seg_size );
/********************************************
 * Public functions                         *
 ********************************************/
cup_error_t mpi_set_fft_parms( cup_mpi_solver* mpi_solver )
{
	size_t plane_size;

	struct cup_solver* solver = &mpi_solver->base;
	struct mpi_fft_solver_data* mpi_sd = &mpi_solver->sd.fft;

	assert( mpi_solver != NULL );

	solver->sd.fft.parms.np_sym_dimension = ( solver->grid->nps[0] / 2 ) + 1;
	if( global_info.config.mpi_segment_size == 0 )
	{
		distribute_blocks( solver->grid->nps[2],
		                   mpi_solver->size,
		                  &mpi_sd->distr_2d );
		mpi_sd->padded_np2 = solver->grid->nps[2];
	}
	else
	{
		distribute_blocks_padded( solver->grid->nps[2],
		                          mpi_solver->size,
		                         &mpi_sd->distr_2d );
		mpi_sd->padded_np2 = mpi_sd->distr_2d.num_blocks * mpi_sd->distr_2d.block_size;
	}
	mpi_sd->offset_2d = get_block_offset( &mpi_sd->distr_2d, mpi_solver->rank );
	mpi_sd->size_2d = get_block_size( &mpi_sd->distr_2d, mpi_solver->rank );

	distribute_blocks( solver->sd.fft.parms.np_sym_dimension,
	                   mpi_solver->size,
	                   &mpi_sd->distr_1d );

	mpi_sd->offset_1d = get_block_offset( &mpi_sd->distr_1d, mpi_solver->rank );
	mpi_sd->size_1d = get_block_size( &mpi_sd->distr_1d, mpi_solver->rank );

	if( global_info.config.mpi_segment_size == 0 )
		mpi_sd->seg_size = 0;
	else
	{
		plane_size = solver->sd.fft.parms.np_sym_dimension *
		             solver->grid->nps[1] *
		             sizeof( double ) * 2;
		mpi_sd->seg_size = div_up( global_info.config.mpi_segment_size, plane_size );
		if( mpi_sd->seg_size % 2 != 0 &&
		    solver->grid->nps[0] * solver->grid->nps[1] %2 != 0 )
		{
			print_msg( WARNING,
			           "An odd segment size is not allowed for this grid size. "
			           "The size is incremented.");
			mpi_sd->seg_size++;
		}
		if( mpi_sd->seg_size >= mpi_sd->distr_2d.block_size )
		{
			print_msg( WARNING,
			           "The segment size is too big. The solver won't be pipelined." );
			mpi_sd->seg_size = 0;
		}
	}

	if( mpi_sd->seg_size > 0 )
	{
		mpi_sd->num_segs = div_up( mpi_sd->size_2d, mpi_sd->seg_size );
		mpi_sd->last_seg_size = mpi_sd->size_2d -
		                       (mpi_sd->num_segs - 1) * mpi_sd->seg_size;
	}

	mpi_sd->np.real_seg =      solver->grid->nps[0] *
	                           solver->grid->nps[1] *
	                           mpi_sd->seg_size;
	mpi_sd->np.real_last_seg = solver->grid->nps[0] *
	                           solver->grid->nps[1] *
	                           mpi_sd->last_seg_size;
	mpi_sd->np.c2d_seg =       solver->sd.fft.parms.np_sym_dimension *
	                           solver->grid->nps[1]* mpi_sd->seg_size;
	mpi_sd->np.c2d_last_seg = solver->sd.fft.parms.np_sym_dimension *
	                          solver->grid->nps[1] *
	                          mpi_sd->last_seg_size;
	mpi_sd->np.c1d_seg =      mpi_sd->size_1d *
	                          solver->grid->nps[1] *
	                          mpi_sd->seg_size;
	mpi_sd->np.c1d_last_seg = mpi_sd->size_1d *
	                          solver->grid->nps[1] *
	                          mpi_sd->last_seg_size;
	mpi_sd->np.c1d_block =    mpi_sd->size_1d *
	                          solver->grid->nps[1] *
	                          mpi_sd->distr_2d.block_size;

	THROW_CUP_ERROR( set_all2all_parms( mpi_solver ) );

	return CUP_SUCCESS;
}

cup_error_t mpi_set_fft_plans( cup_mpi_solver* mpi_solver )
{
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;
	int  nps_rev[3],
	     batch;
	struct cup_solver* solver = &mpi_solver->base;

	assert( mpi_solver != NULL );

	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	nps_rev[0] = solver->grid->nps[2];
	nps_rev[1] = solver->grid->nps[1];
	nps_rev[2] = solver->grid->nps[0];

	if( mpi_sd->size_2d > 0 )
	{
		batch = (mpi_sd->seg_size > 0) ? mpi_sd->seg_size: mpi_sd->size_2d;

		MALLOC( sd->plan_main_fw, 1, cufftHandle );
		MALLOC( sd->plan_main_inv, 1, cufftHandle );
		CUFFT( cufftPlanMany( sd->plan_main_fw,
		                      2,
		                      nps_rev + 1,
		                      NULL,
		                      0,
		                      0,
		                      NULL,
		                      0,
		                      0,
		                      CUFFT_D2Z,
		                      batch ) );
		CUFFT( cufftSetCompatibilityMode( sd->plan_main_fw[0],
		                                  CUFFT_COMPATIBILITY_NATIVE ) );
		CUFFT( cufftSetStream( sd->plan_main_fw[0],
		                       solver->streams[0] ) );

		if( mpi_sd->seg_size > 0 )
		{
			CUFFT( cufftPlanMany( &mpi_sd->plan_2d_fw_last,
			                      2,
			                      nps_rev + 1,
			                      NULL,
			                      0,
			                      0,
			                      NULL,
			                      0,
			                      0,
			                      CUFFT_D2Z,
			                      mpi_sd->last_seg_size ) );
			CUFFT( cufftSetCompatibilityMode( mpi_sd->plan_2d_fw_last,
			                                  CUFFT_COMPATIBILITY_NATIVE ) );
		}

		CUFFT( cufftPlanMany( sd->plan_main_inv,
		                      2,
		                      nps_rev + 1,
		                      NULL,
		                      0,
		                      0,
		                      NULL,
		                      0,
		                      0,
		                      CUFFT_Z2D,
		                      batch ) )
		CUFFT( cufftSetCompatibilityMode( sd->plan_main_inv[0],
		                                  CUFFT_COMPATIBILITY_NATIVE ) );
		CUFFT( cufftSetStream( sd->plan_main_inv[0],
		                       solver->streams[0] ) );
		if( mpi_sd->seg_size > 0 )
		{
			CUFFT( cufftPlanMany( &mpi_sd->plan_2d_inv_last,
			                      2,
			                      nps_rev + 1,
			                      NULL,
			                      0,
			                      0,
			                      NULL,
			                      0,
			                      0,
			                      CUFFT_Z2D,
			                      mpi_sd->last_seg_size ) )
			CUFFT( cufftSetCompatibilityMode( mpi_sd->plan_2d_inv_last,
			                                  CUFFT_COMPATIBILITY_NATIVE ) );
		}
	}	

	if( mpi_sd->size_1d > 0 )
	{
		MALLOC( sd->plan_sec, 1, cufftHandle );
		if( mpi_sd->seg_size == 0 || solver->grid->nps[2] == mpi_sd->padded_np2 )
		{
			CUFFT( cufftPlan1d( sd->plan_sec,
			                    solver->grid->nps[2],
			                    CUFFT_Z2Z,
			                    mpi_sd->size_1d * solver->grid->nps[1]) );
		}
		else
		{
			CUFFT( cufftPlanMany( sd->plan_sec,
			                      1,
			                      &solver->grid->nps[2],
			                      &mpi_sd->padded_np2,
			                      1,
			                      mpi_sd->padded_np2,
			                      &mpi_sd->padded_np2,
			                      1,
			                      mpi_sd->padded_np2,
			                      CUFFT_Z2Z,
			                      mpi_sd->size_1d * solver->grid->nps[1] ) );
		}
		CUFFT( cufftSetStream( sd->plan_sec[0],
		                       solver->streams[0] ) );
	}

	return CUP_SUCCESS;
}

cup_error_t mpi_exchange_fft_data( struct cup_mpi_solver* mpi_solver,
                                   int direction )
{

	struct cup_solver* solver = &mpi_solver->base;
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;
	size_t block_bytes;

	assert( mpi_solver != NULL );
	assert( direction == CUFFT_FORWARD || direction == CUFFT_INVERSE );

	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	if( direction == CUFFT_FORWARD )
	{
		if( mpi_sd->size_2d > 0 )
		{
			transpose_210( mpi_sd->dev_transposed[0],
			               (cufftDoubleComplex *) solver->dev_data[0],
			               sd->parms.np_sym_dimension,
			               solver->grid->nps[1],
			               mpi_sd->size_2d,
			               solver->streams[0] );
			block_bytes = sd->parms.np_sym_dimension
			              * solver->grid->nps[1]
			              * mpi_sd->size_2d
			              * sizeof( cufftDoubleComplex );
			CUDA( cudaMemcpy( mpi_sd->buffer2d[0],
			                  mpi_sd->dev_transposed[0],
			                  block_bytes,
			                  cudaMemcpyDeviceToHost ) );
		}
		PROFILE_PUSH( "MPI_Alltoall (FW)" );
		MPI_CALL( MPI_Alltoallv( mpi_sd->buffer2d[0],
		                         mpi_sd->counts_2d[0],
		                         mpi_sd->displs_2d[0],
		                         MPI_DOUBLE,
		                         mpi_sd->buffer1d[0],
		                         mpi_sd->counts_1d[0],
		                         mpi_sd->displs_1d[0],
		                         MPI_DOUBLE,
		                         mpi_solver->comm ) );
		PROFILE_POP();

		if( mpi_sd->size_1d > 0 )
		{
			block_bytes = mpi_sd->size_1d
			              * solver->grid->nps[1]
			              * solver->grid->nps[2]
			              * sizeof( cufftDoubleComplex );
			CUDA( cudaMemcpy( mpi_sd->dev_transposed[0],
			                  mpi_sd->buffer1d[0],
			                  block_bytes,
			                  cudaMemcpyHostToDevice ) );
			shuffle_data( mpi_solver, CUFFT_FORWARD );
		}
	}
	else // direction == CUFFT_INVERSE
	{
		if( mpi_sd->size_1d > 0 )
		{
			shuffle_data( mpi_solver, CUFFT_INVERSE );
			block_bytes = mpi_sd->size_1d
			              * solver->grid->nps[1]
			              * solver->grid->nps[2]
			              * sizeof( cufftDoubleComplex );
			CUDA( cudaMemcpy( mpi_sd->buffer1d[0],
			                  mpi_sd->dev_transposed[0],
			                  block_bytes,
			                  cudaMemcpyDeviceToHost ) );
		}
		PROFILE_PUSH( "MPI_Alltoall (INV)" );
		MPI_CALL( MPI_Alltoallv( mpi_sd->buffer1d[0],
		                         mpi_sd->counts_1d[0],
		                         mpi_sd->displs_1d[0],
		                         MPI_DOUBLE,
		                         mpi_sd->buffer2d[0],
		                         mpi_sd->counts_2d[0],
		                         mpi_sd->displs_2d[0],
		                         MPI_DOUBLE,
		                         mpi_solver->comm ) );
		PROFILE_POP();

		if( mpi_sd->size_2d > 0 )
		{
			block_bytes = sd->parms.np_sym_dimension
			              * solver->grid->nps[1]
			              * mpi_sd->size_2d
			              * sizeof( cufftDoubleComplex );

			CUDA( cudaMemcpy( mpi_sd->dev_transposed[0],
			                  mpi_sd->buffer2d[0],
			                  block_bytes,
			                  cudaMemcpyHostToDevice ) );
			transpose_210( (cufftDoubleComplex *) solver->dev_data[0],
			               mpi_sd->dev_transposed[0],
			               mpi_sd->size_2d,
			               solver->grid->nps[1],
			               sd->parms.np_sym_dimension,
			               solver->streams[0] );
		}
	}
	return CUP_SUCCESS;
}

void mpi_solve_poisson_in_GPU( cup_mpi_solver* mpi_solver,
                               cufftDoubleComplex* data )
{
	struct cup_solver* solver = &mpi_solver->base;
	int  num_blocks,
	     block_size,
	     num_threads,
	     nps[3],
	   * dev_nps;
	double  size2[3],
	      * dev_size2;

	assert( mpi_solver != NULL );

	num_threads = mpi_solver->sd.fft.size_1d
	              * solver->grid->nps[1]
	              * solver->grid->nps[2];
	set_grid_dims( num_threads, &num_blocks, &block_size );

	cudaMalloc( &dev_nps, 3 * sizeof( int ) );
	cudaMalloc( &dev_size2, 3 * sizeof( double ) );

	nps[0] = solver->grid->nps[2];
	nps[1] = solver->grid->nps[1];
	nps[2] = mpi_solver->sd.fft.size_1d;
	cudaMemcpy( dev_nps, nps, 3 * sizeof( int ), cudaMemcpyHostToDevice );

	size2[0] = solver->grid->size[2] * solver->grid->size[2];
	size2[1] = solver->grid->size[1] * solver->grid->size[1];
	size2[2] = solver->grid->size[0] * solver->grid->size[0];
	cudaMemcpy( dev_size2, size2, 3 * sizeof( double ), cudaMemcpyHostToDevice );

	dev_solve_poisson<<< num_blocks, block_size, 0, solver->streams[0] >>>
	                 ( data,
	                   dev_nps,
	                   mpi_solver->sd.fft.padded_np2,
	                   2,
	                   mpi_solver->sd.fft.offset_1d,
	                   dev_size2 );
}

cup_error_t mpi_pipeline_fft2d_fw( struct cup_mpi_solver* mpi_solver, int i_seg )
{
	struct cup_solver* solver;
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;
	int seg_size,
	    c2d_seg;
	cufftHandle plan;

	assert( mpi_solver != NULL );

	solver = &mpi_solver->base;
	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	assert( i_seg >= 0 && i_seg < mpi_sd->num_segs );

	if( i_seg < mpi_sd->num_segs - 1 )
	{
		plan = sd->plan_main_fw[0];
		seg_size = mpi_sd->seg_size;
		c2d_seg = mpi_sd->np.c2d_seg;
	}
	else
	{
		plan = mpi_sd->plan_2d_fw_last;
		seg_size = mpi_sd->last_seg_size;
		c2d_seg = mpi_sd->np.c2d_last_seg;
	}

	CUFFT( cufftSetStream( plan,
	                       mpi_sd->streams[i_seg % 2] ) );
	CUFFT( cufftExecD2Z( plan,
	                     solver->dev_data[0] + i_seg * mpi_sd->np.real_seg,
	                     mpi_sd->dev_aux[i_seg % 2] ) );
	transpose_210( mpi_sd->dev_transposed[i_seg % 2],
	               mpi_sd->dev_aux[i_seg % 2],
	               sd->parms.np_sym_dimension,
	               solver->grid->nps[1],
	               seg_size,
	               mpi_sd->streams[i_seg % 2] );
	CUDA( cudaMemcpyAsync( mpi_sd->buffer2d[i_seg % 2],
	                       mpi_sd->dev_transposed[i_seg % 2],
	                       c2d_seg * sizeof( cufftDoubleComplex ),
	                       cudaMemcpyDeviceToHost,
	                       mpi_sd->streams[i_seg % 2] ) );
	CUDA( cudaEventRecord( mpi_sd->events[i_seg % 2],
	                       mpi_sd->streams[i_seg % 2] ) );

	return CUP_SUCCESS;
}

cup_error_t mpi_pipeline_alltoall_fw( struct cup_mpi_solver* mpi_solver, int i_seg )
{
	struct mpi_fft_solver_data* mpi_sd;
	int is_last;

	assert( mpi_solver != NULL );

	mpi_sd = &mpi_solver->sd.fft;

	is_last = ( i_seg < mpi_sd->num_segs - 1 ) ? 0 : 1;

	CUDA( cudaEventSynchronize( mpi_sd->events[i_seg % 2] ) );
	if( i_seg > 1 )
		CUDA( cudaEventSynchronize( mpi_sd->events[(i_seg % 2) + 2] ) );
	PROFILE_PUSH( "MPI_Alltoall (FW)" );
	MPI_CALL( MPI_Alltoallv( mpi_sd->buffer2d[i_seg % 2],
	                         mpi_sd->counts_2d[is_last],
	                         mpi_sd->displs_2d[is_last],
	                         MPI_DOUBLE,
	                         mpi_sd->buffer1d[i_seg % 2],
	                         mpi_sd->counts_1d[is_last],
	                         mpi_sd->displs_1d[is_last],
	                         MPI_DOUBLE,
	                         mpi_solver->comm ) );
	PROFILE_POP();

	return CUP_SUCCESS;
}

cup_error_t mpi_pipeline_h2d_fw( struct cup_mpi_solver* mpi_solver,
                                      int i_seg )
{
	struct mpi_fft_solver_data* mpi_sd;
	int c1d_seg;

	assert( mpi_solver != NULL );

	mpi_sd = &mpi_solver->sd.fft;

	assert( i_seg >= 0 && i_seg < mpi_sd->num_segs );

	c1d_seg = (i_seg < mpi_sd->num_segs - 1) ? mpi_sd->np.c1d_seg : mpi_sd->np.c1d_last_seg;

	CUDA( cudaMemcpyAsync( mpi_sd->dev_transposed[(i_seg % 2) + 2],
	                       mpi_sd->buffer1d[i_seg % 2],
	                       c1d_seg
	                       * mpi_solver->size * sizeof( cufftDoubleComplex ),
	                       cudaMemcpyHostToDevice,
	                       mpi_sd->streams[(i_seg % 2) + 2] ) );
	CUDA( cudaEventRecord( mpi_sd->events[(i_seg % 2) + 2],
	                       mpi_sd->streams[(i_seg % 2) + 2]) );
	shuffle_seg_data( mpi_solver, i_seg, CUFFT_FORWARD );

	return CUP_SUCCESS;
}

cup_error_t mpi_pipeline_d2h_inv( struct cup_mpi_solver* mpi_solver, int i_seg )
{
	struct mpi_fft_solver_data* mpi_sd;
	int c1d_seg;

	assert( mpi_solver != NULL );

	mpi_sd = &mpi_solver->sd.fft;

	assert( i_seg >= 0 && i_seg < mpi_sd->num_segs );

	c1d_seg = (i_seg < mpi_sd->num_segs - 1) ? mpi_sd->np.c1d_seg : mpi_sd->np.c1d_last_seg;

	shuffle_seg_data( mpi_solver, i_seg, CUFFT_INVERSE );
	CUDA( cudaMemcpyAsync( mpi_sd->buffer1d[i_seg % 2],
	                       mpi_sd->dev_transposed[(i_seg % 2) + 2],
	                       c1d_seg
	                       * mpi_solver->size
	                       * sizeof( cufftDoubleComplex ),
	                       cudaMemcpyDeviceToHost,
	                       mpi_sd->streams[(i_seg % 2) + 2] ) );
	CUDA( cudaEventRecord( mpi_sd->events[i_seg % 2],
	                       mpi_sd->streams[(i_seg % 2) + 2]) );

	return CUP_SUCCESS;
}

cup_error_t mpi_pipeline_alltoall_inv( struct cup_mpi_solver* mpi_solver,
                                       int i_seg )
{
	struct mpi_fft_solver_data* mpi_sd;
	int is_last;

	assert( mpi_solver != NULL );

	mpi_sd = &mpi_solver->sd.fft;

	assert( i_seg >= 0 && i_seg < mpi_sd->num_segs );

	is_last = ( i_seg < mpi_sd->num_segs - 1 ) ? 0 : 1;

	CUDA( cudaEventSynchronize( mpi_sd->events[i_seg % 2] ) );
	if( i_seg > 1 )
		CUDA( cudaEventSynchronize( mpi_sd->events[(i_seg % 2) + 2] ) );
	PROFILE_PUSH( "MPI_Alltoall (INV)" );
	MPI_CALL( MPI_Alltoallv( mpi_sd->buffer1d[i_seg % 2],
	                         mpi_sd->counts_1d[is_last],
	                         mpi_sd->displs_1d[is_last],
	                         MPI_DOUBLE,
	                         mpi_sd->buffer2d[i_seg % 2],
	                         mpi_sd->counts_2d[is_last],
	                         mpi_sd->displs_2d[is_last],
	                         MPI_DOUBLE,
	                         mpi_solver->comm ) );
	PROFILE_POP();

	return CUP_SUCCESS;
}

cup_error_t mpi_pipeline_fft2d_inv( struct cup_mpi_solver* mpi_solver,
                                    int i_seg )
{
	struct cup_solver* solver;
	struct fft_solver_data* sd;
	struct mpi_fft_solver_data* mpi_sd;
	int seg_size,
	    np_c2d_seg,
	    np_real_seg;
	cufftHandle plan;

	assert( mpi_solver != NULL );

	solver = &mpi_solver->base;
	sd = &solver->sd.fft;
	mpi_sd = &mpi_solver->sd.fft;

	assert( i_seg >= 0 && i_seg < mpi_sd->num_segs );

	if( i_seg < mpi_sd->num_segs - 1 )
	{
		plan = sd->plan_main_inv[0];
		seg_size = mpi_sd->seg_size;
		np_c2d_seg = mpi_sd->np.c2d_seg;
		np_real_seg = mpi_sd->np.real_seg;
	}
	else
	{
		plan = mpi_sd->plan_2d_inv_last;
		seg_size = mpi_sd->last_seg_size;
		np_c2d_seg = mpi_sd->np.c2d_last_seg;
		np_real_seg = mpi_sd->np.real_last_seg;
	}

	CUDA( cudaMemcpyAsync( mpi_sd->dev_transposed[i_seg % 2],
	                       mpi_sd->buffer2d[i_seg % 2],
	                       np_c2d_seg * sizeof( cufftDoubleComplex ),
	                       cudaMemcpyHostToDevice,
	                       mpi_sd->streams[i_seg % 2]) );
	CUDA( cudaEventRecord( mpi_sd->events[(i_seg % 2) + 2],
	                       mpi_sd->streams[i_seg % 2] ) );
	transpose_210( mpi_sd->dev_aux[i_seg % 2],
	               mpi_sd->dev_transposed[i_seg % 2],
	               seg_size,
	               solver->grid->nps[1],
	               sd->parms.np_sym_dimension,
	               mpi_sd->streams[i_seg % 2] );
	CUFFT( cufftSetStream( plan, mpi_sd->streams[i_seg % 2] ) );
	CUFFT( cufftExecZ2D( plan,
	                     mpi_sd->dev_aux[i_seg % 2],
	                     solver->dev_data[0] + i_seg * mpi_sd->np.real_seg ) );
	dscalv( 1.0/solver->grid->np,
	        solver->dev_data[0] + i_seg * mpi_sd->np.real_seg,
	        solver->dev_data[0] + i_seg * mpi_sd->np.real_seg,
	        np_real_seg,
	        0,
	        mpi_sd->streams[i_seg % 2] );

	return CUP_SUCCESS;
}

/********************************************
 * Private functions                        *
 ********************************************/
static cup_error_t set_all2all_parms( struct cup_mpi_solver* mpi_solver )
{
	struct mpi_fft_solver_data* mpi_sd = &mpi_solver->sd.fft;

	int i,
	    pipeline;

	assert( mpi_solver != NULL );

	pipeline = mpi_sd->seg_size > 0;

	for( i = 0; i < (pipeline ? 2 : 1); i++ )
	{
		MALLOC( mpi_sd->counts_2d[i], mpi_solver->size, int );
		MALLOC( mpi_sd->displs_2d[i], mpi_solver->size, int );
	}

	for( i = 0; i < mpi_sd->distr_1d.num_full_blocks; i++ )
	{
		mpi_sd->counts_2d[0][i] = 2
		                          * mpi_sd->distr_1d.block_size
		                          * mpi_solver->base.grid->nps[1];
		if( pipeline )
		{
			mpi_sd->counts_2d[1][i] = mpi_sd->counts_2d[0][i] * mpi_sd->last_seg_size;
			mpi_sd->displs_2d[1][i] = i * mpi_sd->counts_2d[1][0];
		}
		mpi_sd->counts_2d[0][i] *= ( pipeline ? mpi_sd->seg_size : mpi_sd->size_2d );
		mpi_sd->displs_2d[0][i] = i * mpi_sd->counts_2d[0][0];
	}

	for( ; i < mpi_sd->distr_1d.num_blocks; i++ )
	{
		mpi_sd->counts_2d[0][i] = 2
		                          * mpi_sd->distr_1d.rest_block_size
		                          * mpi_solver->base.grid->nps[1];
		if( pipeline )
		{
			mpi_sd->counts_2d[1][i] = mpi_sd->counts_2d[0][i] * mpi_sd->last_seg_size;
			mpi_sd->displs_2d[1][i] = mpi_sd->distr_1d.num_full_blocks * mpi_sd->counts_2d[1][0] +
		                          (i - mpi_sd->distr_1d.num_full_blocks) *
		                          mpi_sd->counts_2d[1][mpi_sd->distr_1d.num_full_blocks];
		}
		mpi_sd->counts_2d[0][i] *= ( pipeline ? mpi_sd->seg_size : mpi_sd->size_2d );
		mpi_sd->displs_2d[0][i] = mpi_sd->distr_1d.num_full_blocks * mpi_sd->counts_2d[0][0] +
		                          (i - mpi_sd->distr_1d.num_full_blocks) *
		                          mpi_sd->counts_2d[0][mpi_sd->distr_1d.num_full_blocks];
	}

	for( ; i < mpi_solver->size; i++ )
	{
		mpi_sd->counts_2d[0][i] = mpi_sd->displs_2d[0][i] = 0;
		if( pipeline )
			mpi_sd->counts_2d[1][i] = mpi_sd->displs_2d[1][i] = 0;
	}

	for( i = 0; i < (mpi_sd->seg_size > 0 ? 2 : 1); i++ )
	{
		MALLOC( mpi_sd->counts_1d[i], mpi_solver->size, int );
		MALLOC( mpi_sd->displs_1d[i], mpi_solver->size, int );
	}

	for( i = 0; i < mpi_sd->distr_2d.num_full_blocks; i++ )
	{
		mpi_sd->counts_1d[0][i] = 2
		                          * mpi_solver->base.grid->nps[1]
		                          * mpi_sd->size_1d;
		if( !pipeline )
		{
			mpi_sd->counts_1d[0][i] *= mpi_sd->distr_2d.block_size;
		}
		else
		{
			mpi_sd->counts_1d[1][i] = mpi_sd->counts_1d[0][i] * mpi_sd->last_seg_size;
			mpi_sd->counts_1d[0][i] *= mpi_sd->seg_size;
			mpi_sd->displs_1d[1][i] = i * mpi_sd->counts_1d[1][i];
		}
		mpi_sd->displs_1d[0][i] = i * mpi_sd->counts_1d[0][i];
	}

	for( ; i < mpi_sd->distr_2d.num_blocks; i++ )
	{
		assert( !pipeline ); // The pipelined version shouldn't reach this line.
		mpi_sd->counts_1d[0][i] = 2
		                          * mpi_sd->distr_2d.rest_block_size
		                          * mpi_solver->base.grid->nps[1]
		                          * mpi_sd->size_1d;
		mpi_sd->displs_1d[0][i] = mpi_sd->distr_2d.num_full_blocks * mpi_sd->counts_1d[0][0] +
		                          (i - mpi_sd->distr_2d.num_full_blocks) *
		                          mpi_sd->counts_1d[0][mpi_sd->distr_2d.num_full_blocks];
	}

	for( ; i < mpi_solver->size; i++ )
	{
		assert( !pipeline ); // The pipelined version shouldn't reach this line.
		mpi_sd->counts_1d[0][i] = mpi_sd->displs_1d[0][i] = 0;
	}

	return CUP_SUCCESS;
}

static cup_error_t shuffle_data( struct cup_mpi_solver* mpi_solver,
                                 int                    direction )
{

	dim3 num_blocks,
	     block_size;
	cup_grid *grid;
	struct mpi_fft_solver_data* mpi_sd = &mpi_solver->sd.fft;

	assert( mpi_solver != NULL );
	grid = mpi_solver->base.grid;

	mpi_sd = &mpi_solver->sd.fft;

	num_blocks.x = mpi_solver->size;
	num_blocks.y = div_up( grid->nps[1], SHUFFLE_TILE_SIZE );
	num_blocks.z = mpi_sd->size_1d;

	block_size.x = block_size.y = SHUFFLE_TILE_SIZE;
	block_size.z = 1;

	if( direction == CUFFT_FORWARD )
		dev_shuffle_data_fw<<<num_blocks, block_size, 0, mpi_solver->base.streams[0]>>>
		                   ( (cufftDoubleComplex *) mpi_solver->base.dev_data[0],
		                     mpi_sd->dev_transposed[0],
		                     mpi_sd->size_1d,
		                     grid->nps[1],
		                     grid->nps[2],
		                     mpi_sd->distr_2d.num_full_blocks,
		                     mpi_sd->distr_2d.block_size,
		                     mpi_sd->distr_2d.rest_block_size );
	else
		dev_shuffle_data_inv<<<num_blocks, block_size, 0, mpi_solver->base.streams[0]>>>
		                    ( mpi_sd->dev_transposed[0],
		                      (cufftDoubleComplex *) mpi_solver->base.dev_data[0],
		                      mpi_sd->size_1d,
		                      grid->nps[1],
		                      grid->nps[2],
		                      mpi_sd->distr_2d.num_full_blocks,
		                      mpi_sd->distr_2d.block_size,
		                      mpi_sd->distr_2d.rest_block_size );

	return CUP_SUCCESS;
}

static __global__
void dev_shuffle_data_fw( cufftDoubleComplex*       output,
                          const cufftDoubleComplex* input,
                          int                       np0,
                          int                       np1,
                          int                       np2,
                          int                       num_full_blocks,
                          int                       block_size,
                          int                       rest_block_size )
{
	int x, y,
	    ind_in,
	    ind_out,
	    my_block_size;

	x = threadIdx.x;
	y = blockIdx.y * SHUFFLE_TILE_SIZE + threadIdx.y;

	my_block_size = (blockIdx.x < num_full_blocks) ? block_size : rest_block_size;

	if( y >= np1 )
		return;

	ind_in  = blockIdx.x * block_size * np0 * np1;
	ind_in += blockIdx.z * my_block_size * np1;
	ind_in += y * my_block_size;

	ind_out  = blockIdx.z * np1 * np2;
	ind_out += y * np2;
	ind_out += blockIdx.x * block_size;

	for( ; x < my_block_size; x += SHUFFLE_TILE_SIZE )
		output[ind_out + x] = input[ind_in + x];

}

static __global__
void dev_shuffle_data_inv( cufftDoubleComplex*       output,
                           const cufftDoubleComplex* input,
                           int                       np0,
                           int                       np1,
                           int                       np2,
                           int                       num_full_blocks,
                           int                       block_size,
                           int                       rest_block_size )
{
	int x, y,
	    ind_in,
	    ind_out,
	    my_block_size;

	x = threadIdx.x;
	y = blockIdx.y * SHUFFLE_TILE_SIZE + threadIdx.y;

	my_block_size = (blockIdx.x < num_full_blocks) ? block_size : rest_block_size;

	if( y >= np1 )
		return;

	ind_in  = blockIdx.x * block_size * np0 * np1;
	ind_in += blockIdx.z * my_block_size * np1;
	ind_in += y * my_block_size;

	ind_out  = blockIdx.z * np1 * np2;
	ind_out += y * np2;
	ind_out += blockIdx.x * block_size;

	for( ; x < my_block_size; x += SHUFFLE_TILE_SIZE )
		output[ind_in + x] = input[ind_out + x];

}

static cup_error_t shuffle_seg_data( struct cup_mpi_solver* mpi_solver,
                                     int                    i_seg,
                                     int                    direction )
{

	dim3 num_blocks,
	     block_size;
	cup_grid *grid;
	struct mpi_fft_solver_data* mpi_sd = &mpi_solver->sd.fft;
	int seg_size;

	assert( mpi_solver != NULL );
	grid = mpi_solver->base.grid;

	mpi_sd = &mpi_solver->sd.fft;

	seg_size = (i_seg < mpi_sd->num_segs - 1) ? mpi_sd->seg_size : mpi_sd->last_seg_size;

	block_size.x = ( seg_size < SHUFFLE_TILE_SIZE ) ? seg_size : SHUFFLE_TILE_SIZE;
	block_size.y = 512 / block_size.x;
	block_size.z = 1;

	num_blocks.x = mpi_solver->size;
	num_blocks.y = div_up( grid->nps[1], block_size.y );
	num_blocks.z = mpi_sd->size_1d;

	if( direction == CUFFT_FORWARD )
		dev_shuffle_seg_data_fw<<<num_blocks, block_size, 0, mpi_sd->streams[(i_seg % 2) + 2]>>>
		                       ( mpi_sd->dev_data_1d,
		                         mpi_sd->dev_transposed[(i_seg % 2) + 2],
		                         mpi_sd->size_1d,
		                         grid->nps[1],
		                         mpi_sd->padded_np2,
		                         mpi_sd->distr_2d.block_size,
		                         i_seg * mpi_sd->seg_size,
		                         seg_size );
	else
		dev_shuffle_seg_data_inv<<<num_blocks, block_size, 0, mpi_sd->streams[(i_seg % 2) + 2]>>>
		                        ( mpi_sd->dev_transposed[(i_seg % 2) + 2],
		                          mpi_sd->dev_data_1d,
		                          mpi_sd->size_1d,
		                          grid->nps[1],
		                          mpi_sd->padded_np2,
		                          mpi_sd->distr_2d.block_size,
		                          i_seg * mpi_sd->seg_size,
		                          seg_size );

	return CUP_SUCCESS;
}

static __global__
void dev_shuffle_seg_data_fw( cufftDoubleComplex*       output,
                              const cufftDoubleComplex* input,
                              int                       np0,
                              int                       np1,
                              int                       np2,
                              int                       block_size,
                              int                       seg_offset,
                              int                       seg_size )
{
	int x, y,
	    ind_in,
	    ind_out;

	x = threadIdx.x;
	y = blockIdx.y * blockDim.y + threadIdx.y;

	if( y >= np1 )
		return;

	ind_in  = blockIdx.x * seg_size * np0 * np1;
	ind_in += blockIdx.z * seg_size * np1;
	ind_in += y * seg_size;

	ind_out  = blockIdx.z * np1 * np2;
	ind_out += y * np2;
	ind_out += blockIdx.x * block_size + seg_offset;

	for( ; x < seg_size; x += SHUFFLE_TILE_SIZE )
		output[ind_out + x] = input[ind_in + x];

}

static __global__
void dev_shuffle_seg_data_inv( cufftDoubleComplex*       output,
                               const cufftDoubleComplex* input,
                               int                       np0,
                               int                       np1,
                               int                       np2,
                               int                       block_size,
                               int                       seg_offset,
                               int                       seg_size )
{
	int x, y,
	    ind_in,
	    ind_out;

	x = threadIdx.x;
	y = blockIdx.y * blockDim.y + threadIdx.y;

	if( y >= np1 )
		return;

	ind_in  = blockIdx.x * seg_size * np0 * np1;
	ind_in += blockIdx.z * seg_size * np1;
	ind_in += y * seg_size;

	ind_out  = blockIdx.z * np1 * np2;
	ind_out += y * np2;
	ind_out += blockIdx.x * block_size + seg_offset;

	for( ; x < seg_size; x += SHUFFLE_TILE_SIZE )
		output[ind_in + x] = input[ind_out + x];

}
