package Learners; import java.io.File; import java.io.FileNotFoundException; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Scanner; import java.util.Vector; import Exceptions.UnsupportedFileTypeException; import Experiment.Dataset; import Experiment.ResultsTable.Doublet; public class kNearestNeighbors implements Learner { private int numNeighbors; private String name; public kNearestNeighbors(int numNeighbors) { this.numNeighbors = numNeighbors; this.name = "kNN" + numNeighbors; } @Override public void learn(String trainingData, String testingData, String output) throws IOException, UnsupportedFileTypeException { File f = new File(trainingData); Scanner reader = new Scanner(f); ArrayList training = new ArrayList(); Vector tempD = new Vector(); String[] temp; Double[] tempDA; while (reader.hasNext()) { temp = reader.next().split(","); for (int i =0; i < temp.length; i++) { tempD.add(Double.valueOf(temp[i])); } if (temp.length > 0) { tempDA = new Double[tempD.size()]; tempD.copyInto(tempDA); training.add(tempDA); } } reader.close(); f = new File(testingData); reader = new Scanner(f); ArrayList testing = new ArrayList(); while (reader.hasNext()) { temp = reader.next().split(","); for (int i =0; i < temp.length; i++) { tempD.add(Double.valueOf(temp[i])); } if (temp.length > 0) { tempDA = new Double[tempD.size()]; tempD.copyInto(tempDA); testing.add(tempDA); } } f = new File(output); FileWriter fw = new FileWriter(f); ArrayList nearestNeighbors; Double estimateSum; Double estimateNumber; for (Double[] dA : testing) { nearestNeighbors = new ArrayList(); for (Double[] dA2 : training) { if (nearestNeighbors.size() < numNeighbors) { nearestNeighbors.add(new Doublet(distance(dA,dA2), dA2[dA2.length - 1])); Collections.sort(nearestNeighbors); } else if (nearestNeighbors.get(numNeighbors - 1).one > distance(dA,dA2)) { nearestNeighbors.set(numNeighbors - 1, new Doublet(distance(dA,dA2), dA2[dA2.length - 1])); Collections.sort(nearestNeighbors); } } estimateSum = 0.0; estimateNumber = 0.0; for (Doublet dt : nearestNeighbors) { estimateSum += dt.two; estimateNumber++; } fw.write(dA[dA.length -1] + "," + estimateSum/estimateNumber); } fw.close(); } public double distance(Double[] one, Double[] two) { double distance = 0.0; int max = one.length - 1; if (one.length > two.length) max = two.length - 1; for (int i = 0; i < max; i++) { distance += (one[i] - two[i])*(one[i] - two[i]); } return distance; } @Override public void setSettings(String settings) { // TODO Auto-generated method stub } @Override public String getSettings() { // TODO Auto-generated method stub return null; } @Override public String getName() { return name; } @Override public double estimate(String trainingData, String testingData) throws FileNotFoundException { File f = new File(trainingData); Scanner reader = new Scanner(f); ArrayList training = new ArrayList(); Vector tempD = new Vector(); String[] temp; Double[] tempDA; while (reader.hasNext()) { temp = reader.next().split(","); for (int i =0; i < temp.length; i++) { tempD.add(Double.valueOf(temp[i])); } if (temp.length > 0) { tempDA = new Double[tempD.size()]; tempD.copyInto(tempDA); tempD = new Vector(); training.add(tempDA); } } reader.close(); f = new File(testingData); reader = new Scanner(f); ArrayList testing = new ArrayList(); while (reader.hasNext()) { temp = reader.next().split(","); for (int i =0; i < temp.length; i++) { tempD.add(Double.valueOf(temp[i])); } if (temp.length > 0) { tempDA = new Double[tempD.size()]; tempD.copyInto(tempDA); tempD = new Vector(); testing.add(tempDA); } } ArrayList nearestNeighbors; Double estimateSum; Double estimateNumber; for (Double[] dA : testing) { nearestNeighbors = new ArrayList(); for (Double[] dA2 : training) { if (nearestNeighbors.size() < numNeighbors) { nearestNeighbors.add(new Doublet(distance(dA,dA2), dA2[dA2.length - 1])); Collections.sort(nearestNeighbors); } else if (nearestNeighbors.get(numNeighbors - 1).one > distance(dA,dA2)) { nearestNeighbors.set(numNeighbors - 1, new Doublet(distance(dA,dA2), dA2[dA2.length - 1])); Collections.sort(nearestNeighbors); } } estimateSum = 0.0; estimateNumber = 0.0; for (Doublet dt : nearestNeighbors) { estimateSum += dt.two; estimateNumber++; } return estimateSum / estimateNumber; } return 0.0; } public class Doublet implements Comparable { private double one; private double two; public Doublet(double one, double two) { this.one = one; this.two = two; } @Override public int compareTo(Object o) { return Double.compare(this.one, ((Doublet) o).one); } public String getOne() { return Double.toString(one);} public String getTwo() { return Double.toString(two);} } }