# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf as OC
from mtrl.agent import abstract, wrapper
from mtrl.agent.ds.mt_obs import MTObs
from mtrl.agent.ds.task_info import NoneTaskInfo, TaskInfo
from mtrl.env.types import ObsType
from mtrl.logger import Logger
from mtrl.replay_buffer import ReplayBuffer, ReplayBufferSample
from mtrl.utils.types import (
ComponentType,
ConfigType,
ModelType,
OptimizerType,
ParameterType,
TensorType,
)
ComponentOrOptimizerType = Union[ComponentType, OptimizerType]
[docs]def gaussian_kld(
mean1: TensorType, logvar1: TensorType, mean2: TensorType, logvar2: TensorType
) -> TensorType:
"""Compute KL divergence between a bunch of univariate Gaussian
distributions with the given means and log-variances.
ie `KL(N(mean1, logvar1) || N(mean2, logvar2))`
Args:
mean1 (TensorType):
logvar1 (TensorType):
mean2 (TensorType):
logvar2 (TensorType):
Returns:
TensorType: [description]
"""
gauss_klds = 0.5 * (
(logvar2 - logvar1)
+ ((torch.exp(logvar1) + (mean1 - mean2) ** 2.0) / torch.exp(logvar2))
- 1.0
)
assert len(gauss_klds.size()) == 2
return gauss_klds
[docs]class Agent(abstract.Agent):
def __init__(
self,
env_obs_shape: List[int],
action_shape: List[int],
action_range: Tuple[int, int],
multitask_cfg: ConfigType,
device: torch.device,
distral_alpha: float,
distral_beta: float,
agent_index_to_task_index: List[str],
distilled_agent_cfg: ConfigType,
task_agent_cfg: ConfigType,
cfg_to_load_model: Optional[ConfigType] = None,
should_complete_init: bool = True,
):
"""Distral algorithm."""
super().__init__(
env_obs_shape=env_obs_shape,
action_shape=action_shape,
action_range=action_range,
multitask_cfg=multitask_cfg,
device=device,
)
self.distral_alpha = distral_alpha
self.distral_beta = distral_beta
self.agent_index_to_task_index = agent_index_to_task_index
# eventually, this will be done via OC.
self.num_task_agents = len(self.agent_index_to_task_index)
self.task_index_to_agent_index = {
task_index: agent_index
for agent_index, task_index in enumerate(self.agent_index_to_task_index)
}
self.distilled_agent = hydra.utils.instantiate(
distilled_agent_cfg,
env_obs_shape=self.env_obs_shape,
action_shape=self.action_shape,
action_range=action_range,
multitask_cfg=OC.create({"num_envs": 1}),
device=self.device,
cfg_to_load_model=cfg_to_load_model,
should_complete_init=True,
)
self.task_agents = [
hydra.utils.instantiate(
task_agent_cfg,
env_obs_shape=self.env_obs_shape,
action_shape=self.action_shape,
action_range=action_range,
multitask_cfg=OC.create({"num_envs": 1}),
device=self.device,
index=agent_index,
env_index=task_index,
distilled_agent=self.distilled_agent,
cfg_to_load_model=cfg_to_load_model,
should_complete_init=True,
)
for agent_index, task_index in enumerate(self.agent_index_to_task_index)
]
if should_complete_init:
self.complete_init(cfg_to_load_model=cfg_to_load_model)
[docs] def complete_init(self, cfg_to_load_model: Optional[ConfigType]) -> None:
self.train()
[docs] def train(self, training: bool = True) -> None:
self.training = training
self.distilled_agent.train(training)
[_task_agent.train(training) for _task_agent in self.task_agents]
[docs] def select_action(self, multitask_obs: ObsType, modes: List[str]) -> np.ndarray:
"""Used during testing"""
return self.distilled_agent.select_action(
multitask_obs=multitask_obs, modes=modes
)
[docs] def sample_action(self, multitask_obs: ObsType, modes: List[str]) -> np.ndarray:
"""Used during training"""
obs = multitask_obs["env_obs"]
env_index = multitask_obs["task_obs"]
actions = [
self.task_agents[self.task_index_to_agent_index[index]].sample_action(
multitask_obs={
"env_obs": obs[self.task_index_to_agent_index[index]],
"task_obs": torch.LongTensor(
[
[index],
]
), # not used in the actor.
},
modes=modes,
)
for index in env_index.numpy()
]
actions = np.concatenate(actions, axis=0)
return actions
[docs] def update(
self,
replay_buffer: ReplayBuffer,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Optional[Dict[str, Any]] = None,
buffer_index_to_sample: Optional[np.ndarray] = None,
) -> np.ndarray:
for _task_agent in self.task_agents:
_task_agent.update(
replay_buffer=replay_buffer,
logger=logger,
step=step,
)
[docs] def save(
self,
model_dir: str,
step: int,
retain_last_n: int,
should_save_metadata: bool = True,
) -> None:
self.distilled_agent.save(
model_dir=model_dir,
step=step,
retain_last_n=retain_last_n,
should_save_metadata=False,
)
for agent in self.task_agents:
agent.save(
model_dir=model_dir,
step=step,
retain_last_n=retain_last_n,
should_save_metadata=False,
)
if should_save_metadata:
self.save_metadata(model_dir, step)
[docs] def load(self, model_dir: Optional[str], step: Optional[int]) -> None:
self.distilled_agent.load(model_dir, step)
for agent in self.task_agents:
agent.load(model_dir, step)
[docs] def load_latest_step(self, model_dir: str) -> int:
latest_step = -1
metadata = self.load_metadata(model_dir=model_dir)
if metadata is None:
return latest_step + 1
latest_step = metadata["step"]
self.distilled_agent.load(model_dir, step=latest_step)
for agent in self.task_agents:
agent.load(model_dir, step=latest_step)
return latest_step + 1
[docs]class DistilledAgent(abstract.Agent):
def __init__(
self,
env_obs_shape: List[int],
action_shape: List[int],
action_range: Tuple[int, int],
multitask_cfg: ConfigType,
device: torch.device,
actor_cfg: ConfigType,
actor_optimizer_cfg: ConfigType,
cfg_to_load_model: Optional[ConfigType] = None,
should_complete_init: bool = True,
):
"""Centroid policy for distral"""
super().__init__(
env_obs_shape=env_obs_shape,
action_shape=action_shape,
action_range=action_range,
multitask_cfg=multitask_cfg,
# num_envs=1,
device=device,
)
self._name = "distilled_agent"
self.actor: ModelType = hydra.utils.instantiate(
actor_cfg, env_obs_shape=env_obs_shape, action_shape=action_shape
).to(self.device)
self._components: Dict[str, ModelType] = {
"actor": self.actor,
}
self.actor_optimizer = hydra.utils.instantiate(
actor_optimizer_cfg, self.actor.parameters()
)
self._optimizers: Dict[str, OptimizerType] = {
"actor": self.actor_optimizer,
}
if should_complete_init:
self.complete_init(cfg_to_load_model=cfg_to_load_model)
[docs] def train(self, training=True) -> None:
self.training = training
self.actor.train(training)
[docs] def complete_init(self, cfg_to_load_model: Optional[ConfigType]) -> None:
self.train()
[docs] def select_action(self, multitask_obs: ObsType, modes: List[str]):
with torch.no_grad():
env_obs = multitask_obs["env_obs"].float().to(self.device)
if len(env_obs.shape) == 3:
env_obs = env_obs.unsqueeze(0) # Make a batch
mtobs = MTObs(env_obs=env_obs, task_obs=None, task_info=NoneTaskInfo)
mu, _, _, _ = self.actor(mtobs=mtobs)
return mu.cpu().numpy()
[docs] def sample_action(self, multitask_obs: ObsType, modes: List[str]):
with torch.no_grad():
env_obs = multitask_obs["env_obs"].float().to(self.device)
if len(env_obs.shape) == 3:
env_obs = env_obs.unsqueeze(0) # Make a batch
mtobs = MTObs(env_obs=env_obs, task_obs=None, task_info=NoneTaskInfo)
mu, pi, _, _ = self.actor(mtobs=mtobs, compute_log_pi=False)
return pi.cpu().data.numpy().flatten()
[docs] def save(
self,
model_dir: str,
step: int,
retain_last_n: int,
should_save_metadata: bool = True,
) -> None:
return super().save(
model_dir=os.path.join(model_dir, self._name),
step=step,
retain_last_n=retain_last_n,
should_save_metadata=should_save_metadata,
)
[docs] def load(self, model_dir: Optional[str], step: Optional[int]) -> None:
if model_dir is not None:
return super().load(
model_dir=os.path.join(model_dir, self._name), step=step
)
return
[docs] def load_latest_step(self, model_dir: str) -> int:
latest_step = -1
metadata = self.load_metadata(model_dir=model_dir)
if metadata is None:
return latest_step + 1
latest_step = metadata["step"]
self.load(model_dir=os.path.join(model_dir, self._name), step=latest_step)
return latest_step + 1
[docs] def update(
self,
replay_buffer: ReplayBuffer,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Optional[Dict[str, Any]] = None,
buffer_index_to_sample: Optional[np.ndarray] = None,
):
raise NotImplementedError(
"`update` method is not implemented for distral algorithm."
)
[docs]class TaskAgent(wrapper.Agent):
def __init__(
self,
env_obs_shape: List[int],
action_shape: List[int],
action_range: Tuple[int, int],
multitask_cfg: ConfigType,
device: torch.device,
agent_cfg: ConfigType,
index: int,
env_index: int,
distral_alpha: float,
distral_beta: float,
distilled_agent: DistilledAgent,
cfg_to_load_model: Optional[ConfigType] = None,
should_complete_init: bool = True,
):
"""Wrapper class for the task specific agent"""
super().__init__(
env_obs_shape=env_obs_shape,
action_shape=action_shape,
action_range=action_range,
multitask_cfg=multitask_cfg,
agent_cfg=agent_cfg,
device=device,
cfg_to_load_model=cfg_to_load_model,
should_complete_init=should_complete_init,
)
self.index = index
self.env_index = env_index
self.distral_alpha = distral_alpha
self.distral_beta = distral_beta
self._name = f"task_agent_{self.index}"
self.patch_agent()
self.distilled_agent = distilled_agent
if should_complete_init:
self.complete_init(cfg_to_load_model=cfg_to_load_model)
[docs] def patch_agent(self) -> None:
"""Change some function definitions at runtime."""
self.agent.update_actor_and_alpha = self.update_actor_and_alpha
self.agent._get_target_V = self._get_target_V
def _get_target_V(
self, batch: ReplayBufferSample, task_info: TaskInfo
) -> TensorType:
"""Compute the target values.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
Returns:
TensorType: target values.
"""
mtobs = MTObs(env_obs=batch.next_env_obs, task_obs=None, task_info=task_info)
_, policy_action, log_pi, _ = self.agent.actor(mtobs=mtobs)
_, _, distral_log_pi, _ = self.distilled_agent.actor(mtobs=mtobs)
target_Q1, target_Q2 = self.agent.critic_target(
mtobs=mtobs, action=policy_action
)
agent_alpha = self.agent.get_alpha(batch.task_obs).detach()
alpha_from_paper = self.distral_alpha / (self.distral_alpha + agent_alpha)
beta_from_paper = 1.0 / (self.distral_alpha + agent_alpha)
return (
torch.min(target_Q1, target_Q2)
+ (alpha_from_paper * distral_log_pi - log_pi) / beta_from_paper
)
[docs] def update_actor_and_alpha(
self,
batch: ReplayBufferSample,
task_info: TaskInfo,
logger: Logger,
step: int,
kwargs_to_compute_gradient: Dict[str, Any],
) -> None:
"""Update the actor and alpha component.
Args:
batch (ReplayBufferSample): batch from the replay buffer.
task_info (TaskInfo): task_info object.
logger ([Logger]): logger object.
step (int): step for tracking the training of the agent.
kwargs_to_compute_gradient (Dict[str, Any]):
"""
# detach encoder, so we don't update it with the actor loss
suffix = f"_agent_index_{self.index}"
mtobs = MTObs(env_obs=batch.env_obs, task_obs=None, task_info=task_info)
mu, pi, log_pi, log_std = self.agent.actor(mtobs=mtobs, detach_encoder=True)
actor_Q1, actor_Q2 = self.agent.critic(
mtobs=mtobs, action=pi, detach_encoder=True
)
actor_Q = torch.min(actor_Q1, actor_Q2)
if self.agent.loss_reduction == "mean":
actor_loss = (
self.agent.get_alpha(batch.task_obs).detach() * log_pi - actor_Q
).mean()
logger.log(f"train/actor_loss{suffix}", actor_loss, step)
elif self.agent.loss_reduction == "none":
actor_loss = (
self.agent.get_alpha(batch.task_obs).detach() * log_pi - actor_Q
)
logger.log(f"train/actor_loss{suffix}", actor_loss.mean(), step)
logger.log(
f"train/actor_target_entropy{suffix}", self.agent.target_entropy, step
)
entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(
dim=-1
)
logger.log(f"train/actor_entropy{suffix}", entropy.mean(), step)
mtobs = MTObs(env_obs=batch.env_obs, task_obs=None, task_info=NoneTaskInfo)
distral_mu, _, distral_log_pi, distral_log_std = self.distilled_agent.actor(
mtobs=mtobs, detach_encoder=False
)
distilled_agent_loss = gaussian_kld(
mean1=distral_mu,
logvar1=2 * distral_log_std,
mean2=mu.detach(),
logvar2=2 * log_std.detach(),
)
batch_size = distilled_agent_loss.shape[0]
distilled_agent_loss = torch.sum(distilled_agent_loss) / batch_size
logger.log(
f"train/actor_distilled_agent_loss{suffix}",
distilled_agent_loss.mean(),
step,
)
distilled_agent_loss = distilled_agent_loss * self.distral_alpha
# optimize the actor
component_names = ["actor"]
parameters: List[ParameterType] = []
for name in component_names:
self.agent._optimizers[name].zero_grad()
parameters += self.agent.get_parameters(name)
if task_info.compute_grad:
component_names.append("task_encoder")
kwargs_to_compute_gradient["retain_graph"] = True
parameters += self.agent.get_parameters("task_encoder")
self.agent._compute_gradient(
loss=actor_loss,
parameters=parameters,
step=step,
component_names=component_names,
**kwargs_to_compute_gradient,
)
self.agent.actor_optimizer.step()
self.agent.log_alpha_optimizer.zero_grad()
if self.agent.loss_reduction == "mean":
alpha_loss = (
self.agent.get_alpha(batch.task_obs)
* (-log_pi - self.agent.target_entropy).detach()
).mean()
logger.log(f"train/alpha_loss{suffix}", alpha_loss, step)
elif self.agent.loss_reduction == "none":
alpha_loss = (
self.agent.get_alpha(batch.task_obs)
* (-log_pi - self.agent.target_entropy).detach()
)
logger.log(f"train/alpha_loss{suffix}", alpha_loss.mean(), step)
# logger.log("train/alpha_value", self.get_alpha(batch.task_obs), step)
self.agent._compute_gradient(
loss=alpha_loss,
parameters=self.agent.get_parameters(name="log_alpha"),
step=step,
component_names=["log_alpha"],
**kwargs_to_compute_gradient,
)
self.agent.log_alpha_optimizer.step()
self.distilled_agent._optimizers["actor"].zero_grad()
distilled_agent_loss.backward()
self.distilled_agent._optimizers["actor"].step()
[docs] def save(
self,
model_dir: str,
step: int,
retain_last_n: int,
should_save_metadata: bool = True,
) -> None:
return super().save(
model_dir=os.path.join(model_dir, self._name),
step=step,
retain_last_n=retain_last_n,
should_save_metadata=should_save_metadata,
)
[docs] def load(self, model_dir: Optional[str], step: Optional[int]) -> None:
if model_dir is None or step is None:
return
return super().load(model_dir=os.path.join(model_dir, self._name), step=step)
[docs] def load_latest_step(self, model_dir: str) -> int:
latest_step = -1
metadata = self.load_metadata(model_dir=model_dir)
if metadata is None:
return latest_step + 1
latest_step = metadata["step"]
super().load(model_dir=os.path.join(model_dir, self._name), step=latest_step)
return latest_step + 1