#! /usr/bin/env python
# -*- coding: utf-8 -*-


"""
Extract tagged names and their context information from named entity
recognizer output.

Usage: extract-tagged-names.py [options] [input file ...] > output

For more information, run extract-tagged-names.py --help

Author: Jyrki Niemi (jyrki.niemi@helsinki.fi) 2013
"""


import sys
import codecs
import re

from optparse import OptionParser


class NameInfo(object):

    def __init__(self, line=None, linenr=None, offset=None, stage=None,
                 type=None, name=None, rest=None):
        if line is not None:
            self._split_line(line)
        else:
            self._linenr = linenr
            self._offset = offset
            self._stage = stage
            self._type = type
            self._name = name
            self._rest = rest

    def _split_line(self, line):
        fields = line[:-1].split('\t')
        self._linenr, self._offset = self._split_pos_field(fields[0])
        fieldnr = 1
        if fields[fieldnr][0] in '+-':
            self._stage = fields[1]
            fieldnr += 1
        else:
            self._stage = None
        self._type = fields[fieldnr]
        self._name = fields[fieldnr + 1]
        self._rest = fields[fieldnr + 2:]

    def _split_pos_field(self, pos):
        mo = re.match(r'\s*(\d+)\s+(\d+)', pos)
        if mo:
            return (int(mo.group(1)), int(mo.group(2)))
        else:
            return (None, None)

    def is_removed(self):
        return self._stage and self._stage[0] == '-'

    def precedes(self, other):
        return (other and (self._linenr < other._linenr
                           or (self._linenr == other._linenr
                               and self._offset < other._offset)))

    def has_same_pos(self, other):
        return (other and self._linenr == other._linenr
                and self._offset == other._offset)

    def equals(self, other):
        return (other and self.has_same_pos(other)
                and self._type == other._type and self._name == other._name)

    def remove_at_stage(self, stage_name):
        self._stage = '-' + stage_name + ' ' + self._stage

    def add_at_stage(self, stage_name):
        self._stage = '+' + stage_name

    def extend_rest(self, items):
        if self._rest is None:
            self._rest = []
        self._rest.extend(items)

    def format(self, hide_name=False):
        result = [u'{linenr:6d} {offset:4d}'.format(linenr=self._linenr,
                                                    offset=self._offset)]
        if self._stage:
            result.append(self._stage)
        result.append(self._type)
        if not hide_name:
            result.append(self._name)
        result.extend(self._rest)
        return '\t'.join(result) + '\n'


class NameExtractor(object):

    _tagged_name_res = {
        'attributes':
            r"""(?x)
               (?P<pre>\S*)
               <(?P<tag>(?:ENA|TI|NU)MEX)\s+
                TYPE="(?P<type>[^\"]+)"\s+
                SBT="(?P<sbt>[^\"]+)"
                (?:\s+ANI="(?P<ani>[^\"]+)")?
               \s*>
               (?P<name>[^<]+)
               </(?:ENA|TI|NU)MEX>
               (?P<post>\S*)""",
        'plaintags':
            r"""(?xi)
               (?P<pre>\S*)
               <(?P<tag>(?:Ena|Ti|Nu)mex)
                (?P<combtype>(?:[a-z]{3})+)
               >
               (?P<name>[^<]+)
               </(?:Ena|Ti|Nu)mex[^>]*>
               (?P<post>\S*)"""
        }

    def __init__(self, opts):
        self._opts = opts
        self._tagged_re = re.compile(self._tagged_name_res[self._opts.tags],
                                     re.UNICODE)
        if self._opts.show_contiguous_words:
            self._format_name = self._format_name_contiguous
        else:
            self._format_name = self._format_name_bare
        self._type_fields = ['tag'] if self._opts.show_tag_names else []
        self._type_fields.extend(['type', 'sbt', 'ani'])

    def extract_names(self, fnames):
        self._init_read_previous_stage()
        for fname in fnames:
            if fname == '-':
                self._extract_from_file(sys.stdin)
            else:
                with codecs.open(fname, encoding='utf-8') as f:
                    self._extract_from_file(f)

    def _init_read_previous_stage(self):
        self._previous_stage_nameinfo = None
        self._read_previous_stage = bool(self._opts.previous_stage_output)
        if self._opts.previous_stage_output:
            self._previous_stage_reader = self._read_names_from_file(
                self._opts.previous_stage_output)
            self._read_next_previous_stage_nameinfo()
            
    def _read_next_previous_stage_nameinfo(self):
        if self._read_previous_stage:
            try:
                self._previous_stage_nameinfo = \
                    next(self._previous_stage_reader)
            except StopIteration:
                self._read_previous_stage = False
                self._previous_stage_nameinfo = None

    def _extract_from_file(self, file_):
        linenr = 1
        for line in file_:
            tag_char_count = 0
            for match in self._tagged_re.finditer(line):
                tag_char_count = self._process_match(match, line, linenr,
                                                     tag_char_count)
            linenr += 1

    def _process_match(self, match, line, linenr, tag_char_count):
        nameinfo, tag_char_count = self._extract_match_info(match, linenr,
                                                            tag_char_count)
        if self._read_previous_stage:
            self._compare_with_previous_stage_nameinfo(nameinfo)
        if nameinfo.equals(self._previous_stage_nameinfo):
            self._write_nameinfo(self._previous_stage_nameinfo)
            self._read_next_previous_stage_nameinfo()
        else:
            if self._opts.stage_name:
                nameinfo.add_at_stage(self._opts.stage_name)
            self._nameinfo_add_extra(nameinfo, match, line)
            self._write_nameinfo(nameinfo)
        return tag_char_count

    def _nameinfo_add_extra(self, nameinfo, match, line):
        if self._opts.show_lines:
            rest = [line]
        elif self._opts.show_context_tokens:
            rest = [self._format_context(match, line)]
        else:
            rest = []
        nameinfo.extend_rest(rest)

    def _extract_match_info(self, match, linenr, tag_char_count):
        match_pos = self._get_match_pos(match, tag_char_count)
        match_type = self._format_type(match)
        match_name = self._format_name(match)
        nameinfo = NameInfo(linenr=linenr, offset=match_pos, type=match_type,
                             name=match_name)
        tag_char_count += ((match.start('name') - match.end('pre'))
                           + match.start('post') - match.end('name'))
        return (nameinfo, tag_char_count)
        
    def _compare_with_previous_stage_nameinfo(self, nameinfo):
        while (self._previous_stage_nameinfo.precedes(nameinfo)
               or (self._previous_stage_nameinfo.has_same_pos(nameinfo)
                   and not self._previous_stage_nameinfo.equals(nameinfo))):
            if not self._previous_stage_nameinfo.is_removed():
                self._previous_stage_nameinfo.remove_at_stage(
                    self._opts.stage_name)
            self._write_nameinfo(self._previous_stage_nameinfo)
            self._read_next_previous_stage_nameinfo()

    def _write_nameinfo(self, nameinfo):
        sys.stdout.write(nameinfo.format(hide_name=self._opts.hide_names))

    def _get_match_pos(self, match, tag_char_count):
        return match.end('pre') - tag_char_count + 1

    def _format_type(self, match):
        type_field_vals = match.groupdict('')
        if 'combtype' in type_field_vals:
            self._split_combined_type(type_field_vals)
        return ' '.join([type_field_vals.get(type_, '').upper()
                         for type_ in self._type_fields]).strip()

    def _split_combined_type(self, type_field_vals):
        combtype = type_field_vals.get('combtype', '')
        parts = [combtype[start:start+3]
                 for start in xrange(0, len(combtype), 3)]
        if len(parts) == 3:
            type_field_vals.update({
                'type': parts[0], 'sbt': parts[1], 'ani': parts[2]})
        else:
            for type_field, startnr in [('type', 0), ('sbt', 1)]:
                type_field_vals[type_field] = '/'.join(
                    parts[i] for i in xrange(startnr, len(parts), 2))

    def _format_name_bare(self, match):
        return match.group('name')

    def _format_name_contiguous(self, match):
        return (match.group('pre') + self._opts.match_markers[0]
                + match.group('name') + self._opts.match_markers[1]
                + match.group('post'))

    def _format_context(self, match, line):
        return (self._find_preceding_tokens(line, match.start('pre'))
                + self._format_name_contiguous(match)
                + self._find_following_tokens(line, match.end('post'))).strip()

    def _find_preceding_tokens(self, line, endpos):
        return self._find_adjacent_tokens(line, endpos - 1, 0, -1, '>', '<')

    def _find_following_tokens(self, line, startpos):
        return self._find_adjacent_tokens(
            line, startpos, len(line) - 1, 1, '<', '>')

    def _find_adjacent_tokens(self, line, startpos, endpos, step, tagbegin,
                              tagend):
        if startpos < 0 or startpos > len(line) - 1:
            return ''
        pos = startpos
        tokencount = 0
        intag = False
        while pos != endpos and tokencount < self._opts.show_context_tokens:
            pos += step
            if intag:
                if line[pos] == tagend:
                    intag = False
            elif line[pos] == tagbegin:
                intag = True
            elif line[pos].isspace() and not line[pos - step].isspace():
                tokencount += 1
        if step < 0:
            result = line[pos:startpos+1]
        else:
            result = line[startpos:pos]
        if self._opts.remove_context_tags:
            result = re.sub(r'<.*?>', '', result)
        return result

    def _read_names_from_file(self, fname):
        with codecs.open(fname, 'r', encoding='utf-8') as f:
            for line in f:
                yield NameInfo(line=line)
        raise StopIteration()



def getopts():
    usage = """%prog: [options] [input file ...] > output
Extract tagged names from named entity recognizer output

Input files contain named entities tagged with XML elements. The elements may
specify the type of the name in the tag name (Pmatch-style; --tags=plaintags)
or in tag name and attributes (--tags=attributes).

The output lists the names in the input along with their position and type.
The default output fields are line number, offset, name type, name. Offset
is the offset of the name in the original input, disregarding name tags.
An additional field may contain the whole input line (--show-lines) or tokens
preceding and following the name (--show-context-tokens)."""
    optparser = OptionParser(usage=usage)
    optparser.add_option(
        '--tags', type='choice',
        choices=['attributes', 'plaintags'], default='attributes',
        help=('assume input tags of TYPE where TYPE is "attributes" or'
              ' "plaintags" (default: %default)'), metavar='TYPE')
    optparser.add_option(
        '--show-contiguous-words', action='store_true',
        help=('show words immediately preceding or following tagged names,'
              ' without intervening spaces'))
    optparser.add_option(
        '--show-lines', action='store_true',
        help='show an extra field with whole input lines')
    optparser.add_option(
        '--show-tag-names', action='store_true',
        help='show main tag names in addition to type and subtype')
    optparser.add_option(
        '--show-context-tokens', type='int',
        help=('show an extra field with NUM tokens preceding and following the'
              ' marked name'), metavar='NUM')
    optparser.add_option(
        '--hide-names', action='store_true',
        help='do not show names, only the positions and possibly lines')
    optparser.add_option(
        '--remove-context-tags', action='store_true',
        help='remove tags around context tokens')
    optparser.add_option(
        '--match-markers', default='<<,>>',
        help=('precede marked names with the string LEFT, follow with RIGHT'
              ' (default: %default)'), metavar='LEFT,RIGHT')
    optparser.add_option(
        '--stage-name',
        help=('Add NAME as the stage name of for the names added at this'
              ' stage, compared with the names in the file specified with'
              ' --previous-stage-output'), metavar='NAME')
    optparser.add_option(
        '--previous-stage-output', '--previous-stage-file',
        help=('Compare the names extracted with the previous stage output'
              ' in FILE'), metavar='FILE')
    (opts, args) = optparser.parse_args()
    opts.match_markers = re.split(r'[\s,]+', opts.match_markers, 2)
    if len(opts.match_markers) == 1:
        opts.match_markers = 2 * opts.match_markers[0]
    
    return (opts, args)


def main():
    input_encoding = output_encoding = 'utf-8'
    (opts, args) = getopts()
    sys.stdin = codecs.getreader(input_encoding)(sys.stdin)
    sys.stdout = codecs.getwriter(output_encoding)(sys.stdout)
    sys.stderr = codecs.getwriter(output_encoding)(sys.stderr)
    extractor = NameExtractor(opts)
    extractor.extract_names(args or ['-'])


if __name__ == "__main__":
    main()
