/* File: IVV.java
 * Author: Jason Gookins, converted from an awk script written by Omid Jalali
 * Description: This class does the actual processing of the networks.
 */

import java.awt.Color;
import java.awt.Cursor;
import java.awt.Dimension;
import java.util.ArrayList;
import java.util.Random;
import javax.swing.SwingWorker;

public class IVV extends SwingWorker<Void, Void>
{
	/**********************
	 ** Global Constants **
	 **********************/

	private static final float infinity = 2147483647f;



	/************************
	 ** Instance Variables **
	 ************************/

	private ArrayList<Node> nodes;
	private int totalNodes;
	private int totalBins;
	private int totalRuns;
	private ArrayList<ArrayList<Node>> subGraphs;
	private ArrayList<Float> minScore;
	private ArrayList<Boolean> processed;
	private ArrayList<Boolean> processedFinal;
	private ArrayList<ArrayList<Float>> bestBin;
	private BeliefListener beliefListener;
	private boolean minimizeChange;



	/**********************
	 ** Main Constructor **
	 **********************/

	public IVV(ArrayList<Node> nodes, int totalBins, int totalRuns, boolean minimizeChange, BeliefListener beliefListener)
	{
		this.nodes = nodes;
		this.totalNodes = nodes.size();
		this.totalBins = totalBins;
		this.totalRuns = totalRuns;
		this.minimizeChange = minimizeChange;
		this.subGraphs = new ArrayList<ArrayList<Node>>();
		this.minScore = new ArrayList<Float>();
		this.processed = new ArrayList<Boolean>();
		this.processedFinal = new ArrayList<Boolean>();
		this.bestBin = new ArrayList<ArrayList<Float>>();
		this.beliefListener = beliefListener;
	}



	/*************************
	 ** SwingWorker Methods **
	 *************************/

	public Void doInBackground()
	{
		setProgress(0);
		processNetwork(0);

		return null;
	}

	public void done()
	{
		beliefListener.postProcessing(totalRuns, minimizeChange);
	}



	/**********************
	 ** Instance Methods **
	 **********************/

	public void processNetwork(int progress)
	{
		initializeLists();

		normalizeGoalDists();

		//find all of the subgraphs based on root nodes (nodes that are not the children of any other node)
		findSubGraphs();

		//sort the roots based on their priority
		sortSubGraphs();

		for (ArrayList<Node> subGraph : subGraphs)
		{
			Node highestPriorityNode = findMaxPriorityNode(subGraph);

			for (int i = 0; i < totalBins; i++)
			{
				for (int j = 0; j < totalNodes; j++)
				{
					if (!processedFinal.get(j))
					{
						bestBin.get(j).set(i, 0.0f);
					}
				}

				minScore.set(i, infinity);
			}

			for (int run = 0; run < totalRuns; run++)
			{
				//initialize (possibly old values) for a new run
				int nodesProcessed = 0;

				for (int i = 0; i < totalNodes; i++)
				{
					processed.set(i, false);
				}

				//process the leaves of the subgraph first
				for (Node node : subGraph)
				{
					if (node.getChildren().size() == 0 && !processed.get(node.getIndex()))
					{
						if (!processedFinal.get(node.getIndex()))
						{
							processLeafNode(node, run, highestPriorityNode);
						}

						processed.set(node.getIndex(), true);

						nodesProcessed++;
					}
				}

				//process the rest (all internal nodes including the root) for the subgraph
				while (nodesProcessed < subGraph.size())
				{
					for (Node node : subGraph)
					{
						if (!processed.get(node.getIndex()))
						{
							int processedChildren = 0;

							for (Node childNode : node.getChildren())
							{
								if (processed.get(childNode.getIndex()))
								{
									processedChildren++;
								}
							}

							//if all of its children are processed, process it then
							if (processedChildren == node.getChildren().size())
							{
								if (!processedFinal.get(node.getIndex()))
								{
									processInternalNode(node, subGraph);
								}

								processed.set(node.getIndex(), true);

								nodesProcessed++;
							}
						}
					}
				}

				score(subGraph);

				progress = (int)(((float)(run + (totalRuns * (subGraphs.indexOf(subGraph)))) / (float)(totalRuns * subGraphs.size())) * 100);
				setProgress(progress);
			}

			//After each root is completely processed, save the current nodes' information since the shared nodes can only be calculated once (when seen first)
			for (Node node : subGraph)
			{
				processedFinal.set(node.getIndex(), true);

				node.setNormalizedResultDistribution(bestBin.get(node.getIndex()));
			}
		}

		calculateNodeChange();

		denormalizeGivenResultDists();

		setProgress(100);
	}

	private void initializeLists()
	{
		for (Node node : nodes)
		{
			ArrayList<Float> bestBinList = new ArrayList<Float>();
			ArrayList<Float> resultDistribution = new ArrayList<Float>();

			for (int j = 0; j < totalBins; j++)
			{
				bestBinList.add(0.0f);
				resultDistribution.add(0.0f);

				if (node.getIndex() == 0)
				{
					minScore.add(0.0f);
				}
			}

			bestBin.add(bestBinList);
			node.setNormalizedResultDistribution(resultDistribution);
			processed.add(false);
			processedFinal.add(false);
		}
	}

	private void normalizeGoalDists()
	{
		//Keep the goal dist normalized. This is not needed for given dist since it is normalized as needed.
		for (Node node : nodes)
		{
			if (node.getHasGoal() == 1)
			{
				ArrayList<Float> tempDistList = new ArrayList<Float>();

				for (int i = 0; i < totalBins; i++)
				{
					tempDistList.add((float)node.getGoalDistribution().get(i));
				}

				normalizeDist(tempDistList);

				node.setNormalizedGoalDistribution(tempDistList);
			}
		}
	}

	private void normalizeDist(ArrayList<Float> inputDist)
	{
		float sum = 0;

		for (Float bin : inputDist)
		{
			sum += bin;
		}

		if (sum != 0)
		{
			for (int i = 0; i < totalBins; i++)
			{
				inputDist.set(i, inputDist.get(i) / sum);
			}
		}
	}

	private void findSubGraphs()
	{
		//find all the roots as the ones that are not a child of any other node
		for (Node node : nodes)
		{
			if (node.getParents().size() == 0)
			{
				ArrayList<Node> subGraph = new ArrayList<Node>();

				subGraph.add(node);
				subGraphs.add(subGraph);
			}
		}

		for (ArrayList<Node> subGraph : subGraphs)
		{
			findAllChildren(subGraph.get(0), subGraph);
		}
	}

	private void findAllChildren(Node root, ArrayList<Node> subGraph)
	{
		for (Node childNode : root.getChildren())
		{
			boolean alreadyExists = false;

			for (Node node : subGraph)
			{
				if (node.equals(childNode))
				{
					alreadyExists = true;

					break;
				}
			}

			if(!alreadyExists)
			{
				subGraph.add(childNode);

				if (childNode.getChildren().size() > 0)
				{
					findAllChildren(childNode, subGraph);
				}
			}
		}
	}

	private void sortSubGraphs()
	{
		for (int i = 1; i < subGraphs.size(); i++)
		{
			ArrayList<Node> value = subGraphs.get(i);
			int j = i - 1;

			while (j >= 0 && subGraphs.get(j).get(0).getPriority() < value.get(0).getPriority())
			{
				subGraphs.set(j + 1, subGraphs.get(j));

				j = j - 1;
			}

			subGraphs.set(j + 1, value);
		}
	}

	private Node findMaxPriorityNode(ArrayList<Node> subGraph)
	{
		int tempMax = 0;
		Node tempNode = null;

		for (Node node : subGraph)
		{
			if (node.getPriority() >= tempMax)
			{
				tempMax = node.getPriority();
				tempNode = node;
			}
		}

		if (tempNode.getPriority() == 0 && tempNode.getIndex() != 0)
		{
			tempNode = subGraph.get(0);
		}

		return tempNode;
	}

	private void processLeafNode(Node leafNode, int run, Node highestPriorityNode)
	{
		ArrayList<Float> tempDistList = new ArrayList<Float>();

		if (leafNode.getHasGiven() == 1 && (leafNode.getGivenDistMutability().equals("fixed") || run == 0))
		{
			for (int i = 0; i < totalBins; i++)
			{
				tempDistList.add((float)leafNode.getGivenDistribution().get(i));
			}
		}
		else
		{
			tempDistList = generateDist(leafNode, run, highestPriorityNode);
		}

		normalizeDist(tempDistList);

		leafNode.setNormalizedResultDistribution(tempDistList);
	}

	private ArrayList<Float> generateDist(Node leafNode, int run, Node highestPriorityNode)
	{
		ArrayList<Float> leafNodeDistribution = new ArrayList<Float>();

		for (int i = 0; i < totalBins; i++)
		{
			Random rng = new Random();

			if (rng.nextFloat() > 0.5 && run > 0)
			{
				leafNodeDistribution.add(bestBin.get(leafNode.getIndex()).get(i));
			}
			else
			{
				float bound = 0.0f;

				if (highestPriorityNode.getHasGoal() == 1)
				{
					bound = highestPriorityNode.getNormalizedGoalDistribution().get(i) * 2;
				}

				leafNodeDistribution.add(rng.nextFloat() * bound);
			}
		}

		return leafNodeDistribution;
	}

	private void processInternalNode(Node internalNode, ArrayList<Node> subGraph)
	{
		ArrayList<Float> tempMinCount = new ArrayList<Float>();
		ArrayList<Float> tempMaxCount = new ArrayList<Float>();
		ArrayList<Float> tempDistList = new ArrayList<Float>();
		ArrayList<Node> children = new ArrayList<Node>();

		//Having a flag for internal nodes is meaningless since we have to be able to use distributions from child nodes.
		for (int i = 0; i < totalBins; i++)
		{
			tempMinCount.add(infinity);
			tempMaxCount.add(-infinity);
			tempDistList.add(0.0f);
		}

		for (Node childNode : internalNode.getChildren())
		{
			children.add(childNode);

			for (int i = 0; i < totalBins; i++)
			{
				if (tempMinCount.get(i) > childNode.getNormalizedResultDistribution().get(i))
				{
					tempMinCount.set(i, childNode.getNormalizedResultDistribution().get(i));
				}

				if (tempMaxCount.get(i) < childNode.getNormalizedResultDistribution().get(i))
				{
					tempMaxCount.set(i, childNode.getNormalizedResultDistribution().get(i));
				}
			}
		}

		if (internalNode.getOperation().equals("null"))
		{
			Node childNode = children.get(0);

			for (int i = 0; i < totalBins; i++)
			{
				tempDistList.set(i, childNode.getNormalizedResultDistribution().get(i));
			}
		}
		else if (internalNode.getOperation().equals("not"))
		{
			Node childNode = children.get(0);

			for (int i = 0; i < totalBins; i++)
			{
				tempDistList.set(i, 1.0f - childNode.getNormalizedResultDistribution().get(i));
			}
		}
		else if (internalNode.getOperation().equals("and"))
		{
			for (int i = 0; i < totalBins; i++)
			{
				tempDistList.set(i, tempMinCount.get(i));
			}
		}
		else if (internalNode.getOperation().equals("or"))
		{
			for (int i = 0; i < totalBins; i++)
			{
				tempDistList.set(i, tempMaxCount.get(i));
			}
		}

		normalizeDist(tempDistList);

		internalNode.setNormalizedResultDistribution(tempDistList);
	}

	private void score(ArrayList<Node> subGraph)
	{
		ArrayList<Float> tempScore = new ArrayList<Float>();

		for (int i = 0; i < totalBins; i++)
		{
			tempScore.add(0.0f);
		}

		for (Node node : subGraph)
		{
			if (node.getHasGoal() == 1)
			{
				for (int i = 0; i < totalBins; i++)
				{
					float tempValue = node.getNormalizedGoalDistribution().get(i) - node.getNormalizedResultDistribution().get(i);

					if (tempValue < 0)
					{
						tempValue *= -1;
					}

					tempScore.set(i, tempScore.get(i) + tempValue * node.getNormalizedGoalDistribution().get(i));
				}
			}
		}

		for (int i = 0; i < totalBins; i++)
		{
			if (tempScore.get(i) < minScore.get(i))
			{
				minScore.set(i, tempScore.get(i));

				for (Node node : subGraph)
				{
					bestBin.get(node.getIndex()).set(i, node.getNormalizedResultDistribution().get(i));
				}
			}
		}

		//normalize the bestBin
		for (Node node : subGraph)
		{
			ArrayList<Float> tempDistList = bestBin.get(node.getIndex());

			normalizeDist(tempDistList);

			bestBin.set(node.getIndex(), tempDistList);
		}
	}

	private void calculateNodeChange()
	{
		for (Node node: nodes)
		{
			if (node.getHasGoal() == 1)
			{
				float change = 0.0f;
				float goal = 0.0f;

				for (int i = 0; i < totalBins; i++)
				{
					change += Math.abs(node.getNormalizedGoalDistribution().get(i) - node.getNormalizedResultDistribution().get(i));
					goal += node.getNormalizedGoalDistribution().get(i);
				}

				change /= goal;

				node.setChange((int)(change * 100));
				node.setColor(new Color(0.0f + change, 1.0f - change, 0.0f));
			}
			else if (node.getHasGiven() == 1)
			{
				float change = 0.0f;
				float given = 0.0f;
				ArrayList<Float> tempGivenDist = new ArrayList<Float>();

				for (Integer bin : node.getGivenDistribution())
				{
					tempGivenDist.add((float)bin);
				}

				normalizeDist(tempGivenDist);

				for (int i = 0; i < totalBins; i++)
				{
					change += Math.abs(tempGivenDist.get(i) - node.getNormalizedResultDistribution().get(i));
					given += tempGivenDist.get(i);
				}

				change /= given;

				node.setChange((int)(change * 100));
			}
		}
	}

	private void denormalizeGivenResultDists()
	{
		for (Node node : nodes)
		{
			if (node.getHasGiven() == 1 && node.getGivenDistMutability().equals("opt"))
			{
				ArrayList<Integer> resultDistribution = new ArrayList<Integer>();
				int sum = 0;

				for (Integer bin : node.getGivenDistribution())
				{
					sum += bin;
				}

				for (Float bin : node.getNormalizedResultDistribution())
				{
					resultDistribution.add((int)((bin * sum) + 0.5f));
				}

				node.setResultDistribution(resultDistribution);
			}
		}
	}
}
