Skip to content

Model

model can be replaced with ml, e.g., concord.model.ConcordModel can be concord.ml.ConcordModel

ConcordModel

Bases: Module

A contrastive learning model for domain-aware and covariate-aware latent representations.

This model consists of an encoder, decoder, and optional classifier head. It supports probabilistic augmentation and domain/covariate embeddings.

Attributes:

Name Type Description
domain_embedding_dim int

Dimensionality of domain embeddings.

input_dim int

Input feature dimension.

augmentation_mask Dropout

Dropout layer for augmentation masking.

use_classifier bool

Whether to include a classifier head.

use_decoder bool

Whether to include a decoder head.

use_importance_mask bool

Whether to include an importance mask for feature selection.

encoder Sequential

Encoder layers.

decoder Sequential

Decoder layers.

classifier Sequential

Classifier head.

importance_mask Parameter

Learnable importance mask.

__init__(input_dim, hidden_dim, num_domains, num_classes, domain_embedding_dim=0, covariate_embedding_dims={}, covariate_num_categories={}, encoder_dims=[], decoder_dims=[], augmentation_mask_prob=0.3, dropout_prob=0.1, norm_type='layer_norm', use_decoder=True, decoder_final_activation='leaky_relu', use_classifier=False, use_importance_mask=False)

Initializes the Concord model.

Parameters:

Name Type Description Default
input_dim int

Number of input features.

required
hidden_dim int

Latent representation dimensionality.

required
num_domains int

Number of unique domains for embeddings.

required
num_classes int

Number of unique classes for classification.

required
domain_embedding_dim int

Dimensionality of domain embeddings. Defaults to 0.

0
covariate_embedding_dims dict

Dictionary mapping covariate keys to embedding dimensions.

{}
covariate_num_categories dict

Dictionary mapping covariate keys to category counts.

{}
encoder_dims list

List of encoder layer sizes. Defaults to empty list.

[]
decoder_dims list

List of decoder layer sizes. Defaults to empty list.

[]
augmentation_mask_prob float

Dropout probability for augmentation mask. Defaults to 0.3.

0.3
dropout_prob float

Dropout probability for encoder/decoder layers. Defaults to 0.1.

0.1
norm_type str

Normalization type ('layer_norm' or 'batch_norm'). Defaults to 'layer_norm'.

'layer_norm'
use_decoder bool

Whether to include a decoder. Defaults to True.

True
decoder_final_activation str

Activation function for decoder output. Defaults to 'leaky_relu'.

'leaky_relu'
use_classifier bool

Whether to include a classifier head. Defaults to False.

False
use_importance_mask bool

Whether to learn an importance mask for input features. Defaults to False.

False

_initialize_weights()

Initializes model weights using Kaiming normal initialization.

forward(x, domain_labels=None, covariate_tensors=None, return_latent=False)

Performs a forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

Input data.

required
domain_labels Tensor

Domain labels for embedding lookup.

None
covariate_tensors dict

Dictionary of covariate labels.

None
return_latent bool

Whether to return latent layer outputs.

False

Returns:

Name Type Description
dict

A dictionary with encoded representations, decoded outputs (if enabled), classifier predictions (if enabled), and latent activations (if requested).

freeze_encoder()

Freezes encoder weights to prevent updates during training.

get_importance_weights()

Retrieves the learned importance weights for input features.

Returns:

Type Description

torch.Tensor: The importance weights.

load_model(path, device)

Loads a pre-trained model state.

Parameters:

Name Type Description Default
path str or Path

Path to the saved model checkpoint.

required
device device

Device to load the model onto.

required

ConcordSampler

Bases: Sampler

A custom PyTorch sampler that performs probabilistic domain-aware and neighborhood-aware batch sampling for contrastive learning.

This sampler selects samples from both intra-domain and inter-domain distributions based on configurable probabilities.

Attributes:

Name Type Description
batch_size int

Number of samples per batch.

p_intra_knn float

Probability of selecting samples from k-NN neighborhoods.

p_intra_domain_dict dict

Dictionary mapping domain indices to intra-domain probabilities.

device device

Device to store tensors (default: GPU if available).

domain_ids Tensor

Tensor containing domain labels for each sample.

neighborhood Neighborhood

Precomputed k-NN index.

unique_domains Tensor

Unique domain categories.

domain_counts Tensor

Number of samples per domain.

valid_batches list

List of precomputed valid batches.

min_batch_size int

Minimum allowed batch size.

__init__(batch_size, domain_ids, neighborhood, p_intra_knn=0.3, p_intra_domain_dict=None, min_batch_size=4, device=None)

Initializes the ConcordSampler.

Parameters:

Name Type Description Default
batch_size int

Number of samples per batch.

required
domain_ids Tensor

Tensor of domain labels for each sample.

required
neighborhood Neighborhood

Precomputed k-NN index.

required
p_intra_knn float

Probability of selecting samples from k-NN neighborhoods. Default is 0.3.

0.3
p_intra_domain_dict dict

Dictionary mapping domain indices to intra-domain probabilities. Default is None.

None
min_batch_size int

Minimum allowed batch size. Default is 4.

4
device device

Device to store tensors. Defaults to GPU if available.

None

__iter__()

Iterator for sampling batches.

Yields:

Type Description

torch.Tensor: A batch of sample indices.

__len__()

Returns the number of batches.

Returns:

Name Type Description
int

Number of valid batches.

_generate_batches()

Generates batches based on intra-domain and intra-neighborhood probabilities.

Returns:

Name Type Description
list

A list of valid batches.

permute_nonneg_and_fill(x, ncol) staticmethod

Permutes non-negative values and fills remaining positions with -1.

Parameters:

Name Type Description Default
x Tensor

Input tensor containing indices.

required
ncol int

Number of columns to keep.

required

Returns:

Type Description

torch.Tensor: Permuted tensor with -1s filling unused positions.

DataLoaderManager

Manages data loading for training and evaluation, including optional sampling.

This class handles embedding computation, k-NN graph construction, domain-aware sampling, and splits data into train/validation sets when needed.

Attributes:

Name Type Description
input_layer_key str

Key for input layer in AnnData.

domain_key str

Key for domain labels in adata.obs.

class_key str

Key for class labels in adata.obs. Defaults to None.

covariate_keys list

List of covariate keys in adata.obs. Defaults to None.

batch_size int

Batch size for data loading.

train_frac float

Fraction of data used for training.

use_sampler bool

Whether to use a custom sampler.

sampler_emb str

Key for embeddings used in sampling.

sampler_knn int

Number of k-nearest neighbors for sampling.

p_intra_knn float

Probability of intra-cluster sampling.

p_intra_domain float or dict

Probability of intra-domain sampling.

min_p_intra_domain float

Minimum probability for intra-domain sampling.

max_p_intra_domain float

Maximum probability for intra-domain sampling.

clr_mode str

Contrastive learning mode.

dist_metric str

Distance metric for k-NN graph.

pca_n_comps int

Number of PCA components used in embedding computation.

use_faiss bool

Whether to use FAISS for fast k-NN computation.

use_ivf bool

Whether to use IVF indexing for FAISS.

ivf_nprobe int

Number of probes for IVF-Faiss.

preprocess callable

Preprocessing function for adata.

num_cores int

Number of CPU cores for parallel processing.

device device

Device for computation (CPU or CUDA).

__init__(input_layer_key, domain_key, class_key=None, covariate_keys=None, batch_size=32, train_frac=0.9, use_sampler=True, sampler_emb=None, sampler_knn=300, p_intra_knn=0.3, p_intra_domain=None, min_p_intra_domain=1.0, max_p_intra_domain=1.0, clr_mode='aug', dist_metric='euclidean', pca_n_comps=50, use_faiss=True, use_ivf=False, ivf_nprobe=8, preprocess=None, num_cores=None, device=None)

Initializes the DataLoaderManager.

Parameters:

Name Type Description Default
input_layer_key str

Key for input layer in adata.

required
domain_key str

Key for domain labels in adata.obs.

required
class_key str

Key for class labels. Defaults to None.

None
covariate_keys list

List of covariate keys. Defaults to None.

None
batch_size int

Batch size. Defaults to 32.

32
train_frac float

Fraction of data used for training. Defaults to 0.9.

0.9
use_sampler bool

Whether to use the custom sampler. Defaults to True.

True
sampler_emb str

Key for embeddings used in sampling.

None
sampler_knn int

Number of neighbors for k-NN sampling. Defaults to 300.

300
p_intra_knn float

Probability of intra-cluster sampling. Defaults to 0.3.

0.3
p_intra_domain float or dict

Probability of intra-domain sampling.

None
min_p_intra_domain float

Minimum probability for intra-domain sampling. Defaults to 1.0.

1.0
max_p_intra_domain float

Maximum probability for intra-domain sampling. Defaults to 1.0.

1.0
clr_mode str

Contrastive learning mode. Defaults to 'aug'.

'aug'
dist_metric str

Distance metric for k-NN. Defaults to 'euclidean'.

'euclidean'
pca_n_comps int

Number of PCA components. Defaults to 50.

50
use_faiss bool

Whether to use FAISS. Defaults to True.

True
use_ivf bool

Whether to use IVF-Faiss indexing. Defaults to False.

False
ivf_nprobe int

Number of probes for IVF-Faiss. Defaults to 8.

8
preprocess callable

Preprocessing function for adata.

None
num_cores int

Number of CPU cores. Defaults to None.

None
device device

Device for computation. Defaults to None.

None

anndata_to_dataloader(adata)

Converts an AnnData object to PyTorch DataLoader.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required

Returns:

Name Type Description
tuple

Train DataLoader, validation DataLoader (if train_frac < 1.0), and data structure.

compute_embedding_and_knn(emb_key='X_pca')

Constructs a k-NN graph based on existing embedding or PCA (of not exist, compute automatically).

Parameters:

Name Type Description Default
emb_key str

Key for embedding basis. Defaults to 'X_pca'.

'X_pca'

ChunkLoader

A class for handling large datasets in chunks for efficient training.

This class manages chunk-based data loading for large single-cell datasets, allowing training in smaller subsets without loading the entire dataset into memory.

Attributes:

Name Type Description
adata AnnData

The annotated data matrix.

input_layer_key str

Key for the input layer in adata.

domain_key str

Key for domain labels in adata.obs.

class_key str

Key for class labels in adata.obs.

covariate_keys list

List of keys for covariates in adata.obs.

chunk_size int

The number of samples per chunk.

batch_size int

The batch size used for training.

train_frac float

Fraction of data to be used for training.

sampler_mode str

Mode for sampling (e.g., 'domain').

sampler_knn int

Number of nearest neighbors for k-NN-based sampling.

emb_key str

Key for embedding space used in sampling.

use_faiss bool

Whether to use FAISS for k-NN computation.

use_ivf bool

Whether to use an IVF FAISS index.

ivf_nprobe int

Number of probes for IVF FAISS index.

class_weights dict

Weights for balancing class sampling.

p_intra_knn float

Probability of sampling within k-NN.

p_intra_domain float

Probability of sampling within the same domain.

p_intra_class float

Probability of sampling within the same class.

drop_last bool

Whether to drop the last batch if it is smaller than batch_size.

preprocess callable

Preprocessing function to apply to the dataset.

device device

Device on which to load data (CPU or CUDA).

total_samples int

Total number of samples in the dataset.

num_chunks int

Number of chunks required to load the full dataset.

indices ndarray

Array of shuffled indices for chunking.

data_structure list

Structure of the dataset.

Methods:

Name Description
__len__

Returns the number of chunks.

_shuffle_indices

Randomly shuffles dataset indices.

_load_chunk

Loads a specific chunk of data.

__iter__

Initializes the chunk iterator.

__next__

Retrieves the next chunk of data.

__init__(adata, input_layer_key, domain_key, class_key=None, covariate_keys=None, chunk_size=10000, batch_size=32, train_frac=0.9, sampler_mode='domain', emb_key=None, sampler_knn=300, p_intra_knn=0.3, p_intra_domain=1.0, use_faiss=True, use_ivf=False, ivf_nprobe=8, class_weights=None, p_intra_class=0.3, drop_last=True, preprocess=None, device=None)

Initializes the ChunkLoader.

Parameters:

Name Type Description Default
adata AnnData

The annotated data matrix.

required
input_layer_key str

Key for the input layer in adata.

required
domain_key str

Key for domain labels in adata.obs.

required
class_key str

Key for class labels in adata.obs. Default is None.

None
covariate_keys list

List of covariate keys in adata.obs. Default is None.

None
chunk_size int

Number of samples per chunk. Default is 10,000.

10000
batch_size int

Batch size used in training. Default is 32.

32
train_frac float

Fraction of data for training. Default is 0.9.

0.9
sampler_mode str

Sampling mode ('domain', etc.). Default is "domain".

'domain'
emb_key str

Key for the embedding space used in sampling. Default is None.

None
sampler_knn int

Number of nearest neighbors for k-NN sampling. Default is 300.

300
p_intra_knn float

Probability of sampling within k-NN. Default is 0.3.

0.3
p_intra_domain float

Probability of sampling within the same domain. Default is 1.0.

1.0
use_faiss bool

Whether to use FAISS for k-NN. Default is True.

True
use_ivf bool

Whether to use an IVF FAISS index. Default is False.

False
ivf_nprobe int

Number of probes for IVF FAISS index. Default is 8.

8
class_weights dict

Dictionary of class weights for balancing. Default is None.

None
p_intra_class float

Probability of sampling within the same class. Default is 0.3.

0.3
drop_last bool

Whether to drop the last batch if it is smaller than batch_size. Default is True.

True
preprocess callable

Function to preprocess the dataset. Default is None.

None
device device

Device on which to load data (CPU/GPU). Default is CUDA if available.

None

__iter__()

Initializes the chunk iterator.

Returns:

Name Type Description
ChunkLoader

The chunk loader object itself.

__len__()

Returns the total number of chunks.

Returns:

Name Type Description
int

Number of chunks.

__next__()

Retrieves the next chunk of data.

Returns:

Name Type Description
tuple

Training DataLoader, Validation DataLoader, and chunk indices.

Raises:

Type Description
StopIteration

If all chunks have been iterated over.

_load_chunk(chunk_idx)

Loads a specific chunk of data.

Parameters:

Name Type Description Default
chunk_idx int

Index of the chunk to load.

required

Returns:

Name Type Description
tuple

Training DataLoader, Validation DataLoader, and chunk indices.

_shuffle_indices()

Randomly shuffles dataset indices for chunking.

AnnDataset

Bases: Dataset

A PyTorch Dataset class for handling annotated datasets (AnnData).

This dataset is designed to work with single-cell RNA-seq data stored in AnnData objects. It extracts relevant features, domain labels, class labels, and covariate labels while handling sparse and dense matrices.

Attributes:

Name Type Description
adata AnnData

The annotated data matrix.

input_layer_key str

The key to retrieve input features from adata.

domain_key str

The key in adata.obs specifying domain labels.

class_key str

The key in adata.obs specifying class labels.

covariate_keys list

A list of keys for covariate labels in adata.obs.

device device

The device to store tensors (GPU or CPU).

data Tensor

Tensor containing input data.

domain_labels Tensor

Tensor containing domain labels.

class_labels Tensor

Tensor containing class labels if provided.

covariate_tensors dict

A dictionary containing tensors for covariate labels.

indices ndarray

Array of dataset indices.

data_structure list

A list describing the dataset structure.

__getitem__(idx)

Retrieves the dataset items for the given index.

Parameters:

Name Type Description Default
idx int or list

The sample index.

required

Returns:

Name Type Description
tuple

A tuple containing input data, domain labels, class labels, and covariate tensors.

__init__(adata, input_layer_key='X', domain_key='domain', class_key=None, covariate_keys=None, device=None)

Initializes the AnnDataset.

Parameters:

Name Type Description Default
adata AnnData

The annotated dataset.

required
input_layer_key str

Key to extract input data. Defaults to 'X'.

'X'
domain_key str

Key for domain labels in adata.obs. Defaults to 'domain'.

'domain'
class_key str

Key for class labels in adata.obs. Defaults to None.

None
covariate_keys list

List of keys for covariate labels in adata.obs. Defaults to None.

None
device device

Device to store tensors (GPU or CPU). Defaults to GPU if available.

None

Raises:

Type Description
ValueError

If domain or class key is not found in adata.obs.

__len__()

Returns the number of samples in the dataset.

Returns:

Name Type Description
int

The dataset size.

_get_data_matrix()

Retrieves the feature matrix from adata.

Returns:

Type Description

np.ndarray: The feature matrix as a NumPy array.

Raises:

Type Description
KeyError

If the specified input layer is not found.

_init_data_structure()

Initializes the structure of the dataset.

Returns:

Name Type Description
list

A list defining the dataset structure.

get_class_labels(idx)

Retrieves the class labels for a given index.

Parameters:

Name Type Description Default
idx int or list

Index or indices to retrieve.

required

Returns:

Type Description

torch.Tensor: The class labels.

get_data_structure()

Returns the data structure of the dataset.

Returns:

Name Type Description
list

A list defining the dataset structure.

get_domain_labels(idx)

Retrieves the domain labels for a given index.

Parameters:

Name Type Description Default
idx int or list

Index or indices to retrieve.

required

Returns:

Type Description

torch.Tensor: The domain labels.

get_embedding(embedding_key, idx)

Retrieves embeddings for a given key and index.

Parameters:

Name Type Description Default
embedding_key str

The embedding key in adata.obsm.

required
idx int or list

Index or indices to retrieve.

required

Returns:

Type Description

np.ndarray: The embedding matrix.

Raises:

Type Description
ValueError

If the embedding key is not found.

shuffle_indices()

Shuffles dataset indices.

subset(idx)

Creates a subset of the dataset with the given indices.

Parameters:

Name Type Description Default
idx list

Indices of the subset.

required

Returns:

Name Type Description
AnnDataset

A new AnnDataset instance containing only the selected indices.

Trainer

A trainer class for optimizing the Concord model.

This class manages the training and validation of the Concord model, including contrastive learning, classification, and reconstruction losses.

Attributes:

Name Type Description
model Module

The neural network model being trained.

data_structure list

The structure of the dataset used in training.

device device

The device on which computations are performed.

logger Logger

Logger for recording training progress.

use_classifier bool

Whether to use a classification head.

classifier_weight float

Weighting factor for classification loss.

unique_classes list

List of unique class labels.

unlabeled_class int or None

Label representing unlabeled samples.

use_decoder bool

Whether to use a decoder for reconstruction loss.

decoder_weight float

Weighting factor for reconstruction loss.

clr_mode str

Contrastive learning mode ('aug' or 'nn').

use_clr bool

Whether contrastive learning is enabled.

clr_weight float

Weighting factor for contrastive loss.

importance_penalty_weight float

Weighting for feature importance penalty.

importance_penalty_type str

Type of regularization for importance penalty.

Methods:

Name Description
forward_pass

Computes loss components and model outputs.

train_epoch

Runs one training epoch.

validate_epoch

Runs one validation epoch.

_run_epoch

Handles training or validation for one epoch.

_compute_averages

Computes average losses over an epoch.

__init__(model, data_structure, device, logger, lr, schedule_ratio, use_classifier=False, classifier_weight=1.0, unique_classes=None, unlabeled_class=None, use_decoder=True, decoder_weight=1.0, clr_mode='aug', clr_temperature=0.5, clr_weight=1.0, importance_penalty_weight=0, importance_penalty_type='L1')

Initializes the Trainer.

Parameters:

Name Type Description Default
model Module

The neural network model to train.

required
data_structure list

List defining the structure of input data.

required
device device

Device on which computations will run.

required
logger Logger

Logger for recording training details.

required
lr float

Learning rate for the optimizer.

required
schedule_ratio float

Learning rate decay factor.

required
use_classifier bool

Whether to use classification. Default is False.

False
classifier_weight float

Weight for classification loss. Default is 1.0.

1.0
unique_classes list

List of unique class labels.

None
unlabeled_class int or None

Label for unlabeled data. Default is None.

None
use_decoder bool

Whether to use a decoder. Default is True.

True
decoder_weight float

Weight for decoder loss. Default is 1.0.

1.0
clr_mode str

Contrastive learning mode ('aug', 'nn'). Default is 'aug'.

'aug'
clr_temperature float

Temperature for contrastive loss. Default is 0.5.

0.5
clr_weight float

Weight for contrastive loss. Default is 1.0.

1.0
importance_penalty_weight float

Weight for importance penalty. Default is 0.

0
importance_penalty_type str

Type of penalty ('L1' or 'L2'). Default is 'L1'.

'L1'

forward_pass(inputs, class_labels, domain_labels, covariate_tensors=None)

Performs a forward pass and computes loss components.

Parameters:

Name Type Description Default
inputs Tensor

Input feature matrix.

required
class_labels Tensor

Class labels for classification loss.

required
domain_labels Tensor

Domain labels for batch normalization.

required
covariate_tensors dict

Dictionary of covariate tensors.

None

Returns:

Name Type Description
tuple

Loss components (total loss, classification loss, MSE loss, contrastive loss, importance penalty loss).

train_epoch(epoch, train_dataloader)

Runs one epoch of training.

Parameters:

Name Type Description Default
epoch int

Current epoch number.

required
train_dataloader DataLoader

Training data loader.

required

Returns:

Name Type Description
float

Average training loss.

validate_epoch(epoch, val_dataloader)

Runs one epoch of validation.

Parameters:

Name Type Description Default
epoch int

Current epoch number.

required
val_dataloader DataLoader

Validation data loader.

required

Returns:

Name Type Description
float

Average validation loss.

Neighborhood

A class for k-nearest neighbor (k-NN) computation using either FAISS or sklearn.

This class constructs a k-NN index, retrieves neighbors, and computes distances between embeddings.

Attributes:

Name Type Description
emb ndarray

The embedding matrix (converted to float32).

k int

Number of nearest neighbors to retrieve.

use_faiss bool

Whether to use FAISS for k-NN computation.

use_ivf bool

Whether to use IVF indexing in FAISS.

ivf_nprobe int

Number of probes for FAISS IVF.

metric str

Distance metric ('euclidean' or 'cosine').

__init__(emb, k=10, use_faiss=True, use_ivf=False, ivf_nprobe=10, metric='euclidean')

Initializes the Neighborhood class.

Parameters:

Name Type Description Default
emb ndarray

The embedding matrix.

required
k int

Number of nearest neighbors to retrieve. Defaults to 10.

10
use_faiss bool

Whether to use FAISS for k-NN computation. Defaults to True.

True
use_ivf bool

Whether to use IVF FAISS index. Defaults to False.

False
ivf_nprobe int

Number of probes for FAISS IVF. Defaults to 10.

10
metric str

Distance metric ('euclidean' or 'cosine'). Defaults to 'euclidean'.

'euclidean'

Raises:

Type Description
ValueError

If there are NaN values in the embedding or if the metric is invalid.

_build_knn_index()

Initializes the k-NN index using FAISS or sklearn.

average_knn_distance(core_samples, mtx, k=None, distance_metric='euclidean')

Compute the average distance to the k-th nearest neighbor for each sample.

Parameters

core_samples : np.ndarray The indices of core samples. mtx : np.ndarray The matrix to compute the distance to. k : int, optional Number of neighbors to retrieve. If None, uses self.k. distance_metric : str Distance metric to use: 'euclidean', 'set_diff', or 'drop_diff'.

Returns

np.ndarray The average distance to the k-th nearest neighbor for each sample.

compute_knn_graph(k=None)

Constructs a sparse adjacency matrix for the k-NN graph.

Parameters:

Name Type Description Default
k int

Number of neighbors. Defaults to self.k.

None

get_knn(core_samples, k=None, include_self=True, return_distance=False)

Retrieves the k-nearest neighbors for given samples.

Parameters:

Name Type Description Default
core_samples ndarray or Tensor

Indices of samples for which k-NN is retrieved.

required
k int

Number of neighbors. Defaults to self.k.

None
include_self bool

Whether to include the sample itself. Defaults to True.

True
return_distance bool

Whether to return distances. Defaults to False.

False

Returns:

Type Description

np.ndarray: Indices of nearest neighbors (and distances if return_distance=True).

get_knn_graph()

Returns the precomputed k-NN graph. Computes it if not available.

Returns:

Type Description

scipy.sparse.csr_matrix: Sparse adjacency matrix of shape (n_samples, n_samples).

update_embedding(new_emb)

Updates the embedding matrix and rebuilds the k-NN index.

Parameters:

Name Type Description Default
new_emb ndarray

The new embedding matrix.

required