/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/********************************************
 * Includes                                 *
 ********************************************/
#ifdef MOD_EXCL_MPI
#include <mpi.h>
#endif

#include "core/globals.h"
#include "cuPoisson.h"
#include "core/grid.h"
#include "solvers/fft_utils.h"
#ifdef MOD_EXCL_SERIAL
#include "solvers/fft_ser.h"
#endif
#ifdef MOD_EXCL_MPI
#include "solvers/fft_mpi.h"
#endif
#include "solver.h"


/********************************************
 * Private function prototypes              *
 ********************************************/
#ifdef MOD_EXCL_SERIAL
static void* solver_thread( void* args );
#endif
#ifdef MOD_EXCL_MPI
static void* mpi_solver_thread( void* args );
#endif
static cup_error_t check_constraints( cup_solver_type_t type,
                                      cup_grid_t        grid,
                                      cup_bnd_type      boundary_type );

/********************************************
 * Exported functions                       *
 ********************************************/
#ifdef __cplusplus
extern "C" {
#endif
/**
 * \ingroup public
 * Get a solver's device data pointers and its distribution in devices.
 * \param solver [in] The solver.
 * \param data [out] The pointer array. One pointer per device.
 * \param sizes [out] The number of planes stored in each device stored in an
 * array. One size per device. The array must be allocated by the caller.
 * \return CUP_INVALID_ARGUMENT if any pointer is NULL, CUP_SUCCES otherwise.
 */
#ifdef MOD_EXCL_SERIAL
cup_error_t cup_get_solver_data( cup_solver_t solver,
                                 double***    data,
                                 int*         sizes )
#endif
#ifdef MOD_EXCL_MPI
cup_error_t cup_mpi_get_solver_data( cup_mpi_solver_t mpi_solver,
                                     double***        data,
                                     int*             offset,
                                     int*             sizes )
#endif
{
#ifdef MOD_EXCL_MPI
	cup_solver* solver = &mpi_solver->base;
#endif

	CHECK_INIT();

	if( solver == NULL || data == NULL || sizes == NULL )
		return CUP_INVALID_ARGUMENT;

#ifdef MOD_EXCL_MPI
	*offset = mpi_solver->sd.fft.offset_2d;
#endif

	THROW_CUP_ERROR( get_solver_data( solver, data, sizes ) );

#ifdef MOD_EXCL_MPI
	if( global_info.num_devices == 1 )
		sizes[0] = mpi_solver->sd.fft.size_2d;
#endif
	return CUP_SUCCESS;
}

/**
 * \ingroup public
 * Create a new solver.
 * \param [in] type The solver type.
 * \param [in] grid The grid structure for which the solver will be created.
 * \param [in] boundary The boundary conditions.
 */
#ifdef MOD_EXCL_MPI
/*
 * \param [in] comm The MPI communicator including all the processes that must
 * participate in the solver.
 */
#endif
/*
 * \param [out] solver The new solver.
 * \return CUP_SUCCESS, CUP_OUT_OF_HOST_MEM, CUP_CUDA_ERROR, CUP_INVALID_ARGUMENT.
 */
#ifdef MOD_EXCL_SERIAL
cup_error_t cup_create_solver( cup_solver_type_t type,
                               cup_grid_t        grid,
                               cup_bnd_type      boundary_type,
                               cup_solver_t*     solver )
#endif
#ifdef MOD_EXCL_MPI
cup_error_t cup_mpi_create_solver( cup_solver_type_t type,
                                   cup_grid_t        grid,
                                   cup_bnd_type      boundary_type,
                                   MPI_Comm          comm,
                                   cup_mpi_solver_t* solver )
#endif
{
#ifdef MOD_EXCL_SERIAL
	struct cup_solver* _solver;
#endif
#ifdef MOD_EXCL_MPI
	struct cup_mpi_solver* mpi_solver;
	struct cup_solver* _solver;
#endif
	int i_dev,
	    nd;

	if( grid == NULL || solver == NULL )
		return CUP_INVALID_ARGUMENT;

	THROW_CUP_ERROR( check_constraints( type, grid, boundary_type ) );

	CHECK_INIT();

	nd = global_info.num_devices;
#ifdef MOD_EXCL_SERIAL
	MALLOC( _solver, 1, struct cup_solver );
#endif
#ifdef MOD_EXCL_MPI
	MALLOC( mpi_solver, 1, struct cup_mpi_solver );
	_solver = &mpi_solver->base;
#endif

	_solver->type = type;
	_solver->state = SOLVER_READY;
	_solver->boundary_type = boundary_type;

#ifdef MOD_EXCL_MPI
	if( global_info.config.mpi_nonblocking )
	{
#endif

	MALLOC( _solver->state_sync, 1, struct thread_sync );
	PTHREAD( pthread_mutex_init( &_solver->state_sync->mutex, 0 ) );
	PTHREAD( pthread_cond_init( &_solver->state_sync->cond, 0 ) );

	MALLOC( _solver->threads, nd, pthread_t );
#ifdef MOD_EXCL_MPI
	}
#endif
	MALLOC( _solver->streams, nd, cudaStream_t );

	for( i_dev = 0; i_dev < nd; i_dev++ )
	{
		CUDA( cudaSetDevice( global_info.devices[i_dev] ) );
		CUDA( cudaStreamCreate( _solver->streams + i_dev ) );
	}

	THROW_CUP_ERROR( copy_grid( &_solver->grid, grid ) );

#ifdef MOD_EXCL_MPI
	MPI_CALL( MPI_Comm_dup( comm, &mpi_solver->comm ) );
	MPI_CALL( MPI_Comm_rank( mpi_solver->comm, &mpi_solver->rank ) );
	MPI_CALL( MPI_Comm_size( mpi_solver->comm, &mpi_solver->size ) );
#endif

	switch( type )
	{
	case CUP_FFT_SOLVER:
#ifdef MOD_EXCL_SERIAL
		THROW_CUP_ERROR( create_fft_solver( _solver ) );
		print_msg( INFO, "FFT solver successfully created." );
#endif
#ifdef MOD_EXCL_MPI
		THROW_CUP_ERROR( mpi_create_fft_solver( mpi_solver ) );
		print_msg( INFO,
		           "FFT solver successfully created for %d MPI processes.",
		           mpi_solver->size );
#endif
		break;
	default:
		return CUP_INVALID_ARGUMENT;
	}

#ifdef MOD_EXCL_SERIAL
	*solver = _solver;
#endif
#ifdef MOD_EXCL_MPI
	*solver = mpi_solver;
#endif
	return CUP_SUCCESS;
}

/**
 * \ingroup public
 * Destroy a solver and release all associated resources.
 * \param [in,out] solver The solver to be destroyed.
 * \return CUP_SUCCESS, CUP_INVALID_ARGUMENT, CUP_CUDA_ERROR.
 */
#ifdef MOD_EXCL_SERIAL
cup_error_t cup_destroy_solver( cup_solver_t solver )
#endif
#ifdef MOD_EXCL_MPI
cup_error_t cup_mpi_destroy_solver( cup_mpi_solver_t mpi_solver )
#endif
{
#ifdef MOD_EXCL_SERIAL
	if ( solver == NULL )
		return CUP_INVALID_ARGUMENT;
#endif
#ifdef MOD_EXCL_MPI
	if( mpi_solver == NULL )
		return CUP_INVALID_ARGUMENT;
	struct cup_solver* solver = &mpi_solver->base;
#endif

	int i_dev;

	CHECK_INIT();

	switch( solver->type )
	{
	case CUP_FFT_SOLVER:
#ifdef MOD_EXCL_SERIAL
		THROW_CUP_ERROR( destroy_fft_solver( solver ) );
#endif
#ifdef MOD_EXCL_MPI
		THROW_CUP_ERROR( mpi_destroy_fft_solver( mpi_solver ) );
#endif
		break;
	default:
		return CUP_INVALID_ARGUMENT;
	}

#ifdef MOD_EXCL_MPI
	if( global_info.config.mpi_nonblocking )
	{
#endif

	PTHREAD( pthread_mutex_destroy( &solver->state_sync->mutex ) );
	PTHREAD( pthread_cond_destroy( &solver->state_sync->cond ) );
	free( solver->state_sync );
	free( solver->threads );

#ifdef MOD_EXCL_MPI
	}
#endif

	cup_destroy_grid( solver->grid );
	for( i_dev = 0; i_dev < global_info.num_devices; i_dev++ )
	{
		CUDA( cudaStreamDestroy( solver->streams[i_dev] ) );
	}
	free( solver->streams );

#ifdef MOD_EXCL_MPI
	MPI_CALL( MPI_Comm_free( &mpi_solver->comm ) );
#endif

	free( solver );

	return CUP_SUCCESS;
}

/**
 * \ingroup public
 */
#ifdef MOD_EXCL_SERIAL
/*
 * Execute a solver. The function will return immediately after launching the
 * execution. To check if the execution has finished use
 * \p cup_wait_solver or \p cup_is_solver_ready.
 * \param [in] solver The solver to execute.
 */
#endif
#ifdef MOD_EXCL_MPI
/*
 * Execute a MPI solver. If the CUP_MPI_NONBLOCKING environment variable is
 * nonzero and the MPI implementation provides MPI_THREAD_SERIALIZED
 * level of thread support, the function will return immediately after launching
 * the execution. To check if the execution has finished use
 * \p cup_mpi_wait_solver or \p cup_mpi_is_solver_ready. Otherwise, the function
 * will block until the solver has completed its execution.
 * \param [in,out] mpi_solver The solver to execute.
 */
#endif
/* \param [in,out] input The input data, in host memory, for the solver. If NULL,
 * it will be assumed that the input data is already in the devices' memory.
 * \param [out] output A pointer to memory on the host where the output will be
 * stored. It can be equal to \p input. If NULL, the results will not be copied
 * from devices' memory to host memory.
 * \return CUP_SUCCESS, CUP_OUT_OF_HOST_MEM.
 */
#ifdef MOD_EXCL_SERIAL
cup_error_t cup_exec_solver( cup_solver_t  solver,
                             const double* input,
                             double*       output )
#endif
#ifdef MOD_EXCL_MPI
cup_error_t cup_mpi_exec_solver( cup_mpi_solver_t mpi_solver,
                                 const double*    input,
                                 double*          output )
#endif
{
#ifdef MOD_EXCL_MPI
	if( mpi_solver == NULL )
		return CUP_INVALID_ARGUMENT;
	struct cup_solver* solver = &mpi_solver->base;
	struct mpi_solver_thread_parms* thread_parms;
#endif
#ifdef MOD_EXCL_SERIAL
	struct solver_thread_parms* thread_parms;
	if( solver == NULL )
		return CUP_INVALID_ARGUMENT;
#endif
	int i_dev;
	pthread_attr_t thread_attr;

	CHECK_INIT();

#ifdef MOD_EXCL_SERIAL
	cup_wait_solver( solver, NULL );
#endif
#ifdef MOD_EXCL_MPI
	cup_mpi_wait_solver( mpi_solver, NULL );
#endif

	solver->status = CUP_SUCCESS;
	solver->state = SOLVER_RUNNING;

#ifdef MOD_EXCL_MPI
	if( global_info.config.mpi_nonblocking )
	{
#endif
	solver->state_sync->counter = 0;
#ifdef MOD_EXCL_MPI
	}
#endif

	switch( solver->type )
	{
	case CUP_FFT_SOLVER:
#ifdef MOD_EXCL_SERIAL
		THROW_CUP_ERROR( init_fft_solver( solver ) );
#endif
		break;
	}

#ifdef MOD_EXCL_MPI
	if( global_info.config.mpi_nonblocking )
	{
#endif

	PTHREAD( pthread_attr_init( &thread_attr ) );
	PTHREAD( pthread_attr_setdetachstate( &thread_attr,
	                                      PTHREAD_CREATE_DETACHED ) );

#ifdef MOD_EXCL_MPI
	}
#endif

	for( i_dev = 0; i_dev < global_info.num_devices; i_dev++ )
	{
#ifdef MOD_EXCL_SERIAL
		MALLOC( thread_parms, 1, struct solver_thread_parms );
		thread_parms->i_dev = i_dev;
		thread_parms->solver = solver;
#endif
#ifdef MOD_EXCL_MPI
		MALLOC( thread_parms, 1, struct mpi_solver_thread_parms );
		thread_parms->solver = mpi_solver;
#endif
		thread_parms->input = input;
		thread_parms->output = output;

#ifdef MOD_EXCL_MPI
	if( global_info.config.mpi_nonblocking )
	{
#endif
		PTHREAD( pthread_create( solver->threads + i_dev,
						         &thread_attr,
#ifdef MOD_EXCL_SERIAL
						         solver_thread,
#endif
#ifdef MOD_EXCL_MPI
						         mpi_solver_thread,
#endif
						         thread_parms ) );

		print_msg( INFO, "Nonblocking FFT solver thread launched." );

#ifdef MOD_EXCL_MPI
	}
	else
	{
		mpi_solver_thread( thread_parms );
		print_msg( INFO, "Blocking FFT solver launched." );
	}
#endif
	}

	return CUP_SUCCESS;
}

/**
 * \ingroup public
 * Wait a solver for termination.
 * \param [in] solver The solver.
 * \param [out] status The execution status. The actual value will be the
 * error value regarding to the last device producing an error, if any.
 * \return CUP_SUCCESS
 */
#ifdef MOD_EXCL_SERIAL
cup_error_t cup_wait_solver( cup_solver_t solver, cup_error_t* status )
#endif
#ifdef MOD_EXCL_MPI
cup_error_t cup_mpi_wait_solver( cup_mpi_solver_t mpi_solver, cup_error_t* status )
#endif
{
#ifdef MOD_EXCL_SERIAL
	if( solver == NULL )
		return CUP_INVALID_ARGUMENT;
#endif
#ifdef MOD_EXCL_MPI
	if( mpi_solver == NULL )
		return CUP_INVALID_ARGUMENT;
	struct cup_solver* solver = &mpi_solver->base;
#endif

	CHECK_INIT();

	if( solver->state != SOLVER_READY)
	{

		PTHREAD( pthread_mutex_lock( &solver->state_sync->mutex ) );
		while( solver->state_sync->counter < global_info.num_devices )
		{
			PTHREAD( pthread_cond_wait( &solver->state_sync->cond,
							            &solver->state_sync->mutex ) );
		}
		PTHREAD( pthread_mutex_unlock( &solver->state_sync->mutex ) );

		solver->state = SOLVER_READY;
	}

	if( status != NULL )
		*status = solver->status;

	return CUP_SUCCESS;
}

/**
 * \ingroup public
 * Check the solver for readiness. The solver will be ready if it has terminated
 * its last execution or hasn't been already executed. It is a non-blocking way
 * for checking for execution termination.
 * \param [in] solver The solver.
 * \param [out] is_ready \p True if the solver is ready for execution, \p false
 *                       otherwise.
 * \return CUP_SUCCESS or CUP_INVALID_ARGUMENT if any parameter is NULL.
 */
#ifdef MOD_EXCL_SERIAL
cup_error_t cup_is_solver_ready( cup_solver_t solver, int* is_ready )
#endif
#ifdef MOD_EXCL_MPI
cup_error_t cup_mpi_is_solver_ready( cup_mpi_solver_t mpi_solver, int* is_ready )
#endif
{
#ifdef MOD_EXCL_SERIAL
	if( solver == NULL )
		return CUP_INVALID_ARGUMENT;
#endif
#ifdef MOD_EXCL_MPI
	if( mpi_solver == NULL )
		return CUP_INVALID_ARGUMENT;
	struct cup_solver* solver = &mpi_solver->base;
#endif
	if( is_ready == NULL )
		return CUP_INVALID_ARGUMENT;

	CHECK_INIT();

	if( solver->state == SOLVER_READY )
		*is_ready = 1;
	else
	{
		PTHREAD( pthread_mutex_lock( &solver->state_sync->mutex ) );
		*is_ready = solver->state_sync->counter == global_info.num_devices;
		PTHREAD( pthread_mutex_unlock( &solver->state_sync->mutex ) );
		if( *is_ready )
			solver->state = SOLVER_READY;
	}

	return CUP_SUCCESS;
}

#ifdef __cplusplus
} // extern "C"
#endif

/********************************************
 * Private functions                        *
 ********************************************/
/**
 * The function that executes each solver CPU thread.
 * \param [in,out] args The standard argument type for POSIX threads.
 */
#ifdef MOD_EXCL_SERIAL
static void* solver_thread( void* args )
#endif
#ifdef MOD_EXCL_MPI
static void* mpi_solver_thread( void* args )
#endif
{
	assert( args != NULL );

#ifdef MOD_EXCL_SERIAL
	struct solver_thread_parms* parms = (struct solver_thread_parms*) args;
	struct cup_solver* solver = parms->solver;

	assert( parms->i_dev >= 0 );
	assert( parms->solver != NULL );
#endif
#ifdef MOD_EXCL_MPI
	struct mpi_solver_thread_parms* parms = (struct mpi_solver_thread_parms*) args;
	assert( parms->solver != NULL );
	struct cup_solver* solver = &parms->solver->base;
#endif
	cup_error_t status;

	switch( solver->type )
	{
	case CUP_FFT_SOLVER:
#ifdef MOD_EXCL_SERIAL
		status = exec_fft_solver( parms->output,
		                          parms->input,
		                          parms->solver,
		                          parms->i_dev );
#endif
#ifdef MOD_EXCL_MPI
		status = mpi_exec_fft_solver( parms->output,
		                              parms->input,
		                              parms->solver );
#endif
		break;
	default:
		status = CUP_INVALID_ARGUMENT;
		break;
	}

#ifdef MOD_EXCL_MPI
	if( global_info.config.mpi_nonblocking )
	{
#endif
	if(pthread_mutex_lock( &solver->state_sync->mutex ) != 0)
	{
		solver->status = CUP_THREAD_ERROR;
		return NULL;
	}
	solver->state_sync->counter++;
	if( pthread_cond_signal( &solver->state_sync->cond ) != 0)
	{
		solver->status = CUP_THREAD_ERROR;
		return NULL;
	}
	if( pthread_mutex_unlock( &solver->state_sync->mutex ) != 0 )
	{
		solver->status = CUP_THREAD_ERROR;
		return NULL;
	}
#ifdef MOD_EXCL_MPI
	}
	else
		solver->state = SOLVER_READY;
#endif

	if( status != CUP_SUCCESS )
		solver->status = status;

	free( parms );

	return NULL;
}

/**
 * Function to check the validity of the parameters passed to a solver creation.
 * The most important validity checks should be centralized here.
 * \param [in] type The solver type.
 * \param [in] grid The grid properties.
 * \param [in] boundary_type The boundary conditions.
 * \return CUP_SUCCESS or CUP_INVALID_ARGUMENT.
 */
static cup_error_t check_constraints( cup_solver_type_t type,
                                      cup_grid_t        grid,
                                      cup_bnd_type      boundary_type )
{
	if ( type != CUP_FFT_SOLVER )
	{
		print_msg( ERROR,
		           "Invalid solver type: %d. "
		           "Currently just a FFT solver (CUP_FFT_SOLVER) is allowed.",
		           type );
		return CUP_INVALID_ARGUMENT;
	}

	if( grid->dim != 3 )
	{
		print_msg( ERROR,
		           "Invalid number of dimensions: %d. "
		           "Currently just 3 dimensions are allowed.",
		           grid->dim );
		return CUP_INVALID_ARGUMENT;
	}

	if( boundary_type != CUP_BND_PERIODIC )
	{
		print_msg( ERROR,
		           "Currently just periodic boundary conditions are allowed." )
		return CUP_INVALID_ARGUMENT;
	}

#ifdef MOD_EXCL_MPI
	if( global_info.num_devices > 1 )
	{
		print_msg( ERROR,
		           "The MPI solver cannot work with more than one device per "
		           "process." )
		return CUP_INVALID_ARGUMENT;

	}
#endif

	return CUP_SUCCESS;
}
