提问人:Bhanu Teja Pogiri 提问时间:10/30/2023 最后编辑:MSaltersBhanu Teja Pogiri 更新时间:10/30/2023 访问量:27
使用 MCTS 连接四个游戏实现,我如何处理错误?
Connect four game implementation using MCTS, how do i take care of the errors?
问:
我正在努力设计和开发用于连接四游戏的 MCTS。但是我的方式上有无数的错误,我克服了很多,但不能弄清楚这一点,有人可以帮忙吗?
这是我写的代码,其他代码如 Minmax 等是正确的,因为它们提供了估计的结果,所以我不在这里发布它们。
# Starter Code: MCTS Implementation
import time
class Tree():
def __init__(self, *, start_state=None, parent=None, move=None):
if parent is None:
self.parent = None
self.move = None
self.state = start_state
else:
self.parent = parent
self.move = move
self.state = parent.state.copy()
self.state.play_move(move)
self.values = []
self.n = 0
if self.is_terminal_state:
self.unexplored_moves = set()
else:
self.unexplored_moves = set(self.state.get_moves())
self.children = set([])
@property
def fully_expanded(self):
return len(self.unexplored_moves) == 0
@property
def is_terminal_state(self):
return (self.state.winner is not None)
def uct_score(self, C=5):
"""Pick the best action according to the UCB/UCT algorithm"""
child = Tree(parent=self, move=self.move)
Q = sum(child.values) / child.n if child.n > 0 else 0
U = C * np.sqrt(np.log(self.n) / (child.n + 1))
return Q + U
def monte_carlo_tree_search(start_state, num_iterations=1000):
"""MCTS core loop"""
# Start by creating the root of the tree.
root = Tree(start_state=start_state)
# Loop through MCTS iterations.
for _ in range(num_iterations):
# One step of MCTS iteration
node = traverse(root)
simulation_result = rollout(node, start_state)
backpropagate(node, simulation_result)
# When done iterating, return the 'best' child of the root node.
return best_child(root)
def best_child(node):
"""When done sampling, pick the child visited the most."""
best_child = None
max_visits = -1
for child in node.children:
if child.n > max_visits:
best_child = child
max_visits = child.n
return best_child
def best_uct(node, C=5):
"""Pick the best action according to the UCB/UCT algorithm"""
best_child = None
max_uct = -1
for child in node.children:
uct_score = child.uct_score(C)
if uct_score > max_uct:
best_child = child
max_uct = uct_score
return best_child
def traverse(node):
# If fully explored, pick one of the children
while node.fully_expanded and not node.is_terminal_state:
node = best_uct(node)
# If the node is terminal, return it
if node.is_terminal_state:
return node
# If the node is not terminal:
# 1. pick a new move from 'unexplored_moves'
move = node.unexplored_moves.pop()
# 2. create a new child
new_child = Tree(parent=node, move=move)
# 3. add that child to the list of children
node.children.add(new_child)
# 4. return that new child
return new_child
def rollout(node, start_state):
winner = node.state.copy().play_random_moves_until_done()
if winner == 0:
return 0
elif winner == start_state.current_player:
return 1
else:
return -1
def backpropagate(node, simulation_result):
"""Update the node and its parent (via recursion)."""
if node is None:
return
node.values.append(simulation_result)
node.n += 1
backpropagate(node.parent, simulation_result)
# Evaluation Code
def print_wins(wins):
print("======")
print(f"Total Plays: {len(wins)}")
print(f"MiniMax Wins: {len([w for w in wins if w == 1])}")
print(f"MCTS Wins: {len([w for w in wins if w == 2])}")
print(f"Draws: {len([w for w in wins if w == 0])}")
wins = []
for _ in range(25):
tot_time_minimax = 0
tot_time_mcts = 0
board = ConnectFourBoard(nrows=6, ncols=7)
board.current_player = random.choice([1, 2])
while board.winner is None:
if board.current_player == 1:
stime = time.time()
action = minimax(board, depth=5, verbose=False)
tot_time_minimax += time.time() - stime
else:
stime = time.time()
action = monte_carlo_tree_search(board)
tot_time_mcts += time.time() - stime
print(action)
print("end")
board.play_move(action)
print(board)
print(f"Winner: {board.winner}")
print(tot_time_minimax, tot_time_mcts)
wins.append(board.winner)
print_wins(wins)
我得到了下面的输出(我在打印操作后添加了一个打印语句)
end
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[7], line 92, in ConnectFourBoard.play_move(self, col)
91 try:
---> 92 row = np.where(self.board[:, col] == 0)[0][-1]
93 except IndexError:
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[33], line 27
25 print(action)
26 print("end")
---> 27 board.play_move(action)
29 print(board)
30 print(f"Winner: {board.winner}")
Cell In[7], line 94, in ConnectFourBoard.play_move(self, col)
92 row = np.where(self.board[:, col] == 0)[0][-1]
93 except IndexError:
---> 94 raise ValueError(f"Cannot play column '{col}'.")
95 self.board[row, col] = self.current_player
97 # Check for a winner
ValueError: Cannot play column '<__main__.Tree object at 0x0000017E5CC3BA10>'.
输出状态不断显示,有时我收到此错误:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[7], line 92, in ConnectFourBoard.play_move(self, col)
91 try:
---> 92 row = np.where(self.board[:, col] == 0)[0][-1]
93 except IndexError:
IndexError: index -1 is out of bounds for axis 0 with size 0
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[36], line 23
21 else:
22 stime = time.time()
---> 23 action = monte_carlo_tree_search(board)
24 tot_time_mcts += time.time() - stime
25 print(action)
Cell In[35], line 49, in monte_carlo_tree_search(start_state, num_iterations)
46 # Loop through MCTS iterations.
47 for _ in range(num_iterations):
48 # One step of MCTS iteration
---> 49 node = traverse(root)
50 simulation_result = rollout(node, start_state)
51 backpropagate(node, simulation_result)
Cell In[35], line 84, in traverse(node)
81 def traverse(node):
82 # If fully explored, pick one of the children
83 while node.fully_expanded and not node.is_terminal_state:
---> 84 node = best_uct(node)
85 # If the node is terminal, return it
86 if node.is_terminal_state:
Cell In[35], line 74, in best_uct(node, C)
72 max_uct = -1
73 for child in node.children:
---> 74 uct_score = child.uct_score(C)
75 if uct_score > max_uct:
76 best_child = child
Cell In[35], line 35, in Tree.uct_score(self, C)
32 def uct_score(self, C=5):
33 """Pick the best action according to the UCB/UCT algorithm"""
---> 35 child = Tree(parent=self, move=self.move)
36 Q = sum(child.values) / child.n if child.n > 0 else 0
37 U = C * np.sqrt(np.log(self.n) / (child.n + 1))
Cell In[35], line 15, in Tree.__init__(self, start_state, parent, move)
13 self.move = move
14 self.state = parent.state.copy()
---> 15 self.state.play_move(move)
17 self.values = []
18 self.n = 0
Cell In[7], line 94, in ConnectFourBoard.play_move(self, col)
92 row = np.where(self.board[:, col] == 0)[0][-1]
93 except IndexError:
---> 94 raise ValueError(f"Cannot play column '{col}'.")
95 self.board[row, col] = self.current_player
97 # Check for a winner
ValueError:无法播放列“3”错误不断显示(因为列值正在更改)。有人可以帮助我哪里出错,以及我应该替换的任何代码吗?谢谢。
答: 暂无答案
评论
print_wins
action
board.play_move(action)