Source code for mtrl.agent.components.critic

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Implementation based on Denis Yarats' implementation of [SAC](https://github.com/denisyarats/pytorch_sac).
"""Critic component for the agent."""

from typing import List, Tuple

import torch
from torch import nn

from mtrl.agent import utils as agent_utils
from mtrl.agent.components import base as base_component
from mtrl.agent.components import encoder, moe_layer
from mtrl.agent.components.actor import (
    check_if_should_use_multi_head_policy,
    check_if_should_use_task_encoder,
)
from mtrl.agent.components.soft_modularization import SoftModularizedMLP
from mtrl.agent.ds.mt_obs import MTObs
from mtrl.agent.ds.task_info import TaskInfo
from mtrl.utils.types import ConfigType, ModelType, TensorType


[docs]class QFunction(base_component.Component): def __init__( self, obs_dim: int, action_dim: int, hidden_dim: int, num_layers: int, multitask_cfg: ConfigType, ): """Q-function implemented as a MLP. Args: obs_dim (int): size of the observation. action_dim (int): size of the action vector. hidden_dim (int): size of the hidden layer of the model. num_layers (int): number of layers in the model. multitask_cfg (ConfigType): config for encoding the multitask knowledge. """ super().__init__() self.should_condition_model_on_task_info = False self.should_condition_encoder_on_task_info = True if "critic_cfg" in multitask_cfg and multitask_cfg.critic_cfg: self.should_condition_model_on_task_info = ( multitask_cfg.critic_cfg.should_condition_model_on_task_info ) self.should_condition_encoder_on_task_info = ( multitask_cfg.critic_cfg.should_condition_encoder_on_task_info ) self.should_use_multi_head_policy = check_if_should_use_multi_head_policy( multitask_cfg=multitask_cfg ) self.model = self.build_model( obs_dim=obs_dim, action_dim=action_dim, hidden_dim=hidden_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, ) if self.should_condition_model_on_task_info: # layer to project obs_action to shape of obs_dim self.obs_action_projection_layer = nn.Linear( in_features=obs_dim + action_dim, out_features=obs_dim ) def _make_head( self, input_dim: int, hidden_dim: int, num_layers: int, multitask_cfg: ConfigType, ) -> ModelType: """Make the heads for the Q-function. Args: input_dim (int): size of the input. hidden_dim (int): size of the hidden layer of the head. num_layers (int): number of layers in the model. multitask_cfg (ConfigType): config for encoding the multitask knowledge. Returns: ModelType: """ return moe_layer.FeedForward( num_experts=multitask_cfg.num_envs, in_features=input_dim, out_features=1, hidden_features=hidden_dim, num_layers=num_layers, bias=True, ) def _make_trunk( self, obs_dim: int, action_dim: int, hidden_dim: int, output_dim: int, num_layers: int, multitask_cfg: ConfigType, ) -> ModelType: """Make the tunk for the Q-function. Args: obs_dim (int): size of the observation. action_dim (int): size of the action vector. hidden_dim (int): size of the hidden layer of the trunk. output_dim (int): size of the output. num_layers (int): number of layers in the model. multitask_cfg (ConfigType): config for encoding the multitask knowledge. Returns: ModelType: """ if ( "critic_cfg" in multitask_cfg and multitask_cfg.critic_cfg and "moe_cfg" in multitask_cfg.critic_cfg and multitask_cfg.critic_cfg.moe_cfg.should_use ): moe_cfg = multitask_cfg.critic_cfg.moe_cfg if moe_cfg.mode == "soft_modularization": trunk = SoftModularizedMLP( num_experts=moe_cfg.num_experts, in_features=obs_dim, out_features=output_dim, num_layers=2, hidden_features=hidden_dim, bias=True, ) else: raise NotImplementedError( f"""`moe_cfg.mode` = {moe_cfg.mode} is not implemented.""" ) else: trunk = agent_utils.build_mlp( # type: ignore[assignment] input_dim=obs_dim + action_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers, ) # This seems to be a false alarm since both nn.Module and # SoftModularizedMLP are subtypes of ModelType. return trunk
[docs] def build_model( self, obs_dim: int, action_dim: int, hidden_dim: int, num_layers: int, multitask_cfg: ConfigType, ) -> ModelType: """Build the Q-Function. Args: obs_dim (int): size of the observation. action_dim (int): size of the action vector. hidden_dim (int): size of the hidden layer of the trunk. num_layers (int): number of layers in the model. multitask_cfg (ConfigType): config for encoding the multitask knowledge. Returns: ModelType: """ if self.should_use_multi_head_policy: if multitask_cfg.should_use_disjoint_policy: heads = self._make_head( input_dim=obs_dim + action_dim, hidden_dim=hidden_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, ) return heads else: heads = self._make_head( input_dim=hidden_dim, hidden_dim=hidden_dim, num_layers=2, multitask_cfg=multitask_cfg, ) trunk = self._make_trunk( obs_dim=obs_dim, action_dim=action_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, ) return nn.Sequential(trunk, nn.ReLU(), heads) else: trunk = self._make_trunk( obs_dim=obs_dim, action_dim=action_dim, hidden_dim=hidden_dim, output_dim=1, num_layers=num_layers, multitask_cfg=multitask_cfg, ) return trunk
[docs] def get_last_shared_layers(self) -> List[ModelType]: if self.should_use_multi_head_policy: # the trunk is the first element in `self.model` and is also the last # shared component. return [self.model[0][-1]] # type: ignore[index] else: return [self.model[-1]] # type: ignore[index]
[docs] def forward(self, mtobs: MTObs) -> TensorType: obs_action = mtobs.env_obs if self.should_condition_model_on_task_info: obs_action = self.obs_action_projection_layer(obs_action) new_mtobs = MTObs( env_obs=obs_action, task_obs=mtobs.task_obs, task_info=mtobs.task_info ) return self.model(new_mtobs) else: return self.model(obs_action)
[docs]class Critic(base_component.Component): def __init__( self, env_obs_shape: List[int], action_shape: List[int], hidden_dim: int, num_layers: int, encoder_cfg: ConfigType, multitask_cfg: ConfigType, ): """Critic component for the agent. Args: env_obs_shape (List[int]): shape of the environment observation that the actor gets. action_shape (List[int]): shape of the action vector that the actor produces. hidden_dim (int): hidden dimensionality of the actor. num_layers (int): number of layers in the actor. encoder_cfg (ConfigType): config for the encoder. multitask_cfg (ConfigType): config for encoding the multitask knowledge. """ key = "type_to_select" if key in encoder_cfg: encoder_type_to_select = encoder_cfg[key] encoder_cfg = encoder_cfg[encoder_type_to_select] super().__init__() if check_if_should_use_task_encoder(multitask_cfg): self.should_condition_model_on_task_info = False self.should_condition_encoder_on_task_info = True self.should_concatenate_task_info_with_encoder = True if "critic_cfg" in multitask_cfg and multitask_cfg.critic_cfg: self.should_condition_model_on_task_info = ( multitask_cfg.critic_cfg.should_condition_model_on_task_info ) self.should_condition_encoder_on_task_info = ( multitask_cfg.critic_cfg.should_condition_encoder_on_task_info ) self.should_concatenate_task_info_with_encoder = ( multitask_cfg.critic_cfg.should_concatenate_task_info_with_encoder ) else: self.should_condition_model_on_task_info = False self.should_condition_encoder_on_task_info = False self.should_concatenate_task_info_with_encoder = False self.encoder = self._make_encoder( env_obs_shape=env_obs_shape, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, # **kwargs ) self.should_use_multi_head_policy = check_if_should_use_multi_head_policy( multitask_cfg=multitask_cfg ) if self.should_use_multi_head_policy: task_index_to_mask = torch.eye(multitask_cfg.num_envs) self.moe_masks = moe_layer.MaskCache( task_index_to_mask=task_index_to_mask, **multitask_cfg.multi_head_policy_cfg.mask_cfg, ) self.Q1 = self._make_qfunction( action_shape=action_shape, hidden_dim=hidden_dim, num_layers=num_layers, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) self.Q2 = self._make_qfunction( action_shape=action_shape, hidden_dim=hidden_dim, num_layers=num_layers, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) self.apply(agent_utils.weight_init) def _make_encoder( self, env_obs_shape: List[int], encoder_cfg: ConfigType, multitask_cfg: ConfigType, **kwargs: ConfigType, ) -> encoder.Encoder: """Make the encoder. Args: env_obs_shape (List[int]): encoder_cfg (ConfigType): multitask_cfg (ConfigType): Returns: encoder.Encoder: encoder """ return encoder.make_encoder( env_obs_shape=env_obs_shape, encoder_cfg=encoder_cfg, multitask_cfg=multitask_cfg, ) def _make_qfunction( self, action_shape: List[int], hidden_dim: int, num_layers: int, encoder_cfg: ConfigType, multitask_cfg: ConfigType, ) -> QFunction: """Make the QFunction. Args: action_shape (List[int]): hidden_dim (int): num_layers (int): encoder_cfg (ConfigType): multitask_cfg (ConfigType): Returns: QFunction: """ key = "type_to_select" if key in encoder_cfg: encoder_type_to_select = encoder_cfg[key] encoder_cfg = encoder_cfg[encoder_type_to_select] if encoder_cfg.type in ["moe", "fmoe"]: obs_dim = encoder_cfg.encoder_cfg.feature_dim else: obs_dim = encoder_cfg.feature_dim if ( multitask_cfg.should_use_task_encoder and self.should_condition_encoder_on_task_info ): obs_dim += multitask_cfg.task_encoder_cfg.model_cfg.output_dim return QFunction( obs_dim=obs_dim, action_dim=action_shape[0], hidden_dim=hidden_dim, num_layers=num_layers, multitask_cfg=multitask_cfg, )
[docs] def encode( self, mtobs: MTObs, detach: bool = False, ) -> TensorType: """Encode the input observation. Args: mtobs (MTObs): multi-task observation. detach (bool, optional): should detach the observation encoding from the computation graph. Defaults to False. Returns: TensorType: encoding of the observation. """ encoding = self.encoder(mtobs=mtobs, detach=detach) task_info = mtobs.task_info if self.should_concatenate_task_info_with_encoder: return torch.cat((encoding, task_info.encoding), dim=1) # type: ignore[arg-type, union-attr] # mypy is raising a false alarm. task_info is not None return encoding
[docs] def get_last_shared_layers(self) -> List[ModelType]: last_shared_layers: List[ModelType] = [] for q in [self.Q1, self.Q2]: last_shared_layers += q.get_last_shared_layers() return last_shared_layers
[docs] def forward( self, mtobs: MTObs, action: TensorType, detach_encoder: bool = False, ) -> Tuple[TensorType, TensorType]: task_info = mtobs.task_info assert task_info is not None # detach_encoder allows to stop gradient propogation to encoder if self.should_condition_encoder_on_task_info: obs = self.encode(mtobs=mtobs, detach=detach_encoder) else: # making a new task_info since we do not want to condition on # # the task encoding. temp_task_info = TaskInfo( encoding=None, compute_grad=task_info.compute_grad, env_index=task_info.env_index, ) temp_mtobs = MTObs( env_obs=mtobs.env_obs, task_obs=mtobs.task_obs, task_info=temp_task_info ) obs = self.encode(mtobs=temp_mtobs, detach=detach_encoder) assert obs.size(0) == action.size(0) obs_action = torch.cat([obs, action], dim=-1) mtobs_for_q = MTObs( env_obs=obs_action, task_obs=mtobs.task_obs, task_info=mtobs.task_info, ) q1 = self.Q1(mtobs=mtobs_for_q) q2 = self.Q2(mtobs=mtobs_for_q) if self.should_use_multi_head_policy: q_mask = self.moe_masks.get_mask(task_info=task_info) sum_of_q_count = q_mask.sum(dim=0) sum_of_q1 = (q1 * q_mask).sum(dim=0) q1 = sum_of_q1 / sum_of_q_count sum_of_q2 = (q2 * q_mask).sum(dim=0) q2 = sum_of_q2 / sum_of_q_count return q1, q2