Source code for mtrl.agent.components.moe_layer

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Layers for parallelizing computation with mixture of experts.

A mixture of experts(models) can be easily simulated by maintaining a list
of models and iterating over them. However, this can be slow in practice.
We provide some additional modules which makes it easier to create mixture
of experts without slowing down training/inference.
"""
from typing import Dict, List, Optional

import torch
from torch import nn

from mtrl.agent import utils as agent_utils
from mtrl.agent.ds.task_info import TaskInfo
from mtrl.utils.types import ConfigType, TensorType


[docs]class Linear(nn.Module): def __init__( self, num_experts: int, in_features: int, out_features: int, bias: bool = True ): """torch.nn.Linear layer extended for use as a mixture of experts. Args: num_experts (int): number of experts in the mixture. in_features (int): size of each input sample for one expert. out_features (int): size of each output sample for one expert. bias (bool, optional): if set to ``False``, the layer will not learn an additive bias. Defaults to True. """ super().__init__() self.num_experts = num_experts self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter( torch.rand(self.num_experts, self.in_features, self.out_features) ) if bias: self.bias = nn.Parameter(torch.rand(self.num_experts, 1, self.out_features)) self.use_bias = True else: self.use_bias = False
[docs] def forward(self, x: TensorType) -> TensorType: if self.use_bias: return x.matmul(self.weight) + self.bias else: return x.matmul(self.weight)
[docs] def extra_repr(self) -> str: return f"num_experts={self.num_experts}, in_features={self.in_features}, out_features={self.out_features}, bias={self.use_bias}"
[docs]class FeedForward(nn.Module): def __init__( self, num_experts: int, in_features: int, out_features: int, num_layers: int, hidden_features: int, bias: bool = True, ): """A feedforward model of mixture of experts layers. Args: num_experts (int): number of experts in the mixture. in_features (int): size of each input sample for one expert. out_features (int): size of each output sample for one expert. num_layers (int): number of layers in the feedforward network. hidden_features (int): dimensionality of hidden layer in the feedforward network. bias (bool, optional): if set to ``False``, the layer will not learn an additive bias. Defaults to True. """ super().__init__() layers: List[nn.Module] = [] current_in_features = in_features for _ in range(num_layers - 1): linear = Linear( num_experts=num_experts, in_features=current_in_features, out_features=hidden_features, bias=bias, ) layers.append(linear) layers.append(nn.ReLU()) current_in_features = hidden_features linear = Linear( num_experts=num_experts, in_features=current_in_features, out_features=out_features, bias=bias, ) layers.append(linear) self._model = nn.Sequential(*layers)
[docs] def forward(self, x: TensorType) -> TensorType: return self._model(x)
def __repr__(self) -> str: return str(self._model)
[docs]class MaskCache: def __init__( self, num_tasks: int, num_eval_episodes: int, batch_size: int, task_index_to_mask: TensorType, ): """In multitask learning, using a mixture of models, different tasks can be mapped to different combination of models. This utility class caches these mappings so that they do not have to be revaluated. For example, when the model is training over 10 tasks, and the tasks are always ordered, the mapping of task index to encoder indices will be the same and need not be recomputed. We take a very simple approach here: cache using the number of tasks, since in our case, the task ordering during training and evaluation does not change. In more complex cases, a mode (train/eval..) based key could be used. This gets a little trickier during evaluation. We assume that we are running multiple evaluation episodes (per task) at once. So during evaluation, the agent is inferring over num_tasks*num_eval_episodes at once. We have to be careful about not caching the mapping during update because neither the task distribution, nor the task ordering, is pre-determined during update. So we explicitly exclude the `batch_size` from the list of keys being cached. Args: num_tasks (int): number of tasks. num_eval_episodes (int): number of episodes run during evaluation. batch_size (int): batch size for update. task_index_to_mask (TensorType): mapping of task index to mask. """ self.masks: Dict[int, TensorType] = {} self.task_index_to_mask = task_index_to_mask keys_to_cache = [num_tasks, num_tasks * num_eval_episodes] self.keys_to_cache = {key for key in keys_to_cache if key != batch_size} # This is a hack to get some speed up, by reusing the mask
[docs] def get_mask(self, task_info: TaskInfo) -> TensorType: """Get the mask corresponding to a given task info. Args: task_info (TaskInfo): Returns: TensorType: encoder mask. """ key = len(task_info.env_index) if key in self.keys_to_cache: if key not in self.masks: self.masks[key] = self._make_mask(task_info) return self.masks[key] return self._make_mask(task_info)
def _make_mask(self, task_info: TaskInfo): env_index = task_info.env_index encoder_mask = self.task_index_to_mask[env_index.squeeze()] if len(encoder_mask.shape) == 1: encoder_mask = encoder_mask.unsqueeze(0) return encoder_mask.t().unsqueeze(2).to(env_index.device)
[docs]class MixtureOfExperts(nn.Module): def __init__( self, multitask_cfg: ConfigType, ): """Class for interfacing with a mixture of experts. Args: multitask_cfg (ConfigType): config for multitask training. """ super().__init__() self.multitask_cfg = multitask_cfg self.mask_cache: MaskCache
[docs] def forward(self, task_info: TaskInfo) -> TensorType: return self.mask_cache.get_mask(task_info=task_info)
[docs]class AttentionBasedExperts(MixtureOfExperts): def __init__( self, num_tasks: int, num_experts: int, embedding_dim: int, hidden_dim: int, num_layers: int, temperature: bool, should_use_soft_attention: bool, task_encoder_cfg: ConfigType, multitask_cfg: ConfigType, topk: Optional[int] = None, ): super().__init__(multitask_cfg=multitask_cfg) self.should_use_soft_attention = should_use_soft_attention self.temperature = temperature self.should_use_task_encoding = task_encoder_cfg.should_use_task_encoding self.should_detach_task_encoding = ( self.should_use_task_encoding and task_encoder_cfg.should_detach_task_encoding ) if not self.should_use_task_encoding: self.emb = nn.Embedding( num_embeddings=num_tasks, embedding_dim=embedding_dim, ) self.emb.apply(agent_utils.weight_init) else: embedding_dim = multitask_cfg.task_encoder_cfg.model_cfg.embedding_dim self.trunk = agent_utils.build_mlp( input_dim=embedding_dim, hidden_dim=hidden_dim, output_dim=num_experts, num_layers=num_layers, ) self.trunk.apply(agent_utils.weight_init) self.topk = topk self._softmax = torch.nn.Softmax(dim=1)
[docs] def forward(self, task_info: TaskInfo) -> TensorType: if self.should_use_task_encoding: emb = task_info.encoding if self.should_detach_task_encoding: emb = emb.detach() # type: ignore[union-attr] else: env_index = task_info.env_index if len(env_index.shape) == 2: env_index = env_index.squeeze(1) emb = self.emb(env_index) output = self.trunk(emb) gate = self._softmax(output / self.temperature) if not self.should_use_soft_attention: topk_attention = gate.topk(self.topk, dim=1) topk_attention_indices = topk_attention[1] hard_attention_mask = torch.zeros_like(gate).scatter_( dim=1, index=topk_attention_indices, value=1.0 ) gate = gate * hard_attention_mask gate = gate / gate.sum(dim=1).unsqueeze(1) if len(gate.shape) > 2: breakpoint() return gate.t().unsqueeze(2)
[docs]class ClusterOfExperts(MixtureOfExperts): def __init__( self, num_tasks: int, num_experts: int, num_eval_episodes: int, batch_size: int, multitask_cfg: ConfigType, env_name: str, task_description: Dict[str, str], ordered_task_list: List[str], mapping_cfg: ConfigType, ): """Map the ith task to a subset (cluster) of experts. Args: num_tasks (int): number of tasks. num_experts (int): number of experts in the mixture of experts. num_eval_episodes (int): number of episodes run during evaluation. batch_size (int): batch size for update. multitask_cfg (ConfigType): config for multitask training. env_name (str): name of the environment. This is used with the mapping configuration. task_description (Dict[str, str]): dictionary mapping task names to descriptions. ordered_task_list (List[str]): ordered list of tasks. This is needed because the task description is not always ordered. mapping_cfg (ConfigType): config for mapping the tasks to subset of experts. """ super().__init__(multitask_cfg=multitask_cfg) assert env_name in ["metaworld-mt10", "metaworld-mt50"] env_name = env_name.split("-")[1] clusters = mapping_cfg[env_name]["cluster"] task_index_to_encoder_index = torch.zeros( (len(task_description), min(num_experts, len(clusters))) ) for task_index, task_name in enumerate(ordered_task_list): description = task_description[task_name] for encoder_index, cluster_values in enumerate(clusters.values()): for word in cluster_values: if word in description: task_index_to_encoder_index[task_index][encoder_index] = 1.0 self.mask_cache = MaskCache( num_tasks=num_tasks, num_eval_episodes=num_eval_episodes, batch_size=batch_size, task_index_to_mask=task_index_to_encoder_index, )
[docs]class OneToOneExperts(MixtureOfExperts): def __init__( self, num_tasks: int, num_experts: int, num_eval_episodes: int, batch_size: int, multitask_cfg: ConfigType, ): """Map the output of ith expert with the ith task. Args: num_tasks (int): number of tasks. num_experts (int): number of experts in the mixture of experts. num_eval_episodes (int): number of episodes run during evaluation. batch_size (int): batch size for update. multitask_cfg (ConfigType): config for multitask training. """ super().__init__(multitask_cfg=multitask_cfg) assert num_tasks == num_experts self.mask_cache = MaskCache( num_tasks=num_tasks, num_eval_episodes=num_eval_episodes, batch_size=batch_size, task_index_to_mask=torch.eye(num_tasks), )
[docs]class EnsembleOfExperts(MixtureOfExperts): def __init__( self, num_tasks: int, num_experts: int, num_eval_episodes: int, batch_size: int, multitask_cfg: ConfigType, ): """Ensemble of all the experts. Args: num_tasks (int): number of tasks. num_experts (int): number of experts in the mixture of experts. num_eval_episodes (int): number of episodes run during evaluation. batch_size (int): batch size for update. multitask_cfg (ConfigType): config for multitask training. """ super().__init__(multitask_cfg=multitask_cfg) self.mask_cache = MaskCache( num_tasks=num_tasks, num_eval_episodes=num_eval_episodes, batch_size=batch_size, task_index_to_mask=torch.ones(num_tasks, num_experts), )