import os.path as osp
import random
import torch
import torch.nn as nn
import pandas as pd
import torchvision
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import Conv2dStaticSamePadding
from transformers import DeiTModel, DeiTConfig
from coatnet import *
from effnetv2 import *


NO_SPLITTED = 999


def train_test_split_by_sequence(data, test_sequence_num, balanced_test=False, random_seed=999):
    train_sets, test_sets = [], []
    for row in data:
        sample_file_path = row[0]
        sequence_name = osp.basename(osp.dirname(sample_file_path))
        if sequence_name == f"sequence{test_sequence_num}":
            test_sets.append(row)
        else:
            train_sets.append(row)

    if balanced_test:
        random.seed(random_seed)
        true_test_sets = [row for row in test_sets if row[-1]]
        false_test_sets = [row for row in test_sets if not(row[-1])]
        random.shuffle(true_test_sets)
        random.shuffle(false_test_sets)

        min_data_num = min(len(true_test_sets), len(false_test_sets))
        test_sets = true_test_sets[:min_data_num] + false_test_sets[:min_data_num]

    return train_sets, test_sets


def load_dsets_for_bagging(test_sequence_num, split_num, balanced_test=False):
    corrects = pd.read_csv(r"../input/bagging/corrects.csv", header=None).values.tolist()
    incorrects = pd.read_csv(r"../input/bagging/incorrects.csv", header=None).values.tolist()
    all_data = corrects + incorrects

    train_sets, test_sets = [], []
    for row in all_data:
        sample_file_path = row[0]
        sequence_name = osp.basename(osp.dirname(sample_file_path))
        if sequence_name == f"sequence{test_sequence_num}":
            test_sets.append(row[:4])
        else:
            if (row[-1] == split_num) or (row[-1] == NO_SPLITTED):
                train_sets.append(row[:4])

    if balanced_test:
        true_test_sets = [row for row in test_sets if row[-1]]
        false_test_sets = [row for row in test_sets if not(row[-1])]
        random.shuffle(true_test_sets)
        random.shuffle(false_test_sets)

        min_data_num = min(len(true_test_sets), len(false_test_sets))
        test_sets = true_test_sets[:min_data_num] + false_test_sets[:min_data_num]

    return train_sets, test_sets


def load_efficientnet(model_name, image_height, image_width):
    model = EfficientNet.from_pretrained(model_name, advprop=True, num_classes=2)
    model._conv_stem = Conv2dStaticSamePadding(9, 48, (3, 3), (2, 2), image_size=(image_height, image_width), bias=False)

    return model


def load_resnet50():
    model = torchvision.models.resnet50(pretrained=True)
    model.conv1 = nn.Conv2d(9, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc.out_features = 2
    return model


def load_densenet161():
    model = torchvision.models.densenet161(pretrained=True)
    model.features[0] = nn.Conv2d(9, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    # model.classifier = nn.Linear(in_features=2208, out_features=2, bias=True)
    model.classifier.out_features = 2

    return model


def load_deit(image_height, image_width):
    configuration = DeiTConfig(size=384, num_channels=9, )
    model = DeiTModel(configuration)

    return model


def load_CoAtNet():
    model = coatnet_2()
    model.s0[0] = nn.Conv2d(9, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    model.fc = nn.Linear(in_features=1026, out_features=2, bias=True)

    return model


def load_effnetv2():
    model = effnetv2_m(num_classes=2)

    return model


def load_mobilenet_v2():
    model = torchvision.models.mobilenet_v2(pretrained=True)
    model.features[0][0] = nn.Conv2d(9, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    model.classifier[1] = nn.Linear(1280, 2, bias=True)

    return model
