1.0 - Tutorial 8
Monte Carlo Tree Search
- MCTS belongs to a family of algorithms for fast planning in online settings
- The longer MCTS is allowed to run for, the higher quality the resulting policy is (on average)
- Rather than iterating over every connection in the state graph representation, we build a tree (subset of state graph), and prioritise the important states
- The MCTS algorith consists of four components/steps:
- Selection
- Expansion
- Simulation
- Back-Propagation
- At each node, we store the following statistics:
- Q(s,a) for each action in A - this is the average reward for the (s,a) pair over all of our trials
- N(s,a) for each action in a - this is the number of times action a has been performed from state s
- N(s) this is the number of times this state has been visited with any action performed
MCTS - Selection
- Given that we are at some current state, how do we choose what action to perform?
- We aim to compromise between:
- Exploration (visiting under-explored branches), and;
- Exploitation (visiting branches with higher average reward)
- Selection strategies include:
- Random Choice - if any actions have never been tried, choose an untried action at uniform random (tends to have better performance than trying to choose actions in a fixed order).
- Epsilon Greedy - Choose the highest Q-action with probability ϵ and all other actions equally likely (with n1−ϵ probability, where n is the number of actions that can be performed).
- UCB - compute a confidence interval for the true average reward, based on the number of trials and choose the action with the highest UCB.
MCTS - Expansion
- Convert a leaf node into a non-leaf node
- When a leaf node is reached
- Set N(s)←1 (that is, initialise the node count to 1)
- Estimate V (the future expected value of the state) via simulation
MCTS - Simulation
- Estimate the future expected value of a state without building up a tree
- Random Roll-out: Choose actions at random until some maximum horizon is reached, keeping a running total of the record
- For example, look 20 steps into the future, discounted by γ
- This is not necessarily close to the optimum value for the state, but it is a "good enough" estimate
- It indicates whether the state leads to potentially dangerous states, or states with high reward.
- Can average this over a number of random rollouts
- Can use a heuristic to choose the actions during roll-out rather than choosing purely at random
- Return the estimate value V.
MCTS - Backpropagation
- We want to use the results from our simulations and update our node statistics and values.
- Update the average total discounted reward and node/action counts for each branch visited.
- Let state st be the state visited at time step t, and let at be the action performed at time step t.
- At time step t, the discounted future reward is given by the following equation:
Rt=r(st)+γ1r(st+1)+⋯+γn(st+n)+γn+1V
- Fpr all time steps, we compute the Q-value and update the node statistics
Q(st,at)N(st,at)←N(st)+1N(st)Q(st,at)+Rt←N(st,at)+1,N(st)←N(st)+1
MCTS Implementation - Iterative
- Create a Tree Node class that stores:
- N(s)
- Q(s,a) and N(s,a) for each available action
- Stores a list of child nodes for each available actions
- Stores a reference to the parent node
mcts_search(current_state)
- node ← current_state
- While the node is not a leaf node, select an action and sample a next state (and set node← next_state)
- Expand the leaf node and estimate the value via simulation
- Create a new
TreeNode
instance for it
- Whilst the node doesn't have a parent, update Q(s,a), N(s,a) and N(a)
- And set Node ← node.parent
- This is our backpropagation step (where we move backward in time, and update values)
- We keep repeating this step until we reach the root node
MCTS Implementation - Recursive
- Dictionaries are used to store node statistics
- Ns[s] - number of times a state 's' has been visited
- Ns,a[(s,a)] - number of times an action 'a' has been performed from state 's'
- Qs,a[(s,a)] - Average reward from performing action a at state s
mcts_search(current_state)
- If the current state is a leaf node, estimate the value V from simulation and return the value (this is the recursion base case)
- Otherwise, select an action for the current state using our dictionaries
- Action can be selected using epsilon-greedy or UCB
- Sample the outcome of the next state and set
- V=immediate_reward+γ×mcts_search(next_state)
- Increment Ns[s] and Ns,a[(s,a)] and update Qs,a[(s,a)] using the value V
- Return the value V (so that the next level above can use the value)
For both approaches, the mcts_select_action should:
- Call mcts_search(current_state) while time/memory/iteration limits have not been reached
- Action ←argmax(Q(s,a)) over all actions
- Return the action