MuZero

Tags

  • rl

Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model. Great stuff. Prior to studying RL in more detail I was wondering why more people don’t try this type of hybrid approach; the gist of it is, as with 99% of RL, “it’s complicated”.

Cf. also Julian Schrittwieser’s explanation and DeepMind’s post and Julian’s presentation and the pseudocode.

MuZero figure

Delta from AlphaZero is that this planning approach assumes “no knowledge”, i.e. the simulator and representations and actions are all bootstrapped using NNs.

Neural networks

The idea here is to entirely plan in NN-generated latent space while acting according to MCTS policy counts. The NNs consist of:

  • Representation network \(h\) converts current game observation history to latent state \(s_0\)
  • Unroll network \(g\) converts latent state and action \((s_{k-1}, a_k)\) to next latent state and reward \((s_k, r_k)\).
  • Policy and value network \(f\) produces value and action \((v_k, a_k)\) given state \(s_k\).

Training: planning as an improvement operator

Loss function consists of 4 different components, namely:

  • Matching loss (MCTS policy versus NN policy)
  • Value loss (NN versus long-term reward)
  • Reward loss (NN versus immediate reward)
  • L2 regularization

In board games, there is no discounting as only value signal is the end result (binary win/loss).

NNs are trained to match ground-truth replay data from self-play episodes.

Acting

Cf. run_mcts in the pseudocode.

Some hyperparams were picked according to previous work on AlphaZero or generally according to RL “conventions” (e.g. ALE settings). They don’t really mention how they chose others, so I assume empirical (for instance, rollout is for \(K=5\) steps in all experiments).

Due to a smaller action space, it suffices to pick a sticky action for Atari every 4 steps and do 50 simulations per step. For board games, an action is chosen every step after 800 MCTS simulations.

Actions are sampled according to visit counts of MCTS policy and PUCT. PUCT basically ensures a tradeoff between policy and value NN predictions (the more times a policy has been visited, the more confident we are in its presumed utility). \(c_1=1.25\), \(c_2=19.652\) seem oddly specific:

\[a^k = \mathrm{arg}\max_a \left\{ Q(s,a) + P(s,a) \frac{\sqrt{\sum_b N(s,b)}}{N(s,a)+1} \left [ c_1 + \log \left (\frac{\sum_bN(s,b) + c_2 + 1}{c_2}\right) \right ] \right\}\]

In order to keep rewards general, Q-values are normalized:

\[\bar Q(s^{k-1}, a^k) = \frac{Q(s^{k-1}, a^k) - \min_{s,a \in \mathrm{Tree}} Q(s,a) } { \max_{s,a \in \mathrm{Tree}} Q(s,a) - \min_{s,a \in \mathrm{Tree}} Q(s,a) }\]

Dirichlet exploration noise added at tree root, because it boosts the prior of ~1-2 random actions high enough to be explored. Temperature is decayed to ensure convergence over time (starts high).

Representation

Moves in Go were represented as 19x19x1 board (one-hot placement). Moves in chess 8x8x8 (starting position, ending position, whether the move is valid; the rest was reserved for piece promotion - e.g. rook, queen, etc.). Moves in Shogi similar to chess.

Results

Curiously, it seems that using a different number of simulations during evaluation time doesn’t alter performance.

Using “reanalyze” they show that training can also be done on static episodes.

“Don’t try this at home” since you won’t be able to, anyway:

  • Board games: 16 TPUs training / 1000 TPUs self-play. 5 hours training.
  • Atari (less simulations): 8 TPUs train / 32 self-play.

Within a million training frames, results are SOTA.

Future work suggests generalizations of MuZero: to stochastic, continuous, non-stationary, temporally extended environments, to imperfect information on general sum games. Zero-shot generalization and model sharing between environments (single model, multiple games).

Code exploration from pseudocode

(Update 2021-06-20: Implementation in the Google Research repo)

Game is a stateful engine running the “real world” actions instantiated before playing an episode. Keeps track of action history, reward history, visit statistics for tree nodes. [NOTE: Why the last one? Seems kinda out of place.]

A new root node is created before taking an action. Node expansion is potentially inefficient for large action spaces (attempts all actions, regardless of validity). Seems to be called in one of two ways:

  • for first simulation step, limit only to legal actions
  • for all next steps, base on history; I guess the network will gradually prune illegal actions in the limit.

In backpropagation, a discounted value is updated along the simulation path.

Training is standard: a number of actors generate episodes in self-play mode. The training code samples replay buffer, and for each sample runs network inference on the raw observation, comparing simulation steps to actual actions. The replay buffer sampling constructs (image, history, traget) tuples, where the targets consist of actual observed replay values.