Importance-weighted Actor-Learner Architecture

Importance-weighted Actor-Learner Architecture

Tags

  • rl
  • research

IMPALA (no, not the animal; no, not the music band). Also cf. Lilian Weng’s explanation (of much more than just IMPALA).

In older approaches (e.g. A3C, UNREAL), each actor computes its own gradient and sends to learner. This is slow, so IMPALA operates on actors’ trajectories instead. Separating learning and acting increases system throughput.

IMPALA vs batched A2C

The learning is off-policy because each actor’s behavior policies \(\mu\) lag behind the target policy \(\pi\) from the learner. The V-trace algorithm presented in this paper accounts for this discrepancy. Discounted infinite-horizon RL in MDP: \(\max V^{\pi}(x)=E[\sum_t{\gamma^t r_t}]\)

Importance sampling, V-trace target

This V-trace target gets computed for each actor following a behavior policy \(\mu\). The actor generates a trajectory \((x_t, a_t, r_t)\).

The n-steps V-trace target for the value approximation \(V(x_s)\) is

\[v_s = V(x_s) + \sum_{t=s}^{s+n-1} \gamma^{t-s} (\prod_{i=s}^{t-1} c_i) \delta_t V\]

With temporal difference

\[δ_tV = ρ_t * (r_t + γV(x_{t+1}) − V (x_t))\]

The ratio of target and agent policy is defined by \(R_i=\frac{\pi(a_i|x_i)}{\mu(a_i|x_i)}\), used for importance sampling (IS):

\[\rho_i=\min(\bar\rho, R_i)\] \[c_i=\min(\bar c, R_i)\]

This reduces to the on-policy Bellman target if target == behavioral policy. So \(\pi_\rho\) is always between target and behavior. Fixed point of the update is when \(v_s = V(x_s)\).

Actor-critic algorithm

Policy parameter update in the direction of:

\[E_{a_s \tilde{} \mu (\cdot|x_s)} [\frac{\pi_\bar\rho(a_s|x_s)}{\mu(a_s|x_s)} \nabla \log \pi_\rho(a_s|x_s) q_s | x_s ]\]

Where \(q_s = r_s + \rho v_{s+1}\)

To reduce the variance of the policy gradient estimate, we usually subtract from \(q_s\) a state-dependent baseline (e.g. \(V(x_s)\)).

From Denny Britz, instead of measuring the absolute goodness of an action we want to know how much better than “average” it is to take an action given a state. Some states are naturally bad and always give negative reward. This is called the advantage and is defined as \(Q(s, a) - V(s)\). We use that for our policy update, e.g. \(g_t - V(s)\) for REINFORCE.

Canonical V-trace actor-critic

This part of the paper gives a concrete overview.

At training time s, the value params \(\theta\) are updated by gradient descent on the l2 loss to the target \(v_s\), ie in the direction of

\[(v_s − V_\theta(x_s)) (\nabla_\theta V_\theta)(x_s)\]

And the policy params \(\omega\) in the direction of the policy gradient:

\[\rho_s \nabla_\omega \log \pi_\omega(a_s|x_s) r_s + \gamma v_{s+1} - V_\theta(x_s)\]

In order to prevent premature convergence we may add an entropy bonus like A3C:

\[-\nabla_\omega \sum{\pi_\omega(a|x_s) \log(\pi_\omega(a|x_s))}\]

The overall update is obtained by summing these three gradients rescaled by appropriate coefficients (hyperparameters of the algorithm).

Some kind-of-not-really pseudocode in Python

Cf. the PyTorch implementation TorchBeast for less-handwavy specifics.

def learner(*args):
  # "What Would Learner Do" in my place
  behavior_policy, replay_buffer = many_actors_rollouts(env)
  target_policy = learner(replay_buffer)

  # vtrace (they also do some grad clipping which I omit)
  nll = lambda policy: nll_loss(log_softmax(policy), action)
  rho = nll(behavior_policy, action) - nll(target_policy, action)

  cs = torch.clamp(rho, max=1.0)
  deltas = rho * (rewards + discounts * vt_plus1 - vt)

  # xs_k = d_k + c_k * disc_k * xs_{k+1} where xs_T = 0
  acc = 0
  result = []
  for t in reversed(range(discounts.shape[0])): # t in {T,T-1, ... 1,0}
    acc = deltas[t] + cs[t] * discounts[t] * acc
    result.append(acc)
  result.reverse()
  v_minus_v_xs = torch.stack(result)