#!/usr/bin/env python3
# -*- coding: utf-8 -*-
############################################################################
# Copyright (C) GFZ Potsdam                                                #
# All rights reserved.                                                     #
#                                                                          #
# GNU Affero General Public License Usage                                  #
# This file may be used under the terms of the GNU Affero                  #
# Public License version 3.0 as published by the Free Software Foundation  #
# and appearing in the file LICENSE included in the packaging of this      #
# file. Please review the following information to ensure the GNU Affero   #
# Public License version 3.0 requirements will be met:                     #
# https://www.gnu.org/licenses/agpl-3.0.html.                              #
############################################################################

import sys
import struct
import json
import datetime
import argparse
import zmq
import numpy as np
from plugin import Seedlink


VERSION = "0.1 (2024.020)"


def readHeader(text, stationTemplate):
    header = json.loads(text)
    unpackFormat = "%d%s" % (header["nChannels"] * header["nPackagesPerMessage"],
                             "f" if header["dataType"] == "float" else "h")

    stations = []

    for roi in header["roiTable"]:
        for c in range(roi["roiStart"], roi["roiEnd"] + 1, roi["roiDec"]):
            stations.append(eval("f'''" + stationTemplate + "'''", {"channel": c}))

    if len(stations) != header["nChannels"]:
        raise Exception("Invalid number of channels in roiTable: %d != %d" %
                        (len(stations), header["nChannels"]))

    return header, stations, unpackFormat


def flushData(seedlink, header, network, stations, startTime, packageData):
    t = datetime.datetime.utcfromtimestamp(startTime)
    tq = 100 if header["trustedTimeSource"] else 10
    data = np.concatenate(packageData).astype(int)

    for i, station in enumerate(stations):
        seedlink.send_raw3("%s.%s" % (network, station), "X", t, 0, tq, data[:, i])


def main():
    parser = argparse.ArgumentParser()

    parser.set_defaults(
        address="tcp://localhost:3333",
        sample_rate=100,
        gain=1.0,
        network="XX",
        station="{channel:05d}"
    )

    parser.add_argument("--version",
        action="version",
        version="%(prog)s " + VERSION
    )

    parser.add_argument("-a", "--address",
        help="ZeroMQ address (default %(default)s)"
    )

    parser.add_argument("-r", "--sample-rate",
        type = int,
        help = "sample rate (default %(default)s)"
    )

    parser.add_argument("-g", "--gain",
        type=float,
        help="gain (default %(default)s)"
    )

    parser.add_argument("-n", "--network",
        help="network code (default %(default)s)"
    )

    parser.add_argument("-s", "--station",
        help="station code template (default %(default)s)"
    )

    parser.add_argument("plugin_id",
        help="plugin ID"
    )

    args = parser.parse_args()

    seedlink = Seedlink()
    sock = zmq.Context().socket(zmq.SUB)
    sock.connect(args.address)
    sock.setsockopt(zmq.SUBSCRIBE, b"")

    header, stations, unpackFormat = readHeader(sock.recv().decode("utf-8"), args.station)
    packageData = []
    nPackages = 0
    startTime = None
    timestamp = None

    while True:
        message = sock.recv()

        if len(message) != header["bytesPerPackage"] * header["nPackagesPerMessage"] + 8:
            try:
                if nPackages > 0:
                    flushData(seedlink, header, args.network, stations, startTime, packageData)
                    packageData = []
                    nPackages = 0

                header, stations, unpackFormat = readHeader(message.decode("utf-8"))
                print("%s: Parameters Changed. Updated header, ROIs and format" %
                      args.plugin_id, file=sys.stderr)

            except:
                print("%s: Unknown msg. Size: %d" % (args.plugin_id, len(message)),
                      file=sys.stderr)

            continue

        prevTimestamp = timestamp
        timestamp = struct.unpack("<Q", message[:8])[0] * 1e-9

        if nPackages >= args.sample_rate or (nPackages > 0 and
                int(round(1 / (timestamp - prevTimestamp))) != args.sample_rate):
            flushData(seedlink, header, args.network, stations, startTime, packageData)
            packageData = []
            nPackages = 0

        if nPackages == 0:
            startTime = timestamp

        data = np.array(struct.unpack(unpackFormat, message[8:])) \
                .reshape((header["nPackagesPerMessage"], header["nChannels"]))

        if header["dataType"] == "float":
            data *= args.gain

        packageData.append(data)
        nPackages += header["nPackagesPerMessage"]


if __name__ == "__main__":
    main()

