package feedback;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.FilterIndexReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.queryParser.QueryParser;
import org.apache.lucene.search.HitCollector;
import org.apache.lucene.search.Hits;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryTermVector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Searcher;
import org.apache.lucene.search.TopDocCollector;
import org.apache.lucene.search.Similarity;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermFreqVector;

/**
 *
 * @author greg
 */
public class Main {
    
    public static ArrayList<Float> scores = new ArrayList<Float>();

    public static void main(String[] args) throws Exception {
        String usage =
                "Usage:\tjava org.apache.lucene.demo.SearchFiles [-index dir] [-field f] [-repeat n] [-queries file] [-raw] [-norms field] [-paging hitsPerPage]";
        usage += "\n\tSpecify 'false' for hitsPerPage to use streaming instead of paging search.";
        if (args.length > 0 && ("-h".equals(args[0]) || "-help".equals(args[0]))) {
            System.out.println(usage);
            System.exit(0);
        }

        String index = "index";
        String field = "contents";
        String queries = null;
        boolean raw = false;
        String normsField = null;
        boolean paging = true;
        int hitsPerPage = 10;
        String relevant = "";
        String irrelevant = "";
        String skipped = "";
        int numRel = 0;
        int numIr = 0;
        float a=1;
        float b=1;
        float c=1;
        int maxResults=100;
        IndexReader reader = IndexReader.open(index);
        int numDocs= reader.numDocs();

        for (int i = 0; i < args.length; i++) {
            if ("-index".equals(args[i])) {
                index = args[i + 1];
                i++;
            } else if ("-a".equals(args[i])){
                a=Float.parseFloat(args[i+1]);
                i++;
            } else if ("-b".equals(args[i])){
                b=Float.parseFloat(args[i+1]);
                i++;
            } else if ("-c".equals(args[i])){
                c=Float.parseFloat(args[i+1]);
                i++;
            }else if("-results".equals(args[i])) {
                maxResults = Integer.parseInt(args[i+1]);
                i++;
            }else if ("-field".equals(args[i])) {
                field = args[i + 1];
                i++;
            } else if ("-queries".equals(args[i])) { //Doesn't work right now!
                queries = args[i + 1];
                i++;
            } else if ("-raw".equals(args[i])) {
                raw = true;
            } else if ("-numDocs".equals(args[i])) {
                numDocs = Integer.parseInt(args[i + 1]);
                i++;
            } else if ("-norms".equals(args[i])) {
                normsField = args[i + 1];
                i++;
            } else if ("-paging".equals(args[i])) {
                if (args[i + 1].equals("false")) {
                    paging = false;
                } else {
                    hitsPerPage = Integer.parseInt(args[i + 1]);
                    if (hitsPerPage == 0) {
                        paging = false;
                    }
                }
                i++;
            }
        }

        int[] rel = new int[numDocs + 1];
        int[] ir = new int[numDocs + 1];
        int[] skip = new int[numDocs + 1];

        //intialize rating arrays
        for (int i = 0; i <= numDocs; i++) {
            rel[i] = 0;
            ir[i] = 0;
            skip[i]=0;
            scores.add((float) 0.0);
        }

        if (normsField != null) {
            reader = new OneNormsReader(reader, normsField);
        }

        Searcher searcher = new IndexSearcher(reader);
        Analyzer analyzer = new StandardAnalyzer();

        BufferedReader in = null;
        if (queries != null) {
            in = new BufferedReader(new FileReader(queries));
        } else {
            in = new BufferedReader(new InputStreamReader(System.in, "UTF-8"));
        }
        QueryParser parser = new QueryParser(field, analyzer);

        if (queries == null) // prompt the user
        {
            System.out.println("Please enter your initial query: ");
        }
        String line = in.readLine();
        line = line.trim();

        //Parse query.
        Query query = parser.parse(line);
        QueryTermVector queryTermVector = new QueryTermVector( line, analyzer );
        String[] queryTerms = queryTermVector.getTerms();
        int[] queryFreqs = queryTermVector.getTermFrequencies();
    
        ArrayList<String> newTerms = new ArrayList<String>();
        ArrayList<Integer> newFreqs = new ArrayList<Integer>();

        
        //Remove terms with df of 0 (aka: aren't in the corpus)
        for(int i=0;i<queryTerms.length;i++){
            int df = reader.docFreq(new Term("contents",queryTerms[i]));
            if(df>0){
                newTerms.add(queryTerms[i]);
                newFreqs.add(queryFreqs[i]);
            }
        }
        String revisedQuery="";
        for(int j=0;j<newTerms.size();j++){
            for(int k=0;k<newFreqs.get(j);k++){
                revisedQuery=revisedQuery+newTerms.get(j)+" ";
            }
        }

        query = parser.parse(revisedQuery);
        queryTermVector = new QueryTermVector( revisedQuery, analyzer );
        System.out.println("Searching for: " + query.toString(field));

        System.out.println("\nHere are your initial results:");
        if (paging) {
            doPagingSearch(in, searcher, query, hitsPerPage, raw, queries == null);
        } else {
            doStreamingSearch(searcher, query);
            sortAndDisplay(numDocs, reader, maxResults, rel, ir, skip);
        }

        Boolean firstRound = true;
        String command = "";
        String rate = "";
        int docNum = 0;
        String rating = "";

        //Feedback loop
        do {

            if (firstRound == true) {
                System.out.println("\nThis is your first round. There is a total of " + numDocs + " documents.\nLet's rate a few of these results to get started.");
                command = "R";
                firstRound = false;
            } else {
                System.out.println("\nThis is a new round. Type \"R\" to rate additional documents or \"Q\" to quit.");
                command = in.readLine();
            }

            if (!command.equals("Q")) {
                if (command.equals("R")) {
                    //Rating loop
                    do {
                        relevant = "";
                        irrelevant = "";
                        skipped="";
                        System.out.println("Please list the document number you would like to rate, followed by an \"R\" for relevant or \"I\" for irrelevant. Type \"Q\" to quit rating.");
                        rate = in.readLine();
                        if (!rate.equals("Q")) {
                            docNum = Integer.parseInt(rate.substring(0, rate.indexOf(" ")));
                            rating = rate.substring(rate.indexOf(" ") + 1);
                        } else {
                            System.out.println("Done rating");
                            for (int j = 0; j <= numDocs; j++) {
                                if (rel[j] == 1) {
                                    relevant = relevant + " " + j;
                                }
                                if (ir[j] == 1) {
                                    irrelevant = irrelevant + " " + j;
                                }
                                if (skip[j] == 1) {
                                    skipped = skipped + " " + j;
                                }
                            }
                            System.out.println("Relevant Documents:" + relevant);
                            System.out.println("Irrelevant Documents:" + irrelevant);
                            System.out.println("Skipped Documents:"+ skipped);
                            
                            //Build new query
                            String newQuery= buildNewQuery(queryTermVector,reader,rel,ir,a,b,c,numDocs,numRel,numIr);
                            System.out.println("Your new query is:\n"+newQuery);
                            newQuery = newQuery.trim();

                            query = parser.parse(newQuery);
                            queryTermVector = new QueryTermVector( newQuery, analyzer );
        
                            System.out.println("\nHere are your new results:");
                            if (paging) {
                                doPagingSearch(in, searcher, query, hitsPerPage, raw, queries == null);
                            } else {
                                doStreamingSearch(searcher, query);
                                sortAndDisplay(numDocs, reader,maxResults, rel, ir, skip);
                            }
                            
                        }


                        if ((docNum > numDocs) || (docNum < 0)) {
                            System.out.println("This document does not exist");
                        } else if (rating.equals("I")) {
                            if (ir[docNum] == 0) {
                                numIr++;
                                ir[docNum] = 1;
                                rel[docNum] = 0;
                                skip[docNum]=0;
                            //  System.out.println(docNum+" is irrelevant -"+ir[docNum]);
                            }
                        } else if (rating.equals("R")) {
                            if (rel[docNum] == 0) {
                                numRel++;
                                rel[docNum] = 1;
                                ir[docNum] = 0;
                                skip[docNum]=0;
                            //   System.out.println(docNum+" is relevant -"+rel[docNum]);
                            }
                        } else if (rating.equals("S")){
                            if(skip[docNum] ==0){
                                skip[docNum]=1;
                                rel[docNum]=0;
                                ir[docNum]=0;
                            }
                        }

                    } while (!rate.equals("Q"));
                }
            }
        } while (!command.equals("Q"));


        reader.close();
        in.close();
    }

    /** Use the norms from one field for all fields.  Norms are read into memory,
     * using a byte of memory per document per searched field.  This can cause
     * search of large collections with a large number of fields to run out of
     * memory.  If all of the fields contain only a single token, then the norms
     * are all identical, then single norm vector may be shared. */
    private static class OneNormsReader extends FilterIndexReader {

        private String field;

        public OneNormsReader(IndexReader in, String field) {
            super(in);
            this.field = field;
        }

        public byte[] norms(String field) throws IOException {
            return in.norms(this.field);
        }
    }

    public static void sortAndDisplay(int numDocs, IndexReader reader, int maxResults, int[] rel, int[] ir, int[] skip) throws Exception{
            ArrayList<Float> copy = (ArrayList<Float>) scores.clone();
            ArrayList<Integer> cantBe = new ArrayList<Integer>();
            float max;
            int maxDoc;
            int numDisplayed=0;
            
            //Filter out previously rated documents
            for(int i=0;i<=numDocs;i++){
                if(rel[i]==1||ir[i]==1||skip[i]==1){
                    cantBe.add(i);
                }
            }
            
            if(maxResults>1000){
                PrintWriter outFile = new PrintWriter("results.temp");
                while((cantBe.size()<=numDocs)&&(numDisplayed<=maxResults)){
                    maxDoc=-1;
                    max=0;
                    for(int i=0;i<numDocs;i++){
                        if((copy.get(i)>max)&&(cantBe.indexOf(i)==-1)){
                            max=copy.get(i);
                            maxDoc=i;
                        }     
                    }
                    if(maxDoc!=-1){
                        Document doc = reader.document(maxDoc);
                        outFile.println("Document #"+maxDoc+" "+doc.get("title")+", Score: "+max); 
                        numDisplayed++;
                    }
                    cantBe.add(maxDoc);
                }
                outFile.close();
            }
            else{
                while((cantBe.size()<=numDocs)&&(numDisplayed<=maxResults)){
                    maxDoc=-1;
                    max=0;
                    for(int i=0;i<numDocs;i++){
                        if((copy.get(i)>max)&&(cantBe.indexOf(i)==-1)){
                            max=copy.get(i);
                            maxDoc=i;
                        }     
                    }
                    if(maxDoc!=-1){
                        Document doc = reader.document(maxDoc);
                        System.out.println("Document #"+maxDoc+" "+doc.get("title")+", Score: "+max); 
                        numDisplayed++;
                    }
                    cantBe.add(maxDoc);
                }
            }
    }
    
    /*
     * This method takes your old query and builds a new one based on your liked and disliked documents.
     * A,B, and C are user-supplied constants to represent weights. Defaults are 1. 
     */
     
    public static String buildNewQuery(QueryTermVector oldQuery, IndexReader reader, int[] rel, int[] ir, float a, float b, float c, int numDocs, int numLiked, int numDisliked) throws Exception{
        
        String[] queryTerms = oldQuery.getTerms();
        int[] queryFreqs = oldQuery.getTermFrequencies();
    
        ArrayList<String> newTerms = new ArrayList<String>();
        ArrayList<Float> newFreqs = new ArrayList<Float>();

        
        //Load in terms/freqs from original query
        for(int i=0;i<queryTerms.length;i++){
            int df = reader.docFreq(new Term("contents",queryTerms[i]));
            if(df==0)
                df=1;
            newTerms.add(queryTerms[i]);
            //newFreqs.add(a*queryFreqs[i]*(1/df)+1);
            newFreqs.add(a*queryFreqs[i]);
           
           // System.out.println(a*queryFreqs[i]);
        }

        
        //Add terms from liked documents
        if(numLiked>0){
            for(int j=0;j<numDocs;j++){
                if(rel[j]==1){
                    //Grab term list/freqs
                    TermFreqVector test = reader.getTermFreqVector(j,"contents");
                    String[] docTerms = test.getTerms();
                    int[] freqs = test.getTermFrequencies();
                    
                    //Go through list of terms
                    for(int k=0;k<docTerms.length;k++){
                        //See if term is already in list. Add it if not. Adjust freq if so.
                        if(newTerms.indexOf(docTerms[k])==-1){
                            newTerms.add(docTerms[k]);
                            int df = reader.docFreq(new Term("contents",docTerms[k]));
                            if(df==0)
                                df=1;
                            //newFreqs.add(((b/numLiked)*(freqs[k]*(1/df))+1));
                            newFreqs.add((b/numLiked)*freqs[k]);
                            System.out.println("new term: "+newFreqs.get(k));
                        } else{
                            int index = newTerms.indexOf(docTerms[k]);
                            float numToAdd = newFreqs.get(index);
                            int df = reader.docFreq(new Term("contents",docTerms[k]));
                            if(df==0)
                                df=1;
                            //numToAdd = numToAdd + ((b/numLiked)*(freqs[k]*(1/df))+1);
                            numToAdd = numToAdd + ((b/numLiked)*freqs[k]);
                            newFreqs.set(index,numToAdd);
                            System.out.println("old term: "+newFreqs.get(k));
                        }
                    }
                    
                }
            }
        }
        
        //Remove terms from disliked documents
        if(numDisliked>0){
            for(int l=0;l<numDocs;l++){
                if(ir[l]==1){
                    //Grab term list/freqs
                    TermFreqVector test = reader.getTermFreqVector(l,"contents");
                    String[] docTerms = test.getTerms();
                    int[] freqs = test.getTermFrequencies();
                    
                    //Go through list of terms
                    for(int m=0;m<docTerms.length;m++){
                        //See if term is already in list. Adjust freq if so.
                        if(newTerms.indexOf(docTerms[m])!=-1){
                            int index = newTerms.indexOf(docTerms[m]);
                            float numToRm = newFreqs.get(index);
                            int df = reader.docFreq(new Term("contents",docTerms[m]));
                            if(df==0)
                                df=1;
                           // numToRm = numToRm - ((c/numDisliked)*(freqs[m]*(1/df))+1);
                            numToRm = numToRm - ((c/numDisliked)*freqs[m]);
                            newFreqs.set(index,numToRm);
                            System.out.println("Removing "+numToRm);
                        }
                    }
                    
                }
            }
        }
        
        //Build new Query
        String newQuery="";
        
        for(int n=0;n<newTerms.size();n++){
            String term = newTerms.get(n);

                float freq = newFreqs.get(n);
                if(freq>0){
                    //Only add the term if it appears in fewer than 1/3rd of documents.
                    int df = reader.docFreq(new Term("contents",term));
                    if(df==0)
                        df=1;
                    if(df<(numDocs/4)){
                        for(int o=1;o<=(int) freq;o++){
                            newQuery=newQuery+" "+term;
                        }
                    }
                }
            
        }
        
        return newQuery;
    }
    /**
     * This method uses a custom HitCollector implementation which simply prints out
     * the docId and score of every matching document. 
     * 
     *  This simulates the streaming search use case, where all hits are supposed to
     *  be processed, regardless of their relevance.
     */
    public static void doStreamingSearch(final Searcher searcher, Query query) throws IOException {
        HitCollector streamingHitCollector = new HitCollector() {
            // simply print docId and score of every matching document
            public void collect(int doc, float score) {
                scores.set(doc,score);
                //System.out.println("doc=" + doc + " score=" + score);
                
            }
        };

        searcher.search(query, streamingHitCollector);
    }

    /**
     * This demonstrates a typical paging search scenario, where the search engine presents 
     * pages of size n to the user. The user can then go to the next page if interested in
     * the next hits.
     * 
     * When the query is executed for the first time, then only enough results are collected
     * to fill 5 result pages. If the user wants to page beyond this limit, then the query
     * is executed another time and all hits are collected.
     * 
     */
    public static void doPagingSearch(BufferedReader in, Searcher searcher, Query query,
            int hitsPerPage, boolean raw, boolean interactive) throws IOException {

        // Collect enough docs to show 5 pages
        TopDocCollector collector = new TopDocCollector(5 * hitsPerPage);
        searcher.search(query, collector);
        ScoreDoc[] hits = collector.topDocs().scoreDocs;

        int numTotalHits = collector.getTotalHits();
        System.out.println(numTotalHits + " total matching documents");

        int start = 0;
        int end = Math.min(numTotalHits, hitsPerPage);

        while (true) {
            if (end > hits.length) {
                System.out.println("Only results 1 - " + hits.length + " of " + numTotalHits + " total matching documents collected.");
                System.out.println("Collect more (y/n) ?");
                String line = in.readLine();
                if (line.length() == 0 || line.charAt(0) == 'n') {
                    break;
                }

                collector = new TopDocCollector(numTotalHits);
                searcher.search(query, collector);
                hits = collector.topDocs().scoreDocs;
            }

            end = Math.min(hits.length, start + hitsPerPage);

            for (int i = start; i < end; i++) {
                if (raw) {                              // output raw format
                    System.out.println("doc=" + hits[i].doc + " score=" + hits[i].score);
                    continue;
                }

                Document doc = searcher.doc(hits[i].doc);
                String path = doc.get("path");
                if (path != null) {
                    System.out.println((i + 1) + ". " + path);
                    String title = doc.get("title");
                    if (title != null) {
                        System.out.println("   Title: " + doc.get("title"));
                    }
                } else {
                    System.out.println((i + 1) + ". " + "No path for this document");
                }

            }

            if (!interactive) {
                break;
            }

            if (numTotalHits >= end) {
                boolean quit = false;
                while (true) {
                    System.out.print("Press ");
                    if (start - hitsPerPage >= 0) {
                        System.out.print("(p)revious page, ");
                    }
                    if (start + hitsPerPage < numTotalHits) {
                        System.out.print("(n)ext page, ");
                    }
                    System.out.println("(q)uit or enter number to jump to a page.");

                    String line = in.readLine();
                    if (line.length() == 0 || line.charAt(0) == 'q') {
                        quit = true;
                        break;
                    }
                    if (line.charAt(0) == 'p') {
                        start = Math.max(0, start - hitsPerPage);
                        break;
                    } else if (line.charAt(0) == 'n') {
                        if (start + hitsPerPage < numTotalHits) {
                            start += hitsPerPage;
                        }
                        break;
                    } else {
                        int page = Integer.parseInt(line);
                        if ((page - 1) * hitsPerPage < numTotalHits) {
                            start = (page - 1) * hitsPerPage;
                            break;
                        } else {
                            System.out.println("No such page");
                        }
                    }
                }
                if (quit) {
                    break;
                }
                end = Math.min(numTotalHits, start + hitsPerPage);
            }

        }

    }
}
