Portfolio

On this page

Research / Keck USC

MEMOIR-VLM: Multimodal Vision-Language Model for Alzheimer's Disease Classification and VQA

2,363

ADNI subjects

~70M

Model params

0.707

DX 3-class Bal. Acc.

0.933

CN vs Dem Bal. Acc.

MEMOIR-VLM is a two-stage multimodal vision-language framework for Alzheimer's disease characterization using T1-weighted MRI, DTI fractional anisotropy maps, and structured clinical features. A missing-modality-aware encoder learns a shared representation for diagnosis, clinical severity, age, and sex prediction, while a retrieval-augmented VQA pipeline enables case-based reasoning and natural-language answers grounded in similar ADNI subjects.

PyTorch3D ResNet-18Cross-AttentionCLIPFAISSRAGMistral 7BGemma 4 26BMedGemmaADNI
01

The Clinical Ask

Alzheimer's disease assessment is inherently multimodal. Structural MRI captures neurodegeneration, DTI reflects white-matter microstructural integrity, and cognitive and biomarker variables summarize clinical status. In practice, however, these modalities are not always available together. DTI is often missing, cognitive batteries may be incomplete, and many deep learning systems require complete inputs or only produce a categorical label.

This project asks whether a single missing-modality-aware model can integrate whatever data is available, classify disease stage, estimate clinical severity, retrieve similar historical cases, and support natural-language VQA over brain-scan-derived representations.

Primary target

Build a robust multimodal AI system for Alzheimer's disease diagnosis and clinical reasoning. The model predicts CN / MCI / Dementia, CN vs Dementia, CDR-SB severity, age, and sex, then extends the frozen encoder into a retrieval-augmented VQA pipeline for interpretable case-based answers.

02

Dataset

All data comes from the Alzheimer's Disease Neuroimaging Initiative (ADNI). After filtering to subjects with valid DX labels and 9DOF T1 paths and deduplicating to one scan per subject, the cohort is 2,363 subjects. All subjects had valid diagnostic labels and T1-weighted MRI. DTI-FA was available for a subset of 930 participants, corresponding to 39.4% of the full cohort.

80 / 20 stratified split by diagnosis

ClassTrainTestTotal
CN - Cognitively Normal669168837
MCI - Mild Cognitive Impairment650163813
Dementia570143713
Total1,8894742,363

100%

T1 MRI (9DOF 2mm)

39.4%

DTI FA coverage

~100%

Clinical scores

03

Model Architecture

The model is a multimodal vision-language model with missing-modality support. Three modality-specific encoders produce ℓ2-normalized 512-d embeddings, each gated by a per-modality masking probability during training. The masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing a fused representation z_f ∈ ℝ⁵¹² that feeds five MLP task heads.

T1 MRI91×109×91DTI FA91×109×91Clinical5 scores + APOE3D ResNet-18→ 512-d3D ResNet-18→ 512-dClinical MLP→ 512-dz_T1ℓ2, ℝ⁵¹²z_DTIℓ2, ℝ⁵¹²z_Clinℓ2, ℝ⁵¹²×drop 10%×drop 30%×drop 5%Cross-Attention8 headspool query+ modality emb.Fusionz_fℝ⁵¹²DX 3-class→ 3DX Binary→ 2Sex→ 2Age→ 1CDR-SB→ 1

Three modality-specific encoders produce ℓ2-normalized 512-d embeddings. Each passes through a masking gate that randomly drops modalities during training (T1: 10%, DTI: 30%, Clinical: 5%). Masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing z_f ∈ ℝ⁵¹² that feeds five MLP task heads: DX 3-class, DX Binary, Sex, Age, and CDR-SB.

Imaging encoders

Two independent 3D ResNet-18 networks for T1 MRI and DTI FA maps. Input volumes are 91×109×91, output is 512-d followed by a linear projection + LayerNorm.

Clinical encoder

MLP over 5 continuous clinical features (CDR-SB, ADAS-11, ADAS-13, MMSE, MoCA) plus an APOE genotype embedding. Output: 512-d, ℓ2-normalized.

Design decision: no label leakage

Target variables are excluded from the clinical input space. Diagnosis, sex, and age are prediction targets rather than inputs. The clinical encoder only receives cognitive / clinical scores and APOE genotype, forcing the model to learn relationships among imaging features and clinical indicators rather than copying labels.

04

Training Procedure

Training runs in two stages. Stage 1 is contrastive pre-training: a pairwise CLIP/InfoNCE loss between all three modality pairs (T1–DTI, T1–Clinical, DTI–Clinical), computed only on pairs where both modalities are present. Stage 2B is multi-task fine-tuning across five heads with differential learning rates (backbone 10⁻⁵, heads 5×10⁻⁴).

During training, modalities are randomly dropped so the model learns to work with any subset at inference time. T1 is dropped 10% of the time, DTI 30%, Clinical 5%. DTI gets the highest drop rate because only 39.4% of subjects have it, so the model needs to handle "missing DTI" as the normal case, not the exception.

Stage 1 · Contrastive

30 epochs · pairwise InfoNCE

lr 1×10⁻⁴ · AdamW · cosine anneal

Stage 2B · Multi-task

30 epochs · 5 joint heads

focal (γ=2) + smooth L1 · AMP FP16

0.300.410.520.630.740.85151015202530EpochBest (Ep.5)
DX3 Bal. Acc.
Sex Acc.

Stage 2B training history. Best composite checkpoint selected at epoch 5; later epochs showed overfitting. The epoch-5 checkpoint was used for downstream evaluation.

05

Results

Evaluated on the held-out 474-subject test set using all available modalities. Headline metrics on the best model (Stage 2B, epoch 5):

TaskBal. Acc.Macro F1AUC
DX 3-class (CN / MCI / Dem)primary0.7070.7030.865
DX Binary (CN vs Dem)0.9330.9320.981
Sex0.5750.5630.597
Age (years, MAE ↓)6.31
CDR-SB (MAE ↓)0.97

DX 3-class confusion matrix · full test set (n=474)

CN
MCI
Dem
CN
11971%
4829%
11%
168
MCI
4025%
9055%
3320%
163
Dem
21%
1813%
12386%
143

Rows = true class, columns = predicted. Each cell shows count and row-normalized %. MCI is the hardest class at 55% recall. It sits between CN and Dementia so the model hedges in both directions. Dementia recall is the strongest at 86%, with only 2 misclassified as CN.

CN vs Dementia

0.933 Bal. Acc. · 0.981 AUC

Strong separation between the extremes of the disease spectrum. The model cleanly distinguishes cognitively normal subjects from those with dementia.

Severity & demographics

0.97 CDR-SB MAE · 6.31 yr age MAE

Alongside diagnosis, the shared representation supports clinical severity (CDR-SB) and age regression, learned jointly from the same fused embedding.

06

Modality Ablation

Every one of the seven possible modality subsets was evaluated using the same trained model, with non-selected modalities masked at inference time. This shows where the signal actually comes from and which combination works best for each task.

DX 3-class (Bal. Acc.)

CN / MCI / Dementia

T1 + DTI + Clin
0.707
T1 + Clin
0.692
DTI + Clin
0.701
Clin only
0.700
T1 + DTI
0.598
T1 only
0.587
DTI only
0.388

DX Binary (Bal. Acc.)

CN vs Dementia

T1 + DTI + Clin
0.933
T1 + Clin
0.932
DTI + Clin
0.938
Clin only
0.938
T1 + DTI
0.848
T1 only
0.833
DTI only
0.528

Clinical scores carry the strongest diagnostic signal, with clinical-only performance approaching the full-modality model for 3-class diagnosis. Imaging remains useful, especially for non-diagnostic tasks and for deployment settings where clinical information is incomplete. The modality-dropout strategy prevents catastrophic degradation when DTI or clinical variables are missing.

07

Retrieval-Augmented VQA Extension

The encoder gives you a prediction and a confidence score. What it doesn't give you is an explanation, or any way to ask follow-up questions in natural language. That's what the VQA extension adds. The frozen encoder turns a query case into a 512-d embedding. FAISS finds the 50 most similar training subjects by inner product. A cross-encoder reranks those 50 down to the top 5 most relevant matches. Those 5 captions become the context fed to a language model, which answers clinical questions about the case.

The LLM never receives raw brain images. T1, DTI, and clinical inputs are encoded into a 512-dimensional fused representation. FAISS retrieves similar training subjects in embedding space, and only retrieved textual captions are passed to the LLM as context.

FrozenT1 Enc3D ResNet-18DTI Enc3D ResNet-18Clin EncMLPAttentionFusion8-head, pool-qz_fℝ⁵¹²FAISS IndexIndexFlatIP1,889 vectorsCross-EncoderReranktop-50 → top-5LLMMistral 7BGemma 4 26BMedGemma 4BVQA AnswerCaptionSimilar Casestop-50retrieved captionsonly (top-5)

The frozen five-head encoder produces the fused embedding z_f ∈ ℝ⁵¹². FAISS retrieves top-50 similar training subjects; a cross-encoder reranks down to top-5. Only the retrieved textual captions reach the LLM, which never sees raw brain images. Three LLM backbones are compared: Mistral 7B, Gemma 4 26B MoE, and MedGemma 1.5 4B.

Text encoder

all-MiniLM-L6-v2, MLM-pretrained on 26,889 clinical sentences (25K synthetic + 1,889 real captions), then contrastively aligned to the imaging embedding space. 384-d output projected up to 512-d.

Retrieval + rerank

FAISS IndexFlatIP over 1,889 ℓ2-normalized training vectors, exact inner-product search. Cross-encoder: ms-marco-MiniLM-L-6-v2, reranking top-50 to top-5.

08

LLM Backbone Comparison

Three models were given the same retrieved context: Mistral 7B Instruct v0.3 (general-purpose, dense), Gemma 4 26B MoE (larger, mixture-of-experts), and MedGemma 1.5 4B IT (smaller, fine-tuned on medical data). All quantized to 4-bit NF4. The question was simple: does a medical fine-tune beat a bigger general model on this task?

VQA Diagnosis

full modality (T1 + DTI + Clinical)

Mistral 7B
0.947
Gemma 4 26B
0.927
MedGemma 1.5 4B
0.507

BERTScore

contextual similarity

Mistral 7B
0.894
Gemma 4 26B
0.845
MedGemma 1.5 4B
0.823

SBERT CosSim

sentence-level

Mistral 7B
0.811
Gemma 4 26B
0.810
MedGemma 1.5 4B
0.428

Mistral 7B

Instruct v0.3

Gemma 4 26B

MoE

MedGemma 1.5 4B

IT · medical FT

VQA Diagnosis

full modality (T1 + DTI + Clinical)

0.947
0.927
0.507

BERTScore

contextual similarity

0.894
0.845
0.823

SBERT CosSim

sentence-level

0.811
0.810
0.428

Same retrieved context across all three models; only the generation model changes. Mistral 7B wins every metric: diagnosis VQA accuracy and text quality (BERTScore, SBERT). MedGemma's medical fine-tune loses to a general-purpose 7B model. At this size, instruction-following matters more than raw model size or medical fine-tuning.

Headline finding

Instruction-following beats size and domain. Mistral 7B, a general-purpose dense model, outperforms both a 26B MoE and a medically fine-tuned 4B on every metric. The retrieved context already supplies the medical knowledge. What matters is whether the model can follow instructions and format its output correctly.

09

Key Findings

Clinical scores dominate diagnosis, but imaging adds robustness

Clinical-only performance approaches the full model for 3-class diagnosis, confirming that cognitive scores carry much of the diagnostic signal. Imaging still contributes useful structure, especially when clinical data is incomplete and for tasks such as age and severity estimation.

The model handles missing modalities

Stochastic modality dropout during training allows the encoder to operate with any subset of T1, DTI, and clinical inputs. This is important because only 39.4% of subjects had usable DTI.

Binary CN vs Dementia separation is strong

The model reaches 93.3% balanced accuracy and AUC 0.981 for CN vs Dementia, showing strong separation between the extremes of the Alzheimer's disease spectrum.

MCI remains the hardest class

MCI recall is 55%, with errors split toward both CN and Dementia. This reflects the transitional and heterogeneous nature of MCI rather than a simple modeling failure.

Retrieval-augmented VQA improves interpretability

The VQA pipeline retrieves similar training cases and uses their captions as grounded context for the LLM, producing natural-language answers instead of only class probabilities.

Mistral 7B is the best VQA backbone

Mistral 7B achieves 94.7% VQA diagnosis accuracy and outperforms Gemma 4 26B MoE and MedGemma 1.5 4B IT under the same retrieval context.

Status

Manuscript submitted / in review

This work has been submitted as MEMOIR-VLM: A Multimodal Vision-Language Model for Alzheimer's Disease Classification and Question Answering. It was completed at the Keck School of Medicine of USC. Code and manuscript links will be added when publicly available. Future work will extend MEMOIR-VLM to amyloid prediction using amyloid-specific supervision while avoiding biomarker leakage.

Back to portfolioGitHub profile