/* SCE CONFIDENTIAL
 * PlayStation(R)3 Programmer Tool Runtime Library 475.001
 * Copyright (C) 2011 Sony Computer Entertainment Inc.
 * All Rights Reserved.
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <cell/face.h>
#include <cell/spurs.h>

#include "memory.h"

#include <new>

#include "sample_face_util.h"
#include "TargetInfo.h"
#include "LocalSearchInfo.h"
#include "GlobalFaceSearch.h"


#define VERVOSE_DEBUG 0


#define max(a, b) (((a) > (b)) ? (a) : (b))
#define min(a, b) (((a) < (b)) ? (a) : (b))

////////////////////////////////////////////////
// Constructor
////////////////////////////////////////////////
GlobalFaceSearch::
GlobalFaceSearch()
{
	mInit = false;
	mDetecting = false;
	mSaveImage = NULL;
	mSaveResult = NULL;
	mWorkingArea = NULL;
}

////////////////////////////////////////////////
// Destructor
////////////////////////////////////////////////
GlobalFaceSearch::
~GlobalFaceSearch()
{

}

////////////////////////////////////////////////
// GetAlign
////////////////////////////////////////////////
int GlobalFaceSearch::
GetAlign(void)
{
	int align = GLOBAL_FACE_SEARCH_ALIGN;

	align = max(align, CELL_SPURS_TASKSET_ALIGN);
	align = max(align, sampleFaceUtilTask2GetAlign());
	align = max(align, 16); //compare CellFace...
	
	return align;
}

////////////////////////////////////////////////
// GetWorkingMemorySize
////////////////////////////////////////////////
int GlobalFaceSearch::
GetWorkingMemorySize(int size, int width, int height, int rowstride, int maxTarget)
{
	(void)width;
	int memsize = size;
	
	memsize += sizeof_byte_align(128, memsize, sizeof(unsigned char) * rowstride * height);
	memsize += sizeof_byte_align(128, memsize,
								 CELL_FACE_UTIL_WORK_SIZE(width, height, rowstride));
	memsize += sizeof_byte_align(16, memsize, sizeof(CellFaceDetectionResult) * maxTarget);
	memsize += sizeof_byte_align(4, memsize, sizeof(int) * maxTarget);

	return (memsize - size);
}

////////////////////////////////////////////////
// Initialize
////////////////////////////////////////////////
int GlobalFaceSearch::
Initialize(const void *work, const CellSpurs2 *spurs, 
		   int width, int height, int rowstride, int maxTarget)
{
	if (mInit) return GLOBAL_FACE_SEARCH_ERROR;
	
	unsigned char *p = (unsigned char*)work;

	////////////////////////////////////////
	// spurs taskset initialize
	////////////////////////////////////////
	const unsigned char priorities[8] = {1, 1, 1, 1, 1, 1, 1, 1};
	int ret = sampleFaceUtilInitializeEx((CellSpurs2*)spurs, &mSpursTaskset, priorities, 2);
	if (ret != CELL_OK) return ret;

	////////////////////////////////////////
	//memory assignment
	////////////////////////////////////////
	p += assign_memory_byte_align(128, (void *)p, (void *)&mSaveImage,
								  sizeof(unsigned char) * rowstride * height);
	p += assign_memory_byte_align(128, (void *)p, (void *)&mWorkingArea,
								  CELL_FACE_UTIL_WORK_SIZE(width, height, rowstride));
	p += assign_memory_byte_align(16, (void *)p, (void *)&mSaveResult,
								  sizeof(CellFaceDetectionResult) * maxTarget);
	p += assign_memory_byte_align(4, (void *)p, (void *)&mSaveStatus,
								  sizeof(int) * maxTarget);


	/////////////////////////////////////
	//set parameter
	/////////////////////////////////////
	mWidth = width;
	mHeight = height;
	mRowstride = rowstride;
	mMaxTarget = maxTarget;
	for(int i=0; i<maxTarget; i++)
		mSaveStatus[i] = TARGET_STATUS_EMPTY;
	
	for(int i=0; i<CELL_FACE_DETECT_NUM_MAX; i++) { //
		mRegisteredTargetNum[i] = 0;
	}
	
	//set initialize flag
	mInit = true;
	mDetecting = false;

	/////////////////////////////////////
	// set libface parameter
	/////////////////////////////////////
	//global search parameter set
	cellFaceUtilDetection3DParamInitialize(
		&mDetectParam,
		mSaveImage, mWidth, mHeight, mRowstride,
		mWorkingArea,
		mDetectResult, CELL_FACE_DETECT_NUM_MAX );

	// initialize parameter for parts
	cellFaceUtilPartsParamInitialize(
		&mPartsParam,
		mSaveImage, mWidth, mHeight, mRowstride,
		mWorkingArea,
		mDetectResult,
		mPartsResult, &mPosition );
	
	// initialize parameter for feature
	cellFaceUtilFeature2ParamInitialize(
		&mFeatureParam,
		mSaveImage, mWidth, mHeight, mRowstride,
		mWorkingArea,
		&mPosition, &mFeature);
	
	// initialize parameter for similarity
	cellFaceUtilSimilarity2ParamInitialize(
		&mSimilarityParam,
		&mFeature, NULL, 0, 0, NULL);

	return (int)(p - (unsigned char*)work);
}


////////////////////////////////////////////////
// Finalize
////////////////////////////////////////////////
int GlobalFaceSearch::
Finalize(void)
{
	if (!mInit) return GLOBAL_FACE_SEARCH_ERROR;

	if(mDetecting) {
		int numFace; //dummy
		sampleFaceUtilDetection3DTaskEnd(&mSpursTaskset, &mDetectTask, &numFace);
	}
	
	int ret = sampleFaceUtilFinalizeEx(&mSpursTaskset);
	if (ret != CELL_OK) return ret;

	mSaveImage = NULL;
	mSaveResult = NULL;
	mWorkingArea = NULL;
	
	mInit = false;
	mDetecting = false;
	
	return GLOBAL_FACE_SEARCH_OK;
}


////////////////////////////////////////////////
// ExecFaceSearch
////////////////////////////////////////////////
int GlobalFaceSearch::
ExecFaceSearch(const unsigned char *image, //input Y image
			   const int *targetStatus, //input 
			   const FaceRegisterInfo *frInfo, //input
			   const LocalSearchInfo *lsInfo, //input
			   const CellFaceTrackerCallbackFunc callbackFunc, //input
			   int *numRegisteredNewFace //output
			   )
{
	if (!mInit) return GLOBAL_FACE_SEARCH_ERROR;
	
	int numFace = 0;
	*numRegisteredNewFace = 0;
	int ret = 0;

	//check process finish
	if(mDetecting) {
		//check finish
		ret = sampleFaceUtilDetection3DTaskTryEnd(&mSpursTaskset, &mDetectTask, &numFace);
		if (ret == CELL_SPURS_TASK_ERROR_AGAIN) { // still in process
			return ret;
		}
		else { // end of process
			mDetecting = false;
		}
	}

	for(int gnum = 0; gnum < numFace; gnum++) { //detection face loop
		int lnum;
		//////////////////////////////
		//overlap check
		//////////////////////////////
		for(lnum = 0; lnum < mMaxTarget; lnum++) { //local target face loop
			if(mSaveStatus[lnum] == TARGET_STATUS_NULL ||
			   mSaveStatus[lnum] == TARGET_STATUS_EMPTY) { //skip check target
				continue;
			}
			if (CheckOverlap(mDetectResult[gnum], mSaveResult[lnum],
							 GLOBAL_FACE_SEARCH_OVERLAP_THRE)) { //overlap
				break;
			}
		}
		if (lnum < mMaxTarget) { //overlap or NULL/EMPTY
			continue;
		}
		
		
#if VERVOSE_DEBUG
		fprintf(stderr, "#GlobalFaceSearch::ExecFaceSearch() face (%f, %f)-(%f, %f), (%f, %f, %f)\n",
				mDetectResult[gnum].faceX,
				mDetectResult[gnum].faceY,
				mDetectResult[gnum].faceW,
				mDetectResult[gnum].faceH,
				mDetectResult[gnum].faceRoll,
				mDetectResult[gnum].facePitch,
				mDetectResult[gnum].faceYaw
				);
		
		fprintf(stderr, "#GlobalFaceSearch::ExecFaceSearch() detect parts (%d) ...\n", gnum);
#endif //VERVOSE_DEBUG
		
		///////////////////////////
		//parts detection
		///////////////////////////
		mPartsParam.eaImage = (uintptr_t)mSaveImage;
		mPartsParam.faceX = mDetectResult[gnum].faceX;
		mPartsParam.faceY = mDetectResult[gnum].faceY;
		mPartsParam.faceW = mDetectResult[gnum].faceW;
		mPartsParam.faceH = mDetectResult[gnum].faceH;
		mPartsParam.faceRoll  = mDetectResult[gnum].faceRoll;
		mPartsParam.facePitch = mDetectResult[gnum].facePitch;
		mPartsParam.faceYaw   = mDetectResult[gnum].faceYaw;
		ret = sampleFaceUtilPartsTaskBegin(&mSpursTaskset, &mPartsTask, &mPartsParam);
		if(ret != CELL_FACE_OK) { //can't detect parts
			continue;
		}
		ret = sampleFaceUtilPartsTaskEnd(&mSpursTaskset, &mPartsTask);
		if (ret != CELL_FACE_OK) { //can't detect parts
			continue;
		}
		
#if VERVOSE_DEBUG
		for(int i=0; i<CELL_FACE_PARTS_NUM_MAX; i++) {
			fprintf(stderr, "#GlobalFaceSearch::ExecFaceSearch() PartsResult[%d] : %d, %f, %f, %f\n",
					i, 
					mPartsResult[i].partsId,
					mPartsResult[i].partsX,
					mPartsResult[i].partsY,
					mPartsResult[i].score
					);
		}

		fprintf(stderr, "#GlobalFaceSearch::ExecFaceSearch() detect feature (%d) ...\n", gnum);
#endif //VERVOSE_DEBUG
		
		///////////////////////////
		//feature detection
		///////////////////////////
		mFeatureParam.eaImage = (uintptr_t)mSaveImage;
		ret = sampleFaceUtilFeature2TaskBegin(&mSpursTaskset, &mFeatureTask, &mFeatureParam);
		if(ret != CELL_FACE_OK) { //can't detect feature
			continue;
		}
		ret = sampleFaceUtilFeature2TaskEnd(&mSpursTaskset, &mFeatureTask);
		if(ret != CELL_FACE_OK) { //can't detect feature
			continue;
		}
		
		///////////////////////////
		//check feature similarity
		///////////////////////////
		mSimilarityParam.eaRegFeatureArray =
			(uintptr_t)frInfo[0].GetFeature(); //all (include empty) target feature
		mSimilarityParam.strideRegFeature =
			sizeof(FaceRegisterInfo) - sizeof(CellFaceFeature2);
		mSimilarityParam.numRegFeature = mMaxTarget;
		mSimilarityParam.eaScoreResultArray = NULL;

		float simScore;
		int simID;
		
		ret = sampleFaceUtilSimilarity2TaskBegin(&mSpursTaskset, &mSimilarityTask, &mSimilarityParam);
		if (ret != CELL_FACE_OK) {
			continue;
		}
		ret = sampleFaceUtilSimilarity2TaskEnd(&mSpursTaskset, &mSimilarityTask, &simScore, &simID);
		if (ret != CELL_FACE_OK) {
			continue;
		}
		
#if VERVOSE_DEBUG
		fprintf(stderr, "#GlobalFaceSearch::ExecFaceSearch() similarity score = %f, id = %d\n",
				simScore, simID);
#endif //VERVOSE_DEBUG
		
		///////////////////////////
		//append new face result
		///////////////////////////
		if(simScore > GLOBAL_FACE_SEARCH_SIMILARITY_THRE) { //registered new face
			mRegisteredTargetNum[*numRegisteredNewFace] = simID;
			memcpy(&mGlobalDetectResult[*numRegisteredNewFace], &mDetectResult[gnum],
				   sizeof(CellFaceDetectionResult));
			*numRegisteredNewFace = *numRegisteredNewFace + 1;
		}
		else { //not registered new face
			//callback
			if(callbackFunc != NULL) { //if callback function is set
				callbackFunc((const CellFaceFeature2 *)&mFeature,
							  (const void *)&mDetectResult[gnum] //tentative
							  );
			}
		}
	} //detection face loop

	// save image to local buffer.
	memcpy(mSaveImage, image, sizeof(unsigned char) * mRowstride * mHeight);
	memcpy(mSaveStatus, targetStatus, sizeof(int) * mMaxTarget);
	for(int i=0; i<mMaxTarget; i++) {
		memcpy(&mSaveResult[i], lsInfo[i].GetDetectResult(), sizeof(CellFaceDetectionResult));
	}

	/////////////////////////////////
	//restart global search
	/////////////////////////////////
	mDetectParam.eaImage = (uintptr_t)mSaveImage;
	ret = sampleFaceUtilDetection3DTaskBegin(&mSpursTaskset, &mDetectTask, &mDetectParam);
	if (ret != CELL_FACE_OK) {
		return ret;
	}
	mDetecting = true;
	return GLOBAL_FACE_SEARCH_OK;
}

////////////////////////////////////////////////////////////////////////
//private
////////////////////////////////////////////////////////////////////////
/////////////////////////////////
// CheckOverlap
/////////////////////////////////
bool GlobalFaceSearch::
CheckOverlap(
	const CellFaceDetectionResult& a,
	const CellFaceDetectionResult& b,
	const float thresh
)
{
	const float tmpx = b.faceX - a.faceX;
	const float tmpy = b.faceY - a.faceY;

	if (((a.faceW - tmpx)*(b.faceW + tmpx) > 0.0f) &&
	    ((a.faceH - tmpy)*(b.faceH + tmpy) > 0.0f)) {
		// Overlapped
		const float x1 = max(a.faceX, b.faceX);
		const float y1 = max(a.faceY, b.faceY);
		const float x2 = min(a.faceX + a.faceW, b.faceX + b.faceW);
		const float y2 = min(a.faceY + a.faceH, b.faceY + b.faceH);
		const float a_area   = a.faceW * a.faceH;
		const float b_area   = b.faceW * b.faceH;
		const float and_area = (x2 - x1) * (y2 - y1);
		return thresh < (and_area / a_area) || thresh < (and_area / b_area);
	}
	// Not overlapped
	return false;
}



