Model
model
can be replaced with ml
, e.g., concord.model.ConcordModel
can be concord.ml.ConcordModel
ConcordModel
Bases: Module
A contrastive learning model for domain-aware and covariate-aware latent representations.
This model consists of an encoder, decoder, and optional classifier head. It supports probabilistic augmentation and domain/covariate embeddings.
Attributes:
Name | Type | Description |
---|---|---|
domain_embedding_dim |
int
|
Dimensionality of domain embeddings. |
input_dim |
int
|
Input feature dimension. |
augmentation_mask |
Dropout
|
Dropout layer for augmentation masking. |
use_classifier |
bool
|
Whether to include a classifier head. |
use_decoder |
bool
|
Whether to include a decoder head. |
use_importance_mask |
bool
|
Whether to include an importance mask for feature selection. |
encoder |
Sequential
|
Encoder layers. |
decoder |
Sequential
|
Decoder layers. |
classifier |
Sequential
|
Classifier head. |
importance_mask |
Parameter
|
Learnable importance mask. |
__init__(input_dim, hidden_dim, num_domains, num_classes, domain_embedding_dim=0, covariate_embedding_dims={}, covariate_num_categories={}, encoder_dims=[], decoder_dims=[], augmentation_mask_prob=0.3, dropout_prob=0.1, norm_type='layer_norm', use_decoder=True, decoder_final_activation='leaky_relu', use_classifier=False, use_importance_mask=False)
Initializes the Concord model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim
|
int
|
Number of input features. |
required |
hidden_dim
|
int
|
Latent representation dimensionality. |
required |
num_domains
|
int
|
Number of unique domains for embeddings. |
required |
num_classes
|
int
|
Number of unique classes for classification. |
required |
domain_embedding_dim
|
int
|
Dimensionality of domain embeddings. Defaults to 0. |
0
|
covariate_embedding_dims
|
dict
|
Dictionary mapping covariate keys to embedding dimensions. |
{}
|
covariate_num_categories
|
dict
|
Dictionary mapping covariate keys to category counts. |
{}
|
encoder_dims
|
list
|
List of encoder layer sizes. Defaults to empty list. |
[]
|
decoder_dims
|
list
|
List of decoder layer sizes. Defaults to empty list. |
[]
|
augmentation_mask_prob
|
float
|
Dropout probability for augmentation mask. Defaults to 0.3. |
0.3
|
dropout_prob
|
float
|
Dropout probability for encoder/decoder layers. Defaults to 0.1. |
0.1
|
norm_type
|
str
|
Normalization type ('layer_norm' or 'batch_norm'). Defaults to 'layer_norm'. |
'layer_norm'
|
use_decoder
|
bool
|
Whether to include a decoder. Defaults to True. |
True
|
decoder_final_activation
|
str
|
Activation function for decoder output. Defaults to 'leaky_relu'. |
'leaky_relu'
|
use_classifier
|
bool
|
Whether to include a classifier head. Defaults to False. |
False
|
use_importance_mask
|
bool
|
Whether to learn an importance mask for input features. Defaults to False. |
False
|
_initialize_weights()
Initializes model weights using Kaiming normal initialization.
forward(x, domain_labels=None, covariate_tensors=None, return_latent=False)
Performs a forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input data. |
required |
domain_labels
|
Tensor
|
Domain labels for embedding lookup. |
None
|
covariate_tensors
|
dict
|
Dictionary of covariate labels. |
None
|
return_latent
|
bool
|
Whether to return latent layer outputs. |
False
|
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary with encoded representations, decoded outputs (if enabled), classifier predictions (if enabled), and latent activations (if requested). |
freeze_encoder()
Freezes encoder weights to prevent updates during training.
get_importance_weights()
Retrieves the learned importance weights for input features.
Returns:
Type | Description |
---|---|
torch.Tensor: The importance weights. |
load_model(path, device)
Loads a pre-trained model state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
str or Path
|
Path to the saved model checkpoint. |
required |
device
|
device
|
Device to load the model onto. |
required |
ConcordSampler
Bases: Sampler
A custom PyTorch sampler that performs probabilistic domain-aware and neighborhood-aware batch sampling for contrastive learning.
This sampler selects samples from both intra-domain and inter-domain distributions based on configurable probabilities.
Attributes:
Name | Type | Description |
---|---|---|
batch_size |
int
|
Number of samples per batch. |
p_intra_knn |
float
|
Probability of selecting samples from k-NN neighborhoods. |
p_intra_domain_dict |
dict
|
Dictionary mapping domain indices to intra-domain probabilities. |
device |
device
|
Device to store tensors (default: GPU if available). |
domain_ids |
Tensor
|
Tensor containing domain labels for each sample. |
neighborhood |
Neighborhood
|
Precomputed k-NN index. |
unique_domains |
Tensor
|
Unique domain categories. |
domain_counts |
Tensor
|
Number of samples per domain. |
valid_batches |
list
|
List of precomputed valid batches. |
min_batch_size |
int
|
Minimum allowed batch size. |
__init__(batch_size, domain_ids, neighborhood, p_intra_knn=0.3, p_intra_domain_dict=None, min_batch_size=4, device=None)
Initializes the ConcordSampler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size
|
int
|
Number of samples per batch. |
required |
domain_ids
|
Tensor
|
Tensor of domain labels for each sample. |
required |
neighborhood
|
Neighborhood
|
Precomputed k-NN index. |
required |
p_intra_knn
|
float
|
Probability of selecting samples from k-NN neighborhoods. Default is 0.3. |
0.3
|
p_intra_domain_dict
|
dict
|
Dictionary mapping domain indices to intra-domain probabilities. Default is None. |
None
|
min_batch_size
|
int
|
Minimum allowed batch size. Default is 4. |
4
|
device
|
device
|
Device to store tensors. Defaults to GPU if available. |
None
|
__iter__()
Iterator for sampling batches.
Yields:
Type | Description |
---|---|
torch.Tensor: A batch of sample indices. |
__len__()
Returns the number of batches.
Returns:
Name | Type | Description |
---|---|---|
int |
Number of valid batches. |
_generate_batches()
Generates batches based on intra-domain and intra-neighborhood probabilities.
Returns:
Name | Type | Description |
---|---|---|
list |
A list of valid batches. |
permute_nonneg_and_fill(x, ncol)
staticmethod
Permutes non-negative values and fills remaining positions with -1.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input tensor containing indices. |
required |
ncol
|
int
|
Number of columns to keep. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: Permuted tensor with -1s filling unused positions. |
DataLoaderManager
Manages data loading for training and evaluation, including optional sampling.
This class handles embedding computation, k-NN graph construction, domain-aware sampling, and splits data into train/validation sets when needed.
Attributes:
Name | Type | Description |
---|---|---|
input_layer_key |
str
|
Key for input layer in AnnData. |
domain_key |
str
|
Key for domain labels in |
class_key |
str
|
Key for class labels in |
covariate_keys |
list
|
List of covariate keys in |
batch_size |
int
|
Batch size for data loading. |
train_frac |
float
|
Fraction of data used for training. |
use_sampler |
bool
|
Whether to use a custom sampler. |
sampler_emb |
str
|
Key for embeddings used in sampling. |
sampler_knn |
int
|
Number of k-nearest neighbors for sampling. |
p_intra_knn |
float
|
Probability of intra-cluster sampling. |
p_intra_domain |
float or dict
|
Probability of intra-domain sampling. |
min_p_intra_domain |
float
|
Minimum probability for intra-domain sampling. |
max_p_intra_domain |
float
|
Maximum probability for intra-domain sampling. |
clr_mode |
str
|
Contrastive learning mode. |
dist_metric |
str
|
Distance metric for k-NN graph. |
pca_n_comps |
int
|
Number of PCA components used in embedding computation. |
use_faiss |
bool
|
Whether to use FAISS for fast k-NN computation. |
use_ivf |
bool
|
Whether to use IVF indexing for FAISS. |
ivf_nprobe |
int
|
Number of probes for IVF-Faiss. |
preprocess |
callable
|
Preprocessing function for |
num_cores |
int
|
Number of CPU cores for parallel processing. |
device |
device
|
Device for computation (CPU or CUDA). |
__init__(input_layer_key, domain_key, class_key=None, covariate_keys=None, batch_size=32, train_frac=0.9, use_sampler=True, sampler_emb=None, sampler_knn=300, p_intra_knn=0.3, p_intra_domain=None, min_p_intra_domain=1.0, max_p_intra_domain=1.0, clr_mode='aug', dist_metric='euclidean', pca_n_comps=50, use_faiss=True, use_ivf=False, ivf_nprobe=8, preprocess=None, num_cores=None, device=None)
Initializes the DataLoaderManager.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_layer_key
|
str
|
Key for input layer in |
required |
domain_key
|
str
|
Key for domain labels in |
required |
class_key
|
str
|
Key for class labels. Defaults to None. |
None
|
covariate_keys
|
list
|
List of covariate keys. Defaults to None. |
None
|
batch_size
|
int
|
Batch size. Defaults to 32. |
32
|
train_frac
|
float
|
Fraction of data used for training. Defaults to 0.9. |
0.9
|
use_sampler
|
bool
|
Whether to use the custom sampler. Defaults to True. |
True
|
sampler_emb
|
str
|
Key for embeddings used in sampling. |
None
|
sampler_knn
|
int
|
Number of neighbors for k-NN sampling. Defaults to 300. |
300
|
p_intra_knn
|
float
|
Probability of intra-cluster sampling. Defaults to 0.3. |
0.3
|
p_intra_domain
|
float or dict
|
Probability of intra-domain sampling. |
None
|
min_p_intra_domain
|
float
|
Minimum probability for intra-domain sampling. Defaults to 1.0. |
1.0
|
max_p_intra_domain
|
float
|
Maximum probability for intra-domain sampling. Defaults to 1.0. |
1.0
|
clr_mode
|
str
|
Contrastive learning mode. Defaults to 'aug'. |
'aug'
|
dist_metric
|
str
|
Distance metric for k-NN. Defaults to 'euclidean'. |
'euclidean'
|
pca_n_comps
|
int
|
Number of PCA components. Defaults to 50. |
50
|
use_faiss
|
bool
|
Whether to use FAISS. Defaults to True. |
True
|
use_ivf
|
bool
|
Whether to use IVF-Faiss indexing. Defaults to False. |
False
|
ivf_nprobe
|
int
|
Number of probes for IVF-Faiss. Defaults to 8. |
8
|
preprocess
|
callable
|
Preprocessing function for |
None
|
num_cores
|
int
|
Number of CPU cores. Defaults to None. |
None
|
device
|
device
|
Device for computation. Defaults to None. |
None
|
anndata_to_dataloader(adata)
Converts an AnnData object to PyTorch DataLoader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata
|
AnnData
|
The input AnnData object. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
Train DataLoader, validation DataLoader (if |
compute_embedding_and_knn(emb_key='X_pca')
Constructs a k-NN graph based on existing embedding or PCA (of not exist, compute automatically).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_key
|
str
|
Key for embedding basis. Defaults to 'X_pca'. |
'X_pca'
|
ChunkLoader
A class for handling large datasets in chunks for efficient training.
This class manages chunk-based data loading for large single-cell datasets, allowing training in smaller subsets without loading the entire dataset into memory.
Attributes:
Name | Type | Description |
---|---|---|
adata |
AnnData
|
The annotated data matrix. |
input_layer_key |
str
|
Key for the input layer in |
domain_key |
str
|
Key for domain labels in |
class_key |
str
|
Key for class labels in |
covariate_keys |
list
|
List of keys for covariates in |
chunk_size |
int
|
The number of samples per chunk. |
batch_size |
int
|
The batch size used for training. |
train_frac |
float
|
Fraction of data to be used for training. |
sampler_mode |
str
|
Mode for sampling (e.g., 'domain'). |
sampler_knn |
int
|
Number of nearest neighbors for k-NN-based sampling. |
emb_key |
str
|
Key for embedding space used in sampling. |
use_faiss |
bool
|
Whether to use FAISS for k-NN computation. |
use_ivf |
bool
|
Whether to use an IVF FAISS index. |
ivf_nprobe |
int
|
Number of probes for IVF FAISS index. |
class_weights |
dict
|
Weights for balancing class sampling. |
p_intra_knn |
float
|
Probability of sampling within k-NN. |
p_intra_domain |
float
|
Probability of sampling within the same domain. |
p_intra_class |
float
|
Probability of sampling within the same class. |
drop_last |
bool
|
Whether to drop the last batch if it is smaller than batch_size. |
preprocess |
callable
|
Preprocessing function to apply to the dataset. |
device |
device
|
Device on which to load data (CPU or CUDA). |
total_samples |
int
|
Total number of samples in the dataset. |
num_chunks |
int
|
Number of chunks required to load the full dataset. |
indices |
ndarray
|
Array of shuffled indices for chunking. |
data_structure |
list
|
Structure of the dataset. |
Methods:
Name | Description |
---|---|
__len__ |
Returns the number of chunks. |
_shuffle_indices |
Randomly shuffles dataset indices. |
_load_chunk |
Loads a specific chunk of data. |
__iter__ |
Initializes the chunk iterator. |
__next__ |
Retrieves the next chunk of data. |
__init__(adata, input_layer_key, domain_key, class_key=None, covariate_keys=None, chunk_size=10000, batch_size=32, train_frac=0.9, sampler_mode='domain', emb_key=None, sampler_knn=300, p_intra_knn=0.3, p_intra_domain=1.0, use_faiss=True, use_ivf=False, ivf_nprobe=8, class_weights=None, p_intra_class=0.3, drop_last=True, preprocess=None, device=None)
Initializes the ChunkLoader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata
|
AnnData
|
The annotated data matrix. |
required |
input_layer_key
|
str
|
Key for the input layer in |
required |
domain_key
|
str
|
Key for domain labels in |
required |
class_key
|
str
|
Key for class labels in |
None
|
covariate_keys
|
list
|
List of covariate keys in |
None
|
chunk_size
|
int
|
Number of samples per chunk. Default is 10,000. |
10000
|
batch_size
|
int
|
Batch size used in training. Default is 32. |
32
|
train_frac
|
float
|
Fraction of data for training. Default is 0.9. |
0.9
|
sampler_mode
|
str
|
Sampling mode ('domain', etc.). Default is "domain". |
'domain'
|
emb_key
|
str
|
Key for the embedding space used in sampling. Default is None. |
None
|
sampler_knn
|
int
|
Number of nearest neighbors for k-NN sampling. Default is 300. |
300
|
p_intra_knn
|
float
|
Probability of sampling within k-NN. Default is 0.3. |
0.3
|
p_intra_domain
|
float
|
Probability of sampling within the same domain. Default is 1.0. |
1.0
|
use_faiss
|
bool
|
Whether to use FAISS for k-NN. Default is True. |
True
|
use_ivf
|
bool
|
Whether to use an IVF FAISS index. Default is False. |
False
|
ivf_nprobe
|
int
|
Number of probes for IVF FAISS index. Default is 8. |
8
|
class_weights
|
dict
|
Dictionary of class weights for balancing. Default is None. |
None
|
p_intra_class
|
float
|
Probability of sampling within the same class. Default is 0.3. |
0.3
|
drop_last
|
bool
|
Whether to drop the last batch if it is smaller than |
True
|
preprocess
|
callable
|
Function to preprocess the dataset. Default is None. |
None
|
device
|
device
|
Device on which to load data (CPU/GPU). Default is CUDA if available. |
None
|
__iter__()
Initializes the chunk iterator.
Returns:
Name | Type | Description |
---|---|---|
ChunkLoader |
The chunk loader object itself. |
__len__()
Returns the total number of chunks.
Returns:
Name | Type | Description |
---|---|---|
int |
Number of chunks. |
__next__()
Retrieves the next chunk of data.
Returns:
Name | Type | Description |
---|---|---|
tuple |
Training DataLoader, Validation DataLoader, and chunk indices. |
Raises:
Type | Description |
---|---|
StopIteration
|
If all chunks have been iterated over. |
_load_chunk(chunk_idx)
Loads a specific chunk of data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
chunk_idx
|
int
|
Index of the chunk to load. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
Training DataLoader, Validation DataLoader, and chunk indices. |
_shuffle_indices()
Randomly shuffles dataset indices for chunking.
AnnDataset
Bases: Dataset
A PyTorch Dataset class for handling annotated datasets (AnnData).
This dataset is designed to work with single-cell RNA-seq data stored in AnnData objects. It extracts relevant features, domain labels, class labels, and covariate labels while handling sparse and dense matrices.
Attributes:
Name | Type | Description |
---|---|---|
adata |
AnnData
|
The annotated data matrix. |
input_layer_key |
str
|
The key to retrieve input features from |
domain_key |
str
|
The key in |
class_key |
str
|
The key in |
covariate_keys |
list
|
A list of keys for covariate labels in |
device |
device
|
The device to store tensors (GPU or CPU). |
data |
Tensor
|
Tensor containing input data. |
domain_labels |
Tensor
|
Tensor containing domain labels. |
class_labels |
Tensor
|
Tensor containing class labels if provided. |
covariate_tensors |
dict
|
A dictionary containing tensors for covariate labels. |
indices |
ndarray
|
Array of dataset indices. |
data_structure |
list
|
A list describing the dataset structure. |
__getitem__(idx)
Retrieves the dataset items for the given index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int or list
|
The sample index. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
A tuple containing input data, domain labels, class labels, and covariate tensors. |
__init__(adata, input_layer_key='X', domain_key='domain', class_key=None, covariate_keys=None, device=None)
Initializes the AnnDataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata
|
AnnData
|
The annotated dataset. |
required |
input_layer_key
|
str
|
Key to extract input data. Defaults to 'X'. |
'X'
|
domain_key
|
str
|
Key for domain labels in |
'domain'
|
class_key
|
str
|
Key for class labels in |
None
|
covariate_keys
|
list
|
List of keys for covariate labels in |
None
|
device
|
device
|
Device to store tensors (GPU or CPU). Defaults to GPU if available. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If domain or class key is not found in |
__len__()
Returns the number of samples in the dataset.
Returns:
Name | Type | Description |
---|---|---|
int |
The dataset size. |
_get_data_matrix()
Retrieves the feature matrix from adata
.
Returns:
Type | Description |
---|---|
np.ndarray: The feature matrix as a NumPy array. |
Raises:
Type | Description |
---|---|
KeyError
|
If the specified input layer is not found. |
_init_data_structure()
Initializes the structure of the dataset.
Returns:
Name | Type | Description |
---|---|---|
list |
A list defining the dataset structure. |
get_class_labels(idx)
Retrieves the class labels for a given index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int or list
|
Index or indices to retrieve. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The class labels. |
get_data_structure()
Returns the data structure of the dataset.
Returns:
Name | Type | Description |
---|---|---|
list |
A list defining the dataset structure. |
get_domain_labels(idx)
Retrieves the domain labels for a given index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int or list
|
Index or indices to retrieve. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The domain labels. |
get_embedding(embedding_key, idx)
Retrieves embeddings for a given key and index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embedding_key
|
str
|
The embedding key in |
required |
idx
|
int or list
|
Index or indices to retrieve. |
required |
Returns:
Type | Description |
---|---|
np.ndarray: The embedding matrix. |
Raises:
Type | Description |
---|---|
ValueError
|
If the embedding key is not found. |
shuffle_indices()
Shuffles dataset indices.
subset(idx)
Creates a subset of the dataset with the given indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
list
|
Indices of the subset. |
required |
Returns:
Name | Type | Description |
---|---|---|
AnnDataset |
A new AnnDataset instance containing only the selected indices. |
Trainer
A trainer class for optimizing the Concord model.
This class manages the training and validation of the Concord model, including contrastive learning, classification, and reconstruction losses.
Attributes:
Name | Type | Description |
---|---|---|
model |
Module
|
The neural network model being trained. |
data_structure |
list
|
The structure of the dataset used in training. |
device |
device
|
The device on which computations are performed. |
logger |
Logger
|
Logger for recording training progress. |
use_classifier |
bool
|
Whether to use a classification head. |
classifier_weight |
float
|
Weighting factor for classification loss. |
unique_classes |
list
|
List of unique class labels. |
unlabeled_class |
int or None
|
Label representing unlabeled samples. |
use_decoder |
bool
|
Whether to use a decoder for reconstruction loss. |
decoder_weight |
float
|
Weighting factor for reconstruction loss. |
clr_mode |
str
|
Contrastive learning mode ('aug' or 'nn'). |
use_clr |
bool
|
Whether contrastive learning is enabled. |
clr_weight |
float
|
Weighting factor for contrastive loss. |
importance_penalty_weight |
float
|
Weighting for feature importance penalty. |
importance_penalty_type |
str
|
Type of regularization for importance penalty. |
Methods:
Name | Description |
---|---|
forward_pass |
Computes loss components and model outputs. |
train_epoch |
Runs one training epoch. |
validate_epoch |
Runs one validation epoch. |
_run_epoch |
Handles training or validation for one epoch. |
_compute_averages |
Computes average losses over an epoch. |
__init__(model, data_structure, device, logger, lr, schedule_ratio, use_classifier=False, classifier_weight=1.0, unique_classes=None, unlabeled_class=None, use_decoder=True, decoder_weight=1.0, clr_mode='aug', clr_temperature=0.5, clr_weight=1.0, importance_penalty_weight=0, importance_penalty_type='L1')
Initializes the Trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The neural network model to train. |
required |
data_structure
|
list
|
List defining the structure of input data. |
required |
device
|
device
|
Device on which computations will run. |
required |
logger
|
Logger
|
Logger for recording training details. |
required |
lr
|
float
|
Learning rate for the optimizer. |
required |
schedule_ratio
|
float
|
Learning rate decay factor. |
required |
use_classifier
|
bool
|
Whether to use classification. Default is False. |
False
|
classifier_weight
|
float
|
Weight for classification loss. Default is 1.0. |
1.0
|
unique_classes
|
list
|
List of unique class labels. |
None
|
unlabeled_class
|
int or None
|
Label for unlabeled data. Default is None. |
None
|
use_decoder
|
bool
|
Whether to use a decoder. Default is True. |
True
|
decoder_weight
|
float
|
Weight for decoder loss. Default is 1.0. |
1.0
|
clr_mode
|
str
|
Contrastive learning mode ('aug', 'nn'). Default is 'aug'. |
'aug'
|
clr_temperature
|
float
|
Temperature for contrastive loss. Default is 0.5. |
0.5
|
clr_weight
|
float
|
Weight for contrastive loss. Default is 1.0. |
1.0
|
importance_penalty_weight
|
float
|
Weight for importance penalty. Default is 0. |
0
|
importance_penalty_type
|
str
|
Type of penalty ('L1' or 'L2'). Default is 'L1'. |
'L1'
|
forward_pass(inputs, class_labels, domain_labels, covariate_tensors=None)
Performs a forward pass and computes loss components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs
|
Tensor
|
Input feature matrix. |
required |
class_labels
|
Tensor
|
Class labels for classification loss. |
required |
domain_labels
|
Tensor
|
Domain labels for batch normalization. |
required |
covariate_tensors
|
dict
|
Dictionary of covariate tensors. |
None
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
Loss components (total loss, classification loss, MSE loss, contrastive loss, importance penalty loss). |
train_epoch(epoch, train_dataloader)
Runs one epoch of training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
Current epoch number. |
required |
train_dataloader
|
DataLoader
|
Training data loader. |
required |
Returns:
Name | Type | Description |
---|---|---|
float |
Average training loss. |
validate_epoch(epoch, val_dataloader)
Runs one epoch of validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epoch
|
int
|
Current epoch number. |
required |
val_dataloader
|
DataLoader
|
Validation data loader. |
required |
Returns:
Name | Type | Description |
---|---|---|
float |
Average validation loss. |
Neighborhood
A class for k-nearest neighbor (k-NN) computation using either FAISS or sklearn.
This class constructs a k-NN index, retrieves neighbors, and computes distances between embeddings.
Attributes:
Name | Type | Description |
---|---|---|
emb |
ndarray
|
The embedding matrix (converted to float32). |
k |
int
|
Number of nearest neighbors to retrieve. |
use_faiss |
bool
|
Whether to use FAISS for k-NN computation. |
use_ivf |
bool
|
Whether to use IVF indexing in FAISS. |
ivf_nprobe |
int
|
Number of probes for FAISS IVF. |
metric |
str
|
Distance metric ('euclidean' or 'cosine'). |
__init__(emb, k=10, use_faiss=True, use_ivf=False, ivf_nprobe=10, metric='euclidean')
Initializes the Neighborhood class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb
|
ndarray
|
The embedding matrix. |
required |
k
|
int
|
Number of nearest neighbors to retrieve. Defaults to 10. |
10
|
use_faiss
|
bool
|
Whether to use FAISS for k-NN computation. Defaults to True. |
True
|
use_ivf
|
bool
|
Whether to use IVF FAISS index. Defaults to False. |
False
|
ivf_nprobe
|
int
|
Number of probes for FAISS IVF. Defaults to 10. |
10
|
metric
|
str
|
Distance metric ('euclidean' or 'cosine'). Defaults to 'euclidean'. |
'euclidean'
|
Raises:
Type | Description |
---|---|
ValueError
|
If there are NaN values in the embedding or if the metric is invalid. |
_build_knn_index()
Initializes the k-NN index using FAISS or sklearn.
average_knn_distance(core_samples, mtx, k=None, distance_metric='euclidean')
Compute the average distance to the k-th nearest neighbor for each sample.
Parameters
core_samples : np.ndarray The indices of core samples. mtx : np.ndarray The matrix to compute the distance to. k : int, optional Number of neighbors to retrieve. If None, uses self.k. distance_metric : str Distance metric to use: 'euclidean', 'set_diff', or 'drop_diff'.
Returns
np.ndarray The average distance to the k-th nearest neighbor for each sample.
compute_knn_graph(k=None)
Constructs a sparse adjacency matrix for the k-NN graph.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
k
|
int
|
Number of neighbors. Defaults to self.k. |
None
|
get_knn(core_samples, k=None, include_self=True, return_distance=False)
Retrieves the k-nearest neighbors for given samples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
core_samples
|
ndarray or Tensor
|
Indices of samples for which k-NN is retrieved. |
required |
k
|
int
|
Number of neighbors. Defaults to self.k. |
None
|
include_self
|
bool
|
Whether to include the sample itself. Defaults to True. |
True
|
return_distance
|
bool
|
Whether to return distances. Defaults to False. |
False
|
Returns:
Type | Description |
---|---|
np.ndarray: Indices of nearest neighbors (and distances if return_distance=True). |
get_knn_graph()
Returns the precomputed k-NN graph. Computes it if not available.
Returns:
Type | Description |
---|---|
scipy.sparse.csr_matrix: Sparse adjacency matrix of shape (n_samples, n_samples). |
update_embedding(new_emb)
Updates the embedding matrix and rebuilds the k-NN index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_emb
|
ndarray
|
The new embedding matrix. |
required |