Datasets#

This module contains dataset classes and utilities for loading and processing molecular data using a modern, composable architecture.

Core Dataset Classes#

AtomWorks Dataset classes and common APIs.

At a high level, to train models with AtomWorks, we need a Dataset class that:
  1. Takes as input an item index and returns the corresponding example information; typically includes: a. Path to a structural file saved on disk (/path/to/dataset/my_dataset_0.cif) b. Additional item-specific metadata (e.g., class labels)

  2. Pre-loads structural information from the returned example into an AtomArray and assembles inputs for the Transform pipeline

  3. Feed the input dictionary through a Transform pipeline and return the result

Due to the heterogeneity of biomolecular data, in many cases, we may also want:
  1. In the event of a failure during the Transform pipeline, fall back to a different example

For bespoke use cases, users may choose to write a custom Dataset that accomplish these steps; downstream code makes no assumptions.

To accelerate development, we also provide an off-the-shelf, composable approach following common patterns:
  • MolecularDataset: Base class that handles pre-loading structural data and executing the Transform pipeline with error handling and debugging utilities

  • PandasDataset: A subclass of MolecularDataset for tabular data stored as pandas DataFrames

  • FileDataset: A subclass of MolecularDataset where each file is one example

class atomworks.ml.datasets.datasets.ConcatDatasetWithID(datasets: list[ExampleIDMixin])[source]#

Bases: ConcatDataset

Equivalent to torch.utils.data.ConcatDataset but allows accessing examples by ID.

Provides ID-based access across multiple datasets that implement ExampleIDMixin.

datasets: list[ExampleIDMixin]#
get_dataset_by_id(example_id: str) Dataset[source]#

Retrieves the dataset containing the example ID.

Parameters:

example_id – The ID to find.

Returns:

The sub-dataset containing the ID.

Warning

Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned.

get_dataset_by_idx(idx: int) Dataset[source]#

Retrieves the dataset containing the index.

Parameters:

idx – The index to find.

Returns:

The sub-dataset containing the index.

Raises:

ValueError – If the index is out of bounds.

id_to_idx(example_id: str) int[source]#

Retrieves the index corresponding to the example ID.

Parameters:

example_id – The ID to convert.

Returns:

The corresponding index.

Raises:

ValueError – If the example ID is not found.

Warning

Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned.

idx_to_id(idx: int) str[source]#

Retrieves the example ID corresponding to the index.

Parameters:

idx – The index to convert.

Returns:

The corresponding example ID.

Raises:

ValueError – If the index is out of bounds.

class atomworks.ml.datasets.datasets.ExampleIDMixin[source]#

Bases: ABC

Mixin providing example ID functionality to a Dataset.

Provides methods for converting between example IDs and indices, and checking if an example ID exists in the dataset.

abstract id_to_idx(example_id: str | list[str]) int | list[int][source]#

Convert example ID(s) to index(es).

Parameters:

example_id – Single ID or list of IDs to convert.

Returns:

Corresponding index or list of indices.

abstract idx_to_id(idx: int | list[int]) str | list[str][source]#

Convert index(es) to example ID(s).

Parameters:

idx – Single index or list of indices to convert.

Returns:

Corresponding ID or list of IDs.

class atomworks.ml.datasets.datasets.FallbackDatasetWrapper(dataset: Dataset, fallback_dataset: Dataset)[source]#

Bases: Dataset

A wrapper around a dataset that allows for a fallback dataset to be used when an error occurs.

Meant to be used with a FallbackSamplerWrapper.

class atomworks.ml.datasets.datasets.FileDataset(*, file_paths: list[str | PathLike], name: str, filter_fn: Callable[[PathLike], bool] | None = None, **kwargs: Any)[source]#

Bases: MolecularDataset, ExampleIDMixin

Dataset that loads molecular data from individual files.

Each file represents one example in the dataset. If creating a dataset from a directory, use the from_directory() class method instead of the default constructor.

classmethod from_directory(*, directory: PathLike, name: str, max_depth: int = 3, **kwargs: Any) FileDataset[source]#

Create a FileDataset by scanning a directory for files.

Parameters:
  • directory – Path to directory to scan for files.

  • name – Descriptive name for this dataset.

  • max_depth – Maximum depth to scan for files in subdirectories.

  • **kwargs – Additional arguments passed to FileDataset.

Returns:

FileDataset instance with files discovered from the directory.

Example

Create from directory:
>>> dataset = FileDataset.from_directory(directory="/path/to/files", name="my_dataset", max_depth=2)
classmethod from_file_list(*, file_paths: list[str | PathLike], name: str, **kwargs: Any) FileDataset[source]#

Create a FileDataset from an explicit list of file paths.

This is an alias for the main constructor for clarity and consistency with from_directory().

Parameters:
  • file_paths – List of file paths for the dataset. Each file represents one example.

  • name – Descriptive name for this dataset.

  • **kwargs – Additional arguments passed to FileDataset.

Returns:

FileDataset instance with the provided file paths.

id_to_idx(example_id: str | list[str]) int | list[int][source]#

Convert example ID(s) to index(es).

idx_to_id(idx: int | list[int]) str | list[str][source]#

Convert index(es) to example ID(s).

class atomworks.ml.datasets.datasets.MolecularDataset(*, name: str, transform: Callable | None = None, loader: Callable | None = None, save_failed_examples_to_dir: str | Path | None = None)[source]#

Bases: Dataset

Base class for AtomWorks molecular datasets.

Handles Transform pipelines and loader functionality for molecular data. Subclasses implement __getitem__() with their own data access patterns.

class atomworks.ml.datasets.datasets.PandasDataset(*, data: DataFrame | PathLike, name: str, id_column: str | None = 'example_id', filters: list[str] | None = None, columns_to_load: list[str] | None = None, transform: Callable | None = None, loader: Callable | None = None, save_failed_examples_to_dir: str | Path | None = None, load_kwargs: dict | tuple | None = None)[source]#

Bases: MolecularDataset, ExampleIDMixin

Dataset for tabular data stored as pandas DataFrames.

Inherits all functionality from MolecularDataset with additional DataFrame-specific features for filtering and ID-based access.

id_to_idx(example_id: str | list[str]) int | list[int][source]#

Convert an example ID to the corresponding local index.

idx_to_id(idx: int | list[int]) str | ndarray[source]#

Convert a local index to the corresponding example ID.

atomworks.ml.datasets.datasets.StructuralDatasetWrapper(dataset_parser: Callable, transform: Callable | None = None, dataset: PandasDataset | None = None, cif_parser_args: dict | None = None, save_failed_examples_to_dir: str | Path | None = None, **kwargs) PandasDataset[source]#

Backwards-compatible wrapper for the deprecated StructuralDatasetWrapper.

This function is deprecated and will be removed in a future version. Use PandasDataset with the appropriate loader function instead.

Parameters:
  • dataset_parser – The dataset parser to use (e.g., PNUnitsDFParser, InterfacesDFParser).

  • transform – Transform pipeline to apply to loaded data.

  • dataset – The underlying PandasDataset containing the tabular data.

  • cif_parser_args – Arguments to pass to the CIF parser.

  • save_failed_examples_to_dir – Directory to save failed examples for debugging.

  • **kwargs – Additional arguments passed to PandasDataset.

Returns:

PandasDataset instance configured with the deprecated parameters.

Raises:

ValueError – If dataset parameter is required but not provided.

atomworks.ml.datasets.datasets.get_row_and_index_by_example_id(dataset: ExampleIDMixin, example_id: str) dict[source]#

Retrieve a row and its index from a nested dataset structure by its example ID.

Parameters:
  • dataset – The dataset or concatenated dataset to search. Must have the id_to_idx method.

  • example_id – The example ID to search for.

Returns:

Dictionary containing the row and global index corresponding to the example ID.

Functional Loaders#

Functional loader implementations for AtomWorks datasets.

Loaders are functions that process raw dataset output (e.g., pandas Series) into a Transform-ready format. E.g., converts what may be dataset-specific metadata into a standard format for use in AtomWorks Transform pipelines.

atomworks.ml.datasets.loaders.create_base_loader(example_id_colname: str = 'example_id', path_colname: str = 'path', assembly_id_colname: str | None = 'assembly_id', attrs: dict | None = None, base_path: str = '', extension: str = '', sharding_pattern: str | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]][source]#

Factory function that creates a base loader with common logic for many AtomWorks datasets.

Parameters:
  • example_id_colname – Name of column containing unique example identifiers

  • path_colname – Name of column containing paths to structure files

  • assembly_id_colname – Optional column name containing assembly IDs. If None, assembly_id defaults to “1” for all examples.

  • attrs – Additional attributes to merge with highest precedence into the metadata hierarchy (and ultimately included in the output dictionary’s “extra_info” key).

  • base_path – Base path to prepend to file paths if not included in path column

  • extension – File extension to add/replace if not included in path column

  • sharding_pattern – Pattern for how files are organized in subdirectories, if not specified in the path - “/1:2/”: Use characters 1-2 for first directory level - “/1:2/0:2/”: Use chars 1-2 for first dir, then chars 0-2 for second dir - None: No sharding (default)

  • parser_args – Optional dictionary of arguments to pass to the CIF parser when loading the structure file.

Returns:

A function that takes a pandas Series and returns a dictionary of the loaded structure.

atomworks.ml.datasets.loaders.create_loader_with_interfaces_and_pn_units_to_score(example_id_colname: str = 'example_id', path_colname: str = 'path', assembly_id_colname: str | None = 'assembly_id', interfaces_to_score_colname: str | None = 'interfaces_to_score', pn_units_to_score_colname: str | None = 'pn_units_to_score', base_path: str = '', extension: str = '', sharding_pattern: str | None = None, attrs: dict | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]][source]#

Factory function that creates a loader that adds interfaces and pn_units to score for validation datasets.

Example

>>> loader = create_loader_with_interfaces_and_pn_units_to_score(
...     interfaces_to_score_colname="interfaces_to_score", pn_units_to_score_colname="pn_units_to_score"
... )
atomworks.ml.datasets.loaders.create_loader_with_query_pn_units(example_id_colname: str = 'example_id', path_colname: str = 'path', pn_unit_iid_colnames: str | list[str] | None = None, assembly_id_colname: str | None = 'assembly_id', base_path: str = '', extension: str = '', sharding_pattern: str | None = None, attrs: dict | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]][source]#

Factory function that creates a generic loader for pipelines with query pn_units (chains).

For instance, in the interfaces dataset, each sampled row contains two pn_unit instance IDs that should be included in the cropped structure.

Examples

Interfaces dataset:
>>> loader = create_loader_with_query_pn_units(
...     pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], assembly_id_colname="assembly_id"
... )
Chains dataset:
>>> loader = create_loader_with_query_pn_units(
...     pn_unit_iid_colnames="pn_unit_iid", base_path="/data/structures", extension=".cif.gz"
... )

Dataset Architecture and Migration Guide#