# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Implementation based on Denis Yarats' implementation of [SAC](https://github.com/denisyarats/pytorch_sac).
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
import hydra
import numpy as np
import torch
import torch.nn.functional as F
from mtrl.agent import utils as agent_utils
from mtrl.agent.abstract import Agent as AbstractAgent
from mtrl.agent.ds.mt_obs import MTObs
from mtrl.agent.ds.task_info import TaskInfo
from mtrl.env.types import ObsType
from mtrl.logger import Logger
from mtrl.replay_buffer import ReplayBuffer, ReplayBufferSample
from mtrl.utils.types import ConfigType, ModelType, ParameterType, TensorType
[docs]class Agent(AbstractAgent):
"""SAC algorithm."""
def __init__(
self,
env_obs_shape: List[int],
action_shape: List[int],
action_range: Tuple[int, int],
device: torch.device,
actor_cfg: ConfigType,
critic_cfg: ConfigType,
alpha_optimizer_cfg: ConfigType,
actor_optimizer_cfg: ConfigType,
critic_optimizer_cfg: ConfigType,
multitask_cfg: ConfigType,
discount: float,
init_temperature: float,
actor_update_freq: int,
critic_tau: float,
critic_target_update_freq: int,
encoder_tau: float,
loss_reduction: str = "mean",
cfg_to_load_model: Optional[ConfigType] = None,
should_complete_init: bool = True,
):
super().__init__(
env_obs_shape=env_obs_shape,
action_shape=action_shape,
action_range=action_range,
multitask_cfg=multitask_cfg,
device=device,
)
self.should_use_task_encoder = self.multitask_cfg.should_use_task_encoder
self.discount = discount
self.critic_tau = critic_tau
self.encoder_tau = encoder_tau
self.actor_update_freq = actor_update_freq
self.critic_target_update_freq = critic_target_update_freq
self.actor = hydra.utils.instantiate(
actor_cfg, env_obs_shape=env_obs_shape, action_shape=action_shape
).to(self.device)
self.critic = hydra.utils.instantiate(
critic_cfg, env_obs_shape=env_obs_shape, action_shape=action_shape
).to(self.device)
self.critic_target = hydra.utils.instantiate(
critic_cfg, env_obs_shape=env_obs_shape, action_shape=action_shape
).to(self.device)
self.log_alpha = torch.nn.Parameter(
torch.tensor(
[
np.log(init_temperature, dtype=np.float32)
for _ in range(self.num_envs)
]
).to(self.device)
)
# self.log_alpha.requires_grad = True
# set target entropy to -|A|
self.target_entropy = -np.prod(action_shape)
self._components = {
"actor": self.actor,
"critic": self.critic,
"critic_target": self.critic_target,
"log_alpha": self.log_alpha, # type: ignore[dict-item]
}
# optimizers
self.actor_optimizer = hydra.utils.instantiate(
actor_optimizer_cfg, params=self.get_parameters(name="actor")
)
self.critic_optimizer = hydra.utils.instantiate(
critic_optimizer_cfg, params=self.get_parameters(name="critic")
)
self.log_alpha_optimizer = hydra.utils.instantiate(
alpha_optimizer_cfg, params=self.get_parameters(name="log_alpha")
)
if loss_reduction not in ["mean", "none"]:
raise ValueError(
f"{loss_reduction} is not a supported value for `loss_reduction`."
)
self.loss_reduction = loss_reduction
self._optimizers = {
"actor": self.actor_optimizer,
"critic": self.critic_optimizer,
"log_alpha": self.log_alpha_optimizer,
}
if self.should_use_task_encoder:
self.task_encoder = hydra.utils.instantiate(
self.multitask_cfg.task_encoder_cfg.model_cfg,
).to(self.device)
name = "task_encoder"
self._components[name] = self.task_encoder
self.task_encoder_optimizer = hydra.utils.instantiate(
self.multitask_cfg.task_encoder_cfg.optimizer_cfg,
params=self.get_parameters(name=name),
)
self._optimizers[name] = self.task_encoder_optimizer
if should_complete_init:
self.complete_init(cfg_to_load_model=cfg_to_load_model)
[docs] def complete_init(self, cfg_to_load_model: Optional[ConfigType]):
if cfg_to_load_model:
self.load(**cfg_to_load_model)
self.critic_target.load_state_dict(self.critic.state_dict())
# tie encoders between actor and critic
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.train()
[docs] def train(self, training: bool = True) -> None:
self.training = training
for name, component in self._components.items():
if name != "log_alpha":
component.train(training)
[docs] def get_alpha(self, env_index: TensorType) -> TensorType:
"""Get the alpha value for the given environments.
Args:
env_index (TensorType): environment index.
Returns:
TensorType: alpha values.
"""
if self.multitask_cfg.should_use_disentangled_alpha:
return self.log_alpha[env_index].exp()
else:
return self.log_alpha[0].exp()
[docs] def get_task_encoding(
self, env_index: TensorType, modes: List[str], disable_grad: bool
) -> TensorType:
"""Get the task encoding for the different environments.
Args:
env_index (TensorType): environment index.
modes (List[str]):
disable_grad (bool): should disable tracking gradient.
Returns:
TensorType: task encodings.
"""
if disable_grad:
with torch.no_grad():
return self.task_encoder(env_index.to(self.device))
return self.task_encoder(env_index.to(self.device))
[docs] def act(
self,
multitask_obs: ObsType,
# obs, env_index: TensorType,
modes: List[str],
sample: bool,
) -> np.ndarray:
"""Select/sample the action to perform.
Args:
multitask_obs (ObsType): Observation from the multitask environment.
mode (List[str]): mode in which to select the action.
sample (bool): sample (if `True`) or select (if `False`) an action.
Returns:
np.ndarray: selected/sample action.
"""
env_obs = multitask_obs["env_obs"]
env_index = multitask_obs["task_obs"]
env_index = env_index.to(self.device, non_blocking=True)
with torch.no_grad():
if self.should_use_task_encoder:
task_encoding = self.get_task_encoding(
env_index=env_index, modes=modes, disable_grad=True
)
else:
task_encoding = None # type: ignore[assignment]
task_info = self.get_task_info(
task_encoding=task_encoding, component_name="", env_index=env_index
)
obs = env_obs.float().to(self.device)
if len(obs.shape) == 1 or len(obs.shape) == 3:
obs = obs.unsqueeze(0) # Make a batch
mtobs = MTObs(env_obs=obs, task_obs=env_index, task_info=task_info)
mu, pi, _, _ = self.actor(mtobs=mtobs)
if sample:
action = pi
else:
action = mu
action = action.clamp(*self.action_range)
# assert action.ndim == 2 and action.shape[0] == 1
return action.detach().cpu().numpy()
[docs] def select_action(self, multitask_obs: ObsType, modes: List[str]) -> np.ndarray:
return self.act(multitask_obs=multitask_obs, modes=modes, sample=False)
[docs] def sample_action(self, multitask_obs: ObsType, modes: List[str]) -> np.ndarray:
return self.act(multitask_obs=multitask_obs, modes=modes, sample=True)
[docs] def get_last_shared_layers(self, component_name: str) -> Optional[List[ModelType]]: # type: ignore[return]
if component_name in [
"actor",
"critic",
"transition_model",
"reward_decoder",
"decoder",
]:
return self._components[component_name].get_last_shared_layers() # type: ignore[operator]
# The mypy error is because self._components can contain a tensor as well.
if component_name in ["log_alpha", "encoder", "task_encoder"]:
return None
if component_name not in self._components:
raise ValueError(f"""Component named {component_name} does not exist""")
def _compute_gradient(
self,
loss: TensorType,
parameters: List[ParameterType],
step: int,
component_names: List[str],
retain_graph: bool = False,
):
"""Method to override the gradient computation.
Useful for algorithms like PCGrad and GradNorm.
Args:
loss (TensorType):
parameters (List[ParameterType]):
step (int): step for tracking the training of the agent.
component_names (List[str]):
retain_graph (bool, optional): if it should retain graph. Defaults to False.
"""
loss.backward(retain_graph=retain_graph)
def _get_target_V(
self, batch: ReplayBufferSample, task_info: TaskInfo
) -> TensorType:
"""Compute the target values.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
Returns:
TensorType: target values.
"""
mtobs = MTObs(env_obs=batch.next_env_obs, task_obs=None, task_info=task_info)
_, policy_action, log_pi, _ = self.actor(mtobs=mtobs)
target_Q1, target_Q2 = self.critic_target(mtobs=mtobs, action=policy_action)
return (
torch.min(target_Q1, target_Q2)
- self.get_alpha(env_index=batch.task_obs).detach() * log_pi
)
[docs] def update_critic(
self,
batch: ReplayBufferSample,
task_info: TaskInfo,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Dict[str, Any],
) -> None:
"""Update the critic component.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
logger ([Logger]): logger object.
step (int): step for tracking the training of the agent.
kwargs_to_compute_gradient (Dict[str, Any]):
"""
with torch.no_grad():
target_V = self._get_target_V(batch=batch, task_info=task_info)
target_Q = batch.reward + (batch.not_done * self.discount * target_V)
# get current Q estimates
mtobs = MTObs(env_obs=batch.env_obs, task_obs=None, task_info=task_info)
current_Q1, current_Q2 = self.critic(
mtobs=mtobs,
action=batch.action,
detach_encoder=False,
)
critic_loss = F.mse_loss(
current_Q1, target_Q, reduction=self.loss_reduction
) + F.mse_loss(current_Q2, target_Q, reduction=self.loss_reduction)
loss_to_log = critic_loss
if self.loss_reduction == "none":
loss_to_log = loss_to_log.mean()
logger.log("train/critic_loss", loss_to_log, step)
if loss_to_log > 1e8:
raise RuntimeError(
f"critic_loss = {loss_to_log} is too high. Stopping training."
)
component_names = ["critic"]
parameters: List[ParameterType] = []
for name in component_names:
self._optimizers[name].zero_grad()
parameters += self.get_parameters(name)
if task_info.compute_grad:
component_names.append("task_encoder")
kwargs_to_compute_gradient["retain_graph"] = True
parameters += self.get_parameters("task_encoder")
self._compute_gradient(
loss=critic_loss,
parameters=parameters,
step=step,
component_names=component_names,
**kwargs_to_compute_gradient,
)
# Optimize the critic
self.critic_optimizer.step()
[docs] def update_actor_and_alpha(
self,
batch: ReplayBufferSample,
task_info: TaskInfo,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Dict[str, Any],
) -> None:
"""Update the actor and alpha component.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
logger ([Logger]): logger object.
step (int): step for tracking the training of the agent.
kwargs_to_compute_gradient (Dict[str, Any]):
"""
# detach encoder, so we don't update it with the actor loss
mtobs = MTObs(
env_obs=batch.env_obs,
task_obs=None,
task_info=task_info,
)
_, pi, log_pi, log_std = self.actor(mtobs=mtobs, detach_encoder=True)
actor_Q1, actor_Q2 = self.critic(mtobs=mtobs, action=pi, detach_encoder=True)
actor_Q = torch.min(actor_Q1, actor_Q2)
if self.loss_reduction == "mean":
actor_loss = (
self.get_alpha(batch.task_obs).detach() * log_pi - actor_Q
).mean()
logger.log("train/actor_loss", actor_loss, step)
elif self.loss_reduction == "none":
actor_loss = self.get_alpha(batch.task_obs).detach() * log_pi - actor_Q
logger.log("train/actor_loss", actor_loss.mean(), step)
logger.log("train/actor_target_entropy", self.target_entropy, step)
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(
dim=-1
)
logger.log("train/actor_entropy", entropy.mean(), step)
# optimize the actor
component_names = ["actor"]
parameters: List[ParameterType] = []
for name in component_names:
self._optimizers[name].zero_grad()
parameters += self.get_parameters(name)
if task_info.compute_grad:
component_names.append("task_encoder")
kwargs_to_compute_gradient["retain_graph"] = True
parameters += self.get_parameters("task_encoder")
self._compute_gradient(
loss=actor_loss,
parameters=parameters,
step=step,
component_names=component_names,
**kwargs_to_compute_gradient,
)
self.actor_optimizer.step()
self.log_alpha_optimizer.zero_grad()
if self.loss_reduction == "mean":
alpha_loss = (
self.get_alpha(batch.task_obs)
* (-log_pi - self.target_entropy).detach()
).mean()
logger.log("train/alpha_loss", alpha_loss, step)
elif self.loss_reduction == "none":
alpha_loss = (
self.get_alpha(batch.task_obs)
* (-log_pi - self.target_entropy).detach()
)
logger.log("train/alpha_loss", alpha_loss.mean(), step)
# breakpoint()
# logger.log("train/alpha_value", self.get_alpha(batch.task_obs), step)
self._compute_gradient(
loss=alpha_loss,
parameters=self.get_parameters(name="log_alpha"),
step=step,
component_names=["log_alpha"],
**kwargs_to_compute_gradient,
)
self.log_alpha_optimizer.step()
[docs] def get_task_info(
self, task_encoding: TensorType, component_name: str, env_index: TensorType
) -> TaskInfo:
"""Encode task encoding into task info.
Args:
task_encoding (TensorType): encoding of the task.
component_name (str): name of the component.
env_index (TensorType): index of the environment.
Returns:
TaskInfo: TaskInfo object.
"""
if self.should_use_task_encoder:
if component_name in self.multitask_cfg.task_encoder_cfg.losses_to_train:
task_info = TaskInfo(
encoding=task_encoding, compute_grad=True, env_index=env_index
)
else:
task_info = TaskInfo(
encoding=task_encoding.detach(),
compute_grad=False,
env_index=env_index,
)
else:
task_info = TaskInfo(
encoding=task_encoding, compute_grad=False, env_index=env_index
)
return task_info
[docs] def update_transition_reward_model(
self,
batch: ReplayBufferSample,
task_info: TaskInfo,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Dict[str, Any],
) -> None:
"""Update the transition model and reward decoder.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
logger ([Logger]): logger object.
step (int): step for tracking the training of the agent.
kwargs_to_compute_gradient (Dict[str, Any]):
"""
raise NotImplementedError("This method is not implemented for SAC agent.")
[docs] def update_task_encoder(
self,
batch: ReplayBufferSample,
task_info: TaskInfo,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Dict[str, Any],
) -> None:
"""Update the task encoder component.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
logger ([Logger]): logger object.
step (int): step for tracking the training of the agent.
kwargs_to_compute_gradient (Dict[str, Any]):
"""
self.task_encoder_optimizer.step()
[docs] def update_decoder(
self,
batch: ReplayBufferSample,
task_info: TaskInfo,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Dict[str, Any],
) -> None:
"""Update the decoder component.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
logger ([Logger]): logger object.
step (int): step for tracking the training of the agent.
kwargs_to_compute_gradient (Dict[str, Any]):
"""
raise NotImplementedError("This method is not implemented for SAC agent.")
[docs] def update(
self,
replay_buffer: ReplayBuffer,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Optional[Dict[str, Any]] = None,
buffer_index_to_sample: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Update the agent.
Args:
replay_buffer (ReplayBuffer): replay buffer to sample the data.
logger (Logger): logger for logging.
step (int): step for tracking the training progress.
kwargs_to_compute_gradient (Optional[Dict[str, Any]], optional): Defaults
to None.
buffer_index_to_sample (Optional[np.ndarray], optional): if this parameter
is specified, use these indices instead of sampling from the replay
buffer. If this is set to `None`, sample from the replay buffer.
buffer_index_to_sample Defaults to None.
Returns:
np.ndarray: index sampled (from the replay buffer) to train the model. If
buffer_index_to_sample is not set to None, return buffer_index_to_sample.
"""
if kwargs_to_compute_gradient is None:
kwargs_to_compute_gradient = {}
if buffer_index_to_sample is None:
batch = replay_buffer.sample()
else:
batch = replay_buffer.sample(buffer_index_to_sample)
logger.log("train/batch_reward", batch.reward.mean(), step)
if self.should_use_task_encoder:
self.task_encoder_optimizer.zero_grad()
task_encoding = self.get_task_encoding(
env_index=batch.task_obs.squeeze(1),
disable_grad=False,
modes=["train"],
)
else:
task_encoding = None # type: ignore[assignment]
task_info = self.get_task_info(
task_encoding=task_encoding,
component_name="critic",
env_index=batch.task_obs,
)
self.update_critic(
batch=batch,
task_info=task_info,
logger=logger,
step=step,
kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient),
)
if step % self.actor_update_freq == 0:
task_info = self.get_task_info(
task_encoding=task_encoding,
component_name="actor",
env_index=batch.task_obs,
)
self.update_actor_and_alpha(
batch=batch,
task_info=task_info,
logger=logger,
step=step,
kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient),
)
if step % self.critic_target_update_freq == 0:
agent_utils.soft_update_params(
self.critic.Q1, self.critic_target.Q1, self.critic_tau
)
agent_utils.soft_update_params(
self.critic.Q2, self.critic_target.Q2, self.critic_tau
)
agent_utils.soft_update_params(
self.critic.encoder, self.critic_target.encoder, self.encoder_tau
)
if (
"transition_model" in self._components
and "reward_decoder" in self._components
):
# some of the logic is a bit sketchy here. We will get to it soon.
task_info = self.get_task_info(
task_encoding=task_encoding,
component_name="transition_reward",
env_index=batch.task_obs,
)
self.update_transition_reward_model(
batch=batch,
task_info=task_info,
logger=logger,
step=step,
kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient),
)
if (
"decoder" in self._components # should_update_decoder
and self.decoder is not None # type: ignore[attr-defined]
and step % self.decoder_update_freq == 0 # type: ignore[attr-defined]
):
task_info = self.get_task_info(
task_encoding=task_encoding,
component_name="decoder",
env_index=batch.task_obs,
)
self.update_decoder(
batch=batch,
task_info=task_info,
logger=logger,
step=step,
kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient),
)
if self.should_use_task_encoder:
task_info = self.get_task_info(
task_encoding=task_encoding,
component_name="task_encoder",
env_index=batch.task_obs,
)
self.update_task_encoder(
batch=batch,
task_info=task_info,
logger=logger,
step=step,
kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient),
)
return batch.buffer_index
[docs] def get_parameters(self, name: str) -> List[torch.nn.parameter.Parameter]:
"""Get parameters corresponding to a given component.
Args:
name (str): name of the component.
Returns:
List[torch.nn.parameter.Parameter]: list of parameters.
"""
if name == "actor":
return list(self.actor.model.parameters())
elif name in ["log_alpha", "alpha"]:
return [self.log_alpha]
elif name == "encoder":
return list(self.critic.encoder.parameters())
else:
return list(self._components[name].parameters())