Introduction to Vision-Language Models
Vision-Language Models (VLMs) bridge the gap between visual perception and natural language understanding, enabling systems that can see and communicate about what they see. These models have transformed how AI systems interact with multimodal data, powering applications from image search and captioning to visual question answering and multimodal assistants.
The Multimodal Challenge
Humans effortlessly integrate visual and linguistic information, but teaching machines to do the same presents fundamental challenges:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass
@dataclass
class MultimodalChallenges:
"""
Key challenges in vision-language modeling:
1. Representation Gap: Images are continuous, high-dimensional signals
while text is discrete and sequential. How do we align these spaces?
2. Semantic Alignment: The word "dog" must connect to visual features
of dogs across breeds, poses, and contexts.
3. Compositional Understanding: "A red ball on a blue table" requires
understanding objects, attributes, and spatial relationships.
4. Grounding: Connecting language to specific image regions
(e.g., "the cat on the left").
5. Reasoning: Answering "How many people are wearing hats?" requires
detection, counting, and attribute recognition.
"""
@staticmethod
def demonstrate_modality_gap():
# Image features: continuous, spatial
image_feature_dim = 768
image_spatial_size = 14 * 14 # 196 tokens from ViT
# Text features: discrete tokens, sequential
vocab_size = 50000
max_seq_length = 77
print(f"Image: {image_spatial_size} spatial tokens × {image_feature_dim}D")
print(f"Text: {max_seq_length} tokens from {vocab_size} vocabulary")
print("Challenge: Align these different representation spaces")
class VisionLanguageModelTypes:
"""
Three main architectural paradigms for VLMs.
"""
@staticmethod
def describe_architectures():
architectures = {
"Dual Encoder": {
"description": "Separate encoders for image and text, aligned in shared space",
"examples": ["CLIP", "ALIGN", "OpenCLIP"],
"pros": "Fast retrieval, scalable",
"cons": "Limited cross-modal interaction",
"use_cases": ["Image-text retrieval", "Zero-shot classification"]
},
"Fusion Encoder": {
"description": "Early fusion of visual and text tokens in transformer",
"examples": ["VisualBERT", "UNITER", "OSCAR"],
"pros": "Rich cross-modal interaction",
"cons": "Slower, requires paired forward pass",
"use_cases": ["VQA", "Visual reasoning"]
},
"Encoder-Decoder": {
"description": "Encode image, decode text autoregressively",
"examples": ["BLIP", "GIT", "Flamingo", "LLaVA"],
"pros": "Natural for generation",
"cons": "Harder to train, slower inference",
"use_cases": ["Captioning", "Visual dialogue"]
}
}
for name, info in architectures.items():
print(f"\n{name}:")
for k, v in info.items():
print(f" {k}: {v}")Dual Encoder Architecture
The dual encoder approach processes images and text independently, then aligns them in a shared embedding space:
class DualEncoder(nn.Module):
"""
Dual encoder architecture for vision-language alignment.
Processes image and text separately, aligns in shared space.
"""
def __init__(
self,
vision_encoder: nn.Module,
text_encoder: nn.Module,
embed_dim: int = 512,
vision_dim: int = 768,
text_dim: int = 768
):
super().__init__()
self.vision_encoder = vision_encoder
self.text_encoder = text_encoder
# Projection heads to shared embedding space
self.vision_proj = nn.Linear(vision_dim, embed_dim)
self.text_proj = nn.Linear(text_dim, embed_dim)
# Learnable temperature for contrastive loss
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""Encode images to shared embedding space."""
features = self.vision_encoder(images)
# Use CLS token or pooled features
if features.dim() == 3:
features = features[:, 0] # CLS token
features = self.vision_proj(features)
return F.normalize(features, dim=-1)
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Encode text to shared embedding space."""
features = self.text_encoder(input_ids, attention_mask)
# Use CLS or EOS token
if features.dim() == 3:
features = features[:, 0] # CLS token
features = self.text_proj(features)
return F.normalize(features, dim=-1)
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass returning embeddings and similarity logits.
"""
image_embeds = self.encode_image(images)
text_embeds = self.encode_text(input_ids, attention_mask)
# Compute similarity with temperature scaling
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_embeds @ text_embeds.t()
logits_per_text = logits_per_image.t()
return image_embeds, text_embeds, logits_per_image
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for vision-language alignment.
Brings matching pairs together, pushes non-matching apart.
"""
def __init__(self, temperature: float = 0.07):
super().__init__()
self.temperature = temperature
def forward(
self,
image_embeds: torch.Tensor,
text_embeds: torch.Tensor,
logit_scale: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Symmetric contrastive loss (InfoNCE).
Args:
image_embeds: [B, D] normalized image embeddings
text_embeds: [B, D] normalized text embeddings
"""
if logit_scale is not None:
temperature = 1.0 / logit_scale.exp()
else:
temperature = self.temperature
# Similarity matrix [B, B]
logits = image_embeds @ text_embeds.t() / temperature
# Labels: diagonal elements are positives
batch_size = logits.size(0)
labels = torch.arange(batch_size, device=logits.device)
# Symmetric loss: image-to-text and text-to-image
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
return (loss_i2t + loss_t2i) / 2Fusion Encoder Architecture
Fusion encoders enable deep cross-modal interaction by processing both modalities together:
class FusionEncoder(nn.Module):
"""
Fusion encoder that processes image and text tokens together.
Enables rich cross-modal attention and interaction.
"""
def __init__(
self,
vision_encoder: nn.Module,
hidden_dim: int = 768,
num_layers: int = 6,
num_heads: int = 12,
vocab_size: int = 30522,
max_text_len: int = 128
):
super().__init__()
self.vision_encoder = vision_encoder
# Text embedding
self.text_embeddings = nn.Embedding(vocab_size, hidden_dim)
self.text_pos_embeddings = nn.Embedding(max_text_len, hidden_dim)
# Special tokens for modality
self.img_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
self.txt_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
# Cross-modal transformer layers
self.fusion_layers = nn.ModuleList([
FusionTransformerLayer(hidden_dim, num_heads)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_dim)
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
B = images.size(0)
# Encode image patches
image_features = self.vision_encoder(images) # [B, N_img, D]
# Embed text
text_features = self.text_embeddings(input_ids) # [B, N_txt, D]
positions = torch.arange(input_ids.size(1), device=input_ids.device)
text_features = text_features + self.text_pos_embeddings(positions)
# Add modality tokens
img_token = self.img_token.expand(B, -1, -1)
txt_token = self.txt_token.expand(B, -1, -1)
# Concatenate: [IMG] image_patches [TXT] text_tokens
combined = torch.cat([img_token, image_features, txt_token, text_features], dim=1)
# Create attention mask
img_len = image_features.size(1) + 1 # +1 for IMG token
txt_len = text_features.size(1) + 1 # +1 for TXT token
combined_mask = torch.ones(B, img_len + txt_len, device=images.device)
combined_mask[:, img_len + 1:] = attention_mask # Apply text mask
# Fusion layers
for layer in self.fusion_layers:
combined = layer(combined, combined_mask)
return self.norm(combined)
class FusionTransformerLayer(nn.Module):
"""Transformer layer for multimodal fusion."""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Self-attention over all modalities
attn_mask = None
if mask is not None:
# Convert to attention mask format
attn_mask = mask.unsqueeze(1).expand(-1, x.size(1), -1)
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf'))
attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x),
attn_mask=attn_mask
)
x = x + attn_out
x = x + self.ffn(self.norm2(x))
return xEncoder-Decoder Architecture
For generative tasks like captioning, encoder-decoder architectures are natural:
class VisionEncoderDecoder(nn.Module):
"""
Encoder-decoder architecture for vision-to-language generation.
Encodes image, decodes text autoregressively.
"""
def __init__(
self,
vision_encoder: nn.Module,
decoder: nn.Module,
vision_dim: int = 768,
decoder_dim: int = 768,
num_query_tokens: int = 32
):
super().__init__()
self.vision_encoder = vision_encoder
self.decoder = decoder
# Project vision features to decoder space
self.vision_proj = nn.Linear(vision_dim, decoder_dim)
# Learnable query tokens for vision-to-language bridge
self.query_tokens = nn.Parameter(torch.randn(1, num_query_tokens, decoder_dim))
# Cross-attention to compress vision features
self.cross_attn = nn.MultiheadAttention(decoder_dim, 8, batch_first=True)
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""Encode image and extract fixed-length representation."""
B = images.size(0)
# Get vision features
vision_features = self.vision_encoder(images) # [B, N, D_vision]
vision_features = self.vision_proj(vision_features) # [B, N, D_decoder]
# Use query tokens to extract fixed representation
queries = self.query_tokens.expand(B, -1, -1)
visual_tokens, _ = self.cross_attn(queries, vision_features, vision_features)
return visual_tokens # [B, num_queries, D]
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
"""
Forward pass for training.
Args:
images: Input images [B, C, H, W]
input_ids: Target text tokens [B, L]
labels: Same as input_ids, shifted for LM loss
"""
# Encode images
visual_tokens = self.encode_image(images)
# Decode text with visual context
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=visual_tokens,
labels=labels
)
return outputs
@torch.no_grad()
def generate(
self,
images: torch.Tensor,
max_length: int = 50,
num_beams: int = 5,
**kwargs
) -> torch.Tensor:
"""Generate captions for images."""
visual_tokens = self.encode_image(images)
return self.decoder.generate(
encoder_hidden_states=visual_tokens,
max_length=max_length,
num_beams=num_beams,
**kwargs
)Vision Encoders for VLMs
Vision encoders extract visual features for language models:
class VisionEncoderForVLM(nn.Module):
"""
Vision encoder adapted for vision-language models.
Can use ViT, CNN, or hybrid architectures.
"""
def __init__(
self,
model_type: str = 'vit',
pretrained: bool = True,
output_dim: int = 768,
num_output_tokens: int = 256
):
super().__init__()
self.model_type = model_type
self.num_output_tokens = num_output_tokens
if model_type == 'vit':
self.encoder = self._build_vit()
elif model_type == 'resnet':
self.encoder = self._build_resnet()
# Resampler to fixed number of tokens (like Perceiver)
if num_output_tokens > 0:
self.resampler = PerceiverResampler(output_dim, num_output_tokens)
else:
self.resampler = None
def _build_vit(self) -> nn.Module:
"""Build ViT encoder."""
return nn.Sequential(
# Patch embedding
nn.Conv2d(3, 768, kernel_size=14, stride=14),
nn.Flatten(2),
# Would include transformer blocks
)
def _build_resnet(self) -> nn.Module:
"""Build ResNet encoder with spatial features."""
# Return feature map, not pooled features
pass
def forward(self, images: torch.Tensor) -> torch.Tensor:
features = self.encoder(images)
if self.resampler is not None:
features = self.resampler(features)
return features
class PerceiverResampler(nn.Module):
"""
Perceiver-style resampler to convert variable-length vision tokens
to fixed-length representation. Used in Flamingo.
"""
def __init__(
self,
dim: int,
num_queries: int = 64,
num_layers: int = 2,
num_heads: int = 8
):
super().__init__()
self.queries = nn.Parameter(torch.randn(num_queries, dim))
self.layers = nn.ModuleList([
nn.ModuleDict({
'cross_attn': nn.MultiheadAttention(dim, num_heads, batch_first=True),
'cross_norm': nn.LayerNorm(dim),
'ffn': nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
),
'ffn_norm': nn.LayerNorm(dim)
})
for _ in range(num_layers)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Resample variable-length vision tokens to fixed-length queries.
Args:
x: Vision features [B, N, D]
Returns:
Resampled features [B, num_queries, D]
"""
B = x.size(0)
queries = self.queries.unsqueeze(0).expand(B, -1, -1)
for layer in self.layers:
# Cross-attention: queries attend to vision features
attn_out, _ = layer['cross_attn'](
layer['cross_norm'](queries),
x, x
)
queries = queries + attn_out
queries = queries + layer['ffn'](layer['ffn_norm'](queries))
return queriesText Encoders for VLMs
class TextEncoderForVLM(nn.Module):
"""
Text encoder for vision-language models.
Typically BERT-style for understanding, GPT-style for generation.
"""
def __init__(
self,
vocab_size: int = 49408, # CLIP vocab size
hidden_dim: int = 512,
num_layers: int = 12,
num_heads: int = 8,
max_length: int = 77,
encoder_type: str = 'bert' # 'bert' or 'gpt'
):
super().__init__()
self.encoder_type = encoder_type
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_embedding = nn.Embedding(max_length, hidden_dim)
self.layers = nn.ModuleList([
TransformerEncoderLayer(hidden_dim, num_heads)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(hidden_dim)
# For GPT-style, need causal mask
if encoder_type == 'gpt':
self.register_buffer(
'causal_mask',
torch.triu(torch.ones(max_length, max_length), diagonal=1).bool()
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, L = input_ids.shape
# Embed tokens and positions
positions = torch.arange(L, device=input_ids.device)
x = self.token_embedding(input_ids) + self.position_embedding(positions)
# Create attention mask
if self.encoder_type == 'gpt':
attn_mask = self.causal_mask[:L, :L]
else:
attn_mask = None
# Transformer layers
for layer in self.layers:
x = layer(x, attention_mask=attn_mask, key_padding_mask=attention_mask)
x = self.final_norm(x)
return x
class TransformerEncoderLayer(nn.Module):
"""Standard transformer encoder layer."""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
attn_out, _ = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x),
attn_mask=attention_mask,
key_padding_mask=key_padding_mask
)
x = x + attn_out
x = x + self.ffn(self.norm2(x))
return xTraining Objectives
class VLMTrainingObjectives:
"""
Common training objectives for vision-language models.
"""
@staticmethod
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
"""Image-text contrastive loss (ITC)."""
logits = image_embeds @ text_embeds.t() / temperature
labels = torch.arange(len(logits), device=logits.device)
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
return (loss_i2t + loss_t2i) / 2
@staticmethod
def image_text_matching_loss(fused_features, labels):
"""
Image-text matching loss (ITM).
Binary classification: does image match text?
"""
# Use CLS token for matching prediction
cls_features = fused_features[:, 0]
matching_head = nn.Linear(cls_features.size(-1), 2)
logits = matching_head(cls_features)
return F.cross_entropy(logits, labels)
@staticmethod
def masked_language_modeling_loss(
fused_features,
labels,
vocab_size: int
):
"""
Masked language modeling with visual context (MLM).
"""
lm_head = nn.Linear(fused_features.size(-1), vocab_size)
logits = lm_head(fused_features)
# Only compute loss on masked positions (labels != -100)
return F.cross_entropy(
logits.view(-1, vocab_size),
labels.view(-1),
ignore_index=-100
)
@staticmethod
def language_modeling_loss(logits, labels):
"""
Autoregressive language modeling loss (LM).
Used for caption generation.
"""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
def compare_training_objectives():
"""Compare VLM training objectives."""
objectives = {
'ITC (Contrastive)': {
'what': 'Align image-text pairs in embedding space',
'used_by': 'CLIP, ALIGN, BLIP',
'pros': 'Scalable, good for retrieval',
'cons': 'Limited fine-grained understanding'
},
'ITM (Matching)': {
'what': 'Binary classify if image-text match',
'used_by': 'UNITER, OSCAR, BLIP',
'pros': 'Fine-grained alignment',
'cons': 'Requires hard negatives'
},
'MLM (Masked LM)': {
'what': 'Predict masked text with visual context',
'used_by': 'VisualBERT, UNITER',
'pros': 'Rich language understanding',
'cons': 'Only for encoder models'
},
'LM (Language Modeling)': {
'what': 'Autoregressive text generation',
'used_by': 'BLIP, GIT, Flamingo',
'pros': 'Natural for generation',
'cons': 'Needs careful training'
}
}
for name, info in objectives.items():
print(f"\n{name}:")
for k, v in info.items():
print(f" {k}: {v}")
compare_training_objectives()Key Takeaways
Vision-language models connect visual perception with natural language understanding through various architectural approaches. Dual encoders like CLIP process modalities independently and align them in a shared space, enabling efficient retrieval and zero-shot classification. Fusion encoders enable rich cross-modal interaction through joint attention over visual and textual tokens. Encoder-decoder architectures naturally support generation tasks like captioning. Training objectives include contrastive alignment (ITC), matching classification (ITM), masked language modeling (MLM), and autoregressive generation (LM). The choice of architecture and objective depends on the target application: retrieval benefits from dual encoders, understanding tasks from fusion models, and generation from encoder-decoders.