使用 MCTS 连接四个游戏实现,我如何处理错误?

Connect four game implementation using MCTS, how do i take care of the errors?

提问人:Bhanu Teja Pogiri 提问时间:10/30/2023 最后编辑:MSaltersBhanu Teja Pogiri 更新时间:10/30/2023 访问量:27

问:

我正在努力设计和开发用于连接四游戏的 MCTS。但是我的方式上有无数的错误,我克服了很多,但不能弄清楚这一点,有人可以帮忙吗?

这是我写的代码,其他代码如 Minmax 等是正确的,因为它们提供了估计的结果,所以我不在这里发布它们。


# Starter Code: MCTS Implementation
import time


class Tree():
    def __init__(self, *, start_state=None, parent=None, move=None):
        if parent is None:
            self.parent = None
            self.move = None
            self.state = start_state
        else:
            self.parent = parent
            self.move = move
            self.state = parent.state.copy()
            self.state.play_move(move)
        
        self.values = []
        self.n = 0
        if self.is_terminal_state:
            self.unexplored_moves = set()
        else:
            self.unexplored_moves = set(self.state.get_moves())
        self.children = set([])
            
    @property
    def fully_expanded(self):
        return len(self.unexplored_moves) == 0
    
    @property
    def is_terminal_state(self):
        return (self.state.winner is not None)
    def uct_score(self, C=5):
        """Pick the best action according to the UCB/UCT algorithm"""

        child = Tree(parent=self, move=self.move)
        Q = sum(child.values) / child.n if child.n > 0 else 0
        U = C * np.sqrt(np.log(self.n) / (child.n + 1))
        return Q + U


def monte_carlo_tree_search(start_state, num_iterations=1000):
    """MCTS core loop"""
    # Start by creating the root of the tree.
    root = Tree(start_state=start_state)
    
    # Loop through MCTS iterations.
    for _ in range(num_iterations):
        # One step of MCTS iteration
        node = traverse(root)
        simulation_result = rollout(node, start_state)
        backpropagate(node, simulation_result)

    # When done iterating, return the 'best' child of the root node.
    return best_child(root)

def best_child(node):
    """When done sampling, pick the child visited the most."""

    best_child = None
    max_visits = -1
    for child in node.children:
        if child.n > max_visits:
            best_child = child
            max_visits = child.n

    return best_child

def best_uct(node, C=5):
    """Pick the best action according to the UCB/UCT algorithm"""

    best_child = None
    max_uct = -1
    for child in node.children:
        uct_score = child.uct_score(C)
        if uct_score > max_uct:
            best_child = child
            max_uct = uct_score

    return best_child

def traverse(node):
    # If fully explored, pick one of the children
    while node.fully_expanded and not node.is_terminal_state:
        node = best_uct(node)
    # If the node is terminal, return it
    if node.is_terminal_state:
        return node
    
    # If the node is not terminal:
    # 1. pick a new move from 'unexplored_moves'
    move = node.unexplored_moves.pop()
    # 2. create a new child
    new_child = Tree(parent=node, move=move)
    # 3. add that child to the list of children
    node.children.add(new_child)
    # 4. return that new child
    return new_child


def rollout(node, start_state):
    winner = node.state.copy().play_random_moves_until_done()
    if winner == 0:
        return 0
    elif winner == start_state.current_player:
        return 1
    else:
        return -1


def backpropagate(node, simulation_result):
    """Update the node and its parent (via recursion)."""
    if node is None:
        return
    node.values.append(simulation_result)
    node.n += 1

    backpropagate(node.parent, simulation_result)
    

# Evaluation Code

def print_wins(wins):
    print("======")
    print(f"Total Plays: {len(wins)}")
    print(f"MiniMax Wins: {len([w for w in wins if w == 1])}")
    print(f"MCTS Wins: {len([w for w in wins if w == 2])}")
    print(f"Draws: {len([w for w in wins if w == 0])}")

wins = []
for _ in range(25):
    tot_time_minimax = 0
    tot_time_mcts = 0
    board = ConnectFourBoard(nrows=6, ncols=7)
    board.current_player = random.choice([1, 2])
    while board.winner is None:
        if board.current_player == 1:
            stime = time.time()
            action = minimax(board, depth=5, verbose=False)
            tot_time_minimax += time.time() - stime
        else:
            stime = time.time()
            action = monte_carlo_tree_search(board)
            tot_time_mcts += time.time() - stime
        print(action)
        print("end")
        board.play_move(action)
    
    print(board)
    print(f"Winner: {board.winner}")
    print(tot_time_minimax, tot_time_mcts)
    wins.append(board.winner)
    print_wins(wins)

我得到了下面的输出(我在打印操作后添加了一个打印语句)

end
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[7], line 92, in ConnectFourBoard.play_move(self, col)
     91 try:
---> 92     row = np.where(self.board[:, col] == 0)[0][-1]
     93 except IndexError:

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[33], line 27
     25     print(action)
     26     print("end")
---> 27     board.play_move(action)
     29 print(board)
     30 print(f"Winner: {board.winner}")

Cell In[7], line 94, in ConnectFourBoard.play_move(self, col)
     92     row = np.where(self.board[:, col] == 0)[0][-1]
     93 except IndexError:
---> 94     raise ValueError(f"Cannot play column '{col}'.")
     95 self.board[row, col] = self.current_player
     97 # Check for a winner

ValueError: Cannot play column '<__main__.Tree object at 0x0000017E5CC3BA10>'.

输出状态不断显示,有时我收到此错误:


 ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[7], line 92, in ConnectFourBoard.play_move(self, col)
     91 try:
---> 92     row = np.where(self.board[:, col] == 0)[0][-1]
     93 except IndexError:

IndexError: index -1 is out of bounds for axis 0 with size 0

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[36], line 23
     21 else:
     22     stime = time.time()
---> 23     action = monte_carlo_tree_search(board)
     24     tot_time_mcts += time.time() - stime
     25 print(action)

Cell In[35], line 49, in monte_carlo_tree_search(start_state, num_iterations)
     46 # Loop through MCTS iterations.
     47 for _ in range(num_iterations):
     48     # One step of MCTS iteration
---> 49     node = traverse(root)
     50     simulation_result = rollout(node, start_state)
     51     backpropagate(node, simulation_result)

Cell In[35], line 84, in traverse(node)
     81 def traverse(node):
     82     # If fully explored, pick one of the children
     83     while node.fully_expanded and not node.is_terminal_state:
---> 84         node = best_uct(node)
     85     # If the node is terminal, return it
     86     if node.is_terminal_state:

Cell In[35], line 74, in best_uct(node, C)
     72 max_uct = -1
     73 for child in node.children:
---> 74     uct_score = child.uct_score(C)
     75     if uct_score > max_uct:
     76         best_child = child

Cell In[35], line 35, in Tree.uct_score(self, C)
     32 def uct_score(self, C=5):
     33     """Pick the best action according to the UCB/UCT algorithm"""
---> 35     child = Tree(parent=self, move=self.move)
     36     Q = sum(child.values) / child.n if child.n > 0 else 0
     37     U = C * np.sqrt(np.log(self.n) / (child.n + 1))

Cell In[35], line 15, in Tree.__init__(self, start_state, parent, move)
     13     self.move = move
     14     self.state = parent.state.copy()
---> 15     self.state.play_move(move)
     17 self.values = []
     18 self.n = 0

Cell In[7], line 94, in ConnectFourBoard.play_move(self, col)
     92     row = np.where(self.board[:, col] == 0)[0][-1]
     93 except IndexError:
---> 94     raise ValueError(f"Cannot play column '{col}'.")
     95 self.board[row, col] = self.current_player
     97 # Check for a winner

ValueError:无法播放列“3”错误不断显示(因为列值正在更改)。有人可以帮助我哪里出错,以及我应该替换的任何代码吗?谢谢。

蟒蛇 蒙特卡洛树搜索 连接四

评论

0赞 picobit 10/30/2023
您设置的地方有两个地方,其中一个或两个都是您的问题。您传入了一个无效的索引,并且错误消息告诉您您正在尝试访问不存在的元素。print_winsactionboard.play_move(action)

答: 暂无答案