Dataset Architecture#

AtomWorks provides a modern, composable dataset architecture that separates data loading, processing, and transformation concerns. This approach replaces the legacy parser-based system with functional loaders and transform pipelines.

Warning

The metadata parser system (atomworks.ml.datasets.parsers) is deprecated and will be removed in a future version. Use the new loader-based approach with FileDataset and PandasDataset instead.

Modern Dataset Architecture#

The current AtomWorks dataset system consists of three main components:

  1. Datasets: Container classes that manage data access and indexing

  2. Loaders: Functions that process raw data into transform-ready format

  3. Transforms: Pipelines that convert loaded data into model inputs

Dataset Classes#

AtomWorks Dataset classes and common APIs.

At a high level, to train models with AtomWorks, we need a Dataset class that:
  1. Takes as input an item index and returns the corresponding example information; typically includes: a. Path to a structural file saved on disk (/path/to/dataset/my_dataset_0.cif) b. Additional item-specific metadata (e.g., class labels)

  2. Pre-loads structural information from the returned example into an AtomArray and assembles inputs for the Transform pipeline

  3. Feed the input dictionary through a Transform pipeline and return the result

Due to the heterogeneity of biomolecular data, in many cases, we may also want:
  1. In the event of a failure during the Transform pipeline, fall back to a different example

For bespoke use cases, users may choose to write a custom Dataset that accomplish these steps; downstream code makes no assumptions.

To accelerate development, we also provide an off-the-shelf, composable approach following common patterns:
  • MolecularDataset: Base class that handles pre-loading structural data and executing the Transform pipeline with error handling and debugging utilities

  • PandasDataset: A subclass of MolecularDataset for tabular data stored as pandas DataFrames

  • FileDataset: A subclass of MolecularDataset where each file is one example

class atomworks.ml.datasets.datasets.ConcatDatasetWithID(datasets: list[ExampleIDMixin])[source]#

Bases: ConcatDataset

Equivalent to torch.utils.data.ConcatDataset but allows accessing examples by ID.

Provides ID-based access across multiple datasets that implement ExampleIDMixin.

cumulative_sizes: list[int]#
datasets: list[ExampleIDMixin]#
get_dataset_by_id(example_id: str) Dataset[source]#

Retrieves the dataset containing the example ID.

Parameters:

example_id – The ID to find.

Returns:

The sub-dataset containing the ID.

Warning

Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned.

get_dataset_by_idx(idx: int) Dataset[source]#

Retrieves the dataset containing the index.

Parameters:

idx – The index to find.

Returns:

The sub-dataset containing the index.

Raises:

ValueError – If the index is out of bounds.

id_to_idx(example_id: str) int[source]#

Retrieves the index corresponding to the example ID.

Parameters:

example_id – The ID to convert.

Returns:

The corresponding index.

Raises:

ValueError – If the example ID is not found.

Warning

Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned.

idx_to_id(idx: int) str[source]#

Retrieves the example ID corresponding to the index.

Parameters:

idx – The index to convert.

Returns:

The corresponding example ID.

Raises:

ValueError – If the index is out of bounds.

class atomworks.ml.datasets.datasets.ExampleIDMixin[source]#

Bases: ABC

Mixin providing example ID functionality to a Dataset.

Provides methods for converting between example IDs and indices, and checking if an example ID exists in the dataset.

abstract id_to_idx(example_id: str | list[str]) int | list[int][source]#

Convert example ID(s) to index(es).

Parameters:

example_id – Single ID or list of IDs to convert.

Returns:

Corresponding index or list of indices.

abstract idx_to_id(idx: int | list[int]) str | list[str][source]#

Convert index(es) to example ID(s).

Parameters:

idx – Single index or list of indices to convert.

Returns:

Corresponding ID or list of IDs.

class atomworks.ml.datasets.datasets.FallbackDatasetWrapper(dataset: Dataset, fallback_dataset: Dataset)[source]#

Bases: Dataset

A wrapper around a dataset that allows for a fallback dataset to be used when an error occurs.

Meant to be used with a FallbackSamplerWrapper.

class atomworks.ml.datasets.datasets.FileDataset(*, file_paths: list[str | PathLike], name: str, filter_fn: Callable[[PathLike], bool] | None = None, **kwargs: Any)[source]#

Bases: MolecularDataset, ExampleIDMixin

Dataset that loads molecular data from individual files.

Each file represents one example in the dataset. If creating a dataset from a directory, use the from_directory() class method instead of the default constructor.

classmethod from_directory(*, directory: PathLike, name: str, max_depth: int = 3, **kwargs: Any) FileDataset[source]#

Create a FileDataset by scanning a directory for files.

Parameters:
  • directory – Path to directory to scan for files.

  • name – Descriptive name for this dataset.

  • max_depth – Maximum depth to scan for files in subdirectories.

  • **kwargs – Additional arguments passed to FileDataset.

Returns:

FileDataset instance with files discovered from the directory.

Example

Create from directory:
>>> dataset = FileDataset.from_directory(directory="/path/to/files", name="my_dataset", max_depth=2)
classmethod from_file_list(*, file_paths: list[str | PathLike], name: str, **kwargs: Any) FileDataset[source]#

Create a FileDataset from an explicit list of file paths.

This is an alias for the main constructor for clarity and consistency with from_directory().

Parameters:
  • file_paths – List of file paths for the dataset. Each file represents one example.

  • name – Descriptive name for this dataset.

  • **kwargs – Additional arguments passed to FileDataset.

Returns:

FileDataset instance with the provided file paths.

id_to_idx(example_id: str | list[str]) int | list[int][source]#

Convert example ID(s) to index(es).

idx_to_id(idx: int | list[int]) str | list[str][source]#

Convert index(es) to example ID(s).

class atomworks.ml.datasets.datasets.MolecularDataset(*, name: str, transform: Callable | None = None, loader: Callable | None = None, save_failed_examples_to_dir: str | Path | None = None)[source]#

Bases: Dataset

Base class for AtomWorks molecular datasets.

Handles Transform pipelines and loader functionality for molecular data. Subclasses implement __getitem__() with their own data access patterns.

class atomworks.ml.datasets.datasets.PandasDataset(*, data: DataFrame | PathLike, name: str, id_column: str | None = 'example_id', filters: list[str] | None = None, columns_to_load: list[str] | None = None, transform: Callable | None = None, loader: Callable | None = None, save_failed_examples_to_dir: str | Path | None = None, load_kwargs: dict | tuple | None = None)[source]#

Bases: MolecularDataset, ExampleIDMixin

Dataset for tabular data stored as pandas DataFrames.

Inherits all functionality from MolecularDataset with additional DataFrame-specific features for filtering and ID-based access.

id_to_idx(example_id: str | list[str]) int | list[int][source]#

Convert an example ID to the corresponding local index.

idx_to_id(idx: int | list[int]) str | ndarray[source]#

Convert a local index to the corresponding example ID.

atomworks.ml.datasets.datasets.StructuralDatasetWrapper(dataset_parser: Callable, transform: Callable | None = None, dataset: PandasDataset | None = None, cif_parser_args: dict | None = None, save_failed_examples_to_dir: str | Path | None = None, **kwargs) PandasDataset[source]#

Backwards-compatible wrapper for the deprecated StructuralDatasetWrapper.

This function is deprecated and will be removed in a future version. Use PandasDataset with the appropriate loader function instead.

Parameters:
  • dataset_parser – The dataset parser to use (e.g., PNUnitsDFParser, InterfacesDFParser).

  • transform – Transform pipeline to apply to loaded data.

  • dataset – The underlying PandasDataset containing the tabular data.

  • cif_parser_args – Arguments to pass to the CIF parser.

  • save_failed_examples_to_dir – Directory to save failed examples for debugging.

  • **kwargs – Additional arguments passed to PandasDataset.

Returns:

PandasDataset instance configured with the deprecated parameters.

Raises:

ValueError – If dataset parameter is required but not provided.

atomworks.ml.datasets.datasets.get_row_and_index_by_example_id(dataset: ExampleIDMixin, example_id: str) dict[source]#

Retrieve a row and its index from a nested dataset structure by its example ID.

Parameters:
  • dataset – The dataset or concatenated dataset to search. Must have the id_to_idx method.

  • example_id – The example ID to search for.

Returns:

Dictionary containing the row and global index corresponding to the example ID.

Functional Loaders#

Loaders are functions that process raw dataset output (e.g., pandas Series) into a Transform-ready format. They replace the legacy parser classes with a more flexible, functional approach.

Functional loader implementations for AtomWorks datasets.

Loaders are functions that process raw dataset output (e.g., pandas Series) into a Transform-ready format. E.g., converts what may be dataset-specific metadata into a standard format for use in AtomWorks Transform pipelines.

atomworks.ml.datasets.loaders.create_base_loader(example_id_colname: str = 'example_id', path_colname: str = 'path', assembly_id_colname: str | None = 'assembly_id', attrs: dict | None = None, base_path: str = '', extension: str = '', sharding_pattern: str | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]][source]#

Factory function that creates a base loader with common logic for many AtomWorks datasets.

Parameters:
  • example_id_colname – Name of column containing unique example identifiers

  • path_colname – Name of column containing paths to structure files

  • assembly_id_colname – Optional column name containing assembly IDs. If None, assembly_id defaults to “1” for all examples.

  • attrs – Additional attributes to merge with highest precedence into the metadata hierarchy (and ultimately included in the output dictionary’s “extra_info” key).

  • base_path – Base path to prepend to file paths if not included in path column

  • extension – File extension to add/replace if not included in path column

  • sharding_pattern – Pattern for how files are organized in subdirectories, if not specified in the path - “/1:2/”: Use characters 1-2 for first directory level - “/1:2/0:2/”: Use chars 1-2 for first dir, then chars 0-2 for second dir - None: No sharding (default)

  • parser_args – Optional dictionary of arguments to pass to the CIF parser when loading the structure file.

Returns:

A function that takes a pandas Series and returns a dictionary of the loaded structure.

atomworks.ml.datasets.loaders.create_loader_with_interfaces_and_pn_units_to_score(example_id_colname: str = 'example_id', path_colname: str = 'path', assembly_id_colname: str | None = 'assembly_id', interfaces_to_score_colname: str | None = 'interfaces_to_score', pn_units_to_score_colname: str | None = 'pn_units_to_score', base_path: str = '', extension: str = '', sharding_pattern: str | None = None, attrs: dict | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]][source]#

Factory function that creates a loader that adds interfaces and pn_units to score for validation datasets.

Example

>>> loader = create_loader_with_interfaces_and_pn_units_to_score(
...     interfaces_to_score_colname="interfaces_to_score", pn_units_to_score_colname="pn_units_to_score"
... )
atomworks.ml.datasets.loaders.create_loader_with_query_pn_units(example_id_colname: str = 'example_id', path_colname: str = 'path', pn_unit_iid_colnames: str | list[str] | None = None, assembly_id_colname: str | None = 'assembly_id', base_path: str = '', extension: str = '', sharding_pattern: str | None = None, attrs: dict | None = None, parser_args: dict | None = None) Callable[[Series], dict[str, Any]][source]#

Factory function that creates a generic loader for pipelines with query pn_units (chains).

For instance, in the interfaces dataset, each sampled row contains two pn_unit instance IDs that should be included in the cropped structure.

Examples

Interfaces dataset:
>>> loader = create_loader_with_query_pn_units(
...     pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], assembly_id_colname="assembly_id"
... )
Chains dataset:
>>> loader = create_loader_with_query_pn_units(
...     pn_unit_iid_colnames="pn_unit_iid", base_path="/data/structures", extension=".cif.gz"
... )

Basic Usage Examples#

File-based datasets (replacing simple file parsers):

from atomworks.ml.datasets.datasets import FileDataset
from atomworks.io import parse

def simple_loading_fn(raw_data) -> dict:
    """Simple loading function that parses structural data."""
    parse_output = parse(raw_data)
    return {"atom_array": parse_output["assemblies"]["1"][0]}

dataset = FileDataset.from_directory(
    directory="/path/to/structures",
    name="my_dataset",
    loader=simple_loading_fn
)

Tabular datasets (replacing metadata parsers):

from atomworks.ml.datasets.datasets import PandasDataset
from atomworks.ml.datasets.loaders import loader_with_query_pn_units

dataset = PandasDataset(
    data="metadata.parquet",
    name="interfaces_dataset",
    loader=loader_with_query_pn_units(
        pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"]
    )
)

Custom loaders for specialized use cases:

def custom_loader(row: pd.Series) -> dict:
    """Custom loader with specific processing logic."""
    # Load structure
    structure_path = Path(row["path"])
    parse_output = parse(structure_path)

    # Extract specific metadata
    metadata = {
        "resolution": row.get("resolution", None),
        "method": row.get("method", "unknown"),
        "custom_field": row.get("custom_field", "default_value")
    }

    return {
        "atom_array": parse_output["assemblies"]["1"][0],
        "extra_info": metadata,
        "example_id": row["example_id"]
    }

dataset = PandasDataset(
    data=my_dataframe,
    name="custom_dataset",
    loader=custom_loader
)

Common Loader Patterns#

Base loader for standard structure loading:

from atomworks.ml.datasets.loaders import loader_base

loader = loader_base(
    example_id_colname="example_id",
    path_colname="path",
    assembly_id_colname="assembly_id",
    base_path="/data/structures",
    extension=".cif"
)

Interface loader for protein-protein interfaces:

from atomworks.ml.datasets.loaders import loader_with_query_pn_units

loader = loader_with_query_pn_units(
    pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"],
    base_path="/data/pdb",
    extension=".cif.gz"
)

Validation loader with scoring targets:

from atomworks.ml.datasets.loaders import loader_with_interfaces_and_pn_units_to_score

loader = loader_with_interfaces_and_pn_units_to_score(
    interfaces_to_score_colname="interfaces_to_score",
    pn_units_to_score_colname="pn_units_to_score"
)

Integration with Transform Pipelines#

Loaders work seamlessly with AtomWorks transform pipelines. The loader output becomes the input to the transform pipeline:

from atomworks.ml.transforms.base import Compose
from atomworks.ml.transforms.crop import CropSpatialLikeAF3
from atomworks.ml.transforms.atom_array import AddGlobalAtomIdAnnotation

# Create a transform pipeline
transform_pipeline = Compose([
    AddGlobalAtomIdAnnotation(),
    CropSpatialLikeAF3(crop_size=256),
])

# Create dataset with both loader and transforms
dataset = PandasDataset(
    data="metadata.parquet",
    name="my_dataset",
    loader=loader_with_query_pn_units(
        pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"]
    ),
    transform=transform_pipeline
)

# Access processed data
example = dataset[0]  # Returns transformed data ready for model input

Data Flow#

The complete data flow in the new architecture is:

  1. Raw Data: File paths or DataFrame rows

  2. Loader: Processes raw data into standardized format with AtomArray

  3. Transform Pipeline: Converts loaded data into model-ready tensors

  4. Model Input: Final processed data ready for training/inference

This separation allows for: - Reusable loaders across different datasets - Composable transforms that can be mixed and matched - Easy testing of individual components - Clear debugging when issues arise