/*
    Copyright (C) 2016 University of the Basque Country, UPV/EHU.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

/********************************************
 * Includes                                 *
 ********************************************/
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <assert.h>
#include <cuPoisson.h>

#include "test.h"

/********************************************
 * Private function prototypes              *
 ********************************************/
static int test_num_devices();
static int test_lib_init( int num_devs );
static int test_grid( int* nps, double* size );
#ifdef MOD_SERIAL
static int test_fft_solver( int num_devs, char* data_path, char* grid_name );
static int read_header( FILE* fp, int* nps, double* spacing );
#endif

/********************************************
 * Main function                            *
 ********************************************/
int main( int argc, char* argv[] )
{
	assert( argc > 1 );

	if( !strcmp( argv[1], "num_devs" ) )
		return test_num_devices();

	if( !strcmp( argv[1], "lib_init" ) )
	{
		assert( argc > 2 );
		return test_lib_init( atoi( argv[2] ) );
	}

	if( !strcmp( argv[1], "grid" ) )
	{
		assert( argc > 7 );

		int nps[3];
		double size[3];

		nps[0] = atoi( argv[2] );
		nps[1] = atoi( argv[3] );
		nps[2] = atoi( argv[4] );
		size[0] = atof( argv[5] );
		size[1] = atof( argv[6] );
		size[2] = atof( argv[7] );

		return test_grid( nps, size );
	}
#ifdef MOD_SERIAL
	if( !strcmp( argv[1], "fft_solver" ) )
	{
		assert( argc > 4 );
		return test_fft_solver( atoi( argv[2] ), argv[3], argv[4] );
	}
#endif
	TEST( 0 );
}

/********************************************
 * Private functions                        *
 ********************************************/
static int test_num_devices()
{
	int num_devices;
	cup_error_t status;

	status = cup_get_valid_devices( NULL, &num_devices );

	TEST( status == CUP_SUCCESS );
	TEST( num_devices > 0 );

	END_TEST();
}

static int test_lib_init( int num_devs )
{
	cup_error_t status;

	status = cup_init( num_devs, NULL );
	TEST( status == CUP_SUCCESS );
	status = cup_finish();
	TEST( status == CUP_SUCCESS );

	END_TEST();
}

static int test_grid( int* nps, double* size )
{
	cup_error_t status;
	cup_grid_t grid;

	status = cup_create_grid( 3, nps, size, &grid );
	TEST( status == CUP_SUCCESS );

	END_TEST();
}

#ifdef MOD_SERIAL
static int test_fft_solver( int num_devs, char* data_path, char* grid_name )
{
	int num_devices;
	cup_error_t status,
	            exec_status;
	cup_grid_t grid;
	cup_solver_t solver;

	double* input,
	      * output,
	        spacing,
	        spacing_sol,
	        sum_solver,
	        sum_file;
	FILE *fp;
	char filename[80];
	int nps[3],
	    np,
	    nps_sol[3],
	    num_items,
	    is_ready,
	    i;
	const double EPSILON = 1e-10;

	status = cup_get_valid_devices( NULL, &num_devices );
	TEST( status == CUP_SUCCESS );
	TEST( num_devices > 0 );
	if( num_devs > 0 )
		assert( num_devices >= num_devs );

	status = cup_init( num_devs, NULL );
	TEST( status == CUP_SUCCESS );

	sprintf( filename, "%s/%s.grd", data_path, grid_name );
	fp = fopen( filename, "rb" );
	if( fp == NULL )
	{
		fprintf( stderr, "Error opening input grid file: %s\n", filename );
		return 1;
	}
	if( read_header( fp, nps, &spacing ) )
	{
		fprintf( stderr, "Error reading input grid file header.\n" );
		return 1;
	}
	np = nps[0] * nps[1] * nps[2];

	double size[3] = { 1.0, 1.0, 1.0 };
	status = cup_create_grid( 3, nps, size, &grid );
	TEST( status == CUP_SUCCESS );

	cup_create_solver( CUP_FFT_SOLVER, grid, CUP_BND_PERIODIC, &solver );
	TEST( status == CUP_SUCCESS );

	status = cup_malloc( (void**) &input, np * sizeof( double ) );
	TEST( status == CUP_SUCCESS );

	status = cup_malloc( (void**) &output, np * sizeof( double ) );
	TEST( status == CUP_SUCCESS );

	num_items = fread( input, sizeof( double ), np, fp );
	if( num_items < np )
	{
		fprintf( stderr, "Error reading input grid data.\n" );
		return 1;
	}
	fclose( fp );

	status = cup_exec_solver( solver, input, output );
	TEST( status == CUP_SUCCESS );

	status = cup_is_solver_ready( solver, &is_ready );
	TEST( status == CUP_SUCCESS );

	status = cup_wait_solver( solver, &exec_status );
	TEST( status == CUP_SUCCESS );
	TEST( exec_status == CUP_SUCCESS );

	sprintf( filename, "%s/%s_sol.grd", data_path, grid_name );
	fp = fopen( filename, "rb" );
	if( fp == NULL )
	{
		fprintf( stderr, "Error opening solution grid file.\n" );
		return 1;
	}
	if( read_header( fp, nps_sol, &spacing_sol ) )
	{
		fprintf( stderr, "Error reading solution grid file header.\n" );
		return 1;
	}
	if( nps[0] != nps_sol[0] ||
		nps[1] != nps_sol[1] ||
		nps[2] != nps_sol[2] ||
		spacing != spacing_sol )
	{
		fprintf( stderr, "Input grid and solution grid does not match.\n" );
		return 1;
	}

	num_items = fread( input, sizeof( double ), np, fp );
	if( num_items < np )
	{
		fprintf( stderr, "Error reading solution grid data.\n" );
		return 1;
	}
	fclose( fp );

	sum_solver = sum_file = 0.0;
	for( i = 0; i < np; i++ )
	{
		TEST( fabs( input[i] - output[i] ) < EPSILON );
		sum_solver += output[i];
		sum_file += input[i];
	}
	TEST( fabs( sum_solver - sum_file ) / sum_file < EPSILON );

	cup_free( input );
	cup_free( output );

	cup_destroy_solver( solver );
	TEST( status == CUP_SUCCESS );

	status = cup_finish();
	TEST( status == CUP_SUCCESS );

	cup_destroy_grid( grid );

	END_TEST();
}

static int read_header( FILE* fp, int* nps, double* spacing )
{
	int dim;

	if( fread( &dim, sizeof( int ), 1, fp ) < 1 )
		return 1;
	assert( dim == 3 );

	if( fread( nps, sizeof( int ), 3, fp ) < 3 )
		return 1;
	assert( nps[0] > 0 );
	assert( nps[1] > 0 );
	assert( nps[2] > 0 );
	if( fread( spacing, sizeof( double ), 1, fp ) < 1 )
		return 1;
	assert( *spacing > 0 );

	return 0;
}
#endif
