# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import random
from typing import List
import numpy as np
import torch
import torch.nn as nn
from mtrl.agent.components import moe_layer
from mtrl.utils.types import ModelType, TensorType
[docs]class eval_mode(object):
def __init__(self, *models):
"""Put the agent in the eval mode"""
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
[docs]def soft_update_params(net: ModelType, target_net: ModelType, tau: float) -> None:
"""Perform soft udpate on the net using target net.
Args:
net ([ModelType]): model to update.
target_net (ModelType): model to update with.
tau (float): control the extent of update.
"""
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
[docs]def set_seed_everywhere(seed: int) -> None:
"""Set seed for reproducibility.
Args:
seed (int): seed.
"""
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
[docs]def preprocess_obs(obs: TensorType, bits=5) -> TensorType:
"""Preprocessing image, see https://arxiv.org/abs/1807.03039."""
bins = 2 ** bits
assert obs.dtype == torch.float32
if bits < 8:
obs = torch.floor(obs / 2 ** (8 - bits))
obs = obs / bins
obs = obs + torch.rand_like(obs) / bins
obs = obs - 0.5
return obs
[docs]def weight_init_linear(m: ModelType):
assert isinstance(m.weight, TensorType)
nn.init.xavier_uniform_(m.weight)
assert isinstance(m.bias, TensorType)
nn.init.zeros_(m.bias)
[docs]def weight_init_conv(m: ModelType):
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
assert isinstance(m.weight, TensorType)
assert m.weight.size(2) == m.weight.size(3)
m.weight.data.fill_(0.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0) # type: ignore[operator]
mid = m.weight.size(2) // 2
gain = nn.init.calculate_gain("relu")
assert isinstance(m.weight, TensorType)
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
[docs]def weight_init_moe_layer(m: ModelType):
assert isinstance(m.weight, TensorType)
for i in range(m.weight.shape[0]):
nn.init.xavier_uniform_(m.weight[i])
assert isinstance(m.bias, TensorType)
nn.init.zeros_(m.bias)
[docs]def weight_init(m: ModelType):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
weight_init_linear(m)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
weight_init_conv(m)
elif isinstance(m, moe_layer.Linear):
weight_init_moe_layer(m)
def _get_list_of_layers(
input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> List[nn.Module]:
"""Utility function to get a list of layers. This assumes all the hidden
layers are using the same dimensionality.
Args:
input_dim (int): input dimension.
hidden_dim (int): dimension of the hidden layers.
output_dim (int): dimension of the output layer.
num_layers (int): number of layers in the mlp.
Returns:
ModelType: [description]
"""
mods: List[nn.Module]
if num_layers == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
for _ in range(num_layers - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
mods.append(nn.Linear(hidden_dim, output_dim))
return mods
[docs]def build_mlp_as_module_list(
input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> ModelType:
"""Utility function to build a module list of layers. This assumes all
the hidden layers are using the same dimensionality.
Args:
input_dim (int): input dimension.
hidden_dim (int): dimension of the hidden layers.
output_dim (int): dimension of the output layer.
num_layers (int): number of layers in the mlp.
Returns:
ModelType: [description]
"""
mods: List[nn.Module] = _get_list_of_layers(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
num_layers=num_layers,
)
sequential_layers = []
new_layer = []
for index, current_layer in enumerate(mods):
if index % 2 == 0:
new_layer = [current_layer]
else:
new_layer.append(current_layer)
sequential_layers.append(nn.Sequential(*new_layer))
sequential_layers.append(nn.Sequential(*new_layer))
return nn.ModuleList(sequential_layers)
[docs]def build_mlp(
input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> ModelType:
"""Utility function to build a mlp model. This assumes all the hidden
layers are using the same dimensionality.
Args:
input_dim (int): input dimension.
hidden_dim (int): dimension of the hidden layers.
output_dim (int): dimension of the output layer.
num_layers (int): number of layers in the mlp.
Returns:
ModelType: [description]
"""
mods: List[nn.Module] = _get_list_of_layers(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
num_layers=num_layers,
)
return nn.Sequential(*mods)