#include "tool.h"

/*
########################################################################
#
#  KEYS: DDP model optimization tool
#  Version: keysCo (auto sets "correlated" mitigations)
#  Copyright (C) 2009 Gregory Gay <gregoryg@csee.wvu.edu)
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.
########################################################################
*/

int costFlag, attFlag, displayFlag, randomFlag, runFlag, maxFutileFlag;
float costLimit, attLimit;
int Seed;
int RunTotal;

float FixedMitigations[TotalMitigations+1];
float mArray[TotalMitigations+1];

int mCounter,era,run,TotalInstances,instanceCounter,displayMode;
float MinCost,MaxCost,MinAtt,MaxAtt;
float infinity,small;

int recentSetMitigation;
float recentSetMitigationStatus;

float LastMinCost,LastMaxAtt;

int main(int argc, char **argv)
{
	//this is required to set up the ddp model.
	setupModel();
	
	Seed = 0;

	char *costValue = NULL;
	char *attValue = NULL;
	char *displayValue = NULL;
	char *randomValue = NULL;
	char *runValue = NULL;
	char *futileValue = NULL;
	int c;
	int futile = 0;
	int MaxFutile= 1000000;

	displayFlag = 0;
	costFlag = 0;
	attFlag = 0;
	randomFlag = 0;
	runFlag = 0;
	maxFutileFlag = 0;
	
	infinity = pow(10,20);
	small = pow(10,-20);
	
	opterr = 0;

	while ((c = getopt(argc, argv, "a:c:d:f:h:r:t:")) != -1)
	{
		switch (c)
		{
			case 'a':
				attFlag = 1;
				attValue = optarg;
				break;
			case 'c':
				costFlag = 1;
				costValue = optarg;
				break;
			case 'd':
				displayFlag = 1;
				displayValue = optarg;
				break;
			case 'f':
				maxFutileFlag = 1;
				futileValue = optarg;
				break;
			case 'r':
				randomFlag = 1;
				randomValue = optarg;
				break;
			case 't':
				runFlag = 1;
				runValue = optarg;
				break;
			case 'h':
			case '?':
				printf("\nThe options must have the format -a AttainmentLowerLimit -c CostUpperLimit -d DisplayMode -f Futile -r Seed -t TotalRuns.\n");
				printf("\nAttainmentLowerLimit:\n\tThis is the lower limit that the user can set for attainment.\n");
				printf("\tThe tool tries to find a mitigation set that gives an attainment equal to or higher than this value.\n");
				printf("\tHowever, this is not guaranteed especially for limits set too high.\n");
				printf("\nCostUpperLimit:\n\tThis is the upper limit that the user can set for cost.\n");
				printf("\tThe tool tries to find a mitigation set that gives a cost equal to or lower than this value.\n");
				printf("\tHowever, this is not guaranteed especially for limits set too low.\n");
				printf("\nDisplay Mode:\n\tDisplay mode 1 prints the final results.\n");
				printf("\tDisplay mode 2 prints median and spread of each mitigation-fixing round.\n");
				printf("\tDisplay mode 3 is for debugging purposes only and prints everything needed.\n");
				printf("\tDisplay mode 4 is for debugging purposes only and prints cost and attainment of each run.\n");
				printf("\nFutile:\n\tThis is the number of trials before the tool decides that no improvements were seen.\n");
				printf("\tIt can be a number between 1 and the number of mitigations used.\n");
				printf("\tIt can be turned off by setting it the number of mitigations used. The default value is 10.\n");
				printf("\nSeed:\n\tSeed is used in random number generation. By default, it is generated using the system time.\n");
				printf("\nTotalRuns:\n\tThe number of total runs that is used internally. The default value is 100.\n");
				printf("\n");
				return(1);
				break;
			default:
				break;
		}
	}

	if (attFlag == 1 && attValue != NULL)
		attLimit = (float)atof(attValue);
	else
		attLimit = -infinity;

	if (costFlag == 1 && costValue != NULL)
		costLimit = (float)atof(costValue);
	else
		costLimit = infinity;

	if (randomFlag == 1 && randomValue != NULL)
		Seed = atoi(randomValue);
	else
		Seed = 1;

	if (runFlag == 1 && runValue != NULL)
		RunTotal = atoi(runValue);
	else
		RunTotal = 100;

	if (displayFlag == 1 && displayValue != NULL)
		displayMode = atoi(displayValue);
	else
		displayMode = 1;

	if (maxFutileFlag == 1 && futileValue != NULL)
		MaxFutile = atoi(futileValue);
	else
		MaxFutile = 1000000;

	float att, cost;

	float *Distance = new float[RunTotal+1];

	float **Data = new float*[RunTotal+1];
	for (int i = 0; i < RunTotal + 1; i++)
		Data[i] = new float[TotalMitigations+2];

	if (randomFlag == 1)
		srand(Seed);
	else
		srand((unsigned int)time(NULL));

	//set all mitigations to non-fixed value of -1
	for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
		FixedMitigations[mCounter] = -1;

	if (displayMode == 2)
		printf("Era, Mitigation Number, Mitigation Value, Median of Cost, Spread of Cost, Median of Attainment, Spread of Attainment\n");

	MinCost = infinity;
	MaxCost = -infinity;

	MinAtt = infinity;
	MaxAtt = -infinity;

	LastMinCost = MinCost;
	LastMaxAtt = MaxAtt;

	for (era = 1; era <= TotalMitigations; era++)
	{
		TotalInstances = 0;

		for (run = 1; run <= RunTotal; run++)
		{
			for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
			{
				//if in the previous runs the mitigation is fixed to a certain value (0 or 1) then use that value. otherwise, select it at random
				if (FixedMitigations[mCounter] == -1)
					mArray[mCounter] = selectValue(0,1);
				else
					mArray[mCounter] = FixedMitigations[mCounter];
				
			}	
			//find the cost and att using these mitigations
			model(&cost, &att, mArray);

			if (displayMode == 4)
				printf("%.1f,%.5f\n", cost, att);

			//store the current instance
			addInstance(cost, att, Data);
		}

		//find the sweet spot and distances from sweet spot for each distance
		sweetSpot(Data, Distance);
		
		//rank the mitigations and find the mitigation that should be fixed
		rankMitigations(Data, Distance);
		

		if (displayMode == 2)
			reportMedianAndSpread(Data);

//		printf("%d,futile:%d,%f,last:%f,%f,last:%f\n",era,futile,MinCost,LastMinCost,MaxAtt,LastMaxAtt);
		if (MinCost < LastMinCost && MaxAtt >= LastMaxAtt)
		{
			LastMinCost = MinCost;
			LastMaxAtt = MaxAtt;
			futile = 0;
		}
		else if (futile > MaxFutile)
		{
			if (displayMode == 1 || displayMode == 3)
				printf("Terminated at era: %d\n",era);
			era = TotalMitigations + 1;
			futile++;
		}
		else
			futile++;
	}

	//show the final results
	if (displayMode == 1 || displayMode == 3)
	{
		for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
		{
			mArray[mCounter] = FixedMitigations[mCounter];
			if (FixedMitigations[mCounter] == -1)
				mArray[mCounter] = 0;
		}
		model(&cost,&att,mArray);
		for (mCounter=1; mCounter<=TotalMitigations; mCounter++)
			printf("m[%d],",mCounter);
		printf("cost,attainment\n");
		for (mCounter=1; mCounter<=TotalMitigations; mCounter++)
			printf("%.0f,", mArray[mCounter]);
		printf("%.1f,%.5f\n",cost,att);
	}
}

void reportMedianAndSpread(float** Data)
{
	float tempCostArray[TotalInstances], tempAttArray[TotalInstances];
	float costMedian,costSpread,attMedian,attSpread;
	
	//sort the cost and att (individually) and find the median and spread
	for (instanceCounter = 1; instanceCounter <= TotalInstances; instanceCounter++)
	{
		tempCostArray[instanceCounter] = Data[instanceCounter][1];
		tempAttArray[instanceCounter] = Data[instanceCounter][2];
	}
	findMedianAndSpread(tempCostArray,TotalInstances,&costMedian,&costSpread);
	findMedianAndSpread(tempAttArray,TotalInstances,&attMedian,&attSpread);

	printf("%d,%d,%.0f,%.5f,%.5f,%.5f,%.5f\n",era,recentSetMitigation,recentSetMitigationStatus,costMedian,costSpread,attMedian,attSpread);
}

void findMedianAndSpread(float inputArray[], int size, float *median, float *spread)
{
	float tempValue;
	int i,j;
	float tempArray[size];

	for (i = 1; i <= size; i++)
		tempArray[i] = inputArray[i];

	//sort	
	for (i = 1; i <= size; i++)
	{
		tempValue = tempArray[i];
		j = i;
		
		while ((j > 1) && (tempArray[j-1] > tempValue))
		{
			tempArray[j] = tempArray[j-1];
			j = j - 1;
		}
		tempArray[j] = tempValue;
	}

	*median = tempArray[size/2];
	*spread = tempArray[3*size/4] - tempArray[size/2];
}

float findBestDistance(float inputArray[], int size)
{
	float tempValue;
	int i,j;
	float tempArray[size];

	for (i = 1; i <= size; i++)
		tempArray[i] = inputArray[i];
	
	//sort	
	for (i = 1; i <= size; i++)
	{
		tempValue = tempArray[i];
		j = i;
		
		while ((j > 1) && (tempArray[j-1] > tempValue))
		{
			tempArray[j] = tempArray[j-1];
			j = j - 1;
		}
		tempArray[j] = tempValue;
	}

	return tempArray[int(0.1*size)];
}

void rankMitigations(float** Data, float* Distance)
{
	float tempScoreOff[TotalMitigations],tempScoreOn[TotalMitigations],tempBestFreqCount[TotalMitigations][2],tempRestFreqCount[TotalMitigations][2];

	for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
	{
		tempBestFreqCount[mCounter][0] = 0;
		tempBestFreqCount[mCounter][1] = 0;
		tempRestFreqCount[mCounter][0] = 0;
		tempRestFreqCount[mCounter][1] = 0;
	}
	
	float bestValue = findBestDistance(Distance,TotalInstances);

	for (instanceCounter = 1; instanceCounter <= TotalInstances; instanceCounter++)
	{
		if (displayMode == 3) 
		{
			for (mCounter=1; mCounter<=TotalMitigations; mCounter++)
				printf("%.0f,",Data[instanceCounter][mCounter+2]); 
			printf("%.3f,%.3f,\t", Data[instanceCounter][1],Data[instanceCounter][2]);
		}

		//if it is in the Best distance from the sweet spot, count the frequency of each mitigation (0 and 1) for the best instances
		if (Distance[instanceCounter] <= bestValue)
		{
			if (displayMode == 3) printf("%.3f best from %.3f\n", Distance[instanceCounter],bestValue);
			for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
			{
				//keep track of the mitigation's counts if it is not already fixed
				if (FixedMitigations[mCounter] == -1)
				{
					if (Data[instanceCounter][mCounter+2] == 0)
						tempBestFreqCount[mCounter][0]++;
					else if (Data[instanceCounter][mCounter+2] == 1)
						tempBestFreqCount[mCounter][1]++;
				}
			}
		}
		//else it is in the Rest distance from the sweet spot and so count the frequency of each mitigation (0 and 1) for the rest instances
		else
		{
			if (displayMode == 3) printf("%.3f rest from %.3f\n", Distance[instanceCounter],bestValue);
			for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
			{
				//keep track of the mitigation's counts if it is not already fixed
				if (FixedMitigations[mCounter] == -1)
				{
					if (Data[instanceCounter][mCounter+2] == 0)
						tempRestFreqCount[mCounter][0]++;
					else if (Data[instanceCounter][mCounter+2] == 1)
						tempRestFreqCount[mCounter][1]++;
				}
			}
		}
	}

	float maxScore = -infinity;
	int maxScoredMitigation = 0;
	float maxScoredMitigationStatus = -1;
	float best,rest;

	//normalize each frequency count by dividing it by the total number of instances and score each mitigation using the best^2/(best+rest) and keep min and max
	for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
	{
		//do this only if mitigation is not fixed already
		if (FixedMitigations[mCounter] == -1)
		{
			//find the score of the mitigation when it is off
			best = tempBestFreqCount[mCounter][0]/TotalInstances;
			rest = tempRestFreqCount[mCounter][0]/TotalInstances;

			if (best == 0 && rest == 0)
				tempScoreOff[mCounter] = 0;
			else
				tempScoreOff[mCounter] = pow(best,2)/(best+rest);

			if (displayMode == 3) printf( "m%d with best:%.3f and rest:%.3f\n",mCounter,best,rest);

			//keep its information if it is the max score seen so far
			if (tempScoreOff[mCounter] > maxScore)
			{
				maxScore = tempScoreOff[mCounter];
				maxScoredMitigation = mCounter;
				maxScoredMitigationStatus = 0;
			}

			//find the score of the mitigation when it is on
			best = tempBestFreqCount[mCounter][1]/TotalInstances;
			rest = tempRestFreqCount[mCounter][1]/TotalInstances;

			if (best == 0 && rest == 0)
				tempScoreOn[mCounter] = 0;
			else
				tempScoreOn[mCounter] = pow(best,2)/(best+rest);

			if (displayMode == 3) printf( "m%d with best:%.3f and rest:%.3f\n",mCounter,best,rest);

			//keep its information if it is the max score seen so far
			if (tempScoreOn[mCounter] > maxScore)
			{
				maxScore = tempScoreOn[mCounter];
				maxScoredMitigation = mCounter;
				maxScoredMitigationStatus = 1;
			}
			if (displayMode == 3) printf( "score of m%d 0:%.3f 1:%.3f\n",mCounter,tempScoreOff[mCounter],tempScoreOn[mCounter]);
		}
	}

	if (displayMode == 3) printf( "chosen mitigation is m%d with status %.0f has score %.3f\n",maxScoredMitigation,maxScoredMitigationStatus,maxScore);

	//Set those with the same frequency count in "best"
	for(mCounter=1;mCounter<=TotalMitigations;mCounter++){
		if((mCounter!=maxScoredMitigation)&&(FixedMitigations[mCounter]==-1)){
			if((maxScoredMitigationStatus==1)&&(tempBestFreqCount[mCounter][1]==tempBestFreqCount[maxScoredMitigation][1])){
				FixedMitigations[mCounter]=1;
				//printf("Set mitigation %i with a status of 1",mCounter);
			}
			else if((maxScoredMitigationStatus==0)&&(tempBestFreqCount[mCounter][0]==tempBestFreqCount[maxScoredMitigation][0])){
				FixedMitigations[mCounter]=0;
				//printf("Set mitigation %i with a status of 0",mCounter);
			}
		}
	}

	FixedMitigations[maxScoredMitigation] = maxScoredMitigationStatus;
	recentSetMitigation = maxScoredMitigation;
	recentSetMitigationStatus = maxScoredMitigationStatus;
}

void sweetSpot(float** Data, float* Distance)
{
	if (costFlag == 1)
		MaxCost = costLimit;
	if (attFlag == 1)
		MinAtt = attLimit;

	if (displayMode == 3) printf("MIN and MAX %.3f,%.3f,%.3f,%.3f\n",MinCost,MaxCost,MinAtt,MaxAtt);

	float normalizedCost,normalizedAtt;
	//normalize the att and cost using their
	for (instanceCounter = 1; instanceCounter <= TotalInstances; instanceCounter++)
	{
		normalizedCost = (Data[instanceCounter][1] - MinCost)/(MaxCost - MinCost + small);
		normalizedAtt = (Data[instanceCounter][2] - MinAtt)/(MaxAtt - MinAtt + small);
		Distance[instanceCounter] = pow(pow((normalizedCost - 0),2) + pow((normalizedAtt - 1),2),0.5);
	}
}

int selectValue(int val1, int val2)
{
	double randomValue = (double)rand()/((double)(RAND_MAX)+(double)(1));
	int returnValue;

	if (randomValue < 0.5)
		returnValue = val1;
	else
		returnValue = val2;
	return returnValue;
}

void addInstance(float costVar, float attVar, float** Data)
{
	TotalInstances++;
	Data[TotalInstances][1] = costVar;
	Data[TotalInstances][2] = attVar;

	for (mCounter = 1; mCounter <= TotalMitigations; mCounter++)
		Data[TotalInstances][mCounter+2] = mArray[mCounter];

	if (MinCost > Data[TotalInstances][1])
		MinCost = Data[TotalInstances][1];
	if (MaxCost < Data[TotalInstances][1])
		MaxCost = Data[TotalInstances][1];

	if (MinAtt > Data[TotalInstances][2])
		MinAtt = Data[TotalInstances][2];
	if (MaxAtt < Data[TotalInstances][2])
		MaxAtt = Data[TotalInstances][2];
}

float minValue(float val1, float val2)
{
	if (val1 < val2)
		return val1;
	else
		return val2;
}
