Source code for mtrl.agent.components.base

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Interface for the agent components."""
from typing import List

from torch import nn

from mtrl.utils.types import ModelType


[docs]class Component(nn.Module): """Basic component (for building the agent) that every other component should extend. It inherits `torch.nn.Module`. """ def __init__(self): super().__init__()
[docs] def get_last_shared_layers(self) -> List[ModelType]: """Get the list of last layers (for different sub-components) that are shared across tasks. This method should be implemented by the subclasses if the component is to be trained with gradnorm algorithm. Returns: List[ModelType]: list of layers. """ raise NotImplementedError( """Implement the `get_last_shared_layers` method if you want to train the agent with gradnorm algorithm.""" )