/*
 * This source was translated from a C++ implementation found in
 * the MaART (Music and Audio Retrieval Tools) software package.
 * MaART can be found on sourceforge at http://maart.sourceforge.net
 * 
 * Translated by Andrew Matheny 
 */

import java.util.ArrayList;
import java.lang.Math;
import java.util.Scanner;
import java.lang.Double;
import java.lang.Integer;
import java.lang.String;
import java.text.DecimalFormat;


public class FastMap
{
  //private static ArrayList<ArrayList<Double>> _objects;
  //private static int _k;
  
  public double[][] _X;
  public int[][] _pivotArray;
  public int _k;
  private int _numInstances;
  private boolean _keepClass = true;

  private String _classString = "";  

  private ArrayList<String> _header;

  private ArrayList<String> _classes = new ArrayList<String>();
  
  private int _distanceFn;
  private static final int EUCLIDEAN = 0;
  private static final int SUM = 1;
  private static final int BASIC = 2;
  
  private Scanner in;
  
  private DecimalFormat[] floatFormatters = new DecimalFormat[]{
  new DecimalFormat ("###0"),
  new DecimalFormat ("###0.0"),
  new DecimalFormat ("###0.00"),
  new DecimalFormat ("###0.000"),
  new DecimalFormat ("###0.0000"),
  new DecimalFormat ("###0.00000")
 };
  
  public static void main(String[] args)
  {
    FastMap fm = new FastMap();
    
    if(!fm.processCmdLine(args))
    {
      fm.printHelp();
      return;
    }
    
    
    ArrayList<ArrayList<Double>> objects = fm.loadData();
    
    fm._X = new double[objects.size()][fm._k];
    
    fm._pivotArray = new int[fm._k][2];
    
 
    fm.doMap(fm._k, objects, 0);
    
    fm.printData();
  }
  
 
 private String formatFloatNonSci(double value, int numDecimals) {
  
  if (numDecimals < floatFormatters.length) {
   return floatFormatters[numDecimals].format(value);
  }
  else
  {
   StringBuilder format = new StringBuilder ("#,##0.0");
   
   for (int i=2; i<=numDecimals; i++) {
    format.append('#');
   }
   
   return new DecimalFormat(format.toString()).format(value);
  }
 }
 

  
  public boolean processCmdLine(String[] args)
  {
    boolean ok = false;
    _distanceFn = EUCLIDEAN;
    
    for(int i = 0; i<args.length; i++)
    {
      if(args[i].equalsIgnoreCase("-k"))
      {
        i++;
	try
	{
	    _k = Integer.parseInt(args[i]);
	    ok = true;
	}
	catch(Exception ex)
	{
	    ok = false;
	}
      }
      else if(args[i].equalsIgnoreCase("-d"))
      {
        i++;
        if(args[i].equalsIgnoreCase("euclidean"))
        {
          _distanceFn = EUCLIDEAN;
        }
        else if(args[i].equalsIgnoreCase("sum_of_squares"))
        {
          _distanceFn = SUM;
        }
        else if(args[i].equalsIgnoreCase("basic"))
        {
          _distanceFn = BASIC;
        }
      }
      else if(args[i].equalsIgnoreCase("-nc"))
      {
         _keepClass = false;
      }
      else
      {
	  ok = false;
      }
      
    }
      
      return ok;
  }
    
  public void printHelp()
  {
    System.out.println("\n" +
                       "============ FastMap ============== \n" +
                       "\n" + 
                       "This program accepts an arff file from std:in and reduces its \n" +
                       "dimensionality to k using the fastMap algorithm presented by \n" +
                       "Faloutsos and Lin. \n \n" +
"By default this program will assume the last column is the class, will \n"+
		       "ignore this when computing the reduced space, and append it to the final list\n" + 
"to disable this, give the -nc flag" +
                       "\n" +
                       "\n" +
                       "USAGE: \n" + 
                       "\t -k <intended number of dimensions> \n" +
                       "\t -d <distance function> (euclidean <default>, basic, sum_of_squares, cosine)\n" + 
		       "\t -nc Assume no class is given and compute reduced space on all fields");
  }
  
  public void printData()
  {
    for(String s:_header)
    {
      System.out.println(s);
    }
    char ch = 'a';
    for(int i = 0; i<_k; i++)
    {
      System.out.println("@attribute " + ch++ + " real");
    }
    if(_keepClass)
    {
	System.out.println(_classString);
    }
    System.out.println("@data");
   
    for(int i = 0; i < _numInstances; i++)
    {
      for(int j = 0; j < _k; j++)
      {
	  System.out.print(formatFloatNonSci(_X[i][j], 4));
	  //System.out.print(_X[i][j]);
        if (j +1 < _k) System.out.print(", ");
      }
      if(_keepClass)
      {
	  System.out.print(", " + _classes.get(i));
      }
      System.out.println();
    }
  }
  
  public ArrayList<ArrayList<Double>> loadData()
  {
    ArrayList<ArrayList<Double>> objects = new ArrayList<ArrayList<Double>>();
    _header = new ArrayList<String>();
    
    in = new Scanner(System.in);
    String line = "";
    boolean foundData = false;
    _numInstances = 0;
    String priorLine = "";
    while(in.hasNextLine())
    {
      line = in.nextLine();
      
      if(line.equalsIgnoreCase(""))
	  continue;

      if(!foundData)
      {
        if(line.equalsIgnoreCase("@data"))
        {
          foundData = true;
	  _classString = priorLine;
        }
        else if(!line.startsWith("@attribute"))
        {
          _header.add(line);
        }
      }
      else if(!line.equalsIgnoreCase(""))
      {
        String[] sData = line.split(",");
        ArrayList<Double> instance = new ArrayList<Double>();
        boolean add = true;
	int stopPoint = sData.length;
	if(_keepClass)
	{
	    stopPoint = sData.length-1;
	}
	    
        for(int i = 0; i < stopPoint; i++)
        {
          if(!sData[i].equalsIgnoreCase("?") && !sData[i].equalsIgnoreCase("NaN"))
            instance.add(i, Double.parseDouble(sData[i]));        
          else
          {
            add = false;
            break;
          }
        }
        
        if(add)
        {
          objects.add(_numInstances, instance);
	  if(_keepClass)
	  {
	     _classes.add(_numInstances, sData[sData.length-1]);
	  }
	  _numInstances++;
        }
      }
      priorLine = line;
    }
    
    return objects;
  }
  
  public void doMap(int k, ArrayList<ArrayList<Double>> objects, int column)
  {
    if(k == 0)
      return;
    
    int[] a = new int[1];
    int[] b = new int[1];
    chooseDistantObjects(objects, a, b, column);
    
    _pivotArray[k-1][0] = a[0];
    _pivotArray[k-1][1] = b[0];
    
    if ( fmDist(objects, a[0], b[0], column) == 0.0)
    {
      // set X[ i, col#] =0 for every i and return
      // since all inter-object distances are 0

      for (int i=0; i<objects.size(); i++)
        for (int n=0; n<k; n++)
          _X[i][column+n] = 0.0;

      return;
    }
    double dab = fmDist(objects, a[0], b[0], column);
    
    for (int i=0; i<objects.size(); i++)
    {
      double dai = fmDist(objects, a[0], i, column);
      double dbi = fmDist(objects, b[0], i, column);
      
      _X[i][column] = ((dai*dai) + (dab*dab) - (dbi*dbi)) / (dab * 2.0);
    }

    // 6) Consider the projections of the objects on a hyper-plane
    //    perpendicular to the line (Oa, Ob); the distance function D()
    //    between two projections is given by Eq. 4
    doMap(k-1, objects, column+1);
  }
  
  private void chooseDistantObjects(ArrayList<ArrayList<Double>> objects, int[] a, int[] b, int column)
  {
    /** The number of iterations to find the most distant objects */
    int num_iterations = 5;

    //
    // Choose arbitrarily an object, and let it be the second pivot object Ob
    //
    b[0] = 0; // Start with the first object, to avoid randomness, apart from anything else

    double last_distance = 0.0;  // A note of the distance this iteration has to beat
    for (int iteration = 0; iteration < num_iterations; iteration++)
    {
      //
      // let Oa = (the object that is farthest apart from Ob) (according to the distance function dist())
      //
      a[0] = b[0];
      double max_distance = 0.0;
      for (int n=0; n<objects.size(); n++)
      {
       double distance = fmDist(objects, b[0], n, column);
       if (distance > max_distance)
       {
        a[0] = n;
        max_distance = distance;
       }
      }


      //
      // let Ob = (the object that is farthest apart from 0a)
      //
      b[0] = a[0];
      max_distance = 0.0;
      for (int n=0; n<objects.size(); n++)
      {
       double distance = fmDist(objects, a[0], n, column);
       if (distance > max_distance)
       {
        b[0] = n;
        max_distance = distance;
       }
      }


      //
      // Ensure each iteration is increasing the distance, stop if it isn't.
      // If this happens, it is probably due to the same two objects being
      // selected each time round the loop.
      //
      if (max_distance > last_distance)
        last_distance = max_distance;
      else
      {
        break;
      }
    }
  }
  
  private double fmDist(ArrayList<ArrayList<Double>> objects, int a, int b, int column)
  {
    if (column == 0)
    {
      switch(_distanceFn)
      {
        case EUCLIDEAN:
          return euclideanDistance(objects.get(a), objects.get(b));
        case BASIC:
          return basicDistance(objects.get(a), objects.get(b));
        case SUM:
          return sumOfSquaresDistance(objects.get(a), objects.get(b));
        default:
          return euclideanDistance(objects.get(a), objects.get(b));
      }
    }
    else
    {
      // p must be > 0, so xa and xb must be in the p-1 column
      double d=fmDist(objects, a, b, column-1);
      double xa=_X[a][column-1];
      double xb=_X[b][column-1];
      return Math.sqrt( (d*d) - ((xa-xb)*(xa-xb)) );
    }
  }
  
  private double euclideanDistance(ArrayList<Double> a, ArrayList<Double> b)
  {
    double diff = 0;
    
    for (int n=0; n<a.size(); n++)
      diff += (a.get(n) - b.get(n)) * (a.get(n) - b.get(n));

    return Math.sqrt(diff);
  }
  
  private double sumOfSquaresDistance(ArrayList<Double> a, ArrayList<Double> b)
  {
    double diff = 0;

    for (int n=0; n<a.size(); n++)
      diff += (a.get(n) - b.get(n)) * (a.get(n) - b.get(n));

    return diff;
  }
  
  private double basicDistance(ArrayList<Double> a, ArrayList<Double> b)
  {
    double diff = 0;

    for (int n=0; n<a.size(); n++)
      diff += Math.abs(a.get(n) - b.get(n));

    return diff / a.size();
  }
  
}
