from copy import deepcopy
import numpy as np
from tqdm import tqdm
from numba import njit, prange
@njit(cache=True)
def makeNextBoard(set_pos, board, cur_peaces_num):
copy_board = board.copy()
if cur_peaces_num == 3:
copy_board[np.where(0 < copy_board)] -= 1
copy_board[set_pos] = 3
else:
copy_board[set_pos] = cur_peaces_num + 1
return copy_board
@njit(cache=True)
def evaluate(depth, board, is_me):
lines = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [0, 3, 6], [1, 4, 7], [2, 5, 8], [0, 4, 8], [2, 4, 6]])
for line in lines:
if (board[line[0]] > 0 and board[line[1]] > 0 and board[line[2]] > 0):
if is_me:
return 10 - depth
else:
return depth - 10
return 0
@njit(cache=True)
def minmax(depth, board, is_me, max_depth=10):
copy_board = board.copy()
eval_val = evaluate(depth, copy_board, is_me)
if eval_val != 0 or depth == max_depth:
return eval_val
# 次にさせる場所をチェック
# flipboard
is_me = not is_me
copy_board *= -1
next_pos_arr = np.where(copy_board == 0)[0]
cur_my_peaces = np.where(copy_board > 0)[0].shape[0]
best_value = 0
value = -10000 if is_me else 10000
for pos_cand in next_pos_arr:
next_board = makeNextBoard(pos_cand, copy_board, cur_my_peaces)
child_val = minmax(depth + 1, next_board, is_me, max_depth)
if is_me:
if child_val > value:
value = child_val
best_pos_cand = pos_cand
else:
if child_val < value:
value = child_val
best_pos_cand = pos_cand
# valueを空にする
return value
@njit(parallel=True, cache=True)
def minMaxAct_submodule(copy_board, cur_my_peaces, eval_arr):
for pos_cand in prange(9):
if copy_board[pos_cand] == 0:
next_board = makeNextBoard(pos_cand, copy_board, cur_my_peaces)
eval_arr[pos_cand] = minmax(0, next_board, True, max_depth=20)
return eval_arr
def minMaxAct(board):
eval_arr = np.zeros(9)
copy_board = board.copy()
cur_my_peaces = np.where(copy_board > 0)[0].shape[0]
# compile用
# minmax(20, copy_board, True)
#if cur_my_peaces == 0 and board[1] == board[3] == board[5] == board[7] == 0:
# return [1, 3, 5, 7][np.random.randint(0, 4)]
eval_arr = minMaxAct_submodule(copy_board, cur_my_peaces, eval_arr)
print(eval_arr)
pos = 0
while True:
print("どこに置きますか? (1:左上 - 9:右下): ")
pos = int(input())
if 1 <= pos <= 9 and board[pos - 1] == 0:
break
return pos - 1