import argparse
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

csv_file = "out_all.csv"
stability_file = "out_stability.csv"
degradation_file = "out_degradation.csv"
paper_names = {
    "1A": "Glossy Pro Platinum (open)",
    "1B": "Glossy Pro Platinum (close)",
    "2A": "Photo Fine-Grained Luster (open)",
    "2B": "Photo Fine-Grained Luster (close)",
    "3A": "Photo Premium Matte (open)",
    "3B": "Photo Premium Matte (close)",
    "4A": "Matte Photo (open)",
    "4B": "Matte Photo (close)",
    "5A": "ELECOM Matte Photo (open)",
    "5B": "ELECOM Matte Photo (close)",
}
data_dir = (
    r"\\gabor\Project\内科\舌診\TIAS\開発\カラーチャート\TCC2025-01\退色調査3 マットフォト" + "\\"
)
colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple"]


def make_all_data_csv() -> None:
    """Reads all CSV files."""  # noqa: D401

    print("Combining all data into a single CSV file...")
    files = [f for f in os.listdir(data_dir) if f.endswith(".csv")]
    all_data = pd.DataFrame()
    for i, file in enumerate(files):
        if not file.startswith("tcc6"):
            continue
        file_path = os.path.join(data_dir, file)
        parts = os.path.splitext(os.path.basename(file))[0].split("_")
        paper = parts[0][5:7]
        # paper = "4B"  # 紙種類固定
        datestr = parts[1] + "-" + parts[2]
        # iter = 0 if file[11:12] == "a" else 1
        iter = 0
        print(f"Reading {file_path} (Paper: {paper}, Date: {datestr}, Iteration: {iter})")
        df = pd.read_csv(file_path, encoding="utf-8")
        df["tcc"] = df["no"].astype("int")
        df["paper"] = paper
        df["date"] = datestr
        df["iter"] = iter
        # df.drop(columns=["no", "X", "Y", "Z"], inplace=True, errors="ignore")
        # df.drop(columns=[f"{c:.1f}" for c in range(400, 710, 10)], inplace=True, errors="ignore")
        all_data = pd.concat([all_data, df], ignore_index=True)
    if all_data.empty:
        print("No data found in the specified directory.")
        exit(1)

    all_data.sort_values(by=["date", "paper", "iter", "tcc"], inplace=True)
    # Save the combined DataFrame to a CSV file
    all_data[["date", "paper", "iter", "tcc", "L*", "a*", "b*"]].to_csv(
        csv_file, index=False, encoding="utf-8"
    )


def calc_repeat_stability():
    """Calculate repeat stability of color measurements."""
    if not os.path.exists(csv_file):
        make_all_data_csv()

    print("Calculating repeat stability...")
    df = pd.read_csv(csv_file, encoding="utf-8", parse_dates=["date"])

    no_repeat = False
    if df["iter"].max() < 1:
        print("No repeat measurements found.")
        no_repeat = True

    df0 = df[df["iter"] == 0].drop(columns=["iter"]).copy()
    df0.columns = ["date", "paper", "tcc", "L0", "a0", "b0"]
    if not no_repeat:
        df1 = df[df["iter"] == 1].drop(columns=["iter"]).copy()
        df1.columns = ["date", "paper", "tcc", "L1", "a1", "b1"]
        dfm = pd.merge(df0, df1, on=["date", "paper", "tcc"])
        dfm["dL"] = dfm["L1"] - dfm["L0"]
        dfm["da"] = dfm["a1"] - dfm["a0"]
        dfm["db"] = dfm["b1"] - dfm["b0"]
        dfm["dE"] = np.sqrt(
            dfm["dL"] ** 2 + dfm["da"] ** 2 + dfm["db"] ** 2
        )  # Euclidean distance in Lab color space
        dfm["meanL"] = (dfm["L1"] + dfm["L0"]) / 2
        dfm["meana"] = (dfm["a1"] + dfm["a0"]) / 2
        dfm["meanb"] = (dfm["b1"] + dfm["b0"]) / 2
    else:
        dfm = df0.copy()
        dfm["dL"] = 0.0
        dfm["da"] = 0.0
        dfm["db"] = 0.0
        dfm["dE"] = 0.0
        dfm["meanL"] = dfm["L0"]
        dfm["meana"] = dfm["a0"]
        dfm["meanb"] = dfm["b0"]

    dfm.to_csv(stability_file, index=False, encoding="utf-8")

    # print("Top 10 largest deltaE values:")
    # print(dfm.sort_values("dE", ascending=False).head(10))

    # deltaE values Histogram
    # plt.figure()
    # dfm["dE"].hist(bins=30)
    # plt.title("deltaE Histogram")
    # plt.xlabel("deltaE")
    # plt.ylabel("Frequency")
    # plt.show()


def color_degradation():
    """Analyze color degradation over time."""
    if not os.path.exists(stability_file):
        calc_repeat_stability()

    print("Calculating color degradation...")
    dfm = pd.read_csv(stability_file, encoding="utf-8", parse_dates=["date"])
    dfm = dfm[["date", "paper", "tcc", "meanL", "meana", "meanb"]]

    dfm["changeL"] = 0.0
    dfm["changea"] = 0.0
    dfm["changeb"] = 0.0
    dfm["changeDE"] = 0.0
    dfm["days"] = 0
    paper_list = dfm["paper"].drop_duplicates().sort_values()
    tcc_list = dfm["tcc"].drop_duplicates().sort_values()
    dfout = pd.DataFrame()
    for paper in paper_list:
        paper_data = dfm[dfm["paper"] == paper]
        begin_date = paper_data["date"].min()

        for tcc in tcc_list:
            tcc_data = paper_data[paper_data["tcc"] == tcc]
            begin_lab = (
                tcc_data[tcc_data["date"] == begin_date][["meanL", "meana", "meanb"]]
                .iloc[0]
                .to_list()
            )
            tcc_data.loc[:, "days"] = (tcc_data["date"] - begin_date).dt.total_seconds() // 60
            tcc_data.loc[:, "changeL"] = tcc_data["meanL"] - begin_lab[0]
            tcc_data.loc[:, "changea"] = tcc_data["meana"] - begin_lab[1]
            tcc_data.loc[:, "changeb"] = tcc_data["meanb"] - begin_lab[2]
            tcc_data.loc[:, "changeDE"] = np.sqrt(
                tcc_data["changeL"] ** 2 + tcc_data["changea"] ** 2 + tcc_data["changeb"] ** 2
            )
            dfout = pd.concat([dfout, tcc_data], ignore_index=True)
            # print(tcc_data.sort_values("date"))
            # print("begin_lab", begin_lab)
        #     break
        # break

    dfout = dfout.sort_values(by=["paper", "tcc", "date"]).reset_index(drop=True)
    dfout.to_csv(degradation_file, index=False, encoding="utf-8")


def visualize():
    """Visualize color degradation over time."""

    if not os.path.exists(degradation_file):
        color_degradation()

    print("Visualizing color degradation...")
    df = pd.read_csv(degradation_file, encoding="utf-8", parse_dates=["date"])
    paper_list = df["paper"].drop_duplicates().sort_values()

    fig = plt.figure()
    fig.set_size_inches(18, 12)
    fig.suptitle("Color Degradation Over Time (each TCC)", fontsize=16)
    fig.subplots_adjust(wspace=0.3, hspace=0.4)
    fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
    for i in range(1, 25):
        ax = fig.add_subplot(4, 6, i)
        ax.set_title(f"TCC {i}")
        ax.set_xlabel("Days")
        ax.set_ylabel("Change in deltaE")
        for pidx, paper in enumerate(paper_list):
            paper_data = df[df["paper"] == paper]
            if not paper_data.empty:
                ax.plot(
                    paper_data[paper_data["tcc"] == i]["days"] / (24 * 60),
                    paper_data[paper_data["tcc"] == i]["changeDE"],
                    marker="o" if (pidx % 2) == 0 else "^",
                    markersize=4,
                    color=colors[pidx // 2],
                    label=f"Paper {paper}",
                )
                # ax.set_ylim(0, 3)  # Set y-axis limit for better visibility
        ax.legend()
    fig.savefig("out_degradation_eachTCC.png", dpi=300, bbox_inches="tight")

    paper_mean = df.groupby(["paper", "date"]).mean().reset_index()
    fig2 = plt.figure()
    fig2.set_size_inches(9, 6)
    fig2.suptitle("Color Degradation Over Time (mean TCC)", fontsize=16)
    for pidx, paper in enumerate(paper_list):
        paper_data = paper_mean[paper_mean["paper"] == paper]
        if not paper_data.empty:
            plt.plot(
                paper_data["days"] / (24 * 60),
                paper_data["changeDE"],
                marker="o" if (pidx % 2) == 0 else "^",
                markersize=6,
                color=colors[pidx // 2],
                label=f"Paper {paper} {paper_names[paper]}",
            )
    plt.legend(loc="upper left", borderaxespad=0)
    plt.xlabel("Days since first measurement")
    plt.ylabel("Change in deltaE")
    fig2.savefig("out_degradation_meanTCC.png", dpi=300, bbox_inches="tight")

    # plt.show()


if __name__ == "__main__":
    # parser = argparse.ArgumentParser(description="Analyze and visualize color degradation.")
    # parser.add_argument(
    #     "csvfile",
    #     help="Measured CSV file paths",
    # )
    # args = parser.parse_args()
    # if not args.csvfile.endswith(".csv"):
    #     print("Please provide the path to the CSV file.")
    #     exit(1)
    # print("datadir:", os.path.dirname(args.csvfile))
    # data_dir = os.path.dirname(args.csvfile)

    print("Starting analysis...")
    visualize()
    print("All done.")
