Source code for mtrl.agent.components.actor

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Implementation based on Denis Yarats' implementation of [SAC](https://github.com/denisyarats/pytorch_sac).
"""Actor component for the agent."""
from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from mtrl.agent import utils as agent_utils
from mtrl.agent.components import base as base_component
from mtrl.agent.components import encoder, moe_layer
from mtrl.agent.components.soft_modularization import SoftModularizedMLP
from mtrl.agent.ds.mt_obs import MTObs
from mtrl.agent.ds.task_info import TaskInfo
from mtrl.utils.types import ConfigType, ModelType, TensorType


[docs]def check_if_should_use_multi_head_policy(multitask_cfg: ConfigType) -> bool: if "should_use_multi_head_policy" in multitask_cfg: return multitask_cfg.should_use_multi_head_policy return False
[docs]def check_if_should_use_task_encoder(multitask_cfg: ConfigType) -> bool: if "should_use_task_encoder" in multitask_cfg: return multitask_cfg.should_use_task_encoder return False
def _gaussian_logprob(noise: TensorType, log_std: TensorType) -> TensorType: """Compute the gaussian log probability. Args: noise (TensorType): log_std (TensorType): [description] Returns: TensorType: log-probaility of the sample. """ residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) def _squash( mu: TensorType, pi: TensorType, log_pi: TensorType ) -> Tuple[TensorType, TensorType, TensorType]: """Apply squashing function. See appendix C from https://arxiv.org/pdf/1812.05905.pdf. Args: mu ([TensorType]): mean of the gaussian distribution. pi ([TensorType]): sample from the gaussian distribution. log_pi ([TensorType]): log probability. Returns: Tuple[TensorType, TensorType, TensorType]: tuple of (squashed mean of the gaussian, squashed sample from the gaussian, squashed log-probability of the sample) """ mu = torch.tanh(mu) if pi is not None: pi = torch.tanh(pi) if log_pi is not None: log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) return mu, pi, log_pi
[docs]class BaseActor(base_component.Component): def __init__( self, env_obs_shape: List[int], action_shape: List[int], encoder_cfg: ConfigType, multitask_cfg: ConfigType, *args, **kwargs, ): """Interface for the actor component for the agent. Args: env_obs_shape (List[int]): shape of the environment observation that the actor gets. action_shape (List[int]): shape of the action vector that the actor produces. encoder_cfg (ConfigType): config for the encoder. multitask_cfg (ConfigType): config for encoding the multitask knowledge. """ super().__init__() self.multitask_cfg = multitask_cfg
[docs] def encode(self, mtobs: MTObs, detach: bool = False) -> TensorType: """Encode the input observation. Args: mtobs (MTObs): multi-task observation. detach (bool, optional): should detach the observation encoding from the computation graph. Defaults to False. Raises: NotImplementedError: Returns: TensorType: encoding of the observation. """ raise NotImplementedError
[docs] def forward( self, mtobs: MTObs, detach_encoder: bool = False, ) -> Tuple[TensorType, TensorType, TensorType, TensorType]: """Compute the predictions from the actor. Args: mtobs (MTObs): multi-task observation. detach_encoder (bool, optional): should detach the observation encoding from the computation graph. Defaults to False. Raises: NotImplementedError: Returns: Tuple[TensorType, TensorType, TensorType, TensorType]: tuple of (mean of the gaussian, sample from the gaussian, log-probability of the sample, log of standard deviation of the gaussian). """ raise NotImplementedError
[docs]class Actor(BaseActor): def __init__( self, env_obs_shape: List[int], action_shape: List[int], hidden_dim: int, num_layers: int, log_std_bounds: Tuple[float, float], encoder_cfg: ConfigType, multitask_cfg: ConfigType, ): """Actor component for the agent. Args: env_obs_shape (List[int]): shape of the environment observation that the actor gets. action_shape (List[int]): shape of the action vector that the actor produces. hidden_dim (int): hidden dimensionality of the actor. num_layers (int): number of layers in the actor. log_std_bounds (Tuple[float, float]): bounds to clip log of standard deviation. encoder_cfg (ConfigType): config for the encoder. multitask_cfg (ConfigType): config for encoding the multitask knowledge. """ key = "type_to_select" if key in encoder_cfg: encoder_type_to_select = encoder_cfg[key] encoder_cfg = encoder_cfg[encoder_type_to_select] super().__init__( env_obs_shape=env_obs_shape, action_shape=action_shape, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) self.log_std_bounds = log_std_bounds self.should_use_multi_head_policy = check_if_should_use_multi_head_policy( multitask_cfg=multitask_cfg ) if self.should_use_multi_head_policy: task_index_to_mask = torch.eye(multitask_cfg.num_envs) self.moe_masks = moe_layer.MaskCache( task_index_to_mask=task_index_to_mask, **multitask_cfg.multi_head_policy_cfg.mask_cfg, ) if check_if_should_use_task_encoder(multitask_cfg): self.should_condition_model_on_task_info = False self.should_condition_encoder_on_task_info = True self.should_concatenate_task_info_with_encoder = True if "actor_cfg" in multitask_cfg and multitask_cfg.actor_cfg: self.should_condition_model_on_task_info = ( multitask_cfg.actor_cfg.should_condition_model_on_task_info ) self.should_condition_encoder_on_task_info = ( multitask_cfg.actor_cfg.should_condition_encoder_on_task_info ) self.should_concatenate_task_info_with_encoder = ( multitask_cfg.actor_cfg.should_concatenate_task_info_with_encoder ) else: self.should_condition_model_on_task_info = False self.should_condition_encoder_on_task_info = False self.should_concatenate_task_info_with_encoder = False self.encoder = self._make_encoder( env_obs_shape=env_obs_shape, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) self.model = self.make_model( action_shape=action_shape, hidden_dim=hidden_dim, num_layers=num_layers, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) self.apply(agent_utils.weight_init) def _make_encoder( self, env_obs_shape: List[int], encoder_cfg: ConfigType, multitask_cfg: ConfigType, ) -> encoder.Encoder: """Make the encoder. Args: env_obs_shape (List[int]): encoder_cfg (ConfigType): multitask_cfg (ConfigType): Returns: encoder.Encoder: encoder """ return encoder.make_encoder( env_obs_shape=env_obs_shape, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) def _make_head( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, multitask_cfg: ConfigType, ) -> ModelType: """Make the head of the actor. Args: input_dim (int): hidden_dim (int): output_dim (int): num_layers (int): multitask_cfg (ConfigType): Returns: ModelType: head """ return moe_layer.FeedForward( num_experts=multitask_cfg.num_envs, in_features=input_dim, out_features=output_dim, hidden_features=hidden_dim, num_layers=num_layers, bias=True, ) def _make_trunk( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, multitask_cfg: ConfigType, ) -> ModelType: if ( "actor_cfg" in multitask_cfg and multitask_cfg.actor_cfg and "moe_cfg" in multitask_cfg.actor_cfg and multitask_cfg.actor_cfg.moe_cfg.should_use ): moe_cfg = multitask_cfg.actor_cfg.moe_cfg if moe_cfg.mode == "soft_modularization": trunk = SoftModularizedMLP( num_experts=moe_cfg.num_experts, in_features=input_dim, out_features=output_dim, num_layers=2, hidden_features=hidden_dim, bias=True, ) else: raise NotImplementedError( f"""`moe_cfg.mode` = {moe_cfg.mode} is not implemented.""" ) else: trunk = agent_utils.build_mlp( # type: ignore[assignment] input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers, ) # This seems to be a false alarm since both nn.Module and # SoftModularizedMLP are subtypes of ModelType. return trunk
[docs] def make_model( self, action_shape: List[int], hidden_dim: int, num_layers: int, encoder_cfg: ConfigType, multitask_cfg: ConfigType, ) -> ModelType: """Make the model for the actor. Args: action_shape (List[int]): hidden_dim (int): num_layers (int): encoder_cfg (ConfigType): multitask_cfg (ConfigType): Returns: ModelType: model for the actor. """ model_output_dim = 2 * action_shape[0] if encoder_cfg.type in ["moe", "fmoe"]: model_input_dim = encoder_cfg.encoder_cfg.feature_dim else: model_input_dim = encoder_cfg.feature_dim if ( "should_use_task_encoder" in multitask_cfg and multitask_cfg.should_use_task_encoder and self.should_condition_encoder_on_task_info and self.should_concatenate_task_info_with_encoder ): model_input_dim += multitask_cfg.task_encoder_cfg.model_cfg.output_dim if self.should_use_multi_head_policy: if multitask_cfg.should_use_disjoint_policy: heads = self._make_head( input_dim=model_input_dim, hidden_dim=hidden_dim, output_dim=model_output_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, ) return heads else: heads = self._make_head( input_dim=hidden_dim, output_dim=model_output_dim, hidden_dim=hidden_dim, num_layers=2, multitask_cfg=multitask_cfg, ) trunk = self._make_trunk( input_dim=model_input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, ) return nn.Sequential(trunk, nn.ReLU(), heads) else: trunk = self._make_trunk( input_dim=model_input_dim, hidden_dim=hidden_dim, output_dim=model_output_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, ) return trunk
[docs] def get_last_shared_layers(self) -> List[ModelType]: if self.should_use_multi_head_policy: # the trunk is the first element in `self.model` and is also the last # shared component. return [self.model[0][-1]] # type: ignore[index] else: return [self.model[-1]] # type: ignore[index]
[docs] def encode(self, mtobs: MTObs, detach: bool = False) -> TensorType: encoding = self.encoder(mtobs=mtobs, detach=detach) task_info = mtobs.task_info if self.should_concatenate_task_info_with_encoder: return torch.cat((encoding, task_info.encoding), dim=1) # type: ignore[arg-type, union-attr] # mypy is raising a false alarm. task_info is not None return encoding
[docs] def forward( self, mtobs: MTObs, detach_encoder: bool = False, ) -> Tuple[TensorType, TensorType, TensorType, TensorType]: task_info = mtobs.task_info assert task_info is not None if self.should_condition_encoder_on_task_info: obs = self.encode(mtobs=mtobs, detach=detach_encoder) else: # making a new task_info since we do not want to condition on # the task encoding. temp_task_info = TaskInfo( encoding=None, compute_grad=task_info.compute_grad, env_index=task_info.env_index, ) temp_mtobs = MTObs( env_obs=mtobs.env_obs, task_obs=mtobs.task_obs, task_info=temp_task_info ) obs = self.encode(temp_mtobs, detach=detach_encoder) if self.should_condition_model_on_task_info: new_mtobs = MTObs( env_obs=obs, task_obs=mtobs.task_obs, task_info=mtobs.task_info ) mu_and_log_std = self.model(new_mtobs) else: mu_and_log_std = self.model(obs) if self.should_use_multi_head_policy: policy_mask = self.moe_masks.get_mask(task_info=task_info) sum_of_masked_mu_and_log_std = (mu_and_log_std * policy_mask).sum(dim=0) sum_of_policy_count = policy_mask.sum(dim=0) mu_and_log_std = sum_of_masked_mu_and_log_std / sum_of_policy_count mu, log_std = mu_and_log_std.chunk(2, dim=-1) # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std_min, log_std_max = self.log_std_bounds log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) std = log_std.exp() noise = torch.randn_like(mu) pi = mu + noise * std log_pi = _gaussian_logprob(noise, log_std) mu, pi, log_pi = _squash(mu, pi, log_pi) return mu, pi, log_pi, log_std