Interpreting CONCORD latent with gradient-based attribution
In this tutorial, I will demonstrate how to apply gradient-based attribution analysis to derive feature contributions to the CONCORD latent space, using a simple PBMC dataset as an example.
Load package and data¶
# Load required packages
import concord as ccd
import scanpy as sc
import torch
from pathlib import Path
save_dir = Path("./save/")
save_dir.mkdir(parents=True, exist_ok=True)
# Load and prepare example data
adata = sc.datasets.pbmc3k_processed()
adata = adata.raw.to_adata() # Assume starting from raw counts
# (Optional) Select top variably expressed/accessible features for analysis (other methods besides seurat_v3 available)
feature_list = ccd.ul.select_features(adata, n_top_features=2000, flavor='seurat_v3')
sc.pp.normalize_total(adata) # Normalize counts per cell
sc.pp.log1p(adata) # Log-transform data
concord.utils.feature_selector - INFO - Selecting highly variable features with flavor seurat_v3...
/opt/anaconda3/envs/concord/lib/python3.12/site-packages/scanpy/preprocessing/_highly_variable_genes.py:74: UserWarning: `flavor='seurat_v3'` expects raw count data, but non-integers were found. warnings.warn( OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
Run CONCORD¶
# Set device to cpu or to gpu (if your torch has been set up correctly to use GPU), for mac you can use either torch.device('mps') or torch.device('cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Initialize Concord with an AnnData object, skip input_feature to use all features, set preload_dense=False if your data is very large
cur_ccd = ccd.Concord(adata=adata, input_feature=feature_list, device=device, preload_dense=True)
# If integrate across batches, provide domain_key (a column in adata.obs that contains batch label):
# cur_ccd = ccd.Concord(adata=adata, input_feature=feature_list, domain_key='batch', device=device, preload_dense=True)
# Encode data, saving the latent embedding in adata.obsm['Concord']
cur_ccd.fit_transform(output_key='Concord')
concord - WARNING - domain/batch information not found, all samples will be treated as from single domain/batch. concord - WARNING - Only one domain found in the data. Setting p_intra_domain to 1.0. Epoch 0 Training Epoch 1 Training Epoch 2 Training Epoch 3 Training Epoch 4 Training Epoch 5 Training Epoch 6 Training Epoch 7 Training Epoch 8 Training Epoch 9 Training Epoch 10 Training Epoch 11 Training Epoch 12 Training Epoch 13 Training Epoch 14 Training
CONCORD UMAP¶
ccd.ul.run_umap(adata, source_key='Concord', result_key='Concord_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean')
# Plot the UMAP embeddings
adata.obs['cell_type'] = adata.obs['louvain'].copy() # Use original cell type annotations
color_by = ['n_genes', 'cell_type'] # Choose which variables you want to visualize
ccd.pl.plot_embedding(
adata, basis='Concord_UMAP', color_by=color_by, figsize=(10, 5), dpi=600, ncols=2, font_size=6, point_size=10, legend_loc='on data',
save_path=save_dir / 'Concord_UMAP.png'
)
/opt/anaconda3/envs/concord/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism. warn(
Examine CONCORD latent¶
You can plot a heatmap of CONCORD latent with example code below.
import numpy as np
show_basis = 'Concord'
ncells = 1000 # downsample to 1000 cells for faster visualization
adata_ds = adata.copy()[np.random.choice(adata.n_obs, ncells, replace=False), :]
ccd.pl.heatmap_with_annotations(adata_ds, val=show_basis, transpose=True, obs_keys=['cell_type'],
cmap='viridis', vmin=None, vmax=None,
cluster_rows=True, cluster_cols=True, add_color_legend=True,
value_annot=False, title=None, title_fontsize=8, annot_fontsize=8,
yticklabels=False, xticklabels=False,
use_clustermap=True,
cluster_method='ward',
cluster_metric='euclidean',
rasterize=True,
ax=None,
figsize=(8,4),
dpi=600, show=True, save_path=save_dir / f"heatmap_latent.pdf")
<seaborn.matrix.ClusterGrid at 0x38f342ab0>
We can observe a clear block structure in the latent space. To attribute which input features (genes) drive the activation of specific latent neurons, we first need to define the context. A context can be a single cell or a broader group such as a cell type—but it should not encompass all cells in the dataset, since attribution requires a localized reference.
For each context, we determine which neurons are most strongly activated. One systematic way to do this is to use clustering results, which group cells into biologically meaningful contexts and reveal their characteristic activation patterns. In the example below, we take a simpler approach: for each cell type, we select the top two most highly activated neurons. Our goal is then to explain which genes are responsible for driving those neurons’ activation within the chosen context.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
cts = adata.obs['cell_type']
Z = np.asarray(adata.obsm['Concord']) # (n_cells, n_dims)
# Dims x cell_types mean matrix
means = (
pd.DataFrame(Z, index=adata.obs.index)
.groupby(cts, observed=True).mean() # Mean per cell_type
.T # Dims x cell_types
.reindex(columns=cts.cat.categories) # Keep category order
)
# Compute top_dims dict
N = 2
top_dims = {
ct: means[ct].nlargest(N).index.tolist()
for ct in means.columns[means.notna().any(axis=0)] # Skip empty categories
}
# Print results
print("Top activated latent dimensions per cell type:")
for ct, dims in top_dims.items():
print(f"{ct}: {dims}")
# Take top-N dims per cell type, unioned in order of first appearance (for heatmap)
keep, seen = [], set()
for ct in means.columns[means.notna().any(axis=0)]:
for d in top_dims[ct]: # Now uses the dict
if d not in seen:
keep.append(d)
seen.add(d)
top_means = means.loc[keep]
plt.figure(figsize=(4, 5))
ax = sns.heatmap(top_means, cmap='RdBu_r', cbar_kws={'label': 'Mean activation'})
ax.set_xlabel('Cell type')
ax.set_ylabel('Embedding dimension')
ax.set_title('Top activated embedding dims per cell type')
plt.tight_layout()
plt.show()
Top activated latent dimensions per cell type: CD4 T cells: [50, 86] CD14+ Monocytes: [89, 68] B cells: [14, 68] CD8 T cells: [45, 38] NK cells: [26, 12] FCGR3A+ Monocytes: [12, 96] Dendritic cells: [96, 68] Megakaryocytes: [83, 90]
Gradient-based attribution¶
# Reload CONCORD model from disk if necessary:
cur_ccd = ccd.Concord.load(model_dir=save_dir)
For each context, we compute an importance matrix using a gradient-based attribution method. In this tutorial, the context is defined as a cell type. Concretely, we input the data corresponding to all cells of a given type, compute attribution scores for each cell, and then average these scores across cells. The result is a context-specific importance profile that highlights which genes most strongly influence neuron activations for that cell type.
all_cts = adata.obs['cell_type'].cat.categories.tolist()
importance_results = {}
layer_index = 6 # The final layer index is 6 for Concord with default settings (if you did not change the number of layers when initializing Concord)
for ct in all_cts:
adata_subset = adata[adata.obs['cell_type'] == ct, cur_ccd.config.input_feature].copy()
input_tensors = torch.tensor(adata_subset.X.toarray()).to(cur_ccd.config.device)
importance_matrix = ccd.ul.compute_feature_importance(cur_ccd.model, input_tensors, layer_index=layer_index)
importance_results[ct] = importance_matrix
Once the importance matrix is obtained, we can rank genes by their attribution scores within each context and for each neuron:
ranked_gene_lists = {}
for ct in all_cts:
adata_subset = adata[adata.obs['cell_type'] == ct, cur_ccd.config.input_feature].copy()
ranked_lists = ccd.ul.prepare_ranked_list(importance_results[ct], adata=adata_subset, expr_level=True)
# Define filters
min_zero_frac = 0.03
min_expression_level = 0
# Filter function
def filter_genes(df, min_zero_frac, min_expression_level):
return df[(df["Nonzero Fraction"] > min_zero_frac) & (df["Expression Level"] > min_expression_level)]
# Apply the filter to all neuron lists
filtered_gene_lists = {key: filter_genes(df, min_zero_frac, min_expression_level) for key, df in ranked_lists.items()}
ranked_gene_lists[ct] = filtered_gene_lists
We can visualize the top-ranked genes that drive neuron activation within each cell type. Note that these gene rankings are only interpretable for neurons that are strongly activated in the given context—weakly activated neurons may not yield meaningful gene associations.
# show_cts = all_cts
show_cts = ['CD4 T cells', 'CD8 T cells', 'CD14+ Monocytes']
for ct in show_cts:
print(f"Processing cell type: {ct}")
adata_subset = adata[adata.obs['cell_type'] == ct, cur_ccd.config.input_feature].copy()
show_neurons = top_dims[ct] # Strongly activated neurons in this context
show_gene_lists = ranked_gene_lists[ct]
show_gene_lists = {key: show_gene_lists[f'Neuron {key}'] for key in show_neurons}
show_basis = 'Concord_UMAP'
ccd.pl.plot_top_genes_embedding(adata_subset, show_gene_lists, show_basis, top_x=8, figsize=(7.5, 1), point_size=3,
font_size=7, colorbar_loc=None, vmax_quantile=.99,
save_path=save_dir / f"{ct}_embeddings_{show_basis}")
Processing cell type: CD4 T cells Plotting Top 8 genes for 50 on Concord_UMAP
<Figure size 640x480 with 0 Axes>
Plotting Top 8 genes for 86 on Concord_UMAP
<Figure size 640x480 with 0 Axes>
Processing cell type: CD8 T cells Plotting Top 8 genes for 45 on Concord_UMAP
<Figure size 640x480 with 0 Axes>
Plotting Top 8 genes for 38 on Concord_UMAP
<Figure size 640x480 with 0 Axes>
Processing cell type: CD14+ Monocytes Plotting Top 8 genes for 89 on Concord_UMAP
<Figure size 640x480 with 0 Axes>
Plotting Top 8 genes for 68 on Concord_UMAP
<Figure size 640x480 with 0 Axes>
The gene expression patterns of the top activated neurons demonstrate that CONCORD can recover biologically meaningful co-expressed gene sets—for example, LEF1+CCR7+ genes characteristic of naïve CD4⁺ T cells, CD40LG+KLRB1+ genes marking effector memory CD4⁺ T cells, and transcriptional programs associated with monocyte polarization.
To further interpret these results, CONCORD also provides interface to gseapy package for GO enrichment analysis on the top-ranked genes:
import gseapy as gp
import os
all_gsea_results = {}
gene_sets='GO_Biological_Process_2021'
ct = 'CD4 T cells'
show_neurons = top_dims[ct]
show_gene_lists = ranked_gene_lists[ct]
show_gene_lists = {key: show_gene_lists[f'Neuron {key}'] for key in show_neurons}
for neuron, ranked_list in show_gene_lists.items():
print(f"Running GSEA for {ct} - {neuron}")
# GO on top 5% of genes based on importance
top_genes = ranked_list[ranked_list['Importance'] > ranked_list['Importance'].quantile(0.95)]
ccd.ul.compute_go(top_genes['Gene'], organism="human", font_size=12, figsize=(7,3), dpi=600, save_path=save_dir / f"gsea_{ct}_{neuron}.pdf")
Running GSEA for CD4 T cells - 50 (7, 3)
Running GSEA for CD4 T cells - 86 (7, 3)