Source code for atomworks.common
"""Common utility functions used throughout the project."""
import copy
import hashlib
from collections.abc import Callable
from functools import lru_cache, wraps
from typing import Any
import numpy as np
from toolz.curried import compose, reduce
[docs]
def exists(obj: Any) -> bool:
"""Check that obj is not None.
Args:
obj: The object to check.
Returns:
True if obj is not None, False otherwise.
"""
return obj is not None
[docs]
def default(obj: Any, default: Any) -> Any:
"""Return obj if not None, otherwise return default.
Args:
obj: The primary object to return.
default: The fallback value if obj is None.
Returns:
obj if it is not None, otherwise default.
"""
return obj if exists(obj) else default
[docs]
def to_hashable(element: Any) -> Any:
"""Convert an element to a hashable type.
Args:
element: The element to convert.
Returns:
The element if already hashable, otherwise converted to a tuple.
"""
return element if isinstance(element, int | str | np.integer | np.str_) else tuple(element)
[docs]
def string_to_md5_hash(s: str, truncate: int = 32) -> str:
"""Generate an MD5 hash of a string and return the first truncate characters.
Args:
s: The string to hash.
truncate: Number of characters to return from the hash.
Returns:
The truncated MD5 hash as a string.
"""
full_hash = hashlib.md5(s.encode("utf-8")).hexdigest()
return full_hash[:truncate]
[docs]
def sum_string_arrays(*objs: np.ndarray | str) -> np.ndarray:
"""Sum a list of string arrays or strings into a single string array.
Concatenates the arrays and determines the shortest string length to set as dtype.
Args:
*objs: Variable number of string arrays or strings to sum.
Returns:
A single concatenated string array.
"""
return reduce(np.char.add, objs).astype(object).astype(str)
[docs]
def not_isin(element: np.ndarray, array: np.ndarray, **isin_kwargs) -> np.ndarray:
"""Like ~np.isin, but more efficient.
Args:
element: The array to test.
array: The array of values to test against.
**isin_kwargs: Additional keyword arguments for np.isin.
Returns:
Boolean array indicating which elements are not in the array.
"""
return np.isin(element, array, invert=True, **isin_kwargs)
[docs]
def listmap(func: Callable, *iterables) -> list:
"""Like map, but returns a list instead of an iterator.
Args:
func: The function to apply.
*iterables: Variable number of iterables to map over.
Returns:
A list containing the results of applying func to the iterables.
"""
return compose(list, map)(func, *iterables)
[docs]
def as_list(value: Any) -> list:
"""Convert a value to a list.
Handles various types using duck typing:
- Iterable objects (lists, tuples, strings, etc.): converted to list
- Single values: wrapped in a list
Args:
value: The value to convert to a list.
Returns:
A list containing the value(s).
"""
try:
# Try to iterate over the value (duck typing approach)
# Exclude strings since they're iterable but we want to treat them as single values
if isinstance(value, str):
return [value]
return list(value)
except TypeError:
# If it's not iterable, wrap it in a list
return [value]
[docs]
def immutable_lru_cache(maxsize: int = 128, typed: bool = False, deepcopy: bool = True) -> Callable:
"""An immutable version of lru_cache for caching functions that return mutable objects.
Args:
maxsize: Maximum number of items to cache.
typed: Whether to treat different types as separate cache entries.
deepcopy: Whether to use deep copy for immutable caching.
Returns:
A decorator that provides immutable caching functionality.
"""
copy_func = copy.deepcopy if deepcopy else copy.copy
def decorator(func: Callable) -> Callable:
cached_func = lru_cache(maxsize=maxsize, typed=typed)(func)
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
return copy_func(cached_func(*args, **kwargs))
return wrapper
return decorator
[docs]
class KeyToIntMapper:
"""Maps keys to unique integers based on the order of the first appearance of the key.
This is useful for mapping id's such as chain_id, chain_entity, molecule_iid, etc.
to integers.
Example:
>>> chain_id_to_int = KeyToIntMapper()
>>> chain_id_to_int("A") # 0
>>> chain_id_to_int("C") # 1
>>> chain_id_to_int("A") # 0
>>> chain_id_to_int("B") # 2
"""
def __init__(self):
"""Initialize KeyToIntMapper with empty mapping."""
self.key_to_id = {}
self.next_id = 0
def __call__(self, value: Any) -> int:
"""Map a key to a unique integer.
Args:
value: The key to map.
Returns:
The unique integer assigned to the key.
"""
if value not in self.key_to_id:
self.key_to_id[value] = self.next_id
self.next_id += 1
return self.key_to_id[value]