Skip to content

SCRNA API References

Trainer

protoplast.scrna.anndata.trainer.RayTrainRunner

A class to initialize the training this class automatically initializes Ray cluster or detect whether an existing cluster exist if there is an existing cluster it will automatically connect to it refer to ray.init() behavior

Parameters:

Name Type Description Default
Model type[LightningModule]

PyTorch Lightning model class

required
Ds type[DistributedAnnDataset]

DistributedAnnDataset class

required
model_keys list[str]

Keys to pass to model from metadata_cb

required
metadata_cb Callable[[AnnData, dict], None]

Callback to mutate metadata recommended for passing data from obs or var or any additional data your models required by default cell_line_metadata_cb

cell_line_metadata_cb
before_dense_cb Callable[[Tensor, str | int], Tensor]

Callback to perform before densification of sparse matrix where the data at this point is still a sparse CSR Tensor, by default None

None
after_dense_cb Callable[[Tensor, str | int], Tensor]

Callback to perform after densification of sparse matrix where the data at this point is a dense Tensor, by default None

None
shuffle_strategy ShuffleStrategy

Strategy to split or randomize the data during the training, by default SequentialShuffleStrategy

SequentialShuffleStrategy
runtime_env_config dict | None

These env config is to pass the RayTrainer processes, by default None

None
address str | None

Override ray address, by default None

None
ray_trainer_strategy Strategy | None

Override Ray Trainer Strategy if this is None it will default to RayDDP, by default None

None
sparse_key str

description, by default "X",

'X'

Returns:

Type Description
RayTrainRunner

Use this class to start the training

Source code in src/protoplast/scrna/anndata/trainer.py
class RayTrainRunner:
    """A class to initialize the training this class automatically initializes Ray cluster or
    detect whether an existing cluster exist if there is an existing cluster it will automatically
    connect to it refer to `ray.init()` behavior

    Parameters
    ----------
    Model : type[pl.LightningModule]
        PyTorch Lightning model class
    Ds : type[DistributedAnnDataset]
        DistributedAnnDataset class
    model_keys : list[str]
        Keys to pass to model from `metadata_cb`
    metadata_cb : Callable[[anndata.AnnData, dict], None], optional
        Callback to mutate metadata recommended for passing data from `obs` or `var`
        or any additional data your models required
        by default cell_line_metadata_cb
    before_dense_cb : Callable[[torch.Tensor, str  |  int], torch.Tensor], optional
        Callback to perform before densification of sparse matrix where the data at this point
        is still a sparse CSR Tensor, by default None
    after_dense_cb : Callable[[torch.Tensor, str  |  int], torch.Tensor], optional
        Callback to perform after densification of sparse matrix where the data at this point
        is a dense Tensor, by default None
    shuffle_strategy : ShuffleStrategy, optional
        Strategy to split or randomize the data during the training, by default SequentialShuffleStrategy
    runtime_env_config : dict | None, optional
        These env config is to pass the RayTrainer processes, by default None
    address : str | None, optional
        Override ray address, by default None
    ray_trainer_strategy : Strategy | None, optional
        Override Ray Trainer Strategy if this is None it will default to RayDDP, by default None
    sparse_key : str, optional
        _description_, by default "X",
    Returns
    -------
    RayTrainRunner
        Use this class to start the training

    """

    @beartype
    def __init__(
        self,
        Model: type[pl.LightningModule],
        Ds: type[DistributedAnnDataset],
        model_keys: list[str],
        metadata_cb: Callable[[anndata.AnnData, dict], None] = cell_line_metadata_cb,
        before_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        after_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        shuffle_strategy: ShuffleStrategy = SequentialShuffleStrategy,
        runtime_env_config: dict | None = None,
        address: str | None = None,
        ray_trainer_strategy: Strategy | None = None,
        sparse_key: str = "X",
    ):
        self.Model = Model
        self.Ds = Ds
        self.model_keys = model_keys
        self.metadata_cb = metadata_cb
        self.shuffle_strategy = shuffle_strategy
        self.sparse_key = sparse_key
        self.before_dense_cb = before_dense_cb
        self.after_dense_cb = after_dense_cb
        if not ray_trainer_strategy:
            self.ray_trainer_strategy = ray.train.lightning.RayDDPStrategy()
        else:
            self.ray_trainer_strategy = ray_trainer_strategy

        # Init ray cluster
        DEFAULT_RUNTIME_ENV_CONFIG = {
            "working_dir": os.getenv("PWD"),  # Allow ray workers to inherit venv at $PWD if there is any
        }
        if runtime_env_config is None:
            runtime_env_config = DEFAULT_RUNTIME_ENV_CONFIG
        ray.init(
            address=address, runtime_env={**DEFAULT_RUNTIME_ENV_CONFIG, **runtime_env_config}, ignore_reinit_error=True
        )

        self.resources = ray.cluster_resources()

    @beartype
    def train(
        self,
        file_paths: list[str],
        batch_size: int = 2000,
        test_size: float = 0.0,
        val_size: float = 0.2,
        prefetch_factor: int = 4,
        max_epochs: int = 1,
        thread_per_worker: int | None = None,
        num_workers: int | None = None,
        result_storage_path: str = "~/protoplast_results",
        # read more here: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit
        ckpt_path: str | None = None,
        is_gpu: bool = True,
        random_seed: int | None = 42,
        resource_per_worker: dict | None = None,
        is_shuffled: bool = False,
        **kwargs,
    ):
        """Start the training

        Parameters
        ----------
        file_paths : list[str]
            List of h5ad AnnData files
        batch_size : int, optional
            How much data to fetch from disk, by default to 2000
        test_size : float, optional
            Fraction of test data for example 0.1 means 10% will be split for testing
            default to 0.0
        val_size : float, optional
            Fraction of validation data for example 0.2 means 20% will be split for validation,
            default to 0.2
        prefetch_factor : int, optional
            Total data fetch is prefetch_factor * batch_size, by default 4
        max_epochs : int, optional
            How many epoch(s) to train with, by default 1
        thread_per_worker : int | None, optional
            Amount of worker for each dataloader, by default None
        num_workers : int | None, optional
            Override number of Ray processes default to number of GPU(s) in the cluster, by default None
        result_storage_path : str, optional
            Path to store the loss, validation and checkpoint, by default "~/protoplast_results"
        ckpt_path : str | None, optional
            Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the
            training from scratch, by default None
        is_gpu : bool, optional
            By default True turn this off if your system don't have any GPU, by default True
        random_seed : int | None, optional
            Set this to None for real training but for benchmarking and result replication
            you can adjust the seed here, by default 42
        resource_per_worker : dict | None, optional
            This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None
        Returns
        -------
        Result
            The training result from RayTrainer
        """
        self.result_storage_path = result_storage_path
        self.prefetch_factor = prefetch_factor
        self.max_epochs = max_epochs
        self.kwargs = kwargs
        if not resource_per_worker:
            if not thread_per_worker:
                print("Setting thread_per_worker to half of the available CPUs capped at 4")
                thread_per_worker = min(int((self.resources.get("CPU", 2) - 1) / 2), 4)
            resource_per_worker = {"CPU": thread_per_worker}
        if is_gpu and self.resources.get("GPU", 0) == 0:
            warnings.warn("`is_gpu = True` but there is no GPU found. Fallback to CPU.", UserWarning, stacklevel=2)
            is_gpu = False
        if is_gpu:
            if num_workers is None:
                num_workers = int(self.resources.get("GPU"))
            scaling_config = ray.train.ScalingConfig(
                num_workers=num_workers, use_gpu=True, resources_per_worker=resource_per_worker
            )
        else:
            if num_workers is None:
                num_workers = max(int((self.resources.get("CPU", 2) - 1) / thread_per_worker), 1)
            scaling_config = ray.train.ScalingConfig(
                num_workers=num_workers, use_gpu=False, resources_per_worker=resource_per_worker
            )
        print(f"Using {num_workers} workers with {resource_per_worker} each")
        start = time.time()
        shuffle_strategy = self.shuffle_strategy(
            file_paths,
            batch_size,
            num_workers * thread_per_worker,
            test_size,
            val_size,
            random_seed,
            metadata_cb=self.metadata_cb,
            is_shuffled=is_shuffled,
            **kwargs,
        )
        kwargs.pop("drop_last", None)
        kwargs.pop("pre_fetch_then_batch", None)
        indices = shuffle_strategy.split()
        print(f"Data splitting time: {time.time() - start:.2f} seconds")
        train_config = {"indices": indices, "ckpt_path": ckpt_path, "shuffle_strategy": shuffle_strategy}
        my_train_func = self._trainer()
        par_trainer = ray.train.torch.TorchTrainer(
            my_train_func,
            scaling_config=scaling_config,
            train_loop_config=train_config,
            run_config=ray.train.RunConfig(storage_path=self.result_storage_path),
        )
        print("Spawning Ray worker and initiating distributed training")
        return par_trainer.fit()

    def _trainer(self):
        Model, Ds, model_keys = self.Model, self.Ds, self.model_keys

        def anndata_train_func(config):
            ctx = ray.train.get_context()
            if ctx:
                rank = ctx.get_world_rank()
            else:
                rank = 0
            indices = config.get("indices")
            ckpt_path = config.get("ckpt_path")
            num_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count()))
            print(f"=========Starting the training on {rank} with num threads: {num_threads}=========")
            model_params = indices.metadata
            shuffle_strategy = config.get("shuffle_strategy")
            ann_dm = AnnDataModule(
                indices,
                Ds,
                self.prefetch_factor,
                self.sparse_key,
                shuffle_strategy,
                self.before_dense_cb,
                self.after_dense_cb,
                **self.kwargs,
            )
            if model_keys:
                model_params = {k: v for k, v in model_params.items() if k in model_keys}
            model = Model(**model_params)
            trainer = pl.Trainer(
                max_epochs=self.max_epochs,
                devices="auto",
                accelerator="auto",
                strategy=self.ray_trainer_strategy,
                plugins=[ray.train.lightning.RayLightningEnvironment()],
                callbacks=[ray.train.lightning.RayTrainReportCallback()],
                enable_checkpointing=True,
            )
            trainer = ray.train.lightning.prepare_trainer(trainer)
            trainer.fit(model, datamodule=ann_dm, ckpt_path=ckpt_path)

        return anndata_train_func

train(file_paths: list[str], batch_size: int = 2000, test_size: float = 0.0, val_size: float = 0.2, prefetch_factor: int = 4, max_epochs: int = 1, thread_per_worker: int | None = None, num_workers: int | None = None, result_storage_path: str = '~/protoplast_results', ckpt_path: str | None = None, is_gpu: bool = True, random_seed: int | None = 42, resource_per_worker: dict | None = None, is_shuffled: bool = False, **kwargs)

Start the training

Parameters:

Name Type Description Default
file_paths list[str]

List of h5ad AnnData files

required
batch_size int

How much data to fetch from disk, by default to 2000

2000
test_size float

Fraction of test data for example 0.1 means 10% will be split for testing default to 0.0

0.0
val_size float

Fraction of validation data for example 0.2 means 20% will be split for validation, default to 0.2

0.2
prefetch_factor int

Total data fetch is prefetch_factor * batch_size, by default 4

4
max_epochs int

How many epoch(s) to train with, by default 1

1
thread_per_worker int | None

Amount of worker for each dataloader, by default None

None
num_workers int | None

Override number of Ray processes default to number of GPU(s) in the cluster, by default None

None
result_storage_path str

Path to store the loss, validation and checkpoint, by default "~/protoplast_results"

'~/protoplast_results'
ckpt_path str | None

Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the training from scratch, by default None

None
is_gpu bool

By default True turn this off if your system don't have any GPU, by default True

True
random_seed int | None

Set this to None for real training but for benchmarking and result replication you can adjust the seed here, by default 42

42
resource_per_worker dict | None

This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None

None

Returns:

Type Description
Result

The training result from RayTrainer

Source code in src/protoplast/scrna/anndata/trainer.py
@beartype
def train(
    self,
    file_paths: list[str],
    batch_size: int = 2000,
    test_size: float = 0.0,
    val_size: float = 0.2,
    prefetch_factor: int = 4,
    max_epochs: int = 1,
    thread_per_worker: int | None = None,
    num_workers: int | None = None,
    result_storage_path: str = "~/protoplast_results",
    # read more here: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit
    ckpt_path: str | None = None,
    is_gpu: bool = True,
    random_seed: int | None = 42,
    resource_per_worker: dict | None = None,
    is_shuffled: bool = False,
    **kwargs,
):
    """Start the training

    Parameters
    ----------
    file_paths : list[str]
        List of h5ad AnnData files
    batch_size : int, optional
        How much data to fetch from disk, by default to 2000
    test_size : float, optional
        Fraction of test data for example 0.1 means 10% will be split for testing
        default to 0.0
    val_size : float, optional
        Fraction of validation data for example 0.2 means 20% will be split for validation,
        default to 0.2
    prefetch_factor : int, optional
        Total data fetch is prefetch_factor * batch_size, by default 4
    max_epochs : int, optional
        How many epoch(s) to train with, by default 1
    thread_per_worker : int | None, optional
        Amount of worker for each dataloader, by default None
    num_workers : int | None, optional
        Override number of Ray processes default to number of GPU(s) in the cluster, by default None
    result_storage_path : str, optional
        Path to store the loss, validation and checkpoint, by default "~/protoplast_results"
    ckpt_path : str | None, optional
        Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the
        training from scratch, by default None
    is_gpu : bool, optional
        By default True turn this off if your system don't have any GPU, by default True
    random_seed : int | None, optional
        Set this to None for real training but for benchmarking and result replication
        you can adjust the seed here, by default 42
    resource_per_worker : dict | None, optional
        This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None
    Returns
    -------
    Result
        The training result from RayTrainer
    """
    self.result_storage_path = result_storage_path
    self.prefetch_factor = prefetch_factor
    self.max_epochs = max_epochs
    self.kwargs = kwargs
    if not resource_per_worker:
        if not thread_per_worker:
            print("Setting thread_per_worker to half of the available CPUs capped at 4")
            thread_per_worker = min(int((self.resources.get("CPU", 2) - 1) / 2), 4)
        resource_per_worker = {"CPU": thread_per_worker}
    if is_gpu and self.resources.get("GPU", 0) == 0:
        warnings.warn("`is_gpu = True` but there is no GPU found. Fallback to CPU.", UserWarning, stacklevel=2)
        is_gpu = False
    if is_gpu:
        if num_workers is None:
            num_workers = int(self.resources.get("GPU"))
        scaling_config = ray.train.ScalingConfig(
            num_workers=num_workers, use_gpu=True, resources_per_worker=resource_per_worker
        )
    else:
        if num_workers is None:
            num_workers = max(int((self.resources.get("CPU", 2) - 1) / thread_per_worker), 1)
        scaling_config = ray.train.ScalingConfig(
            num_workers=num_workers, use_gpu=False, resources_per_worker=resource_per_worker
        )
    print(f"Using {num_workers} workers with {resource_per_worker} each")
    start = time.time()
    shuffle_strategy = self.shuffle_strategy(
        file_paths,
        batch_size,
        num_workers * thread_per_worker,
        test_size,
        val_size,
        random_seed,
        metadata_cb=self.metadata_cb,
        is_shuffled=is_shuffled,
        **kwargs,
    )
    kwargs.pop("drop_last", None)
    kwargs.pop("pre_fetch_then_batch", None)
    indices = shuffle_strategy.split()
    print(f"Data splitting time: {time.time() - start:.2f} seconds")
    train_config = {"indices": indices, "ckpt_path": ckpt_path, "shuffle_strategy": shuffle_strategy}
    my_train_func = self._trainer()
    par_trainer = ray.train.torch.TorchTrainer(
        my_train_func,
        scaling_config=scaling_config,
        train_loop_config=train_config,
        run_config=ray.train.RunConfig(storage_path=self.result_storage_path),
    )
    print("Spawning Ray worker and initiating distributed training")
    return par_trainer.fit()

DataModule

Wrapper around Dataset on how the data should be forward to the Lightning Model support hooks at various Lifecycle when the data get pass to the model

protoplast.scrna.anndata.torch_dataloader.AnnDataModule

Bases: LightningDataModule

Source code in src/protoplast/scrna/anndata/torch_dataloader.py
class AnnDataModule(pl.LightningDataModule):
    def __init__(
        self,
        indices: dict,
        dataset: DistributedAnnDataset,
        prefetch_factor: int,
        sparse_key: str,
        shuffle_strategy: ShuffleStrategy,
        before_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        after_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        **kwargs,
    ):
        super().__init__()
        self.indices = indices
        self.dataset = dataset
        num_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count()))
        self.loader_config = dict(
            num_workers=num_threads,
        )
        if num_threads > 0:
            self.loader_config["prefetch_factor"] = prefetch_factor
            self.loader_config["persistent_workers"] = True
        if shuffle_strategy.is_mixer:
            self.loader_config["batch_size"] = shuffle_strategy.mini_batch_size
            self.loader_config["collate_fn"] = shuffle_strategy.mixer
            self.loader_config["drop_last"] = True
        else:
            self.loader_config["batch_size"] = None
        self.sparse_key = sparse_key
        self.before_dense_cb = before_dense_cb
        self.after_dense_cb = after_dense_cb
        self.kwargs = kwargs

    def setup(self, stage):
        # this is not necessary but it is here in case we want to download data to local node in the future
        if stage == "fit":
            self.train_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, **self.kwargs)
            self.val_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, "val", **self.kwargs)
        if stage == "test":
            self.val_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, "test", **self.kwargs)
        if stage == "predict":
            self.predict_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, **self.kwargs)

    def train_dataloader(self):
        return DataLoader(self.train_ds, **self.loader_config)

    def val_dataloader(self):
        return DataLoader(self.val_ds, **self.loader_config)

    def test_dataloader(self):
        # for now not support testing for splitting will support it soon in the future
        return DataLoader(self.val_ds, **self.loader_config)

    def predict_dataloader(self):
        return DataLoader(self.predict_ds, **self.loader_config)

    def densify(self, x, idx: str | int = None):
        if isinstance(x, torch.Tensor):
            if self.before_dense_cb:
                x = self.before_dense_cb(x, idx)
            if x.is_sparse or x.is_sparse_csr:
                x = x.to_dense()
            if self.after_dense_cb:
                x = self.after_dense_cb(x, idx)
        return x

    def on_after_batch_transfer(self, batch, dataloader_idx):
        if (type(batch) is list) or (type(batch) is tuple):
            return [self.densify(d, i) for i, d in enumerate(batch)]
        elif isinstance(batch, dict):
            return {k: self.densify(v, k) for k, v in batch.items()}
        elif isinstance(batch, torch.Tensor):
            return self.densify(batch)
        else:
            return batch

Dataset

For fetching and sending data to the model

protoplast.scrna.anndata.torch_dataloader.DistributedAnnDataset

Bases: IterableDataset

Dataset that support multiworker distribution this version will yield the data in a sequential manner

Parameters:

Name Type Description Default
file_paths list[str]

List of files

required
indices list[list[int]]

List of indices from SplitInfo

required
metadata dict

Metadata dictionary for sending data to the model or other logical purposes

required
sparse_key str

AnnData key for the sparse matrix usually it is "X" if "layers" please use the dot notation for example "layers.attr" where attr is the key in the layer you want to refer to

required
mini_batch_size int

How many observation to send to the model must be less than batch_size, by default None and will send the whole batch if this is not specified

None
Source code in src/protoplast/scrna/anndata/torch_dataloader.py
class DistributedAnnDataset(torch.utils.data.IterableDataset):
    """Dataset that support multiworker distribution this version will yield the data
    in a sequential manner

    Parameters
    ----------
    file_paths : list[str]
        List of files
    indices : list[list[int]]
        List of indices from `SplitInfo`
    metadata : dict
        Metadata dictionary for sending data to the model or other logical purposes
    sparse_key : str
        AnnData key for the sparse matrix usually it is "X" if "layers" please use the dot notation for example
        "layers.attr" where attr is the key in the layer you want to refer to
    mini_batch_size : int, optional
        How many observation to send to the model must be less than `batch_size`, by default None
        and will send the whole batch if this is not specified
    """

    def __init__(
        self,
        file_paths: list[str],
        indices: list[list[int]],
        metadata: dict,
        sparse_key: str,
        mini_batch_size: int = None,
        **kwargs,  # FIXME: workaround for PROTO-23
    ):
        # use first file as reference first
        self.files = file_paths
        self.sparse_key = sparse_key
        self.X = None
        self.ad = None
        # map each gene to an index
        for k, v in metadata.items():
            setattr(self, k, v)
        self.metadata = metadata
        self.batches = indices
        self.mini_batch_size = mini_batch_size

    @classmethod
    def create_distributed_ds(cls, indices: SplitInfo, sparse_key: str, mode: str = "train", **kwargs):
        indices = indices.to_dict() if isinstance(indices, SplitInfo) else indices
        return cls(
            indices["files"],
            indices[f"{mode}_indices"],
            indices["metadata"],
            sparse_key,
            mini_batch_size=indices.get("mini_batch_size"),
            **kwargs,
        )

    def _init_rank(self):
        worker_info = get_worker_info()
        if worker_info is None:
            self.wid = 0
            self.nworkers = 1
        else:
            self.wid = worker_info.id
            self.nworkers = worker_info.num_workers
        try:
            w_rank = td.get_rank()
            w_size = td.get_world_size()
        except ValueError:
            w_rank = -1
            w_size = -1
        if w_rank >= 0:
            self.ray_rank = w_rank
            self.ray_size = w_size
        else:
            self.ray_rank = 0
            self.ray_size = 1
        self.global_rank = self.ray_rank * self.nworkers + self.wid
        self.total_workers = self.ray_size * self.nworkers

    def _process_sparse(self, mat) -> torch.Tensor:
        if sp.issparse(mat):
            return torch.sparse_csr_tensor(
                torch.from_numpy(mat.indptr).long(),
                torch.from_numpy(mat.indices).long(),
                torch.from_numpy(mat.data).float(),
                mat.shape,
            )
        return torch.from_numpy(mat).float()

    def _get_mat_by_range(self, ad: anndata.AnnData, start: int, end: int) -> sp.csr_matrix:
        if self.sparse_key == "X":
            return ad.X[start:end]
        elif "layers" in self.sparse_key:
            _, attr = self.sparse_key.split(".")
            return ad.layers[attr][start:end]
        else:
            raise Exception("Sparse key not supported")

    def transform(self, start: int, end: int):
        """The subclass should implement the logic to get more data for the cell. It can leverage this super function
        to efficiently get X as a sparse tensor. An example of how to get to more data from the cell is
        `self.ad.obs["key"][start:end]` where you must only fetch a subset of this data with `start` and `end`


        Parameters
        ----------
        start : int
            Starting index of this batch
        end : int
            Ending index of this batch

        Returns
        -------
        Any
            Usually a tensor, a list of tensor or dictionary with tensor value
        """
        # by default we just return the matrix
        # sometimes, the h5ad file stores X as the dense matrix,
        # so we have to make sure it is a sparse matrix before returning
        # the batch item
        if self.X is None:
            # we don't have the X upstream, so we have to incurr IO to fetch it
            self.X = self._get_mat_by_range(self.ad, start, end)
        X = self._process_sparse(self.X)
        return X

    def __len__(self):
        try:
            world_size = td.get_world_size()
        except ValueError:
            print("Not using tdd default to world size 1")
            world_size = 1
        if self.mini_batch_size:
            total_sample = sum(end - start for i in range(len(self.files)) for start, end in self.batches[i])
            return math.ceil(total_sample / self.mini_batch_size / world_size)
        return sum(1 for i in range(len(self.files)) for _ in self.batches[i]) / world_size

    def __iter__(self):
        self._init_rank()
        gidx = 0
        total_iter = 0
        for fidx, f in enumerate(self.files):
            self.ad = anndata.read_h5ad(f, backed="r")
            for start, end in self.batches[fidx]:
                if not (gidx % self.total_workers) == self.global_rank:
                    gidx += 1
                    continue
                X = self._get_mat_by_range(self.ad, start, end)
                self.X = X
                if self.mini_batch_size is None:
                    # not fetch-then-batch approach, we yield everything
                    yield self.transform(start, end)
                    total_iter += 1
                else:
                    # fetch-then-batch approach
                    for i in range(0, X.shape[0], self.mini_batch_size):
                        # index on the X coordinates
                        b_start, b_end = i, min(i + self.mini_batch_size, X.shape[0])
                        # index on the adata coordinates
                        global_start, global_end = start + i, min(start + i + self.mini_batch_size, end)
                        self.X = X[b_start:b_end]
                        yield self.transform(global_start, global_end)
                        total_iter += 1
                gidx += 1

transform(start: int, end: int)

The subclass should implement the logic to get more data for the cell. It can leverage this super function to efficiently get X as a sparse tensor. An example of how to get to more data from the cell is self.ad.obs["key"][start:end] where you must only fetch a subset of this data with start and end

Parameters:

Name Type Description Default
start int

Starting index of this batch

required
end int

Ending index of this batch

required

Returns:

Type Description
Any

Usually a tensor, a list of tensor or dictionary with tensor value

Source code in src/protoplast/scrna/anndata/torch_dataloader.py
def transform(self, start: int, end: int):
    """The subclass should implement the logic to get more data for the cell. It can leverage this super function
    to efficiently get X as a sparse tensor. An example of how to get to more data from the cell is
    `self.ad.obs["key"][start:end]` where you must only fetch a subset of this data with `start` and `end`


    Parameters
    ----------
    start : int
        Starting index of this batch
    end : int
        Ending index of this batch

    Returns
    -------
    Any
        Usually a tensor, a list of tensor or dictionary with tensor value
    """
    # by default we just return the matrix
    # sometimes, the h5ad file stores X as the dense matrix,
    # so we have to make sure it is a sparse matrix before returning
    # the batch item
    if self.X is None:
        # we don't have the X upstream, so we have to incurr IO to fetch it
        self.X = self._get_mat_by_range(self.ad, start, end)
    X = self._process_sparse(self.X)
    return X

protoplast.scrna.anndata.strategy.ShuffleStrategy

Bases: ABC

Strategy on how to data should be split and shuffle during the training

Parameters:

Name Type Description Default
file_paths list[str]

List of file paths

required
batch_size int

How much data to fetch

required
total_workers int

Total workers this is equal to number of processes times number of threads per process

required
test_size float | None

Fraction of test data for example 0.1 means 10% will be split for testing, by default None

None
validation_size float | None

Fraction of validation data for example 0.2 means 20% will be split for validation, by default None

None
random_seed int | None

Seed to randomize the split set this to None if you want this to be completely random, by default 42

42
metadata_cb Callable[[AnnData, dict], None] | None

Callback to mutate metadata recommended for passing data from obs or var or any additional data your models required by default cell_line_metadata_cb

None
is_shuffled bool

Whether to shuffle the data or not this will be deprecated soon, by default True

True
Source code in src/protoplast/scrna/anndata/strategy.py
class ShuffleStrategy(ABC):
    """Strategy on how to data should be split and shuffle during
    the training

    Parameters
    ----------
    file_paths : list[str]
        List of file paths
    batch_size : int
        How much data to fetch
    total_workers : int
        Total workers this is equal to number of processes times number of threads per process
    test_size : float | None, optional
        Fraction of test data for example 0.1 means 10% will be split for testing, by default None
    validation_size : float | None, optional
        Fraction of validation data for example 0.2 means 20% will be split for validation, by default None
    random_seed : int | None, optional
        Seed to randomize the split set this to None if you want this to be completely random, by default 42
    metadata_cb : Callable[[anndata.AnnData, dict], None] | None, optional
        Callback to mutate metadata recommended for passing data from `obs` or `var`
        or any additional data your models required
        by default cell_line_metadata_cb
    is_shuffled : bool, optional
        Whether to shuffle the data or not this will be deprecated soon, by default True
    """

    def __init__(
        self,
        file_paths: list[str],
        batch_size: int,
        total_workers: int,
        test_size: float | None = None,
        validation_size: float | None = None,
        random_seed: int | None = 42,
        metadata_cb: Callable[[anndata.AnnData, dict], None] | None = None,
        is_shuffled: bool = True,
    ) -> None:
        self.file_paths = file_paths
        self.batch_size = batch_size
        self.total_workers = total_workers
        self.test_size = test_size
        self.validation_size = validation_size
        self.random_seed = random_seed
        self.metadata_cb = metadata_cb
        self.is_shuffled = is_shuffled
        self.rng = random.Random(random_seed) if random_seed else random.Random()

    @property
    def is_mixer(self):
        return False

    @abstractmethod
    def split(self) -> SplitInfo:
        """
        How you want to split the data in each worker must return SplitInfo
        """
        pass

    @abstractmethod
    def mixer(self, batch: list) -> any:
        """
        If your Dataset only return 1 sample and not prebatched
        this need to be implemented
        """
        pass

mixer(batch: list) -> any abstractmethod

If your Dataset only return 1 sample and not prebatched this need to be implemented

Source code in src/protoplast/scrna/anndata/strategy.py
@abstractmethod
def mixer(self, batch: list) -> any:
    """
    If your Dataset only return 1 sample and not prebatched
    this need to be implemented
    """
    pass

split() -> SplitInfo abstractmethod

How you want to split the data in each worker must return SplitInfo

Source code in src/protoplast/scrna/anndata/strategy.py
@abstractmethod
def split(self) -> SplitInfo:
    """
    How you want to split the data in each worker must return SplitInfo
    """
    pass

protoplast.scrna.anndata.strategy.SplitInfo dataclass

Source code in src/protoplast/scrna/anndata/strategy.py
@dataclass
class SplitInfo:
    files: list[str]
    train_indices: list[list[int]]
    val_indices: list[list[int]]
    test_indices: list[list[int]]
    metadata: dict[str, any]
    mini_batch_size: int | None = None
    """Information on how to split the data
    this will get pass to the Dataset to know which part of the data
    they need to access

    Parameters
    ----------
    files : list[str]
        List of files
    train_indices : list[list[str]]
        List of indices for training `train_indices[file_idx][batch_idx]` where `file_idx` must correspond
        to the idx of `files` parameter
    val_indices : list[list[str]]
        List of indices for validation `val_indices[file_idx][batch_idx]` where `file_idx` must correspond
        to the idx of `files` parameter
    test_indices : list[list[str]]
        List of indices for testing `test_indices[file_idx][batch_idx]` where `file_idx` must correspond
        to the idx of `files` parameter
    metadata : dict[str, any]
        Data to pass on to the Dataset and model
    mini_batch_size : int | None
        How much data to send to the model
    """

    def to_dict(self) -> dict[str, any]:
        return {
            "files": self.files,
            "train_indices": self.train_indices,
            "val_indices": self.val_indices,
            "test_indices": self.test_indices,
            "metadata": self.metadata,
            "mini_batch_size": self.mini_batch_size,
        }

mini_batch_size: int | None = None class-attribute instance-attribute

Information on how to split the data this will get pass to the Dataset to know which part of the data they need to access

Parameters:

Name Type Description Default
files list[str]

List of files

required
train_indices list[list[str]]

List of indices for training train_indices[file_idx][batch_idx] where file_idx must correspond to the idx of files parameter

required
val_indices list[list[str]]

List of indices for validation val_indices[file_idx][batch_idx] where file_idx must correspond to the idx of files parameter

required
test_indices list[list[str]]

List of indices for testing test_indices[file_idx][batch_idx] where file_idx must correspond to the idx of files parameter

required
metadata dict[str, any]

Data to pass on to the Dataset and model

required
mini_batch_size int | None

How much data to send to the model

required

protoplast.scrna.anndata.strategy.SequentialShuffleStrategy

Bases: ShuffleStrategy

Return the data in a sequential way randomness is not guarantee there is a high chance the data will come from nearby rows this might affect your training accuracy depending on how the anndata are ordered you can overcome this by preshuffling the data manually yourself if this is an issue

Parameters:

Name Type Description Default
file_paths list[str]

List of file paths

required
batch_size int

How much data to fetch

required
total_workers int

Total workers this is equal to number of processes times number of threads per process

required
test_size float | None

Fraction of test data for example 0.1 means 10% will be split for testing, by default None

None
validation_size float | None

Fraction of validation data for example 0.2 means 20% will be split for validation, by default None

None
random_seed int | None

Seed to randomize the split set this to None if you want this to be completely random, by default 42

42
metadata_cb Callable[[AnnData, dict], None] | None

Callback to mutate metadata recommended for passing data from obs or var or any additional data your models required by default cell_line_metadata_cb

None
is_shuffled bool

Whether to shuffle the data or not this will be deprecated soon, by default True

False
pre_fetch_then_batch int | None

The prefetch factor the total size of data fetch will be equal to pre_fetch_then_batch * batch_size

16
drop_last bool

If there is true drop the remainder, default to True otherwise duplicate the data to make sure the data is evenly distributed to all the workers

True
Source code in src/protoplast/scrna/anndata/strategy.py
class SequentialShuffleStrategy(ShuffleStrategy):
    """Return the data in a sequential way randomness is not guarantee
    there is a high chance the data will come from nearby rows this might
    affect your training accuracy depending on how the anndata are ordered you can
    overcome this by preshuffling the data manually yourself if this is an issue

    Parameters
    ----------
    file_paths : list[str]
        List of file paths
    batch_size : int
        How much data to fetch
    total_workers : int
        Total workers this is equal to number of processes times number of threads per process
    test_size : float | None, optional
        Fraction of test data for example 0.1 means 10% will be split for testing, by default None
    validation_size : float | None, optional
        Fraction of validation data for example 0.2 means 20% will be split for validation, by default None
    random_seed : int | None, optional
        Seed to randomize the split set this to None if you want this to be completely random, by default 42
    metadata_cb : Callable[[anndata.AnnData, dict], None] | None, optional
        Callback to mutate metadata recommended for passing data from `obs` or `var`
        or any additional data your models required
        by default cell_line_metadata_cb
    is_shuffled : bool, optional
        Whether to shuffle the data or not this will be deprecated soon, by default True
    pre_fetch_then_batch : int | None
        The prefetch factor the total size of data fetch will be equal to `pre_fetch_then_batch * batch_size`
    drop_last : bool
        If there is true drop the remainder, default to True otherwise duplicate the data to make sure the
        data is evenly distributed to all the workers
    """

    def __init__(
        self,
        file_paths: list[str],
        batch_size: int,
        total_workers: int,
        test_size: float | None = None,
        validation_size: float | None = None,
        random_seed: int | None = 42,
        metadata_cb: Callable[[anndata.AnnData, dict], None] | None = None,
        is_shuffled: bool = False,
        pre_fetch_then_batch: int | None = 16,
        drop_last: bool = True,
    ) -> None:
        super().__init__(
            file_paths,
            batch_size,
            total_workers,
            test_size,
            validation_size,
            random_seed,
            metadata_cb,
            is_shuffled,
        )
        self.pre_fetch_then_batch = pre_fetch_then_batch
        self.drop_last = drop_last

    def split(self) -> SplitInfo:
        if self.pre_fetch_then_batch:
            batch_size = self.batch_size * self.pre_fetch_then_batch
        else:
            batch_size = self.batch_size
        split_dict = ann_split_data(
            self.file_paths,
            batch_size,
            self.total_workers,
            self.test_size,
            self.validation_size,
            self.rng,
            self.metadata_cb,
            self.is_shuffled,
            self.drop_last,
        )
        # this will be passed to the dataset, inorder to know the mini batch size
        if self.pre_fetch_then_batch:
            split_dict["mini_batch_size"] = self.batch_size
        else:
            # yield everything we read
            split_dict["mini_batch_size"] = None
        return SplitInfo(**split_dict)

    def mixer(self, batch: list):
        return super().mixer(batch)