Skip to content

SCRNA API References

Trainer

protoplast.scrna.anndata.trainer.RayTrainRunner

A class to initialize the training this class automatically initializes Ray cluster or detect whether an existing cluster exist if there is an existing cluster it will automatically connect to it refer to ray.init() behavior

Parameters:

Name Type Description Default
Model type[LightningModule]

PyTorch Lightning model class

required
Ds type[DistributedAnnDataset]

DistributedAnnDataset class

required
model_keys list[str]

Keys to pass to model from metadata_cb

required
metadata_cb Callable[[AnnData, dict], None]

Callback to mutate metadata recommended for passing data from obs or var or any additional data your models required by default cell_line_metadata_cb

cell_line_metadata_cb
before_dense_cb Callable[[Tensor, str | int], Tensor]

Callback to perform before densification of sparse matrix where the data at this point is still a sparse CSR Tensor, by default None

None
after_dense_cb Callable[[Tensor, str | int], Tensor]

Callback to perform after densification of sparse matrix where the data at this point is a dense Tensor, by default None

None
shuffle_strategy ShuffleStrategy

Strategy to split or randomize the data during the training, by default SequentialShuffleStrategy

SequentialShuffleStrategy
runtime_env_config dict | None

These env config is to pass the RayTrainer processes, by default None

None
address str | None

Override ray address, by default None

None
ray_trainer_strategy Strategy | None

Override Ray Trainer Strategy if this is None it will default to RayDDP, by default None

None
sparse_key str

description, by default "X",

'X'

Returns:

Type Description
RayTrainRunner

Use this class to start the training

Source code in src/protoplast/scrna/anndata/trainer.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
class RayTrainRunner:
    """A class to initialize the training this class automatically initializes Ray cluster or
    detect whether an existing cluster exist if there is an existing cluster it will automatically
    connect to it refer to `ray.init()` behavior

    Parameters
    ----------
    Model : type[pl.LightningModule]
        PyTorch Lightning model class
    Ds : type[DistributedAnnDataset]
        DistributedAnnDataset class
    model_keys : list[str]
        Keys to pass to model from `metadata_cb`
    metadata_cb : Callable[[anndata.AnnData, dict], None], optional
        Callback to mutate metadata recommended for passing data from `obs` or `var`
        or any additional data your models required
        by default cell_line_metadata_cb
    before_dense_cb : Callable[[torch.Tensor, str  |  int], torch.Tensor], optional
        Callback to perform before densification of sparse matrix where the data at this point
        is still a sparse CSR Tensor, by default None
    after_dense_cb : Callable[[torch.Tensor, str  |  int], torch.Tensor], optional
        Callback to perform after densification of sparse matrix where the data at this point
        is a dense Tensor, by default None
    shuffle_strategy : ShuffleStrategy, optional
        Strategy to split or randomize the data during the training, by default SequentialShuffleStrategy
    runtime_env_config : dict | None, optional
        These env config is to pass the RayTrainer processes, by default None
    address : str | None, optional
        Override ray address, by default None
    ray_trainer_strategy : Strategy | None, optional
        Override Ray Trainer Strategy if this is None it will default to RayDDP, by default None
    sparse_key : str, optional
        _description_, by default "X",
    Returns
    -------
    RayTrainRunner
        Use this class to start the training

    """

    @beartype
    def __init__(
        self,
        Model: type[pl.LightningModule],
        Ds: type[DistributedAnnDataset],
        model_keys: list[str],
        metadata_cb: Callable[[anndata.AnnData, dict], None] = cell_line_metadata_cb,
        before_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        after_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        shuffle_strategy: ShuffleStrategy = SequentialShuffleStrategy,
        runtime_env_config: dict | None = None,
        address: str | None = None,
        ray_trainer_strategy: Strategy | None = None,
        sparse_key: str = "X",
    ):
        self.Model = Model
        self.Ds = Ds
        self.model_keys = model_keys
        self.metadata_cb = metadata_cb
        self.shuffle_strategy = shuffle_strategy
        self.sparse_key = sparse_key
        self.before_dense_cb = before_dense_cb
        self.after_dense_cb = after_dense_cb
        if not ray_trainer_strategy:
            self.ray_trainer_strategy = ray.train.lightning.RayDDPStrategy()
        else:
            self.ray_trainer_strategy = ray_trainer_strategy

        # Init ray cluster
        DEFAULT_RUNTIME_ENV_CONFIG = {
            "LOG_LEVEL": os.getenv("LOG_LEVEL", "INFO"),
        }
        if runtime_env_config is None:
            runtime_env_config = DEFAULT_RUNTIME_ENV_CONFIG
        ray.init(
            address=address, runtime_env={**DEFAULT_RUNTIME_ENV_CONFIG, **runtime_env_config}, ignore_reinit_error=True
        )

        self.resources = ray.cluster_resources()

    def _worker_fn(self):
        warnings.filterwarnings(action="ignore", module="ray", category=DeprecationWarning)
        Model, Ds, model_keys = self.Model, self.Ds, self.model_keys

        def worker_fn(config):
            ctx = ray.train.get_context()
            if ctx:
                rank = ctx.get_world_rank()
            else:
                rank = 0
            indices = config.get("indices")
            ckpt_path = config.get("ckpt_path")
            scratch_path = config.get("scratch_path")
            scratch_content = config.get("scratch_content")
            logger.debug("Verifying storage path on worker node")
            try:
                file = get_fsspec(scratch_path, "r")
                read_content = file.read()
                file.close()
            except Exception as e:
                logger.error("Failed to access shared storage path: %s", scratch_path, exc_info=True)
                raise Exception("Cannot access the shared storage. Please check your storage path.") from e
            if scratch_content != read_content:
                logger.critical(
                    f"Content mismatch detected for path: {scratch_path}.Worker cannot read expected head node content."
                )
                raise Exception("Content mismatch detected. Please check your shared storage setup.")
            num_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count()))
            logger.debug(f"=========Starting the training on {rank} with num threads: {num_threads}=========")
            model_params = indices.metadata
            shuffle_strategy = config.get("shuffle_strategy")
            ann_dm = AnnDataModule(
                indices,
                Ds,
                self.prefetch_factor,
                self.sparse_key,
                shuffle_strategy,
                self.before_dense_cb,
                self.after_dense_cb,
                random_seed=config["random_seed"],
                **self.kwargs,
            )
            if model_keys:
                model_params = {k: v for k, v in model_params.items() if k in model_keys}
            model = Model(**model_params)
            trainer = pl.Trainer(
                max_epochs=self.max_epochs,
                devices="auto",
                accelerator=_get_accelerator(),
                strategy=self.ray_trainer_strategy,
                plugins=[ray.train.lightning.RayLightningEnvironment()],
                callbacks=[ray.train.lightning.RayTrainReportCallback()],
                enable_checkpointing=True,
                enable_progress_bar=config.get("enable_progress_bar", True),
            )
            trainer = ray.train.lightning.prepare_trainer(trainer)
            if config.get("worker_mode") == "inference":
                logger.debug("Starting inference mode")
                writer_cb = DistributedPredictionWriter(
                    output_dir=self.result_storage_path, rank=rank, format=config["prediction_format"]
                )
                trainer.callbacks.append(writer_cb)
                trainer.predict(model, datamodule=ann_dm, ckpt_path=ckpt_path)
            else:
                logger.debug("Starting training mode")
                trainer.fit(model, datamodule=ann_dm, ckpt_path=ckpt_path)

        return worker_fn

    def _setup(
        self,
        file_paths: list[str],
        batch_size: int,
        test_size: float,
        val_size: float,
        prefetch_factor: int,
        max_epochs: int,
        thread_per_worker: int | None,
        num_workers: int | None,
        result_storage_path: str,
        # read more here: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit
        ckpt_path: str | None,
        is_gpu: bool,
        random_seed: int | None,
        resource_per_worker: dict | None,
        is_shuffled: bool,
        enable_progress_bar: bool,
        worker_mode: Literal["train", "inference"],
        **kwargs,
    ):
        self.result_storage_path = resolve_path_or_url(result_storage_path)
        self.prefetch_factor = prefetch_factor
        self.max_epochs = max_epochs
        self.kwargs = kwargs
        self.enable_progress_bar = enable_progress_bar
        if not resource_per_worker:
            if not thread_per_worker:
                logger.info("Setting thread_per_worker to half of the available CPUs capped at 4")
                thread_per_worker = min(int((self.resources.get("CPU", 2) - 1) / 2), 4)
            resource_per_worker = {"CPU": thread_per_worker}
        if is_gpu and self.resources.get("GPU", 0) == 0:
            logger.warning("`is_gpu = True` but there is no GPU found. Fallback to CPU.")
            is_gpu = False
        if is_gpu:
            if num_workers is None:
                num_workers = int(self.resources.get("GPU"))
            scaling_config = ray.train.ScalingConfig(
                num_workers=num_workers, use_gpu=True, resources_per_worker=resource_per_worker
            )
            resource_per_worker["GPU"] = 1
        else:
            if num_workers is None:
                num_workers = max(int((self.resources.get("CPU", 2) - 1) / thread_per_worker), 1)
            scaling_config = ray.train.ScalingConfig(
                num_workers=num_workers, use_gpu=False, resources_per_worker=resource_per_worker
            )
        logger.info(f"Using {num_workers} workers where each worker uses: {resource_per_worker}")
        start = time.time()

        shuffle_strategy = self.shuffle_strategy(
            [resolve_path_or_url(f) for f in file_paths],
            batch_size,
            num_workers * thread_per_worker,
            test_size,
            val_size,
            random_seed,
            metadata_cb=self.metadata_cb,
            is_shuffled=is_shuffled,
            **kwargs,
        )
        kwargs.pop("drop_last", None)
        kwargs.pop("pre_fetch_then_batch", None)
        indices = shuffle_strategy.split()
        logger.debug(f"Data splitting time: {time.time() - start:.2f} seconds")
        train_config = {
            "indices": indices,
            "ckpt_path": resolve_path_or_url(ckpt_path),
            "shuffle_strategy": shuffle_strategy,
            "enable_progress_bar": self.enable_progress_bar,
            "scratch_path": os.path.join(self.result_storage_path, "scratch.plt"),
            "scratch_content": str(uuid.uuid4()),
            "worker_mode": worker_mode,
            "random_seed": random_seed,
        }
        if worker_mode == "inference":
            train_config["prediction_format"] = kwargs["prediction_format"]
        par_trainer = ray.train.torch.TorchTrainer(
            self._worker_fn(),
            scaling_config=scaling_config,
            train_loop_config=train_config,
            run_config=ray.train.RunConfig(storage_path=self.result_storage_path),
        )

        logger.debug("Writing scratch content to share storage")
        scratch_path = train_config["scratch_path"]
        fs, path_on_fs = fsspec.core.url_to_fs(scratch_path)
        parent_dir = os.path.dirname(path_on_fs)
        if not fs.exists(parent_dir):
            logger.debug(f"Ensuring directory exists: {parent_dir}")
            fs.makedirs(parent_dir, exist_ok=True)
        file = get_fsspec(scratch_path, mode="w")
        file.write(train_config["scratch_content"])
        file.close()
        logger.debug("Spawning Ray worker and initiating distributed training")
        return par_trainer, indices

    @beartype
    def par_inference(
        self,
        file_paths: list[str],
        ckpt_path: str | None = None,
        result_storage_path: str = "~/protoplast_results",
        batch_size: int = 2000,
        prefetch_factor: int = 4,
        thread_per_worker: int | None = None,
        num_workers: int | None = None,
        is_gpu: bool = True,
        resource_per_worker: dict | None = None,
        enable_progress_bar: bool = True,
        prediction_format: Literal["csv", "parquet"] = "csv",
        **kwargs,
    ):
        """Start parallel inference the order of the result is not guaranteed to be the same as input file

        Parameters
        ----------
        file_paths : list[str]
            List of h5ad AnnData files
        batch_size : int, optional
            How much data to fetch from disk, by default to 2000
        prefetch_factor : int, optional
            Total data fetch is prefetch_factor * batch_size, by default 4
        thread_per_worker : int | None, optional
            Amount of worker for each dataloader, by default None
        num_workers : int | None, optional
            Override number of Ray processes default to number of GPU(s) in the cluster, by default None
        is_gpu : bool, optional
            By default True turn this off if your system don't have any GPU, by default True
        resource_per_worker : dict | None, optional
            This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None
        ckpt_path: str | None = None,
            Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the
            training from scratch, by default None
        enable_progress_bar : bool
            Whether to enable Trainer progress bar or not, by default True
        Returns
        -------
        Result
            The inference result from RayTrainer
        """
        par_trainer, _ = self._setup(
            file_paths,
            batch_size,
            0.0,
            0.0,
            prefetch_factor,
            1,
            thread_per_worker,
            num_workers,
            result_storage_path,
            ckpt_path,
            is_gpu,
            None,
            resource_per_worker,
            False,
            enable_progress_bar,
            prediction_format=prediction_format,
            worker_mode="inference",
            **kwargs,
        )
        # despite the confusing name we use fit to run inference here
        result = par_trainer.fit()
        # combine the result and order it correctly
        return result

    def inference(
        self,
        file_paths: list[str],
        result_storage_path: str,
        ckpt_path: str,
        prediction_format: Literal["csv", "parquet"] = "csv",
        enable_progress_bar: bool = True,
        batch_size=2000,
    ):
        """Start inference in a single process order is guarantee to be the same as input file
        don't use this in a distributed cluster
        Parameters
        ----------
        file_paths : list[str]
            List of h5ad AnnData files
        result_storage_path : str
            Path to store the prediction result
        ckpt_path : str
            Path of the checkpoint to run inference
        enable_progress_bar : bool, optional
            Whether to enable Trainer progress bar or not, by default True
        batch_size : int, optional
            How much data to fetch from disk, by default to 2000
        """
        if sys.platform in ("darwin", "win32"):
            override_thread = 0
        else:
            override_thread = 1
        shuffle_strategy = self.shuffle_strategy(
            [resolve_path_or_url(f) for f in file_paths],
            batch_size,
            override_thread,
            0,
            0,
            None,
            metadata_cb=self.metadata_cb,
            is_shuffled=False,
            prediction_format=prediction_format,
        )
        indices = shuffle_strategy.split()
        writer_cb = PredictionWriterCallback(
            output_path=resolve_path_or_url(result_storage_path), format=prediction_format
        )
        trainer = pl.Trainer(
            devices="auto",
            accelerator=_get_accelerator(),
            enable_progress_bar=enable_progress_bar,
        )
        trainer.callbacks.append(writer_cb)

        ann_dm = AnnDataModule(
            indices,
            self.Ds,
            4,
            self.sparse_key,
            SequentialShuffleStrategy,
            self.before_dense_cb,
            self.after_dense_cb,
            batch_size=batch_size,
            override_thread=override_thread,
        )
        model_params = indices.metadata
        if self.model_keys:
            model_params = {k: v for k, v in model_params.items() if k in self.model_keys}
        model = self.Model(**model_params)
        trainer.predict(model, datamodule=ann_dm, ckpt_path=resolve_path_or_url(ckpt_path))

    @beartype
    def train(
        self,
        file_paths: list[str],
        batch_size: int = 2000,
        test_size: float = 0.0,
        val_size: float = 0.2,
        prefetch_factor: int = 4,
        max_epochs: int = 1,
        thread_per_worker: int | None = None,
        num_workers: int | None = None,
        result_storage_path: str = "~/protoplast_results",
        # read more here: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit
        ckpt_path: str | None = None,
        is_gpu: bool = True,
        random_seed: int | None = 42,
        resource_per_worker: dict | None = None,
        is_shuffled: bool = False,
        enable_progress_bar: bool = True,
        **kwargs,
    ):
        """Start the training

        Parameters
        ----------
        file_paths : list[str]
            List of h5ad AnnData files
        batch_size : int, optional
            How much data to fetch from disk, by default to 2000
        test_size : float, optional
            Fraction of test data for example 0.1 means 10% will be split for testing
            default to 0.0
        val_size : float, optional
            Fraction of validation data for example 0.2 means 20% will be split for validation,
            default to 0.2
        prefetch_factor : int, optional
            Total data fetch is prefetch_factor * batch_size, by default 4
        max_epochs : int, optional
            How many epoch(s) to train with, by default 1
        thread_per_worker : int | None, optional
            Amount of worker for each dataloader, by default None
        num_workers : int | None, optional
            Override number of Ray processes default to number of GPU(s) in the cluster, by default None
        result_storage_path : str, optional
            Path to store the loss, validation and checkpoint, by default "~/protoplast_results"
        ckpt_path : str | None, optional
            Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the
            training from scratch, by default None
        is_gpu : bool, optional
            By default True turn this off if your system don't have any GPU, by default True
        random_seed : int | None, optional
            Set this to None for real training but for benchmarking and result replication
            you can adjust the seed here, by default 42
        resource_per_worker : dict | None, optional
            This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None
        enable_progress_bar : bool
            Whether to enable Trainer progress bar or not, by default True
        Returns
        -------
        Result
            The training result from RayTrainer
        """
        par_trainer, _ = self._setup(
            file_paths,
            batch_size,
            test_size,
            val_size,
            prefetch_factor,
            max_epochs,
            thread_per_worker,
            num_workers,
            result_storage_path,
            ckpt_path,
            is_gpu,
            random_seed,
            resource_per_worker,
            is_shuffled,
            enable_progress_bar,
            worker_mode="train",
            **kwargs,
        )
        return par_trainer.fit()

inference(file_paths: list[str], result_storage_path: str, ckpt_path: str, prediction_format: Literal['csv', 'parquet'] = 'csv', enable_progress_bar: bool = True, batch_size=2000)

Start inference in a single process order is guarantee to be the same as input file don't use this in a distributed cluster

Parameters:

Name Type Description Default
file_paths list[str]

List of h5ad AnnData files

required
result_storage_path str

Path to store the prediction result

required
ckpt_path str

Path of the checkpoint to run inference

required
enable_progress_bar bool

Whether to enable Trainer progress bar or not, by default True

True
batch_size int

How much data to fetch from disk, by default to 2000

2000
Source code in src/protoplast/scrna/anndata/trainer.py
def inference(
    self,
    file_paths: list[str],
    result_storage_path: str,
    ckpt_path: str,
    prediction_format: Literal["csv", "parquet"] = "csv",
    enable_progress_bar: bool = True,
    batch_size=2000,
):
    """Start inference in a single process order is guarantee to be the same as input file
    don't use this in a distributed cluster
    Parameters
    ----------
    file_paths : list[str]
        List of h5ad AnnData files
    result_storage_path : str
        Path to store the prediction result
    ckpt_path : str
        Path of the checkpoint to run inference
    enable_progress_bar : bool, optional
        Whether to enable Trainer progress bar or not, by default True
    batch_size : int, optional
        How much data to fetch from disk, by default to 2000
    """
    if sys.platform in ("darwin", "win32"):
        override_thread = 0
    else:
        override_thread = 1
    shuffle_strategy = self.shuffle_strategy(
        [resolve_path_or_url(f) for f in file_paths],
        batch_size,
        override_thread,
        0,
        0,
        None,
        metadata_cb=self.metadata_cb,
        is_shuffled=False,
        prediction_format=prediction_format,
    )
    indices = shuffle_strategy.split()
    writer_cb = PredictionWriterCallback(
        output_path=resolve_path_or_url(result_storage_path), format=prediction_format
    )
    trainer = pl.Trainer(
        devices="auto",
        accelerator=_get_accelerator(),
        enable_progress_bar=enable_progress_bar,
    )
    trainer.callbacks.append(writer_cb)

    ann_dm = AnnDataModule(
        indices,
        self.Ds,
        4,
        self.sparse_key,
        SequentialShuffleStrategy,
        self.before_dense_cb,
        self.after_dense_cb,
        batch_size=batch_size,
        override_thread=override_thread,
    )
    model_params = indices.metadata
    if self.model_keys:
        model_params = {k: v for k, v in model_params.items() if k in self.model_keys}
    model = self.Model(**model_params)
    trainer.predict(model, datamodule=ann_dm, ckpt_path=resolve_path_or_url(ckpt_path))

par_inference(file_paths: list[str], ckpt_path: str | None = None, result_storage_path: str = '~/protoplast_results', batch_size: int = 2000, prefetch_factor: int = 4, thread_per_worker: int | None = None, num_workers: int | None = None, is_gpu: bool = True, resource_per_worker: dict | None = None, enable_progress_bar: bool = True, prediction_format: Literal['csv', 'parquet'] = 'csv', **kwargs)

Start parallel inference the order of the result is not guaranteed to be the same as input file

Parameters:

Name Type Description Default
file_paths list[str]

List of h5ad AnnData files

required
batch_size int

How much data to fetch from disk, by default to 2000

2000
prefetch_factor int

Total data fetch is prefetch_factor * batch_size, by default 4

4
thread_per_worker int | None

Amount of worker for each dataloader, by default None

None
num_workers int | None

Override number of Ray processes default to number of GPU(s) in the cluster, by default None

None
is_gpu bool

By default True turn this off if your system don't have any GPU, by default True

True
resource_per_worker dict | None

This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None

None
ckpt_path str | None

Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the training from scratch, by default None

None
enable_progress_bar bool

Whether to enable Trainer progress bar or not, by default True

True

Returns:

Type Description
Result

The inference result from RayTrainer

Source code in src/protoplast/scrna/anndata/trainer.py
@beartype
def par_inference(
    self,
    file_paths: list[str],
    ckpt_path: str | None = None,
    result_storage_path: str = "~/protoplast_results",
    batch_size: int = 2000,
    prefetch_factor: int = 4,
    thread_per_worker: int | None = None,
    num_workers: int | None = None,
    is_gpu: bool = True,
    resource_per_worker: dict | None = None,
    enable_progress_bar: bool = True,
    prediction_format: Literal["csv", "parquet"] = "csv",
    **kwargs,
):
    """Start parallel inference the order of the result is not guaranteed to be the same as input file

    Parameters
    ----------
    file_paths : list[str]
        List of h5ad AnnData files
    batch_size : int, optional
        How much data to fetch from disk, by default to 2000
    prefetch_factor : int, optional
        Total data fetch is prefetch_factor * batch_size, by default 4
    thread_per_worker : int | None, optional
        Amount of worker for each dataloader, by default None
    num_workers : int | None, optional
        Override number of Ray processes default to number of GPU(s) in the cluster, by default None
    is_gpu : bool, optional
        By default True turn this off if your system don't have any GPU, by default True
    resource_per_worker : dict | None, optional
        This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None
    ckpt_path: str | None = None,
        Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the
        training from scratch, by default None
    enable_progress_bar : bool
        Whether to enable Trainer progress bar or not, by default True
    Returns
    -------
    Result
        The inference result from RayTrainer
    """
    par_trainer, _ = self._setup(
        file_paths,
        batch_size,
        0.0,
        0.0,
        prefetch_factor,
        1,
        thread_per_worker,
        num_workers,
        result_storage_path,
        ckpt_path,
        is_gpu,
        None,
        resource_per_worker,
        False,
        enable_progress_bar,
        prediction_format=prediction_format,
        worker_mode="inference",
        **kwargs,
    )
    # despite the confusing name we use fit to run inference here
    result = par_trainer.fit()
    # combine the result and order it correctly
    return result

train(file_paths: list[str], batch_size: int = 2000, test_size: float = 0.0, val_size: float = 0.2, prefetch_factor: int = 4, max_epochs: int = 1, thread_per_worker: int | None = None, num_workers: int | None = None, result_storage_path: str = '~/protoplast_results', ckpt_path: str | None = None, is_gpu: bool = True, random_seed: int | None = 42, resource_per_worker: dict | None = None, is_shuffled: bool = False, enable_progress_bar: bool = True, **kwargs)

Start the training

Parameters:

Name Type Description Default
file_paths list[str]

List of h5ad AnnData files

required
batch_size int

How much data to fetch from disk, by default to 2000

2000
test_size float

Fraction of test data for example 0.1 means 10% will be split for testing default to 0.0

0.0
val_size float

Fraction of validation data for example 0.2 means 20% will be split for validation, default to 0.2

0.2
prefetch_factor int

Total data fetch is prefetch_factor * batch_size, by default 4

4
max_epochs int

How many epoch(s) to train with, by default 1

1
thread_per_worker int | None

Amount of worker for each dataloader, by default None

None
num_workers int | None

Override number of Ray processes default to number of GPU(s) in the cluster, by default None

None
result_storage_path str

Path to store the loss, validation and checkpoint, by default "~/protoplast_results"

'~/protoplast_results'
ckpt_path str | None

Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the training from scratch, by default None

None
is_gpu bool

By default True turn this off if your system don't have any GPU, by default True

True
random_seed int | None

Set this to None for real training but for benchmarking and result replication you can adjust the seed here, by default 42

42
resource_per_worker dict | None

This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None

None
enable_progress_bar bool

Whether to enable Trainer progress bar or not, by default True

True

Returns:

Type Description
Result

The training result from RayTrainer

Source code in src/protoplast/scrna/anndata/trainer.py
@beartype
def train(
    self,
    file_paths: list[str],
    batch_size: int = 2000,
    test_size: float = 0.0,
    val_size: float = 0.2,
    prefetch_factor: int = 4,
    max_epochs: int = 1,
    thread_per_worker: int | None = None,
    num_workers: int | None = None,
    result_storage_path: str = "~/protoplast_results",
    # read more here: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit
    ckpt_path: str | None = None,
    is_gpu: bool = True,
    random_seed: int | None = 42,
    resource_per_worker: dict | None = None,
    is_shuffled: bool = False,
    enable_progress_bar: bool = True,
    **kwargs,
):
    """Start the training

    Parameters
    ----------
    file_paths : list[str]
        List of h5ad AnnData files
    batch_size : int, optional
        How much data to fetch from disk, by default to 2000
    test_size : float, optional
        Fraction of test data for example 0.1 means 10% will be split for testing
        default to 0.0
    val_size : float, optional
        Fraction of validation data for example 0.2 means 20% will be split for validation,
        default to 0.2
    prefetch_factor : int, optional
        Total data fetch is prefetch_factor * batch_size, by default 4
    max_epochs : int, optional
        How many epoch(s) to train with, by default 1
    thread_per_worker : int | None, optional
        Amount of worker for each dataloader, by default None
    num_workers : int | None, optional
        Override number of Ray processes default to number of GPU(s) in the cluster, by default None
    result_storage_path : str, optional
        Path to store the loss, validation and checkpoint, by default "~/protoplast_results"
    ckpt_path : str | None, optional
        Path of the checkpoint if this is specified it will train from checkpoint otherwise it will start the
        training from scratch, by default None
    is_gpu : bool, optional
        By default True turn this off if your system don't have any GPU, by default True
    random_seed : int | None, optional
        Set this to None for real training but for benchmarking and result replication
        you can adjust the seed here, by default 42
    resource_per_worker : dict | None, optional
        This get pass to Ray you can specify how much CPU or GPU each Ray process get, by default None
    enable_progress_bar : bool
        Whether to enable Trainer progress bar or not, by default True
    Returns
    -------
    Result
        The training result from RayTrainer
    """
    par_trainer, _ = self._setup(
        file_paths,
        batch_size,
        test_size,
        val_size,
        prefetch_factor,
        max_epochs,
        thread_per_worker,
        num_workers,
        result_storage_path,
        ckpt_path,
        is_gpu,
        random_seed,
        resource_per_worker,
        is_shuffled,
        enable_progress_bar,
        worker_mode="train",
        **kwargs,
    )
    return par_trainer.fit()

DataModule

Wrapper around Dataset on how the data should be forward to the Lightning Model support hooks at various Lifecycle when the data get pass to the model

protoplast.scrna.anndata.torch_dataloader.AnnDataModule

Bases: LightningDataModule

Source code in src/protoplast/scrna/anndata/torch_dataloader.py
class AnnDataModule(pl.LightningDataModule):
    def __init__(
        self,
        indices: dict,
        dataset: DistributedAnnDataset,
        prefetch_factor: int,
        sparse_key: str,
        shuffle_strategy: ShuffleStrategy,
        before_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        after_dense_cb: Callable[[torch.Tensor, str | int], torch.Tensor] = None,
        override_thread: int | None = None,
        **kwargs,
    ):
        super().__init__()
        self.indices = indices
        self.dataset = dataset
        if override_thread is not None:
            num_threads = override_thread
        else:
            num_threads = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count()))
        self.loader_config = dict(
            num_workers=num_threads,
        )
        if num_threads > 0:
            self.loader_config["prefetch_factor"] = prefetch_factor
            self.loader_config["persistent_workers"] = True
        if shuffle_strategy.is_mixer():
            self.loader_config["batch_size"] = shuffle_strategy.mini_batch_size
            self.loader_config["collate_fn"] = shuffle_strategy.mixer
            self.loader_config["drop_last"] = True
        else:
            self.loader_config["batch_size"] = None
        self.sparse_key = sparse_key
        self.before_dense_cb = before_dense_cb
        self.after_dense_cb = after_dense_cb
        self.kwargs = kwargs

    def setup(self, stage):
        # this is not necessary but it is here in case we want to download data to local node in the future
        if stage == "fit":
            self.train_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, **self.kwargs)
            self.val_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, "val", **self.kwargs)
        if stage == "test":
            self.val_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, "test", **self.kwargs)
        if stage == "predict":
            self.predict_ds = self.dataset.create_distributed_ds(self.indices, self.sparse_key, **self.kwargs)

    def train_dataloader(self):
        return DataLoader(self.train_ds, **self.loader_config)

    def val_dataloader(self):
        return DataLoader(self.val_ds, **self.loader_config)

    def test_dataloader(self):
        # for now not support testing for splitting will support it soon in the future
        return DataLoader(self.val_ds, **self.loader_config)

    def predict_dataloader(self):
        return DataLoader(self.predict_ds, **self.loader_config)

    def densify(self, x, idx: str | int = None):
        if isinstance(x, torch.Tensor):
            if self.before_dense_cb:
                x = self.before_dense_cb(x, idx)
            if x.is_sparse or x.is_sparse_csr:
                x = x.to_dense()
            if self.after_dense_cb:
                x = self.after_dense_cb(x, idx)
        return x

    def on_after_batch_transfer(self, batch, dataloader_idx):
        if (type(batch) is list) or (type(batch) is tuple):
            return [self.densify(d, i) for i, d in enumerate(batch)]
        elif isinstance(batch, dict):
            return {k: self.densify(v, k) for k, v in batch.items()}
        elif isinstance(batch, torch.Tensor):
            return self.densify(batch)
        else:
            return batch

Dataset

For fetching and sending data to the model

protoplast.scrna.anndata.torch_dataloader.DistributedAnnDataset

Bases: IterableDataset

Dataset that support multiworker distribution this version will yield the data in a sequential manner

Parameters:

Name Type Description Default
file_paths list[str]

List of files

required
indices list[list[int]]

List of indices from SplitInfo

required
metadata dict

Metadata dictionary for sending data to the model or other logical purposes

required
sparse_key str

AnnData key for the sparse matrix usually it is "X" if "layers" please use the dot notation for example "layers.attr" where attr is the key in the layer you want to refer to

required
mini_batch_size int

How many observation to send to the model must be less than batch_size, by default None and will send the whole batch if this is not specified

None
Source code in src/protoplast/scrna/anndata/torch_dataloader.py
class DistributedAnnDataset(torch.utils.data.IterableDataset):
    """Dataset that support multiworker distribution this version will yield the data
    in a sequential manner

    Parameters
    ----------
    file_paths : list[str]
        List of files
    indices : list[list[int]]
        List of indices from `SplitInfo`
    metadata : dict
        Metadata dictionary for sending data to the model or other logical purposes
    sparse_key : str
        AnnData key for the sparse matrix usually it is "X" if "layers" please use the dot notation for example
        "layers.attr" where attr is the key in the layer you want to refer to
    mini_batch_size : int, optional
        How many observation to send to the model must be less than `batch_size`, by default None
        and will send the whole batch if this is not specified
    """

    def __init__(
        self,
        file_paths: list[str],
        indices: list[list[int]],
        metadata: dict,
        sparse_key: str,
        mini_batch_size: int = None,
        **kwargs,  # FIXME: workaround for PROTO-23
    ):
        # use first file as reference first
        self.files = file_paths
        self.sparse_key = sparse_key
        self.X = None
        self.ad = None
        # map each gene to an index
        for k, v in metadata.items():
            setattr(self, k, v)
        self.metadata = metadata
        self.batches = indices
        self.mini_batch_size = mini_batch_size
        self.counter = 0
        if "random_seed" in kwargs:
            self.random_seed = kwargs["random_seed"]
        else:
            self.random_seed = None

    @classmethod
    def create_distributed_ds(cls, indices: SplitInfo, sparse_key: str, mode: str = "train", **kwargs):
        indices = indices.to_dict() if isinstance(indices, SplitInfo) else indices
        return cls(
            indices["files"],
            indices[f"{mode}_indices"],
            indices["metadata"],
            sparse_key,
            mini_batch_size=indices.get("mini_batch_size"),
            **kwargs,
        )

    def _init_rank(self):
        worker_info = get_worker_info()
        if worker_info is None:
            self.wid = 0
            self.nworkers = 1
        else:
            self.wid = worker_info.id
            self.nworkers = worker_info.num_workers
        try:
            w_rank = td.get_rank()
            w_size = td.get_world_size()
        except ValueError:
            w_rank = -1
            w_size = -1
        if w_rank >= 0:
            self.ray_rank = w_rank
            self.ray_size = w_size
        else:
            self.ray_rank = 0
            self.ray_size = 1
        self.global_rank = self.ray_rank * self.nworkers + self.wid
        self.total_workers = self.ray_size * self.nworkers

    def _process_sparse(self, mat) -> torch.Tensor:
        if sp.issparse(mat):
            return torch.sparse_csr_tensor(
                torch.from_numpy(mat.indptr).long(),
                torch.from_numpy(mat.indices).long(),
                torch.from_numpy(mat.data).float(),
                mat.shape,
            )
        return torch.from_numpy(mat).float()

    def _get_mat_by_range(self, ad: anndata.AnnData, start: int, end: int) -> sp.csr_matrix:
        if self.sparse_key == "X":
            return ad.X[start:end]
        elif "layers" in self.sparse_key:
            _, attr = self.sparse_key.split(".")
            return ad.layers[attr][start:end]
        else:
            raise Exception("Sparse key not supported")

    def transform(self, start: int, end: int):
        """The subclass should implement the logic to get more data for the cell. It can leverage this super function
        to efficiently get X as a sparse tensor. An example of how to get to more data from the cell is
        `self.ad.obs["key"][start:end]` where you must only fetch a subset of this data with `start` and `end`


        Parameters
        ----------
        start : int
            Starting index of this batch
        end : int
            Ending index of this batch

        Returns
        -------
        Any
            Usually a tensor, a list of tensor or dictionary with tensor value
        """
        # by default we just return the matrix
        # sometimes, the h5ad file stores X as the dense matrix,
        # so we have to make sure it is a sparse matrix before returning
        # the batch item
        if self.X is None:
            # we don't have the X upstream, so we have to incurr IO to fetch it
            self.X = self._get_mat_by_range(self.ad, start, end)
        X = self._process_sparse(self.X)
        return X

    def __len__(self):
        try:
            world_size = td.get_world_size()
        except ValueError:
            logging.warning("Not using tdd default to world size 1")
            world_size = 1
        if self.mini_batch_size:
            total_sample = sum(end - start for i in range(len(self.files)) for start, end in self.batches[i])
            return math.ceil(total_sample / self.mini_batch_size / world_size)
        return sum(1 for i in range(len(self.files)) for _ in self.batches[i]) / world_size

    def __iter__(self):
        self._init_rank()
        if self.random_seed:
            logger.debug(f"Counter value: {self.counter}, seed value: {self.random_seed}")
            random.seed(self.random_seed + self.counter)
        for fidx, f in enumerate(self.files):
            self.ad = anndata.read_h5ad(f, backed="r")
            # ensure each epoch have different data order
            random.shuffle(self.batches[fidx])
            total_mini_batches = 0
            if self.mini_batch_size is not None:
                total_mini_batches = sum((end - start) // self.mini_batch_size for start, end in self.batches[fidx])
            else:
                # Treat whole batch as one mini-batch
                self.mini_batch_size = self.batches[fidx][0][1] - self.batches[fidx][0][0]
                total_mini_batches = len(self.batches[fidx])

            # Find range of the batches assigned to this worker
            mini_batch_per_worker = (
                total_mini_batches // self.total_workers
            )  # This number is ALWAYS divisble by total_workers
            mini_batch_per_batch = (
                self.batches[fidx][0][1] - self.batches[fidx][0][0]
            ) // self.mini_batch_size  # Will be 1 if mini_batch_size is None
            if mini_batch_per_batch == 0:
                mini_batch_per_batch = 1  # Handle case when mini_batch_size > batch size

            start_mini_batch_gidx = self.global_rank * mini_batch_per_worker  # a.k.a offset
            end_mini_batch_gidx = start_mini_batch_gidx + mini_batch_per_worker  # exclusive

            start_batch_gidx = start_mini_batch_gidx // mini_batch_per_batch
            end_batch_gidx = end_mini_batch_gidx // mini_batch_per_batch

            # Adjust the index of the first and last mini-batch in the first and last batch respectively
            # Only apply when a batch contains multiple mini-batches
            current_worker_batches = self.batches[fidx][start_batch_gidx : end_batch_gidx + 1]
            if mini_batch_per_batch != 1:
                # Offset the index of first mini-batch
                current_worker_batches[0] = (
                    current_worker_batches[0][0]
                    + (start_mini_batch_gidx % mini_batch_per_batch) * self.mini_batch_size,
                    current_worker_batches[0][1],
                )

                if len(current_worker_batches) > 1:
                    # Offset the index of last mini-batch
                    total_mini_batches_exclude_last = sum(
                        (end - start) // self.mini_batch_size for start, end in current_worker_batches[:-1]
                    )
                    remainder = mini_batch_per_worker - total_mini_batches_exclude_last
                    current_worker_batches[-1] = (
                        current_worker_batches[-1][0],
                        current_worker_batches[-1][0] + remainder * self.mini_batch_size,
                    )

            # NOTE: Black magic to improve read performance during data yielding
            if len(current_worker_batches) > 1:
                current_worker_batches = current_worker_batches[1:] + current_worker_batches[:1]

            yielded_mini_batches = 0
            for i, (start, end) in enumerate(current_worker_batches):
                # Fetch the whole block & start yielding data
                X = self._get_mat_by_range(self.ad, start, end)
                for i in range(0, X.shape[0], self.mini_batch_size):
                    # index on the X coordinates
                    b_start, b_end = i, min(i + self.mini_batch_size, X.shape[0])

                    # index on the adata coordinates
                    global_start, global_end = start + i, min(start + i + self.mini_batch_size, end)
                    self.X = X[b_start:b_end]

                    yield self.transform(global_start, global_end)
                    yielded_mini_batches += 1

                    if yielded_mini_batches >= mini_batch_per_worker:
                        break

                if yielded_mini_batches >= mini_batch_per_worker:
                    break
        self.counter += 1

transform(start: int, end: int)

The subclass should implement the logic to get more data for the cell. It can leverage this super function to efficiently get X as a sparse tensor. An example of how to get to more data from the cell is self.ad.obs["key"][start:end] where you must only fetch a subset of this data with start and end

Parameters:

Name Type Description Default
start int

Starting index of this batch

required
end int

Ending index of this batch

required

Returns:

Type Description
Any

Usually a tensor, a list of tensor or dictionary with tensor value

Source code in src/protoplast/scrna/anndata/torch_dataloader.py
def transform(self, start: int, end: int):
    """The subclass should implement the logic to get more data for the cell. It can leverage this super function
    to efficiently get X as a sparse tensor. An example of how to get to more data from the cell is
    `self.ad.obs["key"][start:end]` where you must only fetch a subset of this data with `start` and `end`


    Parameters
    ----------
    start : int
        Starting index of this batch
    end : int
        Ending index of this batch

    Returns
    -------
    Any
        Usually a tensor, a list of tensor or dictionary with tensor value
    """
    # by default we just return the matrix
    # sometimes, the h5ad file stores X as the dense matrix,
    # so we have to make sure it is a sparse matrix before returning
    # the batch item
    if self.X is None:
        # we don't have the X upstream, so we have to incurr IO to fetch it
        self.X = self._get_mat_by_range(self.ad, start, end)
    X = self._process_sparse(self.X)
    return X

protoplast.scrna.anndata.strategy.ShuffleStrategy

Bases: ABC

Strategy on how to data should be split and shuffle during the training

Parameters:

Name Type Description Default
file_paths list[str]

List of file paths

required
batch_size int

How much data to fetch

required
total_workers int

Total workers this is equal to number of processes times number of threads per process

required
test_size float | None

Fraction of test data for example 0.1 means 10% will be split for testing, by default None

None
validation_size float | None

Fraction of validation data for example 0.2 means 20% will be split for validation, by default None

None
random_seed int | None

Seed to randomize the split set this to None if you want this to be completely random, by default 42

42
metadata_cb Callable[[AnnData, dict], None] | None

Callback to mutate metadata recommended for passing data from obs or var or any additional data your models required by default cell_line_metadata_cb

None
is_shuffled bool

Whether to shuffle the data or not this will be deprecated soon, by default True

True
Source code in src/protoplast/scrna/anndata/strategy.py
class ShuffleStrategy(ABC):
    """Strategy on how to data should be split and shuffle during
    the training

    Parameters
    ----------
    file_paths : list[str]
        List of file paths
    batch_size : int
        How much data to fetch
    total_workers : int
        Total workers this is equal to number of processes times number of threads per process
    test_size : float | None, optional
        Fraction of test data for example 0.1 means 10% will be split for testing, by default None
    validation_size : float | None, optional
        Fraction of validation data for example 0.2 means 20% will be split for validation, by default None
    random_seed : int | None, optional
        Seed to randomize the split set this to None if you want this to be completely random, by default 42
    metadata_cb : Callable[[anndata.AnnData, dict], None] | None, optional
        Callback to mutate metadata recommended for passing data from `obs` or `var`
        or any additional data your models required
        by default cell_line_metadata_cb
    is_shuffled : bool, optional
        Whether to shuffle the data or not this will be deprecated soon, by default True
    """

    def __init__(
        self,
        file_paths: list[str],
        batch_size: int,
        total_workers: int,
        test_size: float | None = None,
        validation_size: float | None = None,
        random_seed: int | None = 42,
        metadata_cb: Callable[[anndata.AnnData, dict], None] | None = None,
        is_shuffled: bool = True,
    ) -> None:
        self.file_paths = file_paths
        self.batch_size = batch_size
        self.total_workers = total_workers
        self.test_size = test_size
        self.validation_size = validation_size
        self.random_seed = random_seed
        self.metadata_cb = metadata_cb
        self.is_shuffled = is_shuffled
        self.rng = random.Random(random_seed) if random_seed else random.Random()

    @staticmethod
    def is_mixer():
        return False

    @abstractmethod
    def split(self) -> SplitInfo:
        """
        How you want to split the data in each worker must return SplitInfo
        """
        pass

    @abstractmethod
    def mixer(self, batch: list) -> any:
        """
        If your Dataset only return 1 sample and not prebatched
        this need to be implemented
        """
        pass

mixer(batch: list) -> any abstractmethod

If your Dataset only return 1 sample and not prebatched this need to be implemented

Source code in src/protoplast/scrna/anndata/strategy.py
@abstractmethod
def mixer(self, batch: list) -> any:
    """
    If your Dataset only return 1 sample and not prebatched
    this need to be implemented
    """
    pass

split() -> SplitInfo abstractmethod

How you want to split the data in each worker must return SplitInfo

Source code in src/protoplast/scrna/anndata/strategy.py
@abstractmethod
def split(self) -> SplitInfo:
    """
    How you want to split the data in each worker must return SplitInfo
    """
    pass

protoplast.scrna.anndata.strategy.SplitInfo dataclass

Source code in src/protoplast/scrna/anndata/strategy.py
@dataclass
class SplitInfo:
    files: list[str]
    train_indices: list[list[int]]
    val_indices: list[list[int]]
    test_indices: list[list[int]]
    metadata: dict[str, any]
    mini_batch_size: int | None = None
    """Information on how to split the data
    this will get pass to the Dataset to know which part of the data
    they need to access

    Parameters
    ----------
    files : list[str]
        List of files
    train_indices : list[list[str]]
        List of indices for training `train_indices[file_idx][batch_idx]` where `file_idx` must correspond
        to the idx of `files` parameter
    val_indices : list[list[str]]
        List of indices for validation `val_indices[file_idx][batch_idx]` where `file_idx` must correspond
        to the idx of `files` parameter
    test_indices : list[list[str]]
        List of indices for testing `test_indices[file_idx][batch_idx]` where `file_idx` must correspond
        to the idx of `files` parameter
    metadata : dict[str, any]
        Data to pass on to the Dataset and model
    mini_batch_size : int | None
        How much data to send to the model
    """

    def to_dict(self) -> dict[str, any]:
        return {
            "files": self.files,
            "train_indices": self.train_indices,
            "val_indices": self.val_indices,
            "test_indices": self.test_indices,
            "metadata": self.metadata,
            "mini_batch_size": self.mini_batch_size,
        }

mini_batch_size: int | None = None class-attribute instance-attribute

Information on how to split the data this will get pass to the Dataset to know which part of the data they need to access

Parameters:

Name Type Description Default
files list[str]

List of files

required
train_indices list[list[str]]

List of indices for training train_indices[file_idx][batch_idx] where file_idx must correspond to the idx of files parameter

required
val_indices list[list[str]]

List of indices for validation val_indices[file_idx][batch_idx] where file_idx must correspond to the idx of files parameter

required
test_indices list[list[str]]

List of indices for testing test_indices[file_idx][batch_idx] where file_idx must correspond to the idx of files parameter

required
metadata dict[str, any]

Data to pass on to the Dataset and model

required
mini_batch_size int | None

How much data to send to the model

required

protoplast.scrna.anndata.strategy.SequentialShuffleStrategy

Bases: ShuffleStrategy

Return the data in a sequential way randomness is not guarantee there is a high chance the data will come from nearby rows this might affect your training accuracy depending on how the anndata are ordered you can overcome this by preshuffling the data manually yourself if this is an issue

Parameters:

Name Type Description Default
file_paths list[str]

List of file paths

required
batch_size int

How much data to fetch

required
total_workers int

Total workers this is equal to number of processes times number of threads per process

required
test_size float | None

Fraction of test data for example 0.1 means 10% will be split for testing, by default None

None
validation_size float | None

Fraction of validation data for example 0.2 means 20% will be split for validation, by default None

None
random_seed int | None

Seed to randomize the split set this to None if you want this to be completely random, by default 42

42
metadata_cb Callable[[AnnData, dict], None] | None

Callback to mutate metadata recommended for passing data from obs or var or any additional data your models required by default cell_line_metadata_cb

None
is_shuffled bool

Whether to shuffle the data or not this will be deprecated soon, by default True

False
pre_fetch_then_batch int | None

The prefetch factor the total size of data fetch will be equal to pre_fetch_then_batch * batch_size

16
drop_last bool

If there is true drop the remainder, default to True otherwise duplicate the data to make sure the data is evenly distributed to all the workers

True
Source code in src/protoplast/scrna/anndata/strategy.py
class SequentialShuffleStrategy(ShuffleStrategy):
    """Return the data in a sequential way randomness is not guarantee
    there is a high chance the data will come from nearby rows this might
    affect your training accuracy depending on how the anndata are ordered you can
    overcome this by preshuffling the data manually yourself if this is an issue

    Parameters
    ----------
    file_paths : list[str]
        List of file paths
    batch_size : int
        How much data to fetch
    total_workers : int
        Total workers this is equal to number of processes times number of threads per process
    test_size : float | None, optional
        Fraction of test data for example 0.1 means 10% will be split for testing, by default None
    validation_size : float | None, optional
        Fraction of validation data for example 0.2 means 20% will be split for validation, by default None
    random_seed : int | None, optional
        Seed to randomize the split set this to None if you want this to be completely random, by default 42
    metadata_cb : Callable[[anndata.AnnData, dict], None] | None, optional
        Callback to mutate metadata recommended for passing data from `obs` or `var`
        or any additional data your models required
        by default cell_line_metadata_cb
    is_shuffled : bool, optional
        Whether to shuffle the data or not this will be deprecated soon, by default True
    pre_fetch_then_batch : int | None
        The prefetch factor the total size of data fetch will be equal to `pre_fetch_then_batch * batch_size`
    drop_last : bool
        If there is true drop the remainder, default to True otherwise duplicate the data to make sure the
        data is evenly distributed to all the workers
    """

    def __init__(
        self,
        file_paths: list[str],
        batch_size: int,
        total_workers: int,
        test_size: float | None = None,
        validation_size: float | None = None,
        random_seed: int | None = 42,
        metadata_cb: Callable[[anndata.AnnData, dict], None] | None = None,
        is_shuffled: bool = False,
        pre_fetch_then_batch: int | None = 16,
        drop_last: bool = True,
        **kwargs,
    ) -> None:
        super().__init__(
            file_paths,
            batch_size,
            total_workers,
            test_size,
            validation_size,
            random_seed,
            metadata_cb,
            is_shuffled,
        )
        self.pre_fetch_then_batch = pre_fetch_then_batch
        self.drop_last = drop_last
        self.is_disable_balancing = kwargs.get("is_disable_balancing", False)

    def split(self) -> SplitInfo:
        split_dict = ann_split_data(
            self.file_paths,
            self.batch_size,
            self.total_workers,
            self.test_size,
            self.validation_size,
            self.rng,
            self.metadata_cb,
            self.is_shuffled,
            self.drop_last,
            prefetch_factor=self.pre_fetch_then_batch,
            is_disable_balancing=self.is_disable_balancing,
        )
        # this will be passed to the dataset, inorder to know the mini batch size
        if self.pre_fetch_then_batch:
            split_dict["mini_batch_size"] = self.batch_size
        else:
            # yield everything we read
            split_dict["mini_batch_size"] = None
        return SplitInfo(**split_dict)

    def mixer(self, batch: list):
        return super().mixer(batch)