Source code for mtrl.agent.hipbmdp

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F

from mtrl.agent import deepmdp
from mtrl.agent.components import hipbmdp_theta
from mtrl.agent.ds.task_info import TaskInfo
from mtrl.replay_buffer import ReplayBufferSample
from mtrl.utils.types import ConfigType, ParameterType, TensorType

LOG_FREQ = 10000


[docs]class Agent(deepmdp.Agent): """HiPBMDP Agent""" def __init__( self, env_obs_shape: List[int], action_shape: List[int], action_range: Tuple[int, int], device: torch.device, actor_cfg: ConfigType, critic_cfg: ConfigType, decoder_cfg: ConfigType, reward_decoder_cfg: ConfigType, transition_model_cfg: ConfigType, alpha_optimizer_cfg: ConfigType, actor_optimizer_cfg: ConfigType, critic_optimizer_cfg: ConfigType, multitask_cfg: ConfigType, decoder_optimizer_cfg: ConfigType, encoder_optimizer_cfg: ConfigType, reward_decoder_optimizer_cfg: ConfigType, transition_model_optimizer_cfg: ConfigType, discount: float = 0.99, init_temperature: float = 0.01, actor_update_freq: int = 2, critic_tau: float = 0.005, critic_target_update_freq: int = 2, encoder_tau: float = 0.005, loss_reduction: str = "mean", decoder_update_freq: int = 1, decoder_latent_lambda: float = 0.0, cfg_to_load_model: Optional[ConfigType] = None, should_complete_init: bool = True, ): super().__init__( env_obs_shape=env_obs_shape, action_shape=action_shape, action_range=action_range, device=device, actor_cfg=actor_cfg, critic_cfg=critic_cfg, decoder_cfg=decoder_cfg, transition_model_cfg=transition_model_cfg, reward_decoder_cfg=reward_decoder_cfg, discount=discount, init_temperature=init_temperature, alpha_optimizer_cfg=alpha_optimizer_cfg, actor_optimizer_cfg=actor_optimizer_cfg, critic_optimizer_cfg=critic_optimizer_cfg, multitask_cfg=multitask_cfg, decoder_optimizer_cfg=decoder_optimizer_cfg, encoder_optimizer_cfg=encoder_optimizer_cfg, reward_decoder_optimizer_cfg=reward_decoder_optimizer_cfg, transition_model_optimizer_cfg=transition_model_optimizer_cfg, actor_update_freq=actor_update_freq, critic_tau=critic_tau, critic_target_update_freq=critic_target_update_freq, encoder_tau=encoder_tau, loss_reduction=loss_reduction, decoder_update_freq=decoder_update_freq, decoder_latent_lambda=decoder_latent_lambda, cfg_to_load_model=None, should_complete_init=False, ) self._cache_theta_sampling_strategy: Dict[ str, List[hipbmdp_theta.ThetaSamplingStrategy] ] = {} if should_complete_init: self.complete_init(cfg_to_load_model=cfg_to_load_model) def _get_theta_sampling_strategy(self, modes: List[str]): if modes[0] not in self._cache_theta_sampling_strategy: if modes[0] == "train": strategy = self.multitask_cfg.task_encoder_cfg.sampling_strategy[ "train" ] theta_sampling_strategy = [ hipbmdp_theta.ThetaSamplingStrategy(strategy) for _ in modes ] elif modes[0] == "base": theta_sampling_strategy = [ hipbmdp_theta.ThetaSamplingStrategy( self.multitask_cfg.task_encoder_cfg.sampling_strategy["eval"][ submode ] ) for submode in modes ] else: raise ValueError(f"`mode`={modes[0]} is not supported") assert len(modes) == len(theta_sampling_strategy) self._cache_theta_sampling_strategy[modes[0]] = theta_sampling_strategy return self._cache_theta_sampling_strategy[modes[0]]
[docs] def get_task_encoding( self, env_index: TensorType, modes: List[str], disable_grad: bool ): theta_sampling_strategy = self._get_theta_sampling_strategy(modes=modes) if disable_grad: with torch.no_grad(): return self.task_encoder( env_index=env_index.to(self.device), theta_sampling_strategy=theta_sampling_strategy, modes=modes, ) return self.task_encoder( env_index=env_index.to(self.device), theta_sampling_strategy=theta_sampling_strategy, modes=modes, )
[docs] def update_task_encoder( self, batch: ReplayBufferSample, task_info: TaskInfo, logger, step, kwargs_to_compute_gradient: Dict[str, Any], ): if not self.multitask_cfg.contrastive.should_use: return h = self.critic.encoder(batch.env_obs, task_info=task_info) flipped_task_encoding = torch.flip(task_info.encoding, [0]) # type: ignore[arg-type] h_1 = torch.cat([h, task_info.encoding], dim=1) # type: ignore[list-item] h_2 = torch.cat([h, flipped_task_encoding], dim=1) pred_next_latent_mu_1, pred_next_latent_sigma_1 = self.transition_model( torch.cat([h_1, batch.action], dim=1) ) pred_next_latent_mu_2, pred_next_latent_sigma_2 = self.transition_model( torch.cat([h_2, batch.action], dim=1) ) loss = F.mse_loss( torch.norm(task_info.encoding - flipped_task_encoding), torch.norm(pred_next_latent_mu_1.detach() - pred_next_latent_mu_2.detach()), ) logger.log("train/contrastive_loss", loss, step) loss *= self.multitask_cfg.contrastive.alpha # optimize the task_encoder component_names = ["task_encoder"] parameters: List[ParameterType] = [] for name in component_names: parameters += self.get_parameters(name) self._compute_gradient( loss=loss, parameters=parameters, step=step, component_names=component_names, **kwargs_to_compute_gradient, ) self.task_encoder_optimizer.step()