#!/usr/bin/env seiscomp-python

################################################################################
# Copyright (C) 2012-2013, 2020 Helmholtz-Zentrum Potsdam - Deutsches GeoForschungsZentrum GFZ
#
# tabinvmodifier -- Tool for inventory modification using nettab files.
#
# This software is free software and comes with ABSOLUTELY NO WARRANTY.
#
# Author:  Marcelo Bianchi
# Email:   mbianchi@gfz-potsdam.de
################################################################################

from __future__ import print_function
import os
import sys
import datetime, time
from nettab.lineType import Nw, Sa, Na, Ia
from nettab.basesc3 import sc3
import seiscomp.datamodel, seiscomp.io, seiscomp.client, seiscomp.core, seiscomp.logging

class Rules(object):
    def __init__(self, relaxed = False):
        self.relaxed = relaxed
        self.attributes = {}
        self.iattributes = []
        return

    @staticmethod
    def _overlaps(pstart, pend, cstart, cend):
        if pend:
            if pend > cstart:
                if not cend or pstart < cend:
                    return True
        else:
            if not cend or pstart < cend:
                return True
        return False

    def Nw(self, nw):
        key = (nw.code, nw.start, nw.end)
        if key in self.attributes:
            raise Exception("Nw (%s/%s-%s) is already defined." % key)
        self.attributes[key] = {}
        self.attributes[key]["Sa"] = []
        self.attributes[key]["Na"] = []
        return key

    def Sa(self, key, sa):
        try:
            items = self.attributes[key]["Sa"]
        except KeyError:
            raise Exception ("Nw %s/%s-%s not found in Ruleset" % key)
        items.append(sa)

    def Na(self, key, na):
        try:
            items = self.attributes[key]["Na"]
        except KeyError:
            raise Exception ("Nw %s/%s-%s not found in Ruleset" % key)
        items.append(na)

    def Ia(self, ia):
        self.iattributes.append(ia);

    def findKey(self, ncode, nstart, nend):
        for (code, start, end) in self.attributes:
            if code == ncode and self._overlaps(start, end, nstart, nend):
                return (code, start, end)
        return None

    def getInstrumentsAttributes(self, elementId, elementType):
        att = {}
        for item in self.iattributes:
            if item.match(elementId, elementType):
                att[item.Key] = item.Value
        return att

    def getNetworkAttributes(self, key):
        att = {}
        for item in self.attributes[key]["Na"]:
            att[item.Key] = item.Value 
        return att

    def getStationAttributes(self, key, ncode, scode, lcode, ccode, start, end):
        att = {}
        for item in self.attributes[key]["Sa"]:
            if item.match(scode, lcode, ccode, start, end, self.relaxed):
                att[item.Key] = item.Value 
        return att

class InventoryModifier(seiscomp.client.Application):
    def __init__(self, argc, argv):
        seiscomp.client.Application.__init__(self, argc, argv)
        self.setMessagingUsername("iModify")

        self.rules = None
        self.relaxed = False
        self.outputFile = None

    def _digest(self, tabFilename, rules = None):
        if not tabFilename or not os.path.isfile(tabFilename):
            raise Exception("Supplied filename is invalid.")
        
        if not rules:
            rules = Rules(self.relaxed)
    
        try:
            fd = open(tabFilename)
            for line in fd:
                obj = None
                line = line.strip()
                if not line or line[0] == "#": continue
                if str(line).find(":") == -1:
                    raise Exception("Invalid line format '%s'" % line)
                (Type, Content) = line.split(":",1)
    
                if Type == "Nw":
                    nw = Nw(Content)
                    key = rules.Nw(nw)
                elif Type == "Sg":
                    raise Exception("Type not supported.")
                elif Type == "Na":
                    na = Na(Content)
                    rules.Na(key, na)
                elif Type == "Sa":
                    sa = Sa(Content)
                    rules.Sa(key, sa)
                elif Type == "Sr":
                    raise Exception("Type not supported.")
                elif Type == "Ia":
                    ia = Ia(Content)
                    rules.Ia(ia)
                elif Type == "Se":
                    raise Exception("Type not supported.")
                elif Type == "Dl":
                    raise Exception("Type not supported.")
                elif Type == "Cl":
                    raise Exception("Type not supported.")
                elif Type == "Ff":
                    raise Exception("Type not supported.")
                elif Type == "If":
                    raise Exception("Type not supported.")
                elif Type == "Pz":
                    raise Exception("Type not supported.")
        except Exception as e:
            raise e
    
        finally:
            if fd:
                fd.close()
        return rules

    def validateParameters(self):
        outputFile = None
        rulesFile  = None

        if self.commandline().hasOption("rules"):
            rulesFile = self.commandline().optionString("rules")

        if self.commandline().hasOption("output"):
            outputFile = self.commandline().optionString("output")

        if self.commandline().hasOption("relaxed"):
            self.relaxed = True

        if self.commandline().hasOption("inventory-db") and outputFile is None:
            print("Cannot send notifiers when loading inventory from file.", file=sys.stderr)
            return False

        if self.commandline().unrecognizedOptions():
            print("Invalid options: ", end=' ', file=sys.stderr)
            for i in self.commandline().unrecognizedOptions():
                print(i, end=' ', file=sys.stderr)
            print("", file=sys.stderr)
            return False

        if not rulesFile:
            print("No rule file was supplied for processing", file=sys.stderr)
            return False

        if not os.path.isfile(rulesFile):
            argv0 = os.path.basename(self.arguments()[0])
            print("%s: %s: No such file or directory" % (argv0, rulesFile), file=sys.stderr)
            return False

        if self.commandline().hasOption("inventory-db"):
            self.setDatabaseEnabled(False, False)
            self.setMessagingEnabled(False)

        self.rules = self._digest(rulesFile, self.rules)
        self.outputFile = outputFile
        return True

    def createCommandLineDescription(self):
        seiscomp.client.Application.createCommandLineDescription(self)

        self.commandline().addGroup("Rules")
        self.commandline().addStringOption("Rules", "rules,r", "Input XML filename")
        self.commandline().addOption("Rules", "relaxed,e", "Relax rules for matching NSLC items")

        self.commandline().addGroup("Dump")
        self.commandline().addStringOption("Dump", "output,o", "Output XML filename")

    def initConfiguration(self):
        value = seiscomp.client.Application.initConfiguration(self)
        self.setLoggingToStdErr(True)
        self.setDatabaseEnabled(True, True)
        self.setMessagingEnabled(True)
        self.setLoadInventoryEnabled(True)
        return value

    def send(self, *args):
        while not self.connection().send(*args):
            seiscomp.logging.warning("send failed, retrying")
            time.sleep(1)

    def send_notifiers(self, group):
        Nsize = seiscomp.datamodel.Notifier.Size()

        if Nsize > 0:
            seiscomp.logging.info("trying to apply %d change%s" % (Nsize,"s" if Nsize != 1 else "" ))
        else:
            seiscomp.logging.info("no changes to apply")
            return 0

        Nmsg = seiscomp.datamodel.Notifier.GetMessage(True)
        it = Nmsg.iter()
        msg = seiscomp.datamodel.NotifierMessage()

        maxmsg = 100
        sent = 0
        mcount = 0

        try:
            try:
                while it.get():
                    msg.attach(seiscomp.datamodel.Notifier_Cast(it.get()))
                    mcount += 1
                    if msg and mcount == maxmsg:
                        sent += mcount
                        seiscomp.logging.debug("sending message (%5.1f %%)" % (sent / float(Nsize) * 100.0))
                        self.send(group, msg)
                        msg.clear()
                        mcount = 0
                    next(it)
            except:
                pass
        finally:
            if msg.size():
                seiscomp.logging.debug("sending message (%5.1f %%)" % 100.0)
                self.send(group, msg)
                msg.clear()
        seiscomp.logging.info("done")
        return mcount

    @staticmethod
    def _loop(obj, count):
        return [ obj(i) for i in range(count) ]

    @staticmethod
    def _collect(obj):
        code  = obj.code()
        start = datetime.datetime.strptime(obj.start().toString("%Y %m %d %H %M %S"), "%Y %m %d %H %M %S")
        try:
            end = obj.end()
            end = datetime.datetime.strptime(end.toString("%Y %m %d %H %M %S"), "%Y %m %d %H %M %S")
        except:
            end = None
        return (code, start, end)

    @staticmethod
    def _modifyInventory(mode, obj, att):
        valid = sc3._findValidOnes(mode)
        if not att:
            return

        # Why repeat the code in basesc3.py (sc3::_fillSc3())?
        # What about if there are existing comments/pids - won't
        # this code get the count wrong??  *FIXME*
        commentNum = 0
        for (k,p) in att.items():
                try:
                    if k == 'Comment':
                        # print('DEBUG: Adding comment', p)
                        if p.startswith('Grant'):
                             # 2020: These belong in DOI metadata, not here.
                             continue

                        c = seiscomp.datamodel.Comment()
                        c.setText(p)
                        c.setId(str(commentNum))
                        commentNum += 1
                        obj.add(c)
                        continue

                    if k == 'Pid':
                        print('DEBUG: Adding Pid as comment', p)
                        c = seiscomp.datamodel.Comment()
                        (typ, val) = p.split(':', 1)
                        s = '{"type":"%s", "value":"%s"}' % (typ.upper(), val)
                        c.setText(s)
                        c.setId('FDSNXML:Identifier/' + str(commentNum))
                        commentNum += 1
                        obj.add(c)
                        continue

                    p = valid['attributes'][k]['validator'](p)
                    getattr(obj, 'set'+k)(p)
                except KeyError:
                    import string
                    hint = ''
                    if k[0] in string.lowercase:
                        hint = " (try '%s' instead)" % ( k[0].upper() + k[1:])
                    print('Modifying %s: \'%s\' is not a valid key%s' % (mode, k, hint), file=sys.stderr)
        obj.update()
        return

    def run(self):
        rules = self.rules
        iv = seiscomp.client.Inventory.Instance().inventory()

        if not rules:
            return False

        if not iv:
            return False

        seiscomp.logging.debug("Loaded %d networks" % iv.networkCount())
        if self.outputFile is None:
            seiscomp.datamodel.Notifier.Enable()
            self.setInterpretNotifierEnabled(True)

        for net in self._loop(iv.network, iv.networkCount()):
            (ncode, nstart, nend) = self._collect(net)
            key = rules.findKey(ncode, nstart, nend)
            if not key: continue
            att = rules.getNetworkAttributes(key)
            self._modifyInventory("network", net, att)
            seiscomp.logging.info("%s %s" % (ncode, att))
            for sta in self._loop(net.station, net.stationCount()):
                (scode, sstart, send) = self._collect(sta)
                att = rules.getStationAttributes(key, ncode, scode, None, None, sstart, send)
                self._modifyInventory("station", sta, att)
                if att: seiscomp.logging.info(" %s %s" % (scode, att))
                for loc in self._loop(sta.sensorLocation, sta.sensorLocationCount()):
                    (lcode, lstart, lend) = self._collect(loc)
                    att = rules.getStationAttributes(key, ncode, scode, lcode, None, lstart, lend)
                    self._modifyInventory("location", loc, att)
                    if att: seiscomp.logging.info("  %s %s" % (lcode, att))
                    for cha in self._loop(loc.stream, loc.streamCount()):
                        (ccode, cstart, cend) = self._collect(cha)
                        att = rules.getStationAttributes(key, ncode, scode, lcode, ccode, cstart, cend)
                        self._modifyInventory("channel", cha, att)
                        if att: seiscomp.logging.info("   %s %s" % (ccode, att))

        for sensor in self._loop(iv.sensor, iv.sensorCount()):
            att = rules.getInstrumentsAttributes(sensor.name(), "Se")
            self._modifyInventory("sensor", sensor, att)

        for datalogger in self._loop(iv.datalogger, iv.dataloggerCount()):
            att = rules.getInstrumentsAttributes(datalogger.name(), "Dl")
            self._modifyInventory("datalogger", datalogger, att)

        return True

    def done(self):
        if self.outputFile:
            ar = seiscomp.io.XMLArchive()
            ar.create(self.outputFile)
            ar.setFormattedOutput(True)
            ar.writeObject(seiscomp.client.Inventory.Instance().inventory())
            ar.close()
        else:
            self.send_notifiers("INVENTORY")
        seiscomp.client.Application.done(self)

if __name__ == "__main__":
    app = InventoryModifier(len(sys.argv), sys.argv)
    sys.exit(app())
