#!/usr/bin/env python3

# Copyright © 2010, 2012 marmuta <marmvta@gmail.com>
#
# This file is part of Onboard.
#
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys
import random
from optparse import OptionParser

import matplotlib.pyplot as plt

import pypredict
from pypredict import (timeit, read_corpus, tokenize_text, split_sentences,
                       DynamicModelKN, CachedDynamicModel,
                       overlay, linint, loglinint)


class ModelConfig:
    def __init__(self):
        self.entropies = []
        self.perplexities = []
        self.ksrs = []


def main():
    parser = OptionParser(usage=  # noqa: flake8
"""Usage: %prog [options] merging <training text 1> <training text 2> <testing text>
           <training text 1> trains the base language model
           <training text 2> trains the second language model
           <testing text> text for entropy and ksr calculations

       %prog [options] caching <base model> <testing text>
           <base model>   the static base language mode
           <testing text> text for entropy and ksr calculations,
                          incrementally trains the second language model""")  # noqa: flake8
    parser.add_option(
        "-o", "--order", type="int", dest="order", default=3,
        help="order of the language model, defaults to  %default")
    options, args = parser.parse_args()
    if len(args) < 1:
        parser.print_usage()
        sys.exit(1)

    order = options.order
    command = args[0].lower()

    if command == "merging":
        if len(args) < 4:
            parser.print_usage()
            sys.exit(1)
        analyze_merging(args[1], args[2], args[3], order)
    elif command == "caching":
        if len(args) < 3:
            parser.print_usage()
            sys.exit(1)
        analyze_caching(args[1], args[2], order)
    else:
        print("unknown command, exiting")
        sys.exit(1)


def analyze_merging(training1, training2, testing, order):
    models = []

    for i, filename in enumerate([training1, training2]):
        with timeit("creating model %d from '%s'" % (i + 1, filename)):
            text = read_corpus(filename)
            tokens, spans = tokenize_text(text)
            # if i == 0:
            #     vocabulary = extract_vocabulary(tokens, 2, 10000)
            #    tokens = filter_tokens(tokens, vocabulary)
            model = DynamicModelKN(order)
            model.learn_tokens(tokens)
            models.append(model)

    # print("base vocabulary has %d words" % len(vocabulary))

    filename = testing
    with timeit("tokenizing '%s'" % (filename,)):
        testing_text = read_corpus(filename)
        testing_sentences, spans = split_sentences(testing_text)
        testing_tokens, spans = tokenize_text(testing_text)

    # model parameters
    weights_series = [[] for m in models]
    smoothings = ["witten-bell", "abs-disc", "kneser-ney"]
    merges = ["overlay", "linint", "loglinint"]

    # markers = ["o", "v", "s", "*", "+"]
    markers = ["o", "^", "*"]
    colors = ["r", "g", "b", "c", "m", "k"]

    configurations = []
    for smoothing in smoothings:
        for merge in merges:
            config = ModelConfig()
            config.smoothing = smoothing
            config.merge = merge
            config.label = "%s, %s" % (smoothing, merge)
            config.marker = markers[len(configurations) % len(markers)]
            config.color = colors[len(configurations) // len(markers)]
            config.style = config.color + config.marker
            configurations.append(config)

    iterations = 20
    for iteration in range(iterations):

        # get new weights
        # monte carlo
        weights = [random.random() for m in models]
        wsum = sum(weights)
        weights = [w / wsum  for w in weights]

        # evenly spread
        weights = [iteration / float(iterations - 1),
                   1.0 - iteration / float(iterations - 1)]

        # remember weights for plotting
        for i, ws in enumerate(weights_series):
            ws.append(weights[i])

        for i, config in enumerate(configurations):
            print("iteration %d/%d, config %d/%d" %
                  (iteration + 1, iterations, i + 1, len(configurations)))

            # setup model to analyze
            for m in models:
                m.smoothing = config.smoothing

            if config.merge == "overlay":
                model = overlay(models)
            elif config.merge == "linint":
                model = linint(models, weights)
            elif config.merge == "loglinint":
                model = loglinint(models, weights)

            # get statistics
            entropy, perplexity = random.random(), random.random()
            entropy, perplexity = pypredict.entropy(model, testing_tokens,
                                                    order)
            ksr = random.random() * 100
            ksr = pypredict.ksr(
                model, None, testing_sentences, 10,
                lambda i, n, c, p: sys.stdout.write("%d/%d\n" % (i + 1, n)))

            config.entropies.append(entropy)
            config.perplexities.append(perplexity)
            config.ksrs.append(ksr)

            print("entropy=%10f, perplexity=%10f, ksr=%6.2f, weights=" %
                  (entropy, perplexity, ksr), weights)

        # plot
        plt.ion()  # interactive mode on

        plt.clf()
        plt.figure(1)  # figsize=(1,1)

        plt.subplot(211)
        for config in configurations:
            plt.plot(weights_series[0], config.entropies, config.style + "-",
                     label=config.label)
        plt.xlim(0, 1)
        plt.xlabel('base model weight')
        plt.ylabel('entropy [bit/word]')

        plt.subplot(212)
        lines = []
        labels = []
        for config in configurations:
            line = plt.plot(weights_series[0], config.ksrs, config.style + "-",
                            label=config.label)
            lines.append(line)
            labels.append(config.label)
        plt.xlim(0, 1)
        plt.xlabel('base model weight')
        plt.ylabel('ksr [%]')

        # plt.gcf().suptitle('Smoothing & Interpolation', fontsize=16)
        plt.figlegend(lines, labels, 'upper right')  # 'upper right'
        # plt.subplots_adjust(top=0.92, right=0.74, left=0.06)
        plt.subplots_adjust(left=0.07, top=0.99, right=0.72, bottom=0.08)
        plt.draw()

    plt.show()  # blocks; allows for interaction, saving images


def analyze_caching(base_model, testing, order):
    models = []

    with timeit("loading base model '%s'" % (base_model,)):
        model = DynamicModelKN(order)
        # model.load(base_model)
        models.append(model)

    model = CachedDynamicModel(order)
    models.append(model)

    filename = testing
    with timeit("tokenizing '%s'" % (filename,)):
        testing_text = read_corpus(filename)
        testing_sentences, spans = split_sentences(testing_text)
        testing_tokens, spans = tokenize_text(testing_text)

    # model parameters
    smoothings = ["abs-disc"]
    # merges = ["overlay", "linint"]
    merges = ["overlay"]
    # recency_smoothings = ["jelinek-mercer", "witten-bell"]
    recency_smoothings = ["jelinek-mercer"]

    # markers = ["o", "v", "s", "*", "+"]
    markers = ["o", "^", "*"]
    colors = ["r", "g", "b", "c", "m", "k"]

    paramdesc = [["recency_ratio", None, 0.58, 0, 1],
                 ["recency_halflife", None, 80, 1, 120],
                 ["recency_lambdas", 0, 0.3, 0, 1],
                 ["recency_lambdas", 1, 0.3, 0, 1],
                 ["recency_lambdas", 2, 0.3, 0, 1],
                 ]

    class Parameter:
        pass
    parameters = []
    for pd in paramdesc:
        p = Parameter()
        p.name, p.index, p.default, p.min, p.max = pd
        p.xvalues = []
        p.label = p.name + (str(p.index) if p.index is not None else "") + \
            " (def. " + str(p.default) + ")"
        parameters.append(p)

    configurations = []
    for smoothing in smoothings:
        for merge in merges:
            for recency_smoothing in recency_smoothings:
                config = ModelConfig()
                config.smoothing = smoothing
                config.merge = merge
                config.recency_smoothing = recency_smoothing
                config.label = "%s (frequency), %s (recency), %s" % \
                    (smoothing, recency_smoothing, merge)
                config.marker = markers[len(configurations) % len(markers)]
                config.color = colors[len(configurations) // len(markers)]
                config.style = config.color + config.marker
                config.ksrs = [[] for x in parameters]
                configurations.append(config)

    iterations = 20
    for iteration in range(iterations):

        # evenly spread
        weight = iteration / float(iterations - 1)

        # evenly spread
        weights = [iteration / float(iterations - 1),
                   1.0 - iteration / float(iterations - 1)]

        for dimension, param in enumerate(parameters):
            param.xvalues.append(param.min + (param.max - param.min) * weight)

            for i, config in enumerate(configurations):
                print("iteration %d/%d, config %d/%d" %
                      (iteration + 1, iterations, i + 1, len(configurations)))

                # setup model to analyze
                for m in models:
                    m.smoothing = config.smoothing

                if config.merge == "overlay":
                    model = overlay(models)
                elif config.merge == "linint":
                    model = linint(models, weights)
                elif config.merge == "loglinint":
                    model = loglinint(models, weights)

                learn_model = models[1]
                learn_model.clear()
                learn_model.recency_smoothing = config.recency_smoothing
                # learn_model.lambdas = [1]
                for d, p in enumerate(parameters):
                    value = p.default
                    if d == dimension:
                        value = param.xvalues[iteration]
                    if p.index is None:
                        print(p.name, value)
                        setattr(learn_model, p.name, value)
                    else:
                        a = list(getattr(learn_model, p.name)) + [0]
                        a[p.index] = value
                        print(p.name, a)
                        setattr(learn_model, p.name, a)

                # get statistics
                ksr = random.random() * 100
                ksr = pypredict.ksr(
                    model, learn_model, testing_sentences, 10,
                    lambda i, n, c, p: sys.stdout.write("%d/%d\n" %
                                                        (i + 1, n)))
                # learn_model.save("out.lm")
                config.ksrs[dimension].append(ksr)

                print("ksr=%6.2f" % (ksr,))

        # plot
        plt.ion()  # interactive mode on

        plt.figure(1)
        plt.clf()
        lines = []
        labels = []
        for config in configurations:
            for d, p in enumerate(parameters):
                plt.subplot(3, 2, d + 1)
                line = plt.plot(p.xvalues, config.ksrs[d], config.style + "-",
                                label=config.label)
                plt.xlim(p.min, p.max)
                # plt.ylim(0, 35)
                # plt.xlabel('base model weight')
                plt.xlabel(p.label)
                plt.ylabel('ksr [%]')
            lines.append(line)
            labels.append(config.label)

        # plt.figlegend(lines, labels, 'upper right')  # 'upper right'
        plt.subplots_adjust(left=0.09, right=0.98, top=0.89, bottom=0.08,
                            hspace=0.38)
        plt.gcf().suptitle('Caching', fontsize=16)
        plt.draw()

    plt.ioff()
    plt.show()  # blocks; allows for interaction, saving images


if __name__ == '__main__':
    main()

