Source code for mtrl.utils.utils

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Collection of utility functions"""

import os
import pathlib
import random
import re
import subprocess  # noqa: S404
from typing import Any, Iterator, List, TypeVar, Union

import numpy as np
import torch

T = TypeVar("T")


[docs]def flatten_list(_list: List[List[Any]]) -> List[Any]: """Flatten a list of lists into a single list Args: _list (List[List[Any]]): List of lists Returns: List[Any]: Flattened list """ return [item for sublist in _list for item in sublist]
[docs]def chunks(_list: List[T], n: int) -> Iterator[List[T]]: """Yield successive n-sized chunks from given list. Taken from https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks Args: _list (List[T]): list to chunk. n (int): size of chunks. Yields: Iterator[List[T]]: iterable over the chunks """ for index in range(0, len(_list), n): yield _list[index : index + n] # noqa: E203
[docs]def make_dir(path: str) -> str: """Make a directory, along with parent directories. Does not return an error if the directory already exists. Args: path (str): path to make the directory. Returns: str: path of the new directory. """ pathlib.Path(path).mkdir(parents=True, exist_ok=True) return path
[docs]def get_current_commit_id() -> str: """Get current commit id. Returns: str: current commit id. """ command = "git rev-parse HEAD" commit_id = ( subprocess.check_output(command.split()).strip().decode("utf-8") # noqa: S603 ) return commit_id
[docs]def has_uncommitted_changes() -> bool: """Check if there are uncommited changes. Returns: bool: wether there are uncommiteed changes. """ command = "git status" output = subprocess.check_output(command.split()).strip().decode("utf-8") return "nothing to commit (working directory clean)" not in output
[docs]def set_seed(seed: int) -> None: """Set the seed for python, numpy, and torch. Args: seed (int): seed to set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # type: ignore # Module has no attribute "manual_seed_all" [attr-defined] os.environ["PYTHONHASHSEED"] = str(seed)
[docs]def split_on_caps(input_str: str) -> List[str]: """Split a given string at uppercase characters. Taken from: https://stackoverflow.com/questions/2277352/split-a-string-at-uppercase-letters Args: input_str (str): string to split. Returns: List[str]: splits of the given string. """ return re.findall("[A-Z][^A-Z]*", input_str)
[docs]def is_integer(n: Union[int, str, float]) -> bool: """Check if the given value can be interpreted as an integer. Args: n (Union[int, str, float]): value to check. Returns: bool: can be the value be interpreted as an integer. """ try: int(n) except ValueError: return False else: return True