Model
model
can be replaced with ml
, e.g., concord.model.ConcordModel
can be concord.ml.ConcordModel
ConcordModel
Bases: Module
A contrastive learning model for domain-aware and covariate-aware latent representations.
This model consists of an encoder, decoder, and optional classifier head.
Attributes:
Name | Type | Description |
---|---|---|
domain_embedding_dim |
int
|
Dimensionality of domain embeddings. |
input_dim |
int
|
Input feature dimension. |
use_classifier |
bool
|
Whether to include a classifier head. |
use_decoder |
bool
|
Whether to include a decoder head. |
use_importance_mask |
bool
|
Whether to include an importance mask for feature selection. |
encoder |
Sequential
|
Encoder layers. |
decoder |
Sequential
|
Decoder layers. |
classifier |
Sequential
|
Classifier head. |
importance_mask |
Parameter
|
Learnable importance mask. |
__init__(input_dim, hidden_dim, num_domains, num_classes, domain_embedding_dim=0, covariate_embedding_dims={}, covariate_num_categories={}, encoder_dims=[], decoder_dims=[], dropout_prob=0.0, norm_type='layer_norm', use_decoder=True, decoder_final_activation='leaky_relu', use_classifier=False, use_importance_mask=False)
Initializes the Concord model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim
|
int
|
Number of input features. |
required |
hidden_dim
|
int
|
Latent representation dimensionality. |
required |
num_domains
|
int
|
Number of unique domains for embeddings. |
required |
num_classes
|
int
|
Number of unique classes for classification. |
required |
domain_embedding_dim
|
int
|
Dimensionality of domain embeddings. Defaults to 0. |
0
|
covariate_embedding_dims
|
dict
|
Dictionary mapping covariate keys to embedding dimensions. |
{}
|
covariate_num_categories
|
dict
|
Dictionary mapping covariate keys to category counts. |
{}
|
encoder_dims
|
list
|
List of encoder layer sizes. Defaults to empty list. |
[]
|
decoder_dims
|
list
|
List of decoder layer sizes. Defaults to empty list. |
[]
|
dropout_prob
|
float
|
Dropout probability for encoder/decoder layers. Defaults to 0.1. |
0.0
|
norm_type
|
str
|
Normalization type ('layer_norm' or 'batch_norm'). Defaults to 'layer_norm'. |
'layer_norm'
|
use_decoder
|
bool
|
Whether to include a decoder. Defaults to True. |
True
|
decoder_final_activation
|
str
|
Activation function for decoder output. Defaults to 'leaky_relu'. |
'leaky_relu'
|
use_classifier
|
bool
|
Whether to include a classifier head. Defaults to False. |
False
|
use_importance_mask
|
bool
|
Whether to learn an importance mask for input features. Defaults to False. |
False
|
_initialize_weights()
Initializes model weights using Kaiming normal initialization.
forward(x, domain_labels=None, covariate_tensors=None)
Performs a forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input data. |
required |
domain_labels
|
Tensor
|
Domain labels for embedding lookup. |
None
|
covariate_tensors
|
dict
|
Dictionary of covariate labels. |
None
|
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary with encoded representations, decoded outputs (if enabled), classifier predictions (if enabled), and latent activations (if requested). |
freeze_encoder()
Freezes encoder weights to prevent updates during training.
get_embeddings(domain_labels=None, covariate_tensors=None)
Retrieves embeddings for the specified domain labels and covariate tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain_labels
|
Tensor
|
Domain labels for embedding lookup. |
None
|
covariate_tensors
|
dict
|
Dictionary of covariate tensors. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor: Concatenated embeddings. |
get_importance_weights()
Retrieves the learned importance weights for input features.
Returns:
Type | Description |
---|---|
torch.Tensor: The importance weights. |
load_model(path, device)
Loads a pre-trained model state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
str or Path
|
Path to the saved model checkpoint. |
required |
device
|
device
|
Device to load the model onto. |
required |
ConcordSampler
Bases: Sampler
A custom PyTorch sampler that performs probabilistic domain-aware and neighborhood-aware batch sampling for contrastive learning.
This sampler selects samples from both intra-domain and inter-domain distributions based on configurable probabilities.
__init__(batch_size, domain_ids, neighborhood, p_intra_knn=0.3, p_intra_domain=1.0, min_batch_size=4, domain_minibatch_strategy='proportional', domain_minibatch_min_count=1, domain_coverage=None, sample_with_replacement=False, device=None)
Initializes the ConcordSampler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size
|
int
|
Number of samples per batch. |
required |
domain_ids
|
Tensor
|
Tensor of domain labels for each sample. |
required |
neighborhood
|
Neighborhood
|
Precomputed k-NN index. |
required |
p_intra_knn
|
float
|
Probability of selecting samples from k-NN neighborhoods. Default is 0.3. |
0.3
|
p_intra_domain
|
dict
|
Probability of selecting samples from the same domain. |
1.0
|
min_batch_size
|
int
|
Minimum allowed batch size. Default is 4. |
4
|
device
|
device
|
Device to store tensors. Defaults to GPU if available. |
None
|
__iter__()
Iterator for sampling batches.
Yields:
Type | Description |
---|---|
torch.Tensor: A batch of sample indices. |
__len__()
Returns the number of batches.
Returns:
Name | Type | Description |
---|---|---|
int |
Number of valid batches. |
_calculate_num_batches_per_domain()
Calculates the number of minibatches to generate for each domain based on the chosen strategy.
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary mapping each domain ID to its number of batches. |
_generate_batches()
Generates batches based on intra-domain and intra-neighborhood probabilities.
Returns:
Name | Type | Description |
---|---|---|
list |
A list of valid batches. |
_permute_nonneg_and_fill(x, ncol)
staticmethod
Permutes non-negative values and fills remaining positions with -1.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input tensor containing indices. |
required |
ncol
|
int
|
Number of columns to keep. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: Permuted tensor with -1s filling unused positions. |
DataLoaderManager
Manages data loading for CONCORD, including optional preprocessing and sampling. This class handles the standard workflow of total-count normalization and log1p transformation.
__init__(domain_key, class_key=None, covariate_keys=None, feature_list=None, normalize_total=True, log1p=True, batch_size=32, train_frac=0.9, use_sampler=True, sampler_knn=300, sampler_emb=None, sampler_domain_minibatch_strategy='proportional', domain_coverage=None, p_intra_knn=0.3, p_intra_domain=1.0, dist_metric='euclidean', use_faiss=True, use_ivf=False, ivf_nprobe=8, preload_dense=False, num_workers=None, device=None)
Initializes the DataLoaderManager.
_get_data_structure()
Determines the structure of the data to be returned by the dataset. This logic is now owned by the manager, not the dataset.
anndata_to_dataloader(adata)
Converts an AnnData object to PyTorch DataLoader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata
|
AnnData
|
The input AnnData object. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
Train DataLoader, validation DataLoader (if |
ChunkLoader
An iterator that loads and processes chunks of a large AnnData object, often in backed mode, for memory-efficient training.
It yields a tuple of (train_dataloader, val_dataloader, chunk_indices) for each chunk.
__init__(adata, data_manager, chunk_size)
Initializes the ChunkLoader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata
|
AnnData
|
The large, potentially backed AnnData object. |
required |
data_manager
|
DataLoaderManager
|
A pre-initialized manager responsible for turning an AnnData chunk into DataLoaders. |
required |
chunk_size
|
int
|
The number of observations (cells) per chunk. |
required |
__iter__()
Shuffles indices at the beginning of each epoch (iteration).
__len__()
Returns the total number of chunks.
__next__()
Loads the next chunk into memory and processes it.
AnnDataset
Bases: Dataset
A PyTorch Dataset class for handling annotated datasets (AnnData).
This dataset is designed to work with single-cell RNA-seq data stored in AnnData objects. It extracts relevant features, domain labels, class labels, and covariate labels while handling sparse and dense matrices.
Attributes:
Name | Type | Description |
---|---|---|
adata |
AnnData
|
The annotated data matrix. |
input_layer_key |
str
|
The key to retrieve input features from |
domain_key |
str
|
The key in |
class_key |
str
|
The key in |
covariate_keys |
list
|
A list of keys for covariate labels in |
device |
device
|
The device to store tensors (GPU or CPU). |
data |
Tensor
|
Tensor containing input data. |
domain_labels |
Tensor
|
Tensor containing domain labels. |
class_labels |
Tensor
|
Tensor containing class labels if provided. |
covariate_tensors |
dict
|
A dictionary containing tensors for covariate labels. |
indices |
ndarray
|
Array of dataset indices. |
__init__(adata, domain_key='domain', class_key=None, covariate_keys=None, preload_dense=False)
Initializes a lightweight AnnDataset that only manages labels and indices.
__len__()
Returns the number of samples in the dataset.
Returns:
Name | Type | Description |
---|---|---|
int |
The dataset size. |
_scipy_to_torch_sparse(matrix)
staticmethod
Converts a Scipy sparse matrix to a PyTorch sparse COO tensor.
get_class_labels(idx)
Retrieves the class labels for a given index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int or list
|
Index or indices to retrieve. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The class labels. |
get_domain_labels(idx)
Retrieves the domain labels for a given index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int or list
|
Index or indices to retrieve. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The domain labels. |
get_embedding(embedding_key, idx)
Retrieves embeddings for a given key and index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embedding_key
|
str
|
The embedding key in |
required |
idx
|
int or list
|
Index or indices to retrieve. |
required |
Returns:
Type | Description |
---|---|
np.ndarray: The embedding matrix. |
Raises:
Type | Description |
---|---|
ValueError
|
If the embedding key is not found. |
shuffle_indices()
Shuffles dataset indices.
subset(idx)
Creates a subset of the dataset with the given indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
list
|
Indices of the subset. |
required |
Returns:
Name | Type | Description |
---|---|---|
AnnDataset |
A new AnnDataset instance containing only the selected indices. |
Trainer
A trainer class for optimizing the Concord model.
This class manages the training and validation of the Concord model, including contrastive learning, classification, and reconstruction losses.
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
The neural network model being trained. |
data_structure |
list
|
The structure of the dataset used in training. |
device |
device
|
The device on which computations are performed. |
logger |
Logger
|
Logger for recording training progress. |
use_classifier |
bool
|
Whether to use a classification head. |
classifier_weight |
float
|
Weighting factor for classification loss. |
unique_classes |
list
|
List of unique class labels. |
unlabeled_class |
int or None
|
Label representing unlabeled samples. |
use_decoder |
bool
|
Whether to use a decoder for reconstruction loss. |
decoder_weight |
float
|
Weighting factor for reconstruction loss. |
clr_weight |
float
|
Weighting factor for contrastive loss. |
importance_penalty_weight |
float
|
Weighting for feature importance penalty. |
importance_penalty_type |
str
|
Type of regularization for importance penalty. |
Methods:
Name | Description |
---|---|
forward_pass |
Computes loss components and model outputs. |
train_epoch |
Runs one training epoch. |
validate_epoch |
Runs one validation epoch. |
_run_epoch |
Handles training or validation for one epoch. |
_compute_averages |
Computes average losses over an epoch. |
__init__(model, data_structure, device, logger, lr, schedule_ratio, augment=None, use_classifier=False, classifier_weight=1.0, unique_classes=None, unlabeled_class=None, use_decoder=True, decoder_weight=1.0, clr_temperature=0.5, clr_weight=1.0, clr_beta=0.0, importance_penalty_weight=0, importance_penalty_type='L1')
Initializes the Trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The neural network model to train. |
required |
data_structure
|
list
|
List defining the structure of input data. |
required |
device
|
device
|
Device on which computations will run. |
required |
logger
|
Logger
|
Logger for recording training details. |
required |
lr
|
float
|
Learning rate for the optimizer. |
required |
schedule_ratio
|
float
|
Learning rate decay factor. |
required |
use_classifier
|
bool
|
Whether to use classification. Default is False. |
False
|
classifier_weight
|
float
|
Weight for classification loss. Default is 1.0. |
1.0
|
unique_classes
|
list
|
List of unique class labels. |
None
|
unlabeled_class
|
int or None
|
Label for unlabeled data. Default is None. |
None
|
use_decoder
|
bool
|
Whether to use a decoder. Default is True. |
True
|
decoder_weight
|
float
|
Weight for decoder loss. Default is 1.0. |
1.0
|
clr_temperature
|
float
|
Temperature for contrastive loss. Default is 0.5. |
0.5
|
clr_weight
|
float
|
Weight for contrastive loss. Default is 1.0. |
1.0
|
importance_penalty_weight
|
float
|
Weight for importance penalty. Default is 0. |
0
|
importance_penalty_type
|
str
|
Type of penalty ('L1' or 'L2'). Default is 'L1'. |
'L1'
|
forward_pass(inputs, class_labels, domain_labels, covariate_tensors=None)
Performs a forward pass and computes loss components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs
|
Tensor
|
Input feature matrix. |
required |
class_labels
|
Tensor
|
Class labels for classification loss. |
required |
domain_labels
|
Tensor
|
Domain labels for batch normalization. |
required |
covariate_tensors
|
dict
|
Dictionary of covariate tensors. |
None
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
Loss components (total loss, classification loss, MSE loss, contrastive loss, importance penalty loss). |
train_epoch(epoch, train_dataloader)
Runs one epoch of training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
Current epoch number. |
required |
train_dataloader
|
DataLoader
|
Training data loader. |
required |
Returns:
Name | Type | Description |
---|---|---|
float |
Average training loss. |
validate_epoch(epoch, val_dataloader)
Runs one epoch of validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
Current epoch number. |
required |
val_dataloader
|
DataLoader
|
Validation data loader. |
required |
Returns:
Name | Type | Description |
---|---|---|
float |
Average validation loss. |
Neighborhood
A class for k-nearest neighbor (k-NN) computation using either FAISS or sklearn.
This class constructs a k-NN index, retrieves neighbors, and computes distances between embeddings.
Attributes:
Name | Type | Description |
---|---|---|
emb |
ndarray
|
The embedding matrix (converted to float32). |
k |
int
|
Number of nearest neighbors to retrieve. |
use_faiss |
bool
|
Whether to use FAISS for k-NN computation. |
use_ivf |
bool
|
Whether to use IVF indexing in FAISS. |
ivf_nprobe |
int
|
Number of probes for FAISS IVF. |
metric |
str
|
Distance metric ('euclidean' or 'cosine'). |
__init__(emb, k=10, use_faiss=True, use_ivf=False, ivf_nprobe=10, metric='euclidean')
Initializes the Neighborhood class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb
|
ndarray
|
The embedding matrix. |
required |
k
|
int
|
Number of nearest neighbors to retrieve. Defaults to 10. |
10
|
use_faiss
|
bool
|
Whether to use FAISS for k-NN computation. Defaults to True. |
True
|
use_ivf
|
bool
|
Whether to use IVF FAISS index. Defaults to False. |
False
|
ivf_nprobe
|
int
|
Number of probes for FAISS IVF. Defaults to 10. |
10
|
metric
|
str
|
Distance metric ('euclidean' or 'cosine'). Defaults to 'euclidean'. |
'euclidean'
|
Raises:
Type | Description |
---|---|
ValueError
|
If there are NaN values in the embedding or if the metric is invalid. |
_build_knn_index()
Initializes the k-NN index using FAISS or sklearn.
average_knn_distance(core_samples, mtx, k=None, distance_metric='euclidean')
Compute the average distance to the k-th nearest neighbor for each sample.
Parameters
core_samples : np.ndarray The indices of core samples. mtx : np.ndarray The matrix to compute the distance to. k : int, optional Number of neighbors to retrieve. If None, uses self.k. distance_metric : str Distance metric to use: 'euclidean', 'set_diff', or 'drop_diff'.
Returns
np.ndarray The average distance to the k-th nearest neighbor for each sample.
compute_knn_graph(k=None)
Constructs a sparse adjacency matrix for the k-NN graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
k
|
int
|
Number of neighbors. Defaults to self.k. |
None
|
get_knn(core_samples, k=None, include_self=True, return_distance=False)
Retrieves the k-nearest neighbors for given samples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_samples
|
ndarray or Tensor
|
Indices of samples for which k-NN is retrieved. |
required |
k
|
int
|
Number of neighbors. Defaults to self.k. |
None
|
include_self
|
bool
|
Whether to include the sample itself. Defaults to True. |
True
|
return_distance
|
bool
|
Whether to return distances. Defaults to False. |
False
|
Returns:
Type | Description |
---|---|
np.ndarray: Indices of nearest neighbors (and distances if return_distance=True). |
get_knn_graph()
Returns the precomputed k-NN graph. Computes it if not available.
Returns:
Type | Description |
---|---|
scipy.sparse.csr_matrix: Sparse adjacency matrix of shape (n_samples, n_samples). |
update_embedding(new_emb)
Updates the embedding matrix and rebuilds the k-NN index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_emb
|
ndarray
|
The new embedding matrix. |
required |