Newer
Older
PyTicTacToe / reinforcement / QLearningAgent.py
@sato sato on 1 Mar 2022 3 KB 最初のコミット
import copy
import numpy as np

class QLearningAgent:
    """
        Q学習 エージェント
    """

    def __init__(
            self,
            alpha=.2,
            epsilon=.1,
            gamma=.99,
            actions=None,
            observation=None):
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.reward_history = []
        self.actions = actions
        self.state = str(observation)
        self.ini_state = str(observation)
        self.previous_state = None
        self.previous_action = None
        self.q_values = self._init_q_values()

    def _init_q_values(self):
        """
           Q テーブルの初期化
        """
        q_values = {}
        q_values[self.state] = np.repeat(0.0, len(self.actions))
        return q_values

    def init_state(self):
        """
            状態の初期化
        """
        self.previous_state = copy.deepcopy(self.ini_state)
        self.state = copy.deepcopy(self.ini_state)
        return self.state

    def act(self, board):
        # ε-greedy選択
        if np.random.uniform() < self.epsilon:  # random行動
            while True:
                action = np.random.randint(0, len(self.q_values[self.state]))
                if board[action] == 0:
                    break
        else:   # greedy 行動
            tmp = self.q_values[self.state].copy()
            index = 0
            # ここはもっと考えなくちゃいけなさそう
            while True:
                action = tmp.argsort()[::-1][index]
                if board[action] == 0:
                    break
                index += 1

        self.previous_action = action
        return action

    def observe(self, next_state, reward=None):
        """
            次の状態と報酬の観測
        """
        next_state = str(next_state)
        if next_state not in self.q_values:  # 始めて訪れる状態であれば
            self.q_values[next_state] = np.repeat(0.0, len(self.actions))

        self.previous_state = copy.deepcopy(self.state)
        self.state = next_state

        if reward is not None:
            self.reward_history.append(reward)
            self.learn(reward)

    def learn(self, reward):
        """
            Q値の更新
        """
        q = self.q_values[self.previous_state][self.previous_action]  # Q(s, a)
        max_q = max(self.q_values[self.state])  # max Q(s')
        # Q(s, a) = Q(s, a) + alpha*(r+gamma*maxQ(s')-Q(s, a))
        self.q_values[self.previous_state][self.previous_action] = q + \
            (self.alpha * (reward + (self.gamma * max_q) - q))

    def test_observe(self, next_state):
        next_state = str(next_state)

        self.previous_state = copy.deepcopy(self.state)
        self.state = next_state

    def test_act(self, board):
        # greedy 行動
        if self.state in self.q_values:
            tmp = self.q_values[self.state].copy()
            index = 0
            # ここはもっと考えなくちゃいけなさそう
            while True:
                action = tmp.argsort()[::-1][index]
                if board[action] == 0:
                    break
                index += 1

            self.previous_action = action
        else:
            while True:
                print("ランダム行動を実行")
                action = np.random.randint(0, 8)
                if board[action] == 0:
                    break
        return action