Source code for mtrl.env.vec_env
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict
import torch
from gym.vector.async_vector_env import AsyncVectorEnv
[docs]class VecEnv(AsyncVectorEnv):
def __init__(
self,
env_metadata: Dict[str, Any],
env_fns,
observation_space=None,
action_space=None,
shared_memory=True,
copy=True,
context=None,
daemon=True,
worker=None,
):
"""Return only every `skip`-th frame"""
super().__init__(
env_fns=env_fns,
observation_space=observation_space,
action_space=action_space,
shared_memory=shared_memory,
copy=copy,
context=context,
daemon=daemon,
worker=worker,
)
self.num_envs = len(env_fns)
assert "mode" in env_metadata
assert "ids" in env_metadata
self._metadata = env_metadata
@property
def mode(self):
return self._metadata["mode"]
@property
def ids(self):
return self._metadata["ids"]
[docs] def reset(self):
multitask_obs = super().reset()
return _cast_multitask_obs(multitask_obs=multitask_obs)
[docs] def step(self, actions):
multitask_obs, reward, done, info = super().step(actions)
return _cast_multitask_obs(multitask_obs=multitask_obs), reward, done, info
def _cast_multitask_obs(multitask_obs):
return {key: torch.tensor(value) for key, value in multitask_obs.items()}
[docs]class MetaWorldVecEnv(AsyncVectorEnv):
def __init__(
self,
env_metadata: Dict[str, Any],
env_fns,
observation_space=None,
action_space=None,
shared_memory=True,
copy=True,
context=None,
daemon=True,
worker=None,
):
"""Return only every `skip`-th frame"""
super().__init__(
env_fns=env_fns,
observation_space=observation_space,
action_space=action_space,
shared_memory=shared_memory,
copy=copy,
context=context,
daemon=daemon,
worker=worker,
)
self.num_envs = len(env_fns)
self.task_obs = torch.arange(self.num_envs)
assert "mode" in env_metadata
assert "ids" in env_metadata
self._metadata = env_metadata
@property
def mode(self):
return self._metadata["mode"]
@property
def ids(self):
return self._metadata["ids"]
def _check_observation_spaces(self):
return
[docs] def step(self, actions):
env_obs, reward, done, info = super().step(actions)
return self.create_multitask_obs(env_obs=env_obs), reward, done, info
[docs] def create_multitask_obs(self, env_obs):
return {"env_obs": torch.tensor(env_obs), "task_obs": self.task_obs}