Actor-Critic: когда две головы лучше одной (но сложнее)
Вот представьте: у вас есть нейросеть, которая выбирает действия (актор). И вторая нейросеть, которая оценивает эти действия (критик). Вместе они должны научиться играть в Atari или управлять роботом. В теории — элегантно. На практике — недели дебага, кривые графики и непонятно, почему агент ходит кругами вместо того, чтобы собирать монеты.
Я потратил три месяца, чтобы заставить Actor-Critic работать стабильно. Не на бумаге, а в реальном коде. Собрал все грабли, на которые можно наступить. И сейчас покажу, как обойти их все.
Ошибка #1: Путаница в loss функциях (или "почему всё расходится")
Самый частый вопрос на Stack Overflow про Actor-Critic: "Почему мои loss значения улетают в бесконечность?" Ответ почти всегда один — неправильные loss функции.
Не делайте так. Никогда. Это гарантированный путь к NaN'ам:
# КАК НЕ НАДО ДЕЛАТЬ
actor_loss = -critic_value.mean() # Ошибка: нет advantage
critic_loss = F.mse_loss(critic_value, rewards) # Ошибка: нет дисконтирования
Проблема в том, что многие туториалы показывают упрощённые формулы. А в реальности нужно учитывать advantage, baseline, и дисконтирование будущих наград.
Правильная реализация: от первого нейрона до работающего агента
1 Архитектура сетей (без лишних слоёв)
Первое правило Actor-Critic: актор и критик могут (и должны) делить часть слоёв. Особенно при работе с изображениями.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# Общие слои для извлечения признаков
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Голова актора (возвращает логарифмы вероятностей)
self.actor = nn.Linear(hidden_dim, action_dim)
# Голова критика (возвращает оценку состояния)
self.critic = nn.Linear(hidden_dim, 1)
def forward(self, state):
features = self.shared(state)
action_logits = self.actor(features)
state_value = self.critic(features)
return action_logits, state_value
def get_action(self, state):
with torch.no_grad():
logits, value = self.forward(state)
probs = F.softmax(logits, dim=-1)
action = torch.multinomial(probs, 1).item()
log_prob = F.log_softmax(logits, dim=-1)[action]
return action, log_prob, value
2 Сбор траекторий (где чаще всего ошибаются)
Вот типичная ошибка новичков — они собирают траектории неправильно. Забывают про done флаги. Не сохраняют log_prob'ы. И потом удивляются, почему advantage вычисляется криво.
class TrajectoryBuffer:
def __init__(self, gamma=0.99, gae_lambda=0.95):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
self.values = []
self.gamma = gamma
self.gae_lambda = gae_lambda
def add(self, state, action, reward, done, log_prob, value):
# ВАЖНО: сохраняем всё, что понадобится позже
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.dones.append(done)
self.log_probs.append(log_prob)
self.values.append(value)
def compute_advantages(self):
"""Вычисление advantage с GAE (Generalized Advantage Estimation)"""
advantages = []
last_advantage = 0
# Идём с конца траектории
for t in reversed(range(len(self.rewards))):
if t == len(self.rewards) - 1:
next_value = 0 # Конец эпизода
else:
next_value = self.values[t + 1] * (1 - self.dones[t])
delta = self.rewards[t] + self.gamma * next_value - self.values[t]
# GAE формула
advantage = delta + self.gamma * self.gae_lambda * last_advantage * (1 - self.dones[t])
advantages.insert(0, advantage)
last_advantage = advantage
return advantages
Обратите внимание на (1 - self.dones[t]) в формуле next_value. Если done[t] == True (конец эпизода), то следующего состояния не существует. Без этого условия advantage будет считать будущие награды из следующего эпизода — полная ерунда.
Ошибка #2: Неправильный дисконт-фактор
Gamma = 0.99. Все так пишут. Все так делают. А потом агент в CartPole падает через 10 шагов. Почему?
| Задача | Рекомендованный gamma | Почему |
|---|---|---|
| CartPole (короткие эпизоды) | 0.95 - 0.98 | Эпизоды длятся 200-500 шагов. Большой discount заставляет агента думать о далёком будущем, которого нет |
| Atari игры | 0.99 - 0.999 | Длинные эпизоды, нужно планировать на сотни шагов вперёд |
| Роботика (continuous control) | 0.995 - 0.999 | Очень длинные эпизоды, плавные движения требуют долгосрочного планирования |
Gamma — это не гиперпараметр, который можно просто скопировать из статьи. Это выражение того, насколько далеко в будущее смотрит агент. В CartPole, если шест упал на шаге 50, то шаги 1-49 уже не важны — эпизод завершён. Зачем тогда gamma=0.99?
3 Обновление весов (где теряют градиенты)
Самый болезненный момент. Вы всё сделали правильно: собрали траектории, вычислили advantages. А потом в один backward() call обновляете и актора, и критика. И получаете... ничего. Нулевые градиенты. Или расходящиеся loss.
def update(self, buffer, optimizer, clip_param=0.2, value_coef=0.5, entropy_coef=0.01):
"""PPO-style update с clipping"""
# Преобразуем в тензоры
states = torch.stack(buffer.states)
actions = torch.tensor(buffer.actions, dtype=torch.long)
old_log_probs = torch.stack(buffer.log_probs).detach()
returns = torch.tensor(buffer.returns, dtype=torch.float32)
advantages = torch.tensor(buffer.advantages, dtype=torch.float32)
# Нормализуем advantages (КРИТИЧЕСКИ ВАЖНО!)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Несколько эпох обновления
for _ in range(4): # PPO epochs
# Получаем новые предсказания
logits, values = self.forward(states)
new_log_probs = F.log_softmax(logits, dim=-1)
new_log_probs = new_log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
# Entropy для исследования
probs = F.softmax(logits, dim=-1)
entropy = -(probs * new_log_probs).sum(-1).mean()
# Ratio для PPO
ratio = torch.exp(new_log_probs - old_log_probs)
# Clipped surrogate loss для актора
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_param, 1 + clip_param) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
# Value loss для критика
value_loss = F.mse_loss(values.squeeze(-1), returns)
# Общий loss
loss = actor_loss + value_coef * value_loss - entropy_coef * entropy
# Backward
optimizer.zero_grad()
loss.backward()
# Gradient clipping (ещё одна важная штука!)
torch.nn.utils.clip_grad_norm_(self.parameters(), 0.5)
optimizer.step()
Обратите внимание на три ключевых момента:
- Нормализация advantages — без этого градиенты будут нестабильными
- Clipping градиентов — защита от взрывных обновлений
- Entropy term — чтобы агент не застревал в одном действии
Ошибка #3: Reward engineering (или "почему агент стоит на месте")
Самая коварная проблема. Ваш код идеален, архитектура правильная, гиперпараметры подобраны. Агент запускается... и ничего не делает. Стоит на месте. Получает нулевые награды, но зато не получает отрицательных.
Это проблема sparse rewards. Агент не получает сигнала, пока не сделает что-то полезное. А как сделать что-то полезное, если он не знает, что это?
Неправильно: Давать +1 за выживание на каждом шаге. Агент научится избегать любых действий, чтобы не умереть.
Правильно: Использовать shaping rewards — маленькие награды за движение в правильном направлении.
# Пример для MountainCar
class MountainCarReward:
def __init__(self):
self.max_position = -0.4 # Стартовая позиция
def compute_reward(self, position, velocity, done):
"""Custom reward shaping"""
if done:
return 100.0 # Большая награда за успех
# Награда за движение вправо (к цели)
position_reward = max(0, position - self.max_position) * 10
# Обновляем максимальную достигнутую позицию
self.max_position = max(self.max_position, position)
# Маленькая награда за скорость (чтобы не стоял на месте)
velocity_reward = abs(velocity) * 0.1
# Штраф за действие (чтобы не дёргался без нужды)
action_penalty = -0.01
return position_reward + velocity_reward + action_penalty
Shaping rewards — это искусство. Слишком большие shaping rewards — агент будет оптимизировать их вместо реальной цели. Слишком маленькие — не сработают. Нужно найти баланс.
Реальные баги, которые я находил в продакшн-коде
Не все ошибки очевидны. Некоторые проявляются только через несколько тысяч итераций.
Баг #1: Не сбрасывается скрытое состояние RNN
# БАГ
for episode in range(num_episodes):
state = env.reset()
hidden = None # НЕПРАВИЛЬНО: нужно сбрасывать каждый эпизод
while not done:
action, hidden = model(state, hidden)
# ...
# ФИКС
for episode in range(num_episodes):
state = env.reset()
hidden = model.init_hidden() # ПРАВИЛЬНО
while not done:
action, hidden = model(state, hidden)
# ...
Баг #2: Не детачится в нужных местах
# БАГ
advantages = compute_advantages(buffer)
loss = compute_loss(buffer, advantages) # advantages не detached!
# ФИКС
advantages = compute_advantages(buffer).detach() # ОТСЕКАЕМ ГРАДИЕНТЫ
loss = compute_loss(buffer, advantages)
Если не сделать .detach() у advantages, вы получите двойные градиенты через critic. Loss будет считать одно, а обновлять другое.
Отладка Actor-Critic: что смотреть в tensorboard
Если графики обучения выглядят странно, вот на что смотреть в первую очередь:
- Value loss: Должен уменьшаться, но не до нуля. Если упал до нуля — критика переобучился на шум
- Policy entropy: Должен медленно уменьшаться. Резкий спад — агент перестал исследовать
- Advantage mean/std: Mean около нуля, std не должна взрываться
- Gradient norms: Резкие скачки — нужен gradient clipping
Производительность: когда PyTorch тормозит без видимой причины
Вы написали идеальный код, но обучение идёт медленно. 10 шагов в секунду вместо 1000. Где узкое место?
# МЕДЛЕННО: Частые .to(device) вызовы
for state, action in zip(states, actions):
state_tensor = torch.FloatTensor(state).to(device) # БАГ!
action_tensor = torch.LongTensor([action]).to(device)
# ...
# БЫСТРО: Батчим всё сразу
states_tensor = torch.FloatTensor(np.array(states)).to(device)
actions_tensor = torch.LongTensor(actions).to(device)
# ...
Каждый вызов .to(device) — это синхронизация CPU-GPU. Делайте это один раз для всего батча. Как и в случае с TraceML для даталоадеров, профилирование показывает неочевидные bottlenecks.
Чеклист перед запуском обучения
- Проверить, что advantages нормализованы (mean=0, std=1)
- Убедиться, что gradient clipping включён (значение 0.5-1.0)
- Проверить entropy coefficient (0.01 для начала)
- Убедиться, что learning rate не слишком высок (3e-4 для Adam)
- Проверить, что rewards не взрывные (можно нормализовать)
- Убедиться, что done flags правильно обрабатываются
- Проверить, что нет утечек памяти (особенно в replay buffer)
Что делать, если всё равно не работает
Бывает. Код правильный, гиперпараметры разумные, а агент тупит. Вот план действий:
- Упростите задачу: CartPole вместо Atari. Если не работает на CartPole — проблема в коде
- Визуализируйте политику: Запишите, какие действия выбирает агент. Может, он просто дёргается туда-сюда?
- Проверьте advantages: Они должны коррелировать с качеством действий. Если нет — проблема в critic
- Постепенно усложняйте: Добавляйте сложность только когда простые задачи решены
И последнее: Actor-Critic — не silver bullet. Для некоторых задач проще работает PPO или SAC. Но понимание Actor-Critic даёт базу для всех policy gradient методов. Как и в случае с продакшен AI-агентами, начинать нужно с основ, а потом уже добавлять сложность.
Мой главный совет: не копируйте код слепо из статей (даже из этой). Пишите с нуля, дебажьте каждый шаг. Только так вы поймёте, что на самом деле происходит внутри. И тогда Actor-Critic из чёрного ящика превратится в инструмент, который вы действительно контролируете.