from gdx import gdx
import threading
import time
import serial
import serial.tools.list_ports
import sys
import csv
from abc import abstractmethod
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import datetime


class DataWorker(threading.Thread):
    def __init__(self, out_filename=""):
        threading.Thread.__init__(self)
        self._running = False
        self._start_time = 0
        self._f = None
        self._writer = None
        if out_filename:
            self._f = open(out_filename, "w", newline="")
            self._writer = csv.writer(self._f)

    @abstractmethod
    def preprocess(self):
        pass

    @abstractmethod
    def read(self):
        pass

    @abstractmethod
    def postprocess(self):
        pass

    def run(self):
        self.preprocess()
        self._running = True
        self._data = np.array([tuple([0] * self._cols)])
        self._start_time = time.perf_counter()
        while self._running:
            measurements = self.read()
            current_time = time.perf_counter() - self._start_time
            measurements.insert(0, current_time)
            if self._writer:
                self._writer.writerow(measurements)
            if len(measurements) == self._cols:
                self._data = np.vstack([self._data, measurements])
            # print(measurements)

    def stop(self):
        self._running = False
        self.join(1.0)
        self.postprocess()
        if self._f:
            self._f.close()
        self._data = self._data[1:, :]


class GoDirectWorker(DataWorker):
    def __init__(self, out_filename):
        super().__init__(out_filename)
        # Connect GoDirect device
        self._gdx = gdx.gdx()
        self._gdx.open(connection="usb")
        self._gdx.select_sensors([1])
        self._cols = 2
        print("GoDirect connected.")

    def preprocess(self):
        self._gdx.start(50)
        if self._writer:
            self._writer.writerow(["time", "force"])

    def read(self):
        return self._gdx.read()

    def postprocess(self):
        self._gdx.stop()
        self._gdx.close()


class SerialWorker(DataWorker):
    def __init__(self, out_filename, port):
        super().__init__(out_filename)
        self._serial = serial.Serial(port, 115200, timeout=1)
        self._cols = 4
        print("Seiral port connected on ", port)

    def preprocess(self):
        if self._writer:
            self._writer.writerow(["time", "IF-I", "IF-Q"])

    def read(self):
        if not self._serial.is_open:
            return [0, 0, 0]
        line = self._serial.readline()
        if not line:
            return [0, 0, 0]
        line2 = line.decode("utf-8").strip()
        if line2 == "":
            return [0, 0, 0]
        return [float(_) for _ in line2.split(",")]

    def postprocess(self):
        self._serial.close()


if __name__ == "__main__":
    # Initialize devices
    ports = serial.tools.list_ports.comports()
    if len(ports) < 1:
        print("Can't find serial port")
        sys.exit()
    serial_worker = SerialWorker("", ports[0].device)
    godirect_worker = GoDirectWorker("")

    # Launch worker threads and start measurements
    serial_worker.start()
    godirect_worker.start()

    # Wait for stop
    print("measurement start.")
    recode_dulation = 3  # sec
    for i in range(recode_dulation):
        time.sleep(1)
        print(i + 1, "/", recode_dulation, "sec")
    # input("Press ENTER to stop")
    print("measurement end.")

    # Closing
    serial_worker.stop()
    godirect_worker.stop()

    # Data arrangement
    gd_data = godirect_worker._data
    gd_t = np.linspace(
        gd_data[0, 0], gd_data[-1, 0], len(gd_data[:, 0])
    )  # Equalize time interval
    valid = serial_worker._data[:, 0] >= gd_data[0, 0]
    ser_data = serial_worker._data[valid, :]

    # print(godirect_worker._data)
    # print(serial_worker._data)

    func = interp1d(gd_t, gd_data[:, 1], kind="cubic")
    new_gd_data = func(ser_data[:, 0])
    combind = np.hstack((ser_data, new_gd_data.reshape(new_gd_data.size, 1)))

    filename = datetime.datetime.now().strftime("Rec_%Y%m%d_%H%M%S.csv")
    np.savetxt(
        filename,
        combind,
        delimiter=",",
        fmt="%.3f",
        header="time(s),device time(ms),IF-I,IF-Q,RespBelt",
    )
    print("Recoding data was saved to ", filename)
    # print(combind)

    # plt.plot(
    #     gd_t,
    #     gd_data[:, 1],
    #     "o",
    #     ser_data[:, 0],
    #     new_gd,
    #     "-",
    # )
    # plt.show()
