2,363
ADNI subjects
~70M
Model params
0.707
DX 3-class Bal. Acc.
0.933
CN vs Dem Bal. Acc.
MEMOIR-VLM is a two-stage multimodal vision-language framework for Alzheimer's disease characterization using T1-weighted MRI, DTI fractional anisotropy maps, and structured clinical features. A missing-modality-aware encoder learns a shared representation for diagnosis, clinical severity, age, and sex prediction, while a retrieval-augmented VQA pipeline enables case-based reasoning and natural-language answers grounded in similar ADNI subjects.
Alzheimer's disease assessment is inherently multimodal. Structural MRI captures neurodegeneration, DTI reflects white-matter microstructural integrity, and cognitive and biomarker variables summarize clinical status. In practice, however, these modalities are not always available together. DTI is often missing, cognitive batteries may be incomplete, and many deep learning systems require complete inputs or only produce a categorical label.
This project asks whether a single missing-modality-aware model can integrate whatever data is available, classify disease stage, estimate clinical severity, retrieve similar historical cases, and support natural-language VQA over brain-scan-derived representations.
Primary target
Build a robust multimodal AI system for Alzheimer's disease diagnosis and clinical reasoning. The model predicts CN / MCI / Dementia, CN vs Dementia, CDR-SB severity, age, and sex, then extends the frozen encoder into a retrieval-augmented VQA pipeline for interpretable case-based answers.
All data comes from the Alzheimer's Disease Neuroimaging Initiative (ADNI). After filtering to subjects with valid DX labels and 9DOF T1 paths and deduplicating to one scan per subject, the cohort is 2,363 subjects. All subjects had valid diagnostic labels and T1-weighted MRI. DTI-FA was available for a subset of 930 participants, corresponding to 39.4% of the full cohort.
80 / 20 stratified split by diagnosis
| Class | Train | Test | Total |
|---|---|---|---|
| CN - Cognitively Normal | 669 | 168 | 837 |
| MCI - Mild Cognitive Impairment | 650 | 163 | 813 |
| Dementia | 570 | 143 | 713 |
| Total | 1,889 | 474 | 2,363 |
100%
T1 MRI (9DOF 2mm)
39.4%
DTI FA coverage
~100%
Clinical scores
The model is a multimodal vision-language model with missing-modality support. Three modality-specific encoders produce ℓ2-normalized 512-d embeddings, each gated by a per-modality masking probability during training. The masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing a fused representation z_f ∈ ℝ⁵¹² that feeds five MLP task heads.
Three modality-specific encoders produce ℓ2-normalized 512-d embeddings. Each passes through a masking gate that randomly drops modalities during training (T1: 10%, DTI: 30%, Clinical: 5%). Masked embeddings are fused via 8-head cross-attention with a learnable pool query, producing z_f ∈ ℝ⁵¹² that feeds five MLP task heads: DX 3-class, DX Binary, Sex, Age, and CDR-SB.
Imaging encoders
Two independent 3D ResNet-18 networks for T1 MRI and DTI FA maps. Input volumes are 91×109×91, output is 512-d followed by a linear projection + LayerNorm.
Clinical encoder
MLP over 5 continuous clinical features (CDR-SB, ADAS-11, ADAS-13, MMSE, MoCA) plus an APOE genotype embedding. Output: 512-d, ℓ2-normalized.
Design decision: no label leakage
Target variables are excluded from the clinical input space. Diagnosis, sex, and age are prediction targets rather than inputs. The clinical encoder only receives cognitive / clinical scores and APOE genotype, forcing the model to learn relationships among imaging features and clinical indicators rather than copying labels.
Training runs in two stages. Stage 1 is contrastive pre-training: a pairwise CLIP/InfoNCE loss between all three modality pairs (T1–DTI, T1–Clinical, DTI–Clinical), computed only on pairs where both modalities are present. Stage 2B is multi-task fine-tuning across five heads with differential learning rates (backbone 10⁻⁵, heads 5×10⁻⁴).
During training, modalities are randomly dropped so the model learns to work with any subset at inference time. T1 is dropped 10% of the time, DTI 30%, Clinical 5%. DTI gets the highest drop rate because only 39.4% of subjects have it, so the model needs to handle "missing DTI" as the normal case, not the exception.
Stage 1 · Contrastive
30 epochs · pairwise InfoNCE
lr 1×10⁻⁴ · AdamW · cosine anneal
Stage 2B · Multi-task
30 epochs · 5 joint heads
focal (γ=2) + smooth L1 · AMP FP16
Stage 2B training history. Best composite checkpoint selected at epoch 5; later epochs showed overfitting. The epoch-5 checkpoint was used for downstream evaluation.
Evaluated on the held-out 474-subject test set using all available modalities. Headline metrics on the best model (Stage 2B, epoch 5):
| Task | Bal. Acc. | Macro F1 | AUC |
|---|---|---|---|
| DX 3-class (CN / MCI / Dem)primary | 0.707 | 0.703 | 0.865 |
| DX Binary (CN vs Dem) | 0.933 | 0.932 | 0.981 |
| Sex | 0.575 | 0.563 | 0.597 |
| Age (years, MAE ↓) | 6.31 | ||
| CDR-SB (MAE ↓) | 0.97 | ||
DX 3-class confusion matrix · full test set (n=474)
Rows = true class, columns = predicted. Each cell shows count and row-normalized %. MCI is the hardest class at 55% recall. It sits between CN and Dementia so the model hedges in both directions. Dementia recall is the strongest at 86%, with only 2 misclassified as CN.
CN vs Dementia
0.933 Bal. Acc. · 0.981 AUC
Strong separation between the extremes of the disease spectrum. The model cleanly distinguishes cognitively normal subjects from those with dementia.
Severity & demographics
0.97 CDR-SB MAE · 6.31 yr age MAE
Alongside diagnosis, the shared representation supports clinical severity (CDR-SB) and age regression, learned jointly from the same fused embedding.
Every one of the seven possible modality subsets was evaluated using the same trained model, with non-selected modalities masked at inference time. This shows where the signal actually comes from and which combination works best for each task.
DX 3-class (Bal. Acc.)
CN / MCI / Dementia
DX Binary (Bal. Acc.)
CN vs Dementia
Clinical scores carry the strongest diagnostic signal, with clinical-only performance approaching the full-modality model for 3-class diagnosis. Imaging remains useful, especially for non-diagnostic tasks and for deployment settings where clinical information is incomplete. The modality-dropout strategy prevents catastrophic degradation when DTI or clinical variables are missing.
The encoder gives you a prediction and a confidence score. What it doesn't give you is an explanation, or any way to ask follow-up questions in natural language. That's what the VQA extension adds. The frozen encoder turns a query case into a 512-d embedding. FAISS finds the 50 most similar training subjects by inner product. A cross-encoder reranks those 50 down to the top 5 most relevant matches. Those 5 captions become the context fed to a language model, which answers clinical questions about the case.
The LLM never receives raw brain images. T1, DTI, and clinical inputs are encoded into a 512-dimensional fused representation. FAISS retrieves similar training subjects in embedding space, and only retrieved textual captions are passed to the LLM as context.
The frozen five-head encoder produces the fused embedding z_f ∈ ℝ⁵¹². FAISS retrieves top-50 similar training subjects; a cross-encoder reranks down to top-5. Only the retrieved textual captions reach the LLM, which never sees raw brain images. Three LLM backbones are compared: Mistral 7B, Gemma 4 26B MoE, and MedGemma 1.5 4B.
Text encoder
all-MiniLM-L6-v2, MLM-pretrained on 26,889 clinical sentences (25K synthetic + 1,889 real captions), then contrastively aligned to the imaging embedding space. 384-d output projected up to 512-d.
Retrieval + rerank
FAISS IndexFlatIP over 1,889 ℓ2-normalized training vectors, exact inner-product search. Cross-encoder: ms-marco-MiniLM-L-6-v2, reranking top-50 to top-5.
Three models were given the same retrieved context: Mistral 7B Instruct v0.3 (general-purpose, dense), Gemma 4 26B MoE (larger, mixture-of-experts), and MedGemma 1.5 4B IT (smaller, fine-tuned on medical data). All quantized to 4-bit NF4. The question was simple: does a medical fine-tune beat a bigger general model on this task?
VQA Diagnosis
full modality (T1 + DTI + Clinical)
BERTScore
contextual similarity
SBERT CosSim
sentence-level
Mistral 7B
Instruct v0.3
Gemma 4 26B
MoE
MedGemma 1.5 4B
IT · medical FT
VQA Diagnosis
full modality (T1 + DTI + Clinical)
BERTScore
contextual similarity
SBERT CosSim
sentence-level
Same retrieved context across all three models; only the generation model changes. Mistral 7B wins every metric: diagnosis VQA accuracy and text quality (BERTScore, SBERT). MedGemma's medical fine-tune loses to a general-purpose 7B model. At this size, instruction-following matters more than raw model size or medical fine-tuning.
Headline finding
Instruction-following beats size and domain. Mistral 7B, a general-purpose dense model, outperforms both a 26B MoE and a medically fine-tuned 4B on every metric. The retrieved context already supplies the medical knowledge. What matters is whether the model can follow instructions and format its output correctly.
Clinical-only performance approaches the full model for 3-class diagnosis, confirming that cognitive scores carry much of the diagnostic signal. Imaging still contributes useful structure, especially when clinical data is incomplete and for tasks such as age and severity estimation.
Stochastic modality dropout during training allows the encoder to operate with any subset of T1, DTI, and clinical inputs. This is important because only 39.4% of subjects had usable DTI.
The model reaches 93.3% balanced accuracy and AUC 0.981 for CN vs Dementia, showing strong separation between the extremes of the Alzheimer's disease spectrum.
MCI recall is 55%, with errors split toward both CN and Dementia. This reflects the transitional and heterogeneous nature of MCI rather than a simple modeling failure.
The VQA pipeline retrieves similar training cases and uses their captions as grounded context for the LLM, producing natural-language answers instead of only class probabilities.
Mistral 7B achieves 94.7% VQA diagnosis accuracy and outperforms Gemma 4 26B MoE and MedGemma 1.5 4B IT under the same retrieval context.
Status
This work has been submitted as MEMOIR-VLM: A Multimodal Vision-Language Model for Alzheimer's Disease Classification and Question Answering. It was completed at the Keck School of Medicine of USC. Code and manuscript links will be added when publicly available. Future work will extend MEMOIR-VLM to amyloid prediction using amyloid-specific supervision while avoiding biomarker leakage.