Source code for mtrl.env.builder
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
import mtenv
from gym.vector.async_vector_env import AsyncVectorEnv
from mtrl.env.vec_env import MetaWorldVecEnv, VecEnv
from mtrl.utils.types import ConfigType
[docs]def build_dmcontrol_vec_env(
domain_name: str,
task_name: str,
prefix: str,
make_kwargs: ConfigType,
env_id_list: List[int],
seed_list: List[int],
mode_list: List[str],
) -> VecEnv:
def get_func_to_make_envs(seed: int, initial_task_state: int):
def _func() -> mtenv.MTEnv:
kwargs = deepcopy(make_kwargs)
kwargs["seed"] += seed
kwargs["initial_task_state"] = initial_task_state
return mtenv.make(
f"MT-HiPBMDP-{domain_name.capitalize()}-{task_name.capitalize()}-vary-{prefix.replace('_', '-')}-v0",
**kwargs,
)
return _func
funcs_to_make_envs = [
get_func_to_make_envs(seed=seed, initial_task_state=task_state)
for (seed, task_state) in zip(seed_list, env_id_list)
]
env_metadata = {"ids": env_id_list, "mode": mode_list}
env = VecEnv(env_metadata=env_metadata, env_fns=funcs_to_make_envs, context="spawn")
return env
[docs]def build_metaworld_vec_env(
config: ConfigType,
benchmark: "metaworld.Benchmark", # type: ignore[name-defined] # noqa: F821
mode: str,
env_id_to_task_map: Optional[Dict[str, "metaworld.Task"]], # type: ignore[name-defined] # noqa: F821
) -> Tuple[AsyncVectorEnv, Optional[Dict[str, Any]]]:
from mtenv.envs.metaworld.env import (
get_list_of_func_to_make_envs as get_list_of_func_to_make_metaworld_envs,
)
benchmark_name = config.env.benchmark._target_.replace("metaworld.", "")
num_tasks = int(benchmark_name.replace("MT", ""))
make_kwargs = {
"benchmark": benchmark,
"benchmark_name": benchmark_name,
"env_id_to_task_map": env_id_to_task_map,
"num_copies_per_env": 1,
"should_perform_reward_normalization": True,
}
funcs_to_make_envs, env_id_to_task_map = get_list_of_func_to_make_metaworld_envs(
**make_kwargs
)
env_metadata = {
"ids": list(range(num_tasks)),
"mode": [mode for _ in range(num_tasks)],
}
env = MetaWorldVecEnv(
env_metadata=env_metadata,
env_fns=funcs_to_make_envs,
context="spawn",
shared_memory=False,
)
return env, env_id_to_task_map