mtrl package

Subpackages

Submodules

mtrl.logger module

class mtrl.logger.AverageMeter[source]

Bases: mtrl.logger.Meter

update(value, n=1)[source]
value()[source]
class mtrl.logger.CurrentMeter[source]

Bases: mtrl.logger.Meter

update(value, n=1)[source]
value()[source]
class mtrl.logger.Logger(log_dir, config, retain_logs: bool = False)[source]

Bases: object

dump(step)[source]
log(key, value, step, n=1)[source]
class mtrl.logger.Meter[source]

Bases: object

update(value, n=1)[source]
value()[source]
class mtrl.logger.MetersGroup(file_name, formating, mode: str, retain_logs: bool)[source]

Bases: object

dump(step, prefix)[source]
log(key, value, n=1)[source]
mtrl.logger.np_float32(val)[source]
mtrl.logger.np_int64(val)[source]
mtrl.logger.serialize_log(val)[source]
mtrl.logger.serialize_log(val: numpy.float32)
mtrl.logger.serialize_log(val: numpy.int64)

Used by default.

mtrl.replay_buffer module

class mtrl.replay_buffer.ReplayBuffer(env_obs_shape, task_obs_shape, action_shape, capacity, batch_size, device)[source]

Bases: object

Buffer to store environment transitions.

add(env_obs, action, reward, next_env_obs, done, task_obs)[source]
delete_from_filesystem(dir_to_delete_from: str)[source]
is_empty()[source]
load(save_dir)[source]
reset()[source]
sample(index=None)mtrl.replay_buffer.ReplayBufferSample[source]
sample_an_index(index, total_number_of_environments)mtrl.replay_buffer.ReplayBufferSample[source]

Return env_observations for only the given index

save(save_dir, size_per_chunk: int, num_samples_to_save: int)[source]
class mtrl.replay_buffer.ReplayBufferSample(env_obs: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, next_env_obs: torch.Tensor, not_done: torch.Tensor, task_obs: torch.Tensor, buffer_index: torch.Tensor)[source]

Bases: object

action
buffer_index
env_obs
next_env_obs
not_done
reward
task_obs

Module contents