/*******************************************************************************
!   Copyright(C) 2011-2012 Intel Corporation. All Rights Reserved.
!
!   The source code, information  and  material ("Material") contained herein is
!   owned  by Intel Corporation or its suppliers or licensors, and title to such
!   Material remains  with Intel Corporation  or its suppliers or licensors. The
!   Material  contains proprietary information  of  Intel or  its  suppliers and
!   licensors. The  Material is protected by worldwide copyright laws and treaty
!   provisions. No  part  of  the  Material  may  be  used,  copied, reproduced,
!   modified, published, uploaded, posted, transmitted, distributed or disclosed
!   in any way  without Intel's  prior  express written  permission. No  license
!   under  any patent, copyright  or  other intellectual property rights  in the
!   Material  is  granted  to  or  conferred  upon  you,  either  expressly,  by
!   implication, inducement,  estoppel or  otherwise.  Any  license  under  such
!   intellectual  property  rights must  be express  and  approved  by  Intel in
!   writing.
!
!   *Third Party trademarks are the property of their respective owners.
!
!   Unless otherwise  agreed  by Intel  in writing, you may not remove  or alter
!   this  notice or  any other notice embedded  in Materials by Intel or Intel's
!   suppliers or licensors in any way.
!
!*******************************************************************************
!   Content:
!       SGEMM, DGEMM, CGEMM, ZGEMM benchmarks driver
!******************************************************************************/

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <math.h>

#include "bench.h"
#include "utils.h"
#include "cross_timer.h"

#include "omp.h"

static size_t elemsA, elemsB, elemsC;
static fptype_mtx_t *A, *B;
static fptype_mtx_c_t *C;
cross_timer_t timeIn, timeOut;

// counts padded matrix dimension
#define padd_dim(dim)  (((dim+127ll)/128)*128+16)
// counts size of padded matrix
#define padd_size(dim_p, dim_q) (padd_dim(dim_p)*dim_q)

#ifndef KNL_SNC_MODE
void initialize_bench(int maxM, int maxN, int maxK, int num_threads,
	const char transa, const char transb)
{
	elemsA = (transa == 'N' ? padd_size(maxM, maxK) : padd_size(maxK, maxN));
	elemsB = (transb == 'N' ? padd_size(maxK, maxN) : padd_size(maxN, maxK));
	elemsC = padd_size(maxN, maxM);

	A = bench_malloc(elemsA * sizeof(fptype_mtx_t));
	B = bench_malloc(elemsB * sizeof(fptype_mtx_t));
	C = bench_malloc(elemsC * sizeof(fptype_mtx_c_t));

	fill_matrix(A, elemsA);
	fill_matrix(B, elemsB);
	fill_matrix_c(C, elemsC);

	int max_threads = kmp_get_affinity_max_proc();
	int num_cores = max_threads / THREADS_PER_CORE;

	if (getenv("OMP_NUM_THREADS") == NULL || num_threads > 0) {
		if (num_threads <= 0) {
			num_threads = num_cores;
			printf("threads used: %d (autodetected)\n", num_threads);
		} else {
			printf("threads used: %d (set by user)\n", num_threads);
		}
		fflush(NULL);
		omp_set_num_threads(num_threads);
	} else {
		printf("threads used: %d (OMP_NUM_THREADS)\n",
			omp_get_max_threads());
		fflush(NULL);
	}

	printf("threads/core: %d\n",
		num_threads > num_cores ? num_threads/num_cores : 1);
	fflush(NULL);
}


#else	// KNL_SNC_MODE

#include <mpi.h>

void initialize_bench(int maxM, int maxN, int maxK, int num_threads,
	const char transa, const char transb, int hbw_memory_numa_node,
	int rank, int world_size)
{
	int max_threads;
	int num_cores;

        elemsA = (transa == 'N' ? padd_size(maxM, maxK) : padd_size(maxK, maxN));
        elemsB = (transb == 'N' ? padd_size(maxK, maxN) : padd_size(maxN, maxK));
	elemsC = padd_size(maxN, maxM);

	A = bench_malloc(elemsA * sizeof(fptype_t), hbw_memory_numa_node);
	B = bench_malloc(elemsB * sizeof(fptype_t), hbw_memory_numa_node);
	C = bench_malloc(elemsC * sizeof(fptype_t), hbw_memory_numa_node);

	fill_matrix(A, elemsA);
	fill_matrix(B, elemsB);
	fill_matrix_c(C, elemsC);

	max_threads = kmp_get_affinity_max_proc();
	num_cores = max_threads / THREADS_PER_CORE;

	if (getenv("KMP_HW_SUBSET") == NULL || num_threads > 0) {
		if (num_threads <= 0) {
			num_threads = num_cores / world_size;
			printf("MPI rank %d  : using %d threads (autodetected)\n",
				rank, num_threads);
		} else {
			printf("MPI rank %d  : using %d threads (set by user)\n",
				rank, num_threads);
		}
		fflush(NULL);
		omp_set_num_threads(num_threads);
	} else {
		printf("MPI rank %d  : using %d threads (KMP_HW_SUBSET)\n",
			rank, omp_get_max_threads());
		fflush(NULL);
	}

	MPI_Barrier(MPI_COMM_WORLD);
	if (rank == 0) {
	        printf("threads/core: %d\n",
		        num_threads > num_cores ? num_threads/num_cores : 1);
		fflush(NULL);
	}
}


int is_bench_initialized(){
	if (A == NULL || B == NULL || C == NULL)
		return 0;
	else
		return 1;
}

#endif

void finalize_bench()
{
	bench_free(A, elemsA * sizeof(fptype_mtx_t));
	bench_free(B, elemsB * sizeof(fptype_mtx_t));
	bench_free(C, elemsC * sizeof(fptype_mtx_c_t));
}

double xgemm_bench(char transa, char transb,
		int M, int N, int K, fptype_t alpha, fptype_t beta)
{
	int LDA, LDB, LDC;

	LDA = (transa == 'N' ? padd_dim(M) : padd_dim(K));
	LDB = (transb == 'N' ? padd_dim(K) : padd_dim(N));
	LDC = padd_dim(M);

	cross_timer_sample(&timeIn);
#if defined(INTEGER)
	//igemm zero offset
	char offsetc = 'F';
	MKL_INT16 xo = 0;
	fptype_mtx_c_t *co;

	co = _mm_malloc(sizeof(fptype_mtx_c_t)*M, 2*1024*1024);
	for (int i = 0; i < M; ++i) co[i] = i;

	gemm_s16s16s32(&transa, &transb, &offsetc,
			&M, &N, &K, &alpha,
			A, &LDA, &xo,
			B, &LDB, &xo,
			&beta, C, &LDC, co);

	_mm_free(co);
#else
	xgemm(&transa, &transb, &M, &N, &K, &alpha,
		A, &LDA, B, &LDB, &beta, C, &LDC);

#endif
	cross_timer_sample(&timeOut);
	return cross_timer_diff(timeIn, timeOut);
}

