#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Tatu Huovilainen
email: tatu.huovilainen@helsinki.fi
"""

#Init
import os
import pandas as pd
import argparse


script_dir = os.path.dirname(os.path.abspath(__file__))
tables_dir = script_dir + '/tables'

#Functions
def check_inputfile(parser, arg):
    if not os.path.exists(arg):
        parser.error('The file {} does not exist!'.format(arg))
    else:
        print('Reading input from file: {}'.format(arg))
        return open(arg, 'r')  # return an open file handle
    
def check_outputfile(parser, arg):
    try:
        return open(arg,'w+')
        
    except:
        parser.error('Can\'t write to file: {}'.format(arg))
        


def parseCorpString(corp_string):
    print('Checking which corpora to consider')
    all_set = set(['S24', 'KLK', 'WIKI', 'OPENSUB', 'REDDIT', 'LEHDET'])
    corp_string = corp_string.strip('," ')
    if corp_string == 'ALL':
        print('Using all corpora available')
        return ('ALL')
    else:
        corpora = set(corp_string.split(','))
        unused = all_set - set(corpora)

        #Check if corp_string contained unknown entries
        if (corpora - all_set):
            raise ValueError('Unknown corpus/corpora in input: {}'.format(corpora - all_set))
        
        #If input set is all the available
        if not unused:
            print('Using all corpora available')
            return ['S24', 'KLK', 'WIKI', 'OPENSUB', 'REDDIT', 'LEHDET']
        
        #If input set is not all_set print what's left out
        if unused:
            print('Using: {}\tLeaving out: {}'.format(', '.join(corpora),', '.join(unused)))
            return corpora

def parsePOSString(pos_string):
    print('Checking which POS classes to consider')
    all_set = set(['N', 'V', 'A', 'Adp', 'Adv', 'C', 'Interj', 'Num', 'Pron'])
    pos_string = pos_string.strip('," ')
    if pos_string == 'ALL':
        print('Using all POS classes available')
        return ('ALL')
    if pos_string == 'IGNORE':
        print('Ignoring POS class information. Different POS class instances of written forms will be collapsed')
        return('IGNORE')
    else:
        pos_classes = set(pos_string.split(','))
        unused = all_set - set(pos_classes)
        
        #Check if pos_string contained unknown entries
        if (pos_classes - all_set):
            raise ValueError('Unknown POS classes in input: {}'.format(pos_classes - all_set))
        
        #If input set is all the available
        if not unused:
            print('Using all POS classes available')
            return ['N' ,'V' ,'A', 'Adp', 'Adv', 'C', 'Interj', 'Num', 'Pron']
        
        #If input set is not all_set print what's left out
        if unused:
            print('Using: {}\tLeaving out: {}'.format(', '.join(pos_classes),', '.join(unused)))
            return pos_classes

def update_dataset(daframe,corpora,morpho,pos_classes):
    #Update main dataframe to reflect only chosen corpora
    if corpora != 'ALL':
    
        cols_to_keep = [['ABS_FREQ_' + c,'RELATIVE_FREQ_' + c] for c in corpora] #Create column names to keep
        cols_to_keep = [morpho] + ['POS_CLASS'] + [item for sublist in cols_to_keep for item in sublist] # Sort them to be consistent (abs-rel-abs-rel...)
        
        #Drop columns of unwanted corpora
        daframe.drop([x for x in daframe.columns.tolist() if not x in cols_to_keep],axis=1,inplace=True)
        
        #Catch still present value columns
        abs_val_cols = [col for col in daframe if col.startswith('ABS')]
        rel_val_cols = [col for col in daframe if col.startswith('REL')]
        
        #Drop rows (words) that are all zero in the remaining frame
        daframe = daframe.drop(daframe[daframe[abs_val_cols].sum(axis=1) == 0].index)
        
        #Update totals to reflect only the remaining corpora
        daframe.loc[:,'TOTAL_ABS'] = daframe[abs_val_cols].sum(axis=1)
        total_wpm = daframe['TOTAL_ABS'].sum() / 1000000
        daframe.loc[:,'TOTAL_RELATIVE'] = daframe[abs_val_cols].sum(axis=1) / total_wpm
        daframe.loc[:,'AVERAGE_RELATIVE'] = daframe[rel_val_cols].sum(axis=1)/len(rel_val_cols)
        

    
     #Update main dataframe to reflect only chosen POS classes
    if pos_classes != 'ALL':
        #If "IGNORE" remove POS_CLASS and collapse
        if pos_classes == 'IGNORE':
            daframe.drop(columns=['POS_CLASS'],inplace=True)
            daframe = daframe.groupby([morpho],as_index=False).sum()
        #Else drop all but rows where POS_CLASS is in pos_string
        else:
            daframe = daframe[daframe['POS_CLASS'].isin(pos_classes)]
    
    #Round to keep "accuracy" reasonable
    daframe = daframe.round(4)
    return daframe

def add_freqs(inset,daframe,morpho):
    #Find frequencies. Mess because of multiple or empty line returns, works.
    inset = inset[morpho].tolist()
    matching_rows = []
    for word in inset:
        rows = daframe[daframe[morpho] == word]
        if rows.empty:
            rows = pd.DataFrame(columns=daframe.columns)
            rows.loc[1,morpho] = word
            matching_rows.append(rows)
        else:
            matching_rows.append(rows)
    #Merge to one frame
    inset = pd.concat(matching_rows)
    inset.drop([col for col in inset if col.startswith('ABS')] + ['TOTAL_ABS'],axis=1,inplace=True)
    return inset

def add_ngram_freqs(inset,morpho):
    #List letter ngrams for word
    def substrings(word,xgram):
        word_len = len(word)
        substrings = [word[i:(i + xgram)] for i in range(word_len - xgram + 1)]
        return substrings
    
    def ave(value_list):
        #Don't want to import numpy just for this
        return sum(value_list)/len(value_list)
    
    #Read in 2gram and 3gram frequency tables
    twogram_freqs = pd.read_csv(tables_dir + '/letter2gram_freqs.csv',index_col=0).to_dict(orient='dict')['AVERAGE_RELATIVE_FREQ']
    threegram_freqs = pd.read_csv(tables_dir + '/letter3gram_freqs.csv',index_col=0).to_dict(orient='dict')['AVERAGE_RELATIVE_FREQ']
    #Init frame
    inset['2GRAM_AVG_FREQ'] = 0
    inset['3GRAM_AVG_FREQ'] = 0
    #Calculate average letter ngram freqs
    inset.loc[:,'2GRAM_AVG_FREQ'] = inset[morpho].apply(lambda word: ave([twogram_freqs[subs] for subs in substrings(word,2)]))
    inset.loc[:,'3GRAM_AVG_FREQ'] = inset[morpho].apply(lambda word: ave([threegram_freqs[subs] for subs in substrings(word,3)]))
    return inset

def add_syll_freqs(inset,morpho):
    #List letter ngrams for word
    import re

    def ave(value_list):
        #Don't want to import numpy just for this
        return sum(value_list)/len(value_list)
    
    def getSyllFreqs(word,syll_freqs):
        sylls = re.split(r'\.|\-',fs.syllabify(word)[0])
        freqs = []
        for syll in sylls:
            try:
                freqs.append(syll_freqs[syll])
            except KeyError:
                freqs.append(0)
                
        return freqs
    
    #Read in syllable frequency table
    syll_freqs = pd.read_csv(tables_dir + '/syll_freqs.csv',index_col=0).to_dict(orient='dict')['AVERAGE_RELATIVE_FREQ']
    
    #Find syllables
    inset['SYLLABLES'] = ''
    inset.loc[:,'SYLLABLES'] = inset[morpho].apply(lambda word: re.split(r'\.|\-',fs.syllabify(word)[0]))
    
    #Count syllables
    inset['SYLLABLES_N'] = 0
    inset.loc[:,'SYLLABLES_N'] = inset['SYLLABLES'].apply(lambda sylls: len(sylls))
    
    #Find syllable freqs
    inset['SYLLABLE_FREQS'] = 0
    inset.loc[:,'SYLLABLE_FREQS'] = inset[morpho].apply(lambda word: getSyllFreqs(word,syll_freqs))
    
    #Calculate average syllable freq
    inset['SYLLABLE_AVG_FREQ'] = 0
    inset.loc[:,'SYLLABLE_AVG_FREQ'] = inset[morpho].apply(lambda word: ave(getSyllFreqs(word,syll_freqs)))
    
    return inset

def findNeigh(inset,whole_data,wpm_lim,morpho):
    #return words of same length
    def by_size(words, size):
        return [word for word in words if len(word) == size]
    #find ortographic neighbours
    def neighbours(word,same_len_words):
        for j in range(len(word)):
            for d in finnish_alphabet:
                word1 = ''.join(d if i==j else c for i,c in enumerate(word))
                if word1 != word and word1 in same_len_words: yield word1

    finnish_alphabet = 'abcdefghijklmnopqrsštuvwxyzžåäö'
    candidate_words = whole_data[whole_data['AVERAGE_RELATIVE'] >= wpm_lim][morpho].astype(str).tolist()
    
    #Find neighbours for all words
    inset['NEIGHBOURS'] = ''
    inset.loc[:,'NEIGHBOURS'] = inset[morpho].apply(lambda word: list(neighbours(word,by_size(candidate_words,len(word)))))
    
    #Count neighbours
    inset['NEIGHBOURS_N'] = 0
    inset.loc[:,'NEIGHBOURS_N'] = inset['NEIGHBOURS'].apply(lambda l: len(l))
    return inset

#Main
if __name__ == '__main__':
    #Parse input arguments
    parser = argparse.ArgumentParser(description='Gather descrpitives for list of lemmas or words',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    required_args = parser.add_argument_group('Required arguments')
    required_args.add_argument('-i', dest='inputfile', required=True, \
                        help='input file. Either a filename in psychlingdesc folder or full path. Format is one word per line', \
                        type=lambda x: check_inputfile(parser, x))
    
    required_args.add_argument('-o', dest='outputfile', required=True, \
                        help='output file. Either a filename in psychlingdesc folder or full path.', \
                        type=lambda x: check_outputfile(parser, x))
    
    required_args.add_argument('-t', dest='morpho', required=True, type=str, \
                        choices=['LEMMA','WORD'], help='input word type, descriptives for lemmas or words')
        
    optional_args = parser.add_argument_group('Optional outputs')
    optional_args.add_argument('-c', dest='corpora', type=str, default='ALL', \
                        help='which corpora to consider. Can be any combination of "S24,KLK,WIKI,OPENSUB,REDDIT,LEHDET"')
    
    optional_args.add_argument('-pc', dest='pos_class', type=str, default='ALL', \
                        help='restrict search to given part-of-speech classes, or ignore part-of-speech class completely and collapse all identical written forms. Can be any combination of "N,V,A,Adp,Adv,C,Interj,Num,Pron" or "IGNORE"')
    
    optional_args.add_argument('-f', dest='freqs', default=None, action='store_true', \
                        help='find word/lemma frequencies')    

    optional_args.add_argument('-on', dest='wpm_lim', type=float, nargs='?', metavar='LIMIT', default=None, \
                        help='find ortographic neighbours, LIMIT sets lower bound above which to consider candidate words (can be 0.0). Recommended 3.0 or above as lower freqency items even after filtering contain many non-words.')

    optional_args.add_argument('-ng', dest='ngrams', default=None, action='store_true', \
                        help='calculate average letter 2-gram and 3-gram frequencies')
    
    optional_args.add_argument('-sy', dest='syll', default=None, action='store_true', \
                        help='calculate average syllable frequencies')
    
    args = parser.parse_args()
    
    #Load tables
    print('Loading tables to memory')
    inset = pd.read_csv(args.inputfile,header=None)
    inset.columns = [args.morpho]
    dataset = pd.read_csv(tables_dir + '/{}_freqs.csv'.format(args.morpho.lower()))
    
    #Determine wanted tables & pos classes
    corpora = parseCorpString(args.corpora)
    pos_classes = parsePOSString(args.pos_class)
    
    #Update dataset to reflect only wanted tables
    dataset = update_dataset(dataset,corpora,args.morpho,pos_classes)
    

    #Find frequencies and turn input wordset to pandas frame
    if args.freqs != None:
        print('Adding {} frequency information'.format(args.morpho.lower()))
        inset = add_freqs(inset,dataset,args.morpho)
    
    #Find letter ngram information
    if args.ngrams != None:
        print('Adding letter n-gram information')
        inset = add_ngram_freqs(inset,args.morpho)
    
    #Find syllable information
    if args.syll != None:
        from finnsyll import FinnSyll
        fs = FinnSyll()
        print('Adding syllable information')
        inset = add_syll_freqs(inset,args.morpho)
    
    #Find ortographic neighbours
    if args.wpm_lim != None:
        print('Adding ortographic neighbours information. Only considering candidate words with relative frequency above {}'.format(args.wpm_lim))
        inset = findNeigh(inset,dataset,args.wpm_lim,args.morpho)
    
    #Save to outfile
    print('Saving output to {}'.format(args.outputfile.name))
    inset.to_csv(args.outputfile, index=False, na_rep='NaN')
    
