import numpy as np
from QLearningAgent import QLearningAgent
from AdaptiveTTT import TTTConsole
from QLearningUtils import ReferRewrads, decodeResult
import pickle
NB_EPISODE = 5000000 # エピソード数
EPSILON = 0.1 # 探索率
ALPHA = 0.1 # 学習率
GAMMA = 0.90 # 割引率
ACTIONS = np.arange(9) # 行動の集合
if __name__ == '__main__':
ttt_env = TTTConsole(Is_shown=False)
ini_state = (0, 0, 0, 0, 0, 0, 0, 0, 0)
agent = QLearningAgent(
alpha=ALPHA,
gamma=GAMMA,
epsilon=EPSILON,
actions=ACTIONS,
observation=ini_state
)
rewards = []
is_end_episode = False
ttt_env.init_TTTenv()
myplayer_num = ttt_env._ttt.Player
result_dict = {
"win": 0,
"draw": 0,
"lose": 0,
"overlap": 0
}
for episode in range(NB_EPISODE):
episode_reward = list()
while(is_end_episode == False):
action = agent.act
judge, player, board = ttt_env.step(action)
state, reward, is_end_episode = ReferRewrads(judge, player == myplayer_num, board)
agent.observe(state, reward)
episode_reward.append(reward)
result = decodeResult(judge, player == myplayer_num)
result_dict[result] += 1
if (episode) % 10000 == 9999:
print(f"[episode: {episode + 1}] win: {result_dict['win']} draw: {result_dict['draw']} lose: {result_dict['lose']} overlap: {result_dict['overlap']}")
result_dict = {
"win": 0,
"draw": 0,
"lose": 0,
"overlap": 0
}
rewards.append(np.sum(episode_reward))
is_end_episode = False
state = ttt_env.reset_randTTT_env()
myplayer_num = ttt_env._ttt.Player
agent.observe(state)
agent.epsilon = 0.0
result_dict = {
"win": 0,
"draw": 0,
"lose": 0,
"overlap": 0
}
for episode in range(100000):
while(is_end_episode == False):
action = agent.test_act
judge, player, board = ttt_env.step(action)
state, reward, is_end_episode = ReferRewrads(judge, player == myplayer_num, board)
agent.test_observe(state)
result = decodeResult(judge, player == myplayer_num)
result_dict[result] += 1
if (episode) % 10000 == 9999:
print(
f"[episode: {episode + 1}] win: {result_dict['win']} draw: {result_dict['draw']} lose: {result_dict['lose']} overlap: {result_dict['overlap']}")
result_dict = {
"win": 0,
"draw": 0,
"lose": 0,
"overlap": 0
}
is_end_episode = False
state = ttt_env.reset_randTTT_env()
myplayer_num = ttt_env._ttt.Player
agent.test_observe(state)
with open("q_values.pkl", "wb") as f:
pickle.dump(agent.q_values, f)