In a recent hackathon, the Virtual Cell Program Initiative (VCPI) prediction contest hosted by Ginkgo Bioworks, I worked on predicting gene expression directly from chemical structures. The primary goal of the contest—and the resulting model—was to enable high-throughput screening of compounds in silico. By taking a compound’s SMILES string (a text-based representation of its chemical structure), the model predicts its downstream effect on the expression of 12,995 specific genes, bypassing the need for costly and time-consuming physical assays.
The Dataset
The training data consisted of compound perturbational profiles across multiple experimental batches. In total, the raw dataset contained over 70,000 samples and roughly 78,000 gene measurements. However, the contest focused specifically on a 10 µM concentration level. After filtering for the relevant dose and control samples (like DMSO), the training set was distilled down to 32,500 profiles covering 14,031 unique compounds.
For the target variables, the raw gene counts were converted to expression levels using a log2(CPM + 1) transformation. The final evaluation required predicting the exact expression values for these 12,995 genes across a blinded test set of 1,063 novel compounds, with performance scored via a Weighted Mean Squared Error (wMSE) metric.
This post walks through how I built an end-to-end fine-tuned model using ChemBERTa to solve this problem.
The Architecture
To tackle this challenge, I leveraged ChemBERTa (seyonec/ChemBERTa-zinc-base-v1), a RoBERTa-like transformer model pre-trained on millions of SMILES strings. Instead of just using ChemBERTa as a frozen feature extractor, I opted for an end-to-end fine-tuning approach where both the encoder and a custom MLP (Multi-Layer Perceptron) head learn together.
The overall architecture looks like this:
SMILES → ChemBERTa (fine-tuned) → 384-dim → MLP head → 12,995 genes
- Encoder: The tokenized SMILES strings are passed through the ChemBERTa encoder. We use mean pooling over the token embeddings to extract a single 384-dimensional vector representing the entire molecule.
- MLP Head: This 384-dimensional embedding is then passed through a custom neural network head consisting of linear layers, batch normalization, ReLU activations, and dropout, finally outputting the predictions for the 12,995 genes.
The Training Strategy
One of the key improvements in this approach was optimizing for the exact metric used by the contest leaderboard: Weighted Mean Squared Error (wMSE).
Custom wMSE Loss
The dataset provided a canonical weight matrix indicating the importance of different gene-compound pairs. I pre-computed weight vectors for each compound and built a custom wMSE loss function in PyTorch:
def wmse_loss(pred, truth, weights):
"""
pred, truth, weights: (batch_size, n_genes)
Returns scalar — mean wMSE over the batch.
"""
sq_err = (pred - truth) ** 2 # (batch, genes)
per_compound = (sq_err * weights).sum(dim=1) # (batch,)
return per_compound.mean()Differential Learning Rates
Since the ChemBERTa encoder is already pre-trained and our MLP head is initialized randomly, applying the same learning rate to both would be suboptimal. A large learning rate might destroy the pre-trained weights of the encoder, while a small learning rate would make the MLP head learn too slowly.
To solve this, I used differential learning rates with the AdamW optimizer: - 1e-5 for the ChemBERTa encoder (small, for fine-tuning) - 1e-3 for the MLP head (larger, to learn the new mapping quickly)
Results
The model was trained on a dataset of over 14,000 compounds. By fine-tuning the transformer end-to-end and directly optimizing the wMSE loss, the model was able to capture complex chemical representations tailored specifically to the task of gene expression prediction.
This end-to-end approach significantly outperformed simpler baselines, such as using frozen ChemBERTa embeddings with a separate model, or relying on traditional Morgan fingerprints.
You can find the full code and implementation details in the hackathon repository.