from copy import deepcopy
import numpy as np
from tqdm import tqdm
from numba import njit, prange

@njit(cache=True)
def makeNextBoard(set_pos, board, cur_peaces_num):
    copy_board = board.copy()
    if cur_peaces_num == 3:
        copy_board[np.where(0 < copy_board)] -= 1
        copy_board[set_pos] = 3
    else:
        copy_board[set_pos] = cur_peaces_num + 1

    return copy_board


@njit(cache=True)
def evaluate(depth, board, is_me):
    lines = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [0, 3, 6], [1, 4, 7], [2, 5, 8], [0, 4, 8], [2, 4, 6]])
    for line in lines:
        if (board[line[0]] > 0 and board[line[1]] > 0 and board[line[2]] > 0):
            if is_me:
                return 10 - depth
            else:
                return depth - 10

    return 0


@njit(cache=True)
def minmax(depth, board, is_me, max_depth=10):
    copy_board = board.copy()

    eval_val = evaluate(depth, copy_board, is_me)

    if eval_val != 0 or depth == max_depth:
        return eval_val

    # 次にさせる場所をチェック
    # flipboard
    is_me = not is_me
    copy_board *= -1

    next_pos_arr = np.where(copy_board == 0)[0]
    cur_my_peaces = np.where(copy_board > 0)[0].shape[0]

    best_value = 0
    value = -10000 if is_me else 10000

    for pos_cand in next_pos_arr:
        next_board = makeNextBoard(pos_cand, copy_board, cur_my_peaces)
        child_val = minmax(depth + 1, next_board, is_me, max_depth)

        if is_me:
            if child_val > value:
                value = child_val
                best_pos_cand = pos_cand
        else:
            if child_val < value:
                value = child_val
                best_pos_cand = pos_cand
            # valueを空にする

    return value


@njit(parallel=True, cache=True)
def minMaxAct_submodule(copy_board, cur_my_peaces, eval_arr):
    for pos_cand in prange(9):

        if copy_board[pos_cand] == 0:
            next_board = makeNextBoard(pos_cand, copy_board, cur_my_peaces)
            eval_arr[pos_cand] = minmax(0, next_board, True, max_depth=20)

    return eval_arr


def minMaxAct(board):
    eval_arr = np.zeros(9)
    copy_board = board.copy()
    cur_my_peaces = np.where(copy_board > 0)[0].shape[0]

    # compile用
    # minmax(20, copy_board, True)

    #if cur_my_peaces == 0 and board[1] == board[3] == board[5] == board[7] == 0:
    #    return [1, 3, 5, 7][np.random.randint(0, 4)]

    eval_arr = minMaxAct_submodule(copy_board, cur_my_peaces, eval_arr)

    print(eval_arr)

    pos = 0
    while True:
        print("どこに置きますか？ (1:左上 - 9:右下): ")
        pos = int(input())
        if 1 <= pos <= 9 and board[pos - 1] == 0:
            break
    return pos - 1