Datasets#
This module contains dataset classes and utilities for loading and processing molecular data using a modern, composable architecture.
Core Dataset Classes#
AtomWorks Dataset classes and common APIs.
- At a high level, to train models with AtomWorks, we need a Dataset class that:
Takes as input an item index and returns the corresponding example information; typically includes: a. Path to a structural file saved on disk (/path/to/dataset/my_dataset_0.cif) b. Additional item-specific metadata (e.g., class labels)
Pre-loads structural information from the returned example into an AtomArray and assembles inputs for the Transform pipeline
Feed the input dictionary through a Transform pipeline and return the result
- Due to the heterogeneity of biomolecular data, in many cases, we may also want:
In the event of a failure during the Transform pipeline, fall back to a different example
For bespoke use cases, users may choose to write a custom Dataset that accomplish these steps; downstream code makes no assumptions.
- To accelerate development, we also provide an off-the-shelf, composable approach following common patterns:
MolecularDataset
: Base class that handles pre-loading structural data and executing the Transform pipeline with error handling and debugging utilitiesPandasDataset
: A subclass of MolecularDataset for tabular data stored as pandas DataFramesFileDataset
: A subclass of MolecularDataset where each file is one example
- class atomworks.ml.datasets.datasets.ConcatDatasetWithID(datasets: list[ExampleIDMixin])[source]#
Bases:
ConcatDataset
Equivalent to
torch.utils.data.ConcatDataset
but allows accessing examples by ID.Provides ID-based access across multiple datasets that implement
ExampleIDMixin
.- datasets: list[ExampleIDMixin]#
- get_dataset_by_id(example_id: str) Dataset [source]#
Retrieves the dataset containing the example ID.
- Parameters:
example_id – The ID to find.
- Returns:
The sub-dataset containing the ID.
Warning
Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned.
- get_dataset_by_idx(idx: int) Dataset [source]#
Retrieves the dataset containing the index.
- Parameters:
idx – The index to find.
- Returns:
The sub-dataset containing the index.
- Raises:
ValueError – If the index is out of bounds.
- id_to_idx(example_id: str) int [source]#
Retrieves the index corresponding to the example ID.
- Parameters:
example_id – The ID to convert.
- Returns:
The corresponding index.
- Raises:
ValueError – If the example ID is not found.
Warning
Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned.
- class atomworks.ml.datasets.datasets.ExampleIDMixin[source]#
Bases:
ABC
Mixin providing example ID functionality to a Dataset.
Provides methods for converting between example IDs and indices, and checking if an example ID exists in the dataset.
- class atomworks.ml.datasets.datasets.FallbackDatasetWrapper(dataset: Dataset, fallback_dataset: Dataset)[source]#
Bases:
Dataset
A wrapper around a dataset that allows for a fallback dataset to be used when an error occurs.
Meant to be used with a FallbackSamplerWrapper.
- class atomworks.ml.datasets.datasets.FileDataset(*, file_paths: list[str | PathLike], name: str, filter_fn: Callable[[PathLike], bool] | None = None, **kwargs: Any)[source]#
Bases:
MolecularDataset
,ExampleIDMixin
Dataset that loads molecular data from individual files.
Each file represents one example in the dataset. If creating a dataset from a directory, use the
from_directory()
class method instead of the default constructor.- classmethod from_directory(*, directory: PathLike, name: str, max_depth: int = 3, **kwargs: Any) FileDataset [source]#
Create a FileDataset by scanning a directory for files.
- Parameters:
directory – Path to directory to scan for files.
name – Descriptive name for this dataset.
max_depth – Maximum depth to scan for files in subdirectories.
**kwargs – Additional arguments passed to
FileDataset
.
- Returns:
FileDataset instance with files discovered from the directory.
Example
- Create from directory:
>>> dataset = FileDataset.from_directory(directory="/path/to/files", name="my_dataset", max_depth=2)
- classmethod from_file_list(*, file_paths: list[str | PathLike], name: str, **kwargs: Any) FileDataset [source]#
Create a FileDataset from an explicit list of file paths.
This is an alias for the main constructor for clarity and consistency with
from_directory()
.- Parameters:
file_paths – List of file paths for the dataset. Each file represents one example.
name – Descriptive name for this dataset.
**kwargs – Additional arguments passed to
FileDataset
.
- Returns:
FileDataset instance with the provided file paths.
- class atomworks.ml.datasets.datasets.MolecularDataset(*, name: str, transform: Callable | None = None, loader: Callable | None = None, save_failed_examples_to_dir: str | Path | None = None)[source]#
Bases:
Dataset
Base class for AtomWorks molecular datasets.
Handles Transform pipelines and loader functionality for molecular data. Subclasses implement
__getitem__()
with their own data access patterns.
- class atomworks.ml.datasets.datasets.PandasDataset(*, data: DataFrame | PathLike, name: str, id_column: str | None = 'example_id', filters: list[str] | None = None, columns_to_load: list[str] | None = None, transform: Callable | None = None, loader: Callable | None = None, save_failed_examples_to_dir: str | Path | None = None, load_kwargs: dict | tuple | None = None)[source]#
Bases:
MolecularDataset
,ExampleIDMixin
Dataset for tabular data stored as pandas DataFrames.
Inherits all functionality from
MolecularDataset
with additional DataFrame-specific features for filtering and ID-based access.
- atomworks.ml.datasets.datasets.StructuralDatasetWrapper(dataset_parser: Callable, transform: Callable | None = None, dataset: PandasDataset | None = None, cif_parser_args: dict | None = None, save_failed_examples_to_dir: str | Path | None = None, **kwargs) PandasDataset [source]#
Backwards-compatible wrapper for the deprecated StructuralDatasetWrapper.
This function is deprecated and will be removed in a future version. Use
PandasDataset
with the appropriate loader function instead.- Parameters:
dataset_parser – The dataset parser to use (e.g., PNUnitsDFParser, InterfacesDFParser).
transform – Transform pipeline to apply to loaded data.
dataset – The underlying PandasDataset containing the tabular data.
cif_parser_args – Arguments to pass to the CIF parser.
save_failed_examples_to_dir – Directory to save failed examples for debugging.
**kwargs – Additional arguments passed to PandasDataset.
- Returns:
PandasDataset instance configured with the deprecated parameters.
- Raises:
ValueError – If dataset parameter is required but not provided.
- atomworks.ml.datasets.datasets.get_row_and_index_by_example_id(dataset: ExampleIDMixin, example_id: str) dict [source]#
Retrieve a row and its index from a nested dataset structure by its example ID.
- Parameters:
dataset – The dataset or concatenated dataset to search. Must have the id_to_idx method.
example_id – The example ID to search for.
- Returns:
Dictionary containing the row and global index corresponding to the example ID.
Functional Loaders#
Functional loader implementations for AtomWorks datasets.
Loaders are functions that process raw dataset output (e.g., pandas Series) into a Transform-ready format. E.g., converts what may be dataset-specific metadata into a standard format for use in AtomWorks Transform pipelines.
- atomworks.ml.datasets.loaders.create_base_loader(example_id_colname: str = 'example_id', path_colname: str = 'path', assembly_id_colname: str | None = 'assembly_id', attrs: dict | None = None, base_path: str = '', extension: str = '', sharding_pattern: str | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]] [source]#
Factory function that creates a base loader with common logic for many AtomWorks datasets.
- Parameters:
example_id_colname – Name of column containing unique example identifiers
path_colname – Name of column containing paths to structure files
assembly_id_colname – Optional column name containing assembly IDs. If None, assembly_id defaults to “1” for all examples.
attrs – Additional attributes to merge with highest precedence into the metadata hierarchy (and ultimately included in the output dictionary’s “extra_info” key).
base_path – Base path to prepend to file paths if not included in path column
extension – File extension to add/replace if not included in path column
sharding_pattern – Pattern for how files are organized in subdirectories, if not specified in the path - “/1:2/”: Use characters 1-2 for first directory level - “/1:2/0:2/”: Use chars 1-2 for first dir, then chars 0-2 for second dir - None: No sharding (default)
parser_args – Optional dictionary of arguments to pass to the CIF parser when loading the structure file.
- Returns:
A function that takes a pandas Series and returns a dictionary of the loaded structure.
- atomworks.ml.datasets.loaders.create_loader_with_interfaces_and_pn_units_to_score(example_id_colname: str = 'example_id', path_colname: str = 'path', assembly_id_colname: str | None = 'assembly_id', interfaces_to_score_colname: str | None = 'interfaces_to_score', pn_units_to_score_colname: str | None = 'pn_units_to_score', base_path: str = '', extension: str = '', sharding_pattern: str | None = None, attrs: dict | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]] [source]#
Factory function that creates a loader that adds interfaces and pn_units to score for validation datasets.
Example
>>> loader = create_loader_with_interfaces_and_pn_units_to_score( ... interfaces_to_score_colname="interfaces_to_score", pn_units_to_score_colname="pn_units_to_score" ... )
- atomworks.ml.datasets.loaders.create_loader_with_query_pn_units(example_id_colname: str = 'example_id', path_colname: str = 'path', pn_unit_iid_colnames: str | list[str] | None = None, assembly_id_colname: str | None = 'assembly_id', base_path: str = '', extension: str = '', sharding_pattern: str | None = None, attrs: dict | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]] [source]#
Factory function that creates a generic loader for pipelines with query pn_units (chains).
For instance, in the interfaces dataset, each sampled row contains two pn_unit instance IDs that should be included in the cropped structure.
Examples
- Interfaces dataset:
>>> loader = create_loader_with_query_pn_units( ... pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], assembly_id_colname="assembly_id" ... )
- Chains dataset:
>>> loader = create_loader_with_query_pn_units( ... pn_unit_iid_colnames="pn_unit_iid", base_path="/data/structures", extension=".cif.gz" ... )