Importance-weighted Actor-Learner Architecture

Importance-weighted Actor-Learner Architecture

Tags

  • rl
  • research

IMPALA (no, not the animal; no, not the music band). Also cf. Lilian Weng’s explanation (of much more than just IMPALA).

In older approaches (e.g. A3C, UNREAL), each actor computes its own gradient and sends to learner. This is slow, so IMPALA operates on actors’ trajectories instead. Separating learning and acting increases system throughput.

IMPALA vs batched A2C

The learning is off-policy because each actor’s behavior policies μ lag behind the target policy π from the learner. The V-trace algorithm presented in this paper accounts for this discrepancy. Discounted infinite-horizon RL in MDP: max

Importance sampling, V-trace target

This V-trace target gets computed for each actor following a behavior policy \mu. The actor generates a trajectory (x_t, a_t, r_t).

The n-steps V-trace target for the value approximation V(x_s) is

v_s = V(x_s) + \sum_{t=s}^{s+n-1} \gamma^{t-s} (\prod_{i=s}^{t-1} c_i) \delta_t V

With temporal difference

δ_tV = ρ_t * (r_t + γV(x_{t+1}) − V (x_t))

The ratio of target and agent policy is defined by R_i=\frac{\pi(a_i|x_i)}{\mu(a_i|x_i)}, used for importance sampling (IS):

\rho_i=\min(\bar\rho, R_i) c_i=\min(\bar c, R_i)

This reduces to the on-policy Bellman target if target == behavioral policy. So \pi_\rho is always between target and behavior. Fixed point of the update is when v_s = V(x_s).

Actor-critic algorithm

Policy parameter update in the direction of:

E_{a_s \tilde{} \mu (\cdot|x_s)} [\frac{\pi_\bar\rho(a_s|x_s)}{\mu(a_s|x_s)} \nabla \log \pi_\rho(a_s|x_s) q_s | x_s ]

Where q_s = r_s + \rho v_{s+1}

To reduce the variance of the policy gradient estimate, we usually subtract from q_s a state-dependent baseline (e.g. V(x_s)).

From Denny Britz, instead of measuring the absolute goodness of an action we want to know how much better than “average” it is to take an action given a state. Some states are naturally bad and always give negative reward. This is called the advantage and is defined as Q(s, a) - V(s). We use that for our policy update, e.g. g_t - V(s) for REINFORCE.

Canonical V-trace actor-critic

This part of the paper gives a concrete overview.

At training time s, the value params \theta are updated by gradient descent on the l2 loss to the target v_s, ie in the direction of

(v_s − V_\theta(x_s)) (\nabla_\theta V_\theta)(x_s)

And the policy params \omega in the direction of the policy gradient:

\rho_s \nabla_\omega \log \pi_\omega(a_s|x_s) r_s + \gamma v_{s+1} - V_\theta(x_s)

In order to prevent premature convergence we may add an entropy bonus like A3C:

-\nabla_\omega \sum{\pi_\omega(a|x_s) \log(\pi_\omega(a|x_s))}

The overall update is obtained by summing these three gradients rescaled by appropriate coefficients (hyperparameters of the algorithm).

Some kind-of-not-really pseudocode in Python

Cf. the PyTorch implementation TorchBeast for less-handwavy specifics.

def learner(*args):
  # "What Would Learner Do" in my place
  behavior_policy, replay_buffer = many_actors_rollouts(env)
  target_policy = learner(replay_buffer)

  # vtrace (they also do some grad clipping which I omit)
  nll = lambda policy: nll_loss(log_softmax(policy), action)
  rho = nll(behavior_policy, action) - nll(target_policy, action)

  cs = torch.clamp(rho, max=1.0)
  deltas = rho * (rewards + discounts * vt_plus1 - vt)

  # xs_k = d_k + c_k * disc_k * xs_{k+1} where xs_T = 0
  acc = 0
  result = []
  for t in reversed(range(discounts.shape[0])): # t in {T,T-1, ... 1,0}
    acc = deltas[t] + cs[t] * discounts[t] * acc
    result.append(acc)
  result.reverse()
  v_minus_v_xs = torch.stack(result)