Advanced Expert 105 min read

Chapter 22: Text-to-Image and Beyond

Stable Diffusion, ControlNet, and image editing.

Libraries covered: Diffusers

Learning Objectives

["Generate images from text", "Fine-tune diffusion models", "Apply ControlNet"]


22.1 Introduction to Vision-Language Models Intermediate

Introduction to Vision-Language Models

Vision-Language Models (VLMs) bridge the gap between visual perception and natural language understanding, enabling systems that can see and communicate about what they see. These models have transformed how AI systems interact with multimodal data, powering applications from image search and captioning to visual question answering and multimodal assistants.

The Multimodal Challenge

Humans effortlessly integrate visual and linguistic information, but teaching machines to do the same presents fundamental challenges:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass

@dataclass
class MultimodalChallenges:
    """
    Key challenges in vision-language modeling:

    1. Representation Gap: Images are continuous, high-dimensional signals
       while text is discrete and sequential. How do we align these spaces?

    2. Semantic Alignment: The word "dog" must connect to visual features
       of dogs across breeds, poses, and contexts.

    3. Compositional Understanding: "A red ball on a blue table" requires
       understanding objects, attributes, and spatial relationships.

    4. Grounding: Connecting language to specific image regions
       (e.g., "the cat on the left").

    5. Reasoning: Answering "How many people are wearing hats?" requires
       detection, counting, and attribute recognition.
    """

    @staticmethod
    def demonstrate_modality_gap():
        # Image features: continuous, spatial
        image_feature_dim = 768
        image_spatial_size = 14 * 14  # 196 tokens from ViT

        # Text features: discrete tokens, sequential
        vocab_size = 50000
        max_seq_length = 77

        print(f"Image: {image_spatial_size} spatial tokens × {image_feature_dim}D")
        print(f"Text: {max_seq_length} tokens from {vocab_size} vocabulary")
        print("Challenge: Align these different representation spaces")


class VisionLanguageModelTypes:
    """
    Three main architectural paradigms for VLMs.
    """

    @staticmethod
    def describe_architectures():
        architectures = {
            "Dual Encoder": {
                "description": "Separate encoders for image and text, aligned in shared space",
                "examples": ["CLIP", "ALIGN", "OpenCLIP"],
                "pros": "Fast retrieval, scalable",
                "cons": "Limited cross-modal interaction",
                "use_cases": ["Image-text retrieval", "Zero-shot classification"]
            },
            "Fusion Encoder": {
                "description": "Early fusion of visual and text tokens in transformer",
                "examples": ["VisualBERT", "UNITER", "OSCAR"],
                "pros": "Rich cross-modal interaction",
                "cons": "Slower, requires paired forward pass",
                "use_cases": ["VQA", "Visual reasoning"]
            },
            "Encoder-Decoder": {
                "description": "Encode image, decode text autoregressively",
                "examples": ["BLIP", "GIT", "Flamingo", "LLaVA"],
                "pros": "Natural for generation",
                "cons": "Harder to train, slower inference",
                "use_cases": ["Captioning", "Visual dialogue"]
            }
        }

        for name, info in architectures.items():
            print(f"\n{name}:")
            for k, v in info.items():
                print(f"  {k}: {v}")

Dual Encoder Architecture

The dual encoder approach processes images and text independently, then aligns them in a shared embedding space:

PYTHON
class DualEncoder(nn.Module):
    """
    Dual encoder architecture for vision-language alignment.
    Processes image and text separately, aligns in shared space.
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        text_encoder: nn.Module,
        embed_dim: int = 512,
        vision_dim: int = 768,
        text_dim: int = 768
    ):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder

        # Projection heads to shared embedding space
        self.vision_proj = nn.Linear(vision_dim, embed_dim)
        self.text_proj = nn.Linear(text_dim, embed_dim)

        # Learnable temperature for contrastive loss
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """Encode images to shared embedding space."""
        features = self.vision_encoder(images)
        # Use CLS token or pooled features
        if features.dim() == 3:
            features = features[:, 0]  # CLS token
        features = self.vision_proj(features)
        return F.normalize(features, dim=-1)

    def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Encode text to shared embedding space."""
        features = self.text_encoder(input_ids, attention_mask)
        # Use CLS or EOS token
        if features.dim() == 3:
            features = features[:, 0]  # CLS token
        features = self.text_proj(features)
        return F.normalize(features, dim=-1)

    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass returning embeddings and similarity logits.
        """
        image_embeds = self.encode_image(images)
        text_embeds = self.encode_text(input_ids, attention_mask)

        # Compute similarity with temperature scaling
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_embeds @ text_embeds.t()
        logits_per_text = logits_per_image.t()

        return image_embeds, text_embeds, logits_per_image


class ContrastiveLoss(nn.Module):
    """
    Contrastive loss for vision-language alignment.
    Brings matching pairs together, pushes non-matching apart.
    """

    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature

    def forward(
        self,
        image_embeds: torch.Tensor,
        text_embeds: torch.Tensor,
        logit_scale: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Symmetric contrastive loss (InfoNCE).

        Args:
            image_embeds: [B, D] normalized image embeddings
            text_embeds: [B, D] normalized text embeddings
        """
        if logit_scale is not None:
            temperature = 1.0 / logit_scale.exp()
        else:
            temperature = self.temperature

        # Similarity matrix [B, B]
        logits = image_embeds @ text_embeds.t() / temperature

        # Labels: diagonal elements are positives
        batch_size = logits.size(0)
        labels = torch.arange(batch_size, device=logits.device)

        # Symmetric loss: image-to-text and text-to-image
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.t(), labels)

        return (loss_i2t + loss_t2i) / 2

Fusion Encoder Architecture

Fusion encoders enable deep cross-modal interaction by processing both modalities together:

PYTHON
class FusionEncoder(nn.Module):
    """
    Fusion encoder that processes image and text tokens together.
    Enables rich cross-modal attention and interaction.
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        hidden_dim: int = 768,
        num_layers: int = 6,
        num_heads: int = 12,
        vocab_size: int = 30522,
        max_text_len: int = 128
    ):
        super().__init__()

        self.vision_encoder = vision_encoder

        # Text embedding
        self.text_embeddings = nn.Embedding(vocab_size, hidden_dim)
        self.text_pos_embeddings = nn.Embedding(max_text_len, hidden_dim)

        # Special tokens for modality
        self.img_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.txt_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))

        # Cross-modal transformer layers
        self.fusion_layers = nn.ModuleList([
            FusionTransformerLayer(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(hidden_dim)

    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        B = images.size(0)

        # Encode image patches
        image_features = self.vision_encoder(images)  # [B, N_img, D]

        # Embed text
        text_features = self.text_embeddings(input_ids)  # [B, N_txt, D]
        positions = torch.arange(input_ids.size(1), device=input_ids.device)
        text_features = text_features + self.text_pos_embeddings(positions)

        # Add modality tokens
        img_token = self.img_token.expand(B, -1, -1)
        txt_token = self.txt_token.expand(B, -1, -1)

        # Concatenate: [IMG] image_patches [TXT] text_tokens
        combined = torch.cat([img_token, image_features, txt_token, text_features], dim=1)

        # Create attention mask
        img_len = image_features.size(1) + 1  # +1 for IMG token
        txt_len = text_features.size(1) + 1   # +1 for TXT token
        combined_mask = torch.ones(B, img_len + txt_len, device=images.device)
        combined_mask[:, img_len + 1:] = attention_mask  # Apply text mask

        # Fusion layers
        for layer in self.fusion_layers:
            combined = layer(combined, combined_mask)

        return self.norm(combined)


class FusionTransformerLayer(nn.Module):
    """Transformer layer for multimodal fusion."""

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention over all modalities
        attn_mask = None
        if mask is not None:
            # Convert to attention mask format
            attn_mask = mask.unsqueeze(1).expand(-1, x.size(1), -1)
            attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf'))

        attn_out, _ = self.self_attn(
            self.norm1(x), self.norm1(x), self.norm1(x),
            attn_mask=attn_mask
        )
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

Encoder-Decoder Architecture

For generative tasks like captioning, encoder-decoder architectures are natural:

PYTHON
class VisionEncoderDecoder(nn.Module):
    """
    Encoder-decoder architecture for vision-to-language generation.
    Encodes image, decodes text autoregressively.
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        decoder: nn.Module,
        vision_dim: int = 768,
        decoder_dim: int = 768,
        num_query_tokens: int = 32
    ):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.decoder = decoder

        # Project vision features to decoder space
        self.vision_proj = nn.Linear(vision_dim, decoder_dim)

        # Learnable query tokens for vision-to-language bridge
        self.query_tokens = nn.Parameter(torch.randn(1, num_query_tokens, decoder_dim))

        # Cross-attention to compress vision features
        self.cross_attn = nn.MultiheadAttention(decoder_dim, 8, batch_first=True)

    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """Encode image and extract fixed-length representation."""
        B = images.size(0)

        # Get vision features
        vision_features = self.vision_encoder(images)  # [B, N, D_vision]
        vision_features = self.vision_proj(vision_features)  # [B, N, D_decoder]

        # Use query tokens to extract fixed representation
        queries = self.query_tokens.expand(B, -1, -1)
        visual_tokens, _ = self.cross_attn(queries, vision_features, vision_features)

        return visual_tokens  # [B, num_queries, D]

    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass for training.

        Args:
            images: Input images [B, C, H, W]
            input_ids: Target text tokens [B, L]
            labels: Same as input_ids, shifted for LM loss
        """
        # Encode images
        visual_tokens = self.encode_image(images)

        # Decode text with visual context
        outputs = self.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=visual_tokens,
            labels=labels
        )

        return outputs

    @torch.no_grad()
    def generate(
        self,
        images: torch.Tensor,
        max_length: int = 50,
        num_beams: int = 5,
        **kwargs
    ) -> torch.Tensor:
        """Generate captions for images."""
        visual_tokens = self.encode_image(images)
        return self.decoder.generate(
            encoder_hidden_states=visual_tokens,
            max_length=max_length,
            num_beams=num_beams,
            **kwargs
        )

Vision Encoders for VLMs

Vision encoders extract visual features for language models:

PYTHON
class VisionEncoderForVLM(nn.Module):
    """
    Vision encoder adapted for vision-language models.
    Can use ViT, CNN, or hybrid architectures.
    """

    def __init__(
        self,
        model_type: str = 'vit',
        pretrained: bool = True,
        output_dim: int = 768,
        num_output_tokens: int = 256
    ):
        super().__init__()
        self.model_type = model_type
        self.num_output_tokens = num_output_tokens

        if model_type == 'vit':
            self.encoder = self._build_vit()
        elif model_type == 'resnet':
            self.encoder = self._build_resnet()

        # Resampler to fixed number of tokens (like Perceiver)
        if num_output_tokens > 0:
            self.resampler = PerceiverResampler(output_dim, num_output_tokens)
        else:
            self.resampler = None

    def _build_vit(self) -> nn.Module:
        """Build ViT encoder."""
        return nn.Sequential(
            # Patch embedding
            nn.Conv2d(3, 768, kernel_size=14, stride=14),
            nn.Flatten(2),
            # Would include transformer blocks
        )

    def _build_resnet(self) -> nn.Module:
        """Build ResNet encoder with spatial features."""
        # Return feature map, not pooled features
        pass

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        features = self.encoder(images)

        if self.resampler is not None:
            features = self.resampler(features)

        return features


class PerceiverResampler(nn.Module):
    """
    Perceiver-style resampler to convert variable-length vision tokens
    to fixed-length representation. Used in Flamingo.
    """

    def __init__(
        self,
        dim: int,
        num_queries: int = 64,
        num_layers: int = 2,
        num_heads: int = 8
    ):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(num_queries, dim))

        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'cross_attn': nn.MultiheadAttention(dim, num_heads, batch_first=True),
                'cross_norm': nn.LayerNorm(dim),
                'ffn': nn.Sequential(
                    nn.Linear(dim, dim * 4),
                    nn.GELU(),
                    nn.Linear(dim * 4, dim)
                ),
                'ffn_norm': nn.LayerNorm(dim)
            })
            for _ in range(num_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Resample variable-length vision tokens to fixed-length queries.

        Args:
            x: Vision features [B, N, D]

        Returns:
            Resampled features [B, num_queries, D]
        """
        B = x.size(0)
        queries = self.queries.unsqueeze(0).expand(B, -1, -1)

        for layer in self.layers:
            # Cross-attention: queries attend to vision features
            attn_out, _ = layer['cross_attn'](
                layer['cross_norm'](queries),
                x, x
            )
            queries = queries + attn_out
            queries = queries + layer['ffn'](layer['ffn_norm'](queries))

        return queries

Text Encoders for VLMs

PYTHON
class TextEncoderForVLM(nn.Module):
    """
    Text encoder for vision-language models.
    Typically BERT-style for understanding, GPT-style for generation.
    """

    def __init__(
        self,
        vocab_size: int = 49408,  # CLIP vocab size
        hidden_dim: int = 512,
        num_layers: int = 12,
        num_heads: int = 8,
        max_length: int = 77,
        encoder_type: str = 'bert'  # 'bert' or 'gpt'
    ):
        super().__init__()
        self.encoder_type = encoder_type

        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(max_length, hidden_dim)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])

        self.final_norm = nn.LayerNorm(hidden_dim)

        # For GPT-style, need causal mask
        if encoder_type == 'gpt':
            self.register_buffer(
                'causal_mask',
                torch.triu(torch.ones(max_length, max_length), diagonal=1).bool()
            )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        B, L = input_ids.shape

        # Embed tokens and positions
        positions = torch.arange(L, device=input_ids.device)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)

        # Create attention mask
        if self.encoder_type == 'gpt':
            attn_mask = self.causal_mask[:L, :L]
        else:
            attn_mask = None

        # Transformer layers
        for layer in self.layers:
            x = layer(x, attention_mask=attn_mask, key_padding_mask=attention_mask)

        x = self.final_norm(x)

        return x


class TransformerEncoderLayer(nn.Module):
    """Standard transformer encoder layer."""

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        attn_out, _ = self.self_attn(
            self.norm1(x), self.norm1(x), self.norm1(x),
            attn_mask=attention_mask,
            key_padding_mask=key_padding_mask
        )
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

Training Objectives

PYTHON
class VLMTrainingObjectives:
    """
    Common training objectives for vision-language models.
    """

    @staticmethod
    def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
        """Image-text contrastive loss (ITC)."""
        logits = image_embeds @ text_embeds.t() / temperature
        labels = torch.arange(len(logits), device=logits.device)
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.t(), labels)
        return (loss_i2t + loss_t2i) / 2

    @staticmethod
    def image_text_matching_loss(fused_features, labels):
        """
        Image-text matching loss (ITM).
        Binary classification: does image match text?
        """
        # Use CLS token for matching prediction
        cls_features = fused_features[:, 0]
        matching_head = nn.Linear(cls_features.size(-1), 2)
        logits = matching_head(cls_features)
        return F.cross_entropy(logits, labels)

    @staticmethod
    def masked_language_modeling_loss(
        fused_features,
        labels,
        vocab_size: int
    ):
        """
        Masked language modeling with visual context (MLM).
        """
        lm_head = nn.Linear(fused_features.size(-1), vocab_size)
        logits = lm_head(fused_features)
        # Only compute loss on masked positions (labels != -100)
        return F.cross_entropy(
            logits.view(-1, vocab_size),
            labels.view(-1),
            ignore_index=-100
        )

    @staticmethod
    def language_modeling_loss(logits, labels):
        """
        Autoregressive language modeling loss (LM).
        Used for caption generation.
        """
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        return F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )


def compare_training_objectives():
    """Compare VLM training objectives."""
    objectives = {
        'ITC (Contrastive)': {
            'what': 'Align image-text pairs in embedding space',
            'used_by': 'CLIP, ALIGN, BLIP',
            'pros': 'Scalable, good for retrieval',
            'cons': 'Limited fine-grained understanding'
        },
        'ITM (Matching)': {
            'what': 'Binary classify if image-text match',
            'used_by': 'UNITER, OSCAR, BLIP',
            'pros': 'Fine-grained alignment',
            'cons': 'Requires hard negatives'
        },
        'MLM (Masked LM)': {
            'what': 'Predict masked text with visual context',
            'used_by': 'VisualBERT, UNITER',
            'pros': 'Rich language understanding',
            'cons': 'Only for encoder models'
        },
        'LM (Language Modeling)': {
            'what': 'Autoregressive text generation',
            'used_by': 'BLIP, GIT, Flamingo',
            'pros': 'Natural for generation',
            'cons': 'Needs careful training'
        }
    }

    for name, info in objectives.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")

compare_training_objectives()

Key Takeaways

Vision-language models connect visual perception with natural language understanding through various architectural approaches. Dual encoders like CLIP process modalities independently and align them in a shared space, enabling efficient retrieval and zero-shot classification. Fusion encoders enable rich cross-modal interaction through joint attention over visual and textual tokens. Encoder-decoder architectures naturally support generation tasks like captioning. Training objectives include contrastive alignment (ITC), matching classification (ITM), masked language modeling (MLM), and autoregressive generation (LM). The choice of architecture and objective depends on the target application: retrieval benefits from dual encoders, understanding tasks from fusion models, and generation from encoder-decoders.

22.2 CLIP and Contrastive Vision-Language Advanced

CLIP and Contrastive Vision-Language

CLIP (Contrastive Language-Image Pre-training) revolutionized vision-language modeling by demonstrating that simple contrastive learning on web-scale image-text pairs produces remarkably powerful and transferable representations. This section explores CLIP's architecture, training methodology, and the ecosystem of contrastive vision-language models it inspired.

CLIP Architecture

CLIP consists of two encoders that map images and text to a shared embedding space:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, List, Dict

class CLIP(nn.Module):
    """
    CLIP: Contrastive Language-Image Pre-training.

    Key insight: Learn visual concepts from natural language supervision
    by training on 400M image-text pairs from the web.
    """

    def __init__(
        self,
        embed_dim: int = 512,
        # Vision
        image_resolution: int = 224,
        vision_layers: int = 12,
        vision_width: int = 768,
        vision_patch_size: int = 32,
        # Text
        context_length: int = 77,
        vocab_size: int = 49408,
        transformer_width: int = 512,
        transformer_heads: int = 8,
        transformer_layers: int = 12
    ):
        super().__init__()

        self.context_length = context_length

        # Vision Transformer
        self.visual = VisionTransformer(
            input_resolution=image_resolution,
            patch_size=vision_patch_size,
            width=vision_width,
            layers=vision_layers,
            heads=vision_width // 64,
            output_dim=embed_dim
        )

        # Text Transformer
        self.transformer = TextTransformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            vocab_size=vocab_size,
            context_length=context_length,
            output_dim=embed_dim
        )

        # Learnable temperature parameter
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def encode_image(self, image: torch.Tensor) -> torch.Tensor:
        """Encode images to normalized embeddings."""
        return self.visual(image)

    def encode_text(self, text: torch.Tensor) -> torch.Tensor:
        """Encode text to normalized embeddings."""
        return self.transformer(text)

    def forward(
        self,
        image: torch.Tensor,
        text: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns image and text features for contrastive loss.
        """
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # Normalize features
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        return image_features, text_features


class VisionTransformer(nn.Module):
    """Vision Transformer for CLIP."""

    def __init__(
        self,
        input_resolution: int,
        patch_size: int,
        width: int,
        layers: int,
        heads: int,
        output_dim: int
    ):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim

        # Patch embedding
        self.conv1 = nn.Conv2d(
            3, width, kernel_size=patch_size, stride=patch_size, bias=False
        )

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(
            scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
        )

        self.ln_pre = nn.LayerNorm(width)

        # Transformer blocks
        self.transformer = nn.ModuleList([
            ResidualAttentionBlock(width, heads) for _ in range(layers)
        ])

        self.ln_post = nn.LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patch embedding
        x = self.conv1(x)  # [B, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # [B, width, grid**2]
        x = x.permute(0, 2, 1)  # [B, grid**2, width]

        # Prepend class token
        x = torch.cat([
            self.class_embedding.expand(x.shape[0], 1, -1),
            x
        ], dim=1)

        # Add positional embedding
        x = x + self.positional_embedding

        x = self.ln_pre(x)

        # Transformer
        for block in self.transformer:
            x = block(x)

        x = self.ln_post(x[:, 0, :])  # Use CLS token

        # Project to embed_dim
        x = x @ self.proj

        return x


class TextTransformer(nn.Module):
    """Text Transformer for CLIP."""

    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        vocab_size: int,
        context_length: int,
        output_dim: int
    ):
        super().__init__()
        self.context_length = context_length

        self.token_embedding = nn.Embedding(vocab_size, width)
        self.positional_embedding = nn.Parameter(
            torch.empty(context_length, width)
        )

        self.transformer = nn.ModuleList([
            ResidualAttentionBlock(width, heads, causal=True)
            for _ in range(layers)
        ])

        self.ln_final = nn.LayerNorm(width)
        self.text_projection = nn.Parameter(torch.empty(width, output_dim))

        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)
        nn.init.normal_(self.text_projection, std=self.transformer[0].attn.in_proj_weight.shape[1] ** -0.5)

    def forward(self, text: torch.Tensor) -> torch.Tensor:
        x = self.token_embedding(text)
        x = x + self.positional_embedding[:x.size(1)]

        for block in self.transformer:
            x = block(x)

        x = self.ln_final(x)

        # Take features from EOS token (highest index in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x


class ResidualAttentionBlock(nn.Module):
    """Transformer block with pre-norm."""

    def __init__(self, d_model: int, n_head: int, causal: bool = False):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.ln_2 = nn.LayerNorm(d_model)
        self.causal = causal

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Create causal mask if needed
        attn_mask = None
        if self.causal:
            attn_mask = torch.triu(
                torch.ones(x.size(1), x.size(1), device=x.device),
                diagonal=1
            ).bool()

        # Self-attention with residual
        x = x + self.attn(
            self.ln_1(x), self.ln_1(x), self.ln_1(x),
            attn_mask=attn_mask,
            need_weights=False
        )[0]

        # FFN with residual
        x = x + self.mlp(self.ln_2(x))

        return x

Contrastive Pre-training

CLIP learns by maximizing similarity between matched image-text pairs:

PYTHON
class CLIPLoss(nn.Module):
    """
    CLIP contrastive loss with learnable temperature.
    """

    def __init__(self):
        super().__init__()

    def forward(
        self,
        image_features: torch.Tensor,
        text_features: torch.Tensor,
        logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """
        Symmetric contrastive loss.

        Args:
            image_features: [B, D] normalized image embeddings
            text_features: [B, D] normalized text embeddings
            logit_scale: Learnable temperature (log scale)
        """
        # Compute similarity
        logits_per_image = logit_scale.exp() * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # Labels: diagonal is positive
        batch_size = image_features.size(0)
        labels = torch.arange(batch_size, device=image_features.device)

        # Symmetric cross-entropy loss
        loss_i2t = F.cross_entropy(logits_per_image, labels)
        loss_t2i = F.cross_entropy(logits_per_text, labels)

        return (loss_i2t + loss_t2i) / 2


class CLIPTrainer:
    """Training loop for CLIP."""

    def __init__(
        self,
        model: CLIP,
        optimizer: torch.optim.Optimizer,
        device: torch.device,
        max_grad_norm: float = 1.0
    ):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device
        self.max_grad_norm = max_grad_norm
        self.loss_fn = CLIPLoss()

    def train_step(
        self,
        images: torch.Tensor,
        texts: torch.Tensor
    ) -> float:
        """Single training step."""
        self.model.train()
        images = images.to(self.device)
        texts = texts.to(self.device)

        # Forward pass
        image_features, text_features = self.model(images, texts)

        # Compute loss
        loss = self.loss_fn(
            image_features,
            text_features,
            self.model.logit_scale
        )

        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.max_grad_norm
        )

        self.optimizer.step()

        # Clamp logit scale
        with torch.no_grad():
            self.model.logit_scale.clamp_(0, np.log(100))

        return loss.item()


def clip_training_recipe():
    """CLIP training hyperparameters."""
    config = {
        'batch_size': 32768,  # Very large batches
        'learning_rate': 5e-4,
        'weight_decay': 0.2,
        'warmup_steps': 2000,
        'total_steps': 400000,
        'optimizer': 'AdamW',
        'lr_scheduler': 'cosine',
        'temperature_init': 0.07,
        'mixed_precision': True,
        'gradient_checkpointing': True
    }

    print("CLIP Training Recipe:")
    for k, v in config.items():
        print(f"  {k}: {v}")

    return config

Zero-Shot Classification

CLIP enables classification without task-specific training:

PYTHON
class CLIPZeroShotClassifier:
    """
    Zero-shot image classification using CLIP.
    """

    def __init__(
        self,
        model: CLIP,
        tokenizer,
        device: torch.device
    ):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    def create_text_embeddings(
        self,
        class_names: List[str],
        templates: Optional[List[str]] = None
    ) -> torch.Tensor:
        """
        Create text embeddings for class names using prompt templates.

        Templates like "a photo of a {}" help bridge the domain gap
        between web captions and classification labels.
        """
        if templates is None:
            templates = [
                "a photo of a {}.",
                "a blurry photo of a {}.",
                "a photo of many {}.",
                "a sculpture of a {}.",
                "a photo of the hard to see {}.",
                "a low resolution photo of the {}.",
                "a rendering of a {}.",
                "graffiti of a {}.",
                "a bad photo of the {}.",
                "a cropped photo of the {}.",
                "a photo of a large {}.",
                "a photo of a small {}.",
            ]

        all_embeddings = []

        with torch.no_grad():
            for class_name in class_names:
                # Create prompts from templates
                texts = [template.format(class_name) for template in templates]
                tokens = self.tokenizer(texts).to(self.device)

                # Encode and average
                embeddings = self.model.encode_text(tokens)
                embeddings = F.normalize(embeddings, dim=-1)
                class_embedding = embeddings.mean(dim=0)
                class_embedding = F.normalize(class_embedding, dim=0)

                all_embeddings.append(class_embedding)

        return torch.stack(all_embeddings)

    @torch.no_grad()
    def classify(
        self,
        images: torch.Tensor,
        text_embeddings: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Classify images using precomputed text embeddings.

        Returns:
            probs: Probability distribution over classes [B, num_classes]
            predictions: Predicted class indices [B]
        """
        images = images.to(self.device)

        # Encode images
        image_embeddings = self.model.encode_image(images)
        image_embeddings = F.normalize(image_embeddings, dim=-1)

        # Compute similarity
        logit_scale = self.model.logit_scale.exp()
        logits = logit_scale * image_embeddings @ text_embeddings.t()

        probs = F.softmax(logits, dim=-1)
        predictions = probs.argmax(dim=-1)

        return probs, predictions


def zero_shot_imagenet_example():
    """Example: Zero-shot ImageNet classification."""

    # ImageNet class names (subset)
    imagenet_classes = [
        "tench", "goldfish", "great white shark", "tiger shark",
        "hammerhead", "electric ray", "stingray", "cock", "hen",
        "ostrich", "brambling", "goldfinch", "house finch", "junco"
    ]

    # Enhanced templates for ImageNet
    templates = [
        "a bad photo of a {}.",
        "a photo of many {}.",
        "a sculpture of a {}.",
        "a photo of the hard to see {}.",
        "a low resolution photo of the {}.",
        "a rendering of a {}.",
        "graffiti of a {}.",
        "a bad photo of the {}.",
        "a cropped photo of the {}.",
        "a tattoo of a {}.",
        "the embroidered {}.",
        "a photo of a hard to see {}.",
        "a bright photo of a {}.",
        "a photo of a clean {}.",
        "a photo of a dirty {}.",
        "a dark photo of the {}.",
        "a drawing of a {}.",
        "a photo of my {}.",
        "the plastic {}.",
        "a photo of the cool {}.",
        "a close-up photo of a {}.",
        "a black and white photo of the {}.",
        "a painting of the {}.",
        "a painting of a {}.",
        "a pixelated photo of the {}.",
        "a sculpture of the {}.",
        "a bright photo of the {}.",
        "a cropped photo of a {}.",
        "a plastic {}.",
        "a photo of the dirty {}.",
        "a jpeg corrupted photo of a {}.",
        "a blurry photo of the {}.",
        "a photo of the {}.",
        "a good photo of the {}.",
        "a rendering of the {}.",
        "a {} in a video game.",
        "a photo of one {}.",
        "a doodle of a {}.",
        "a close-up photo of the {}.",
        "a photo of a {}.",
        "the origami {}.",
        "the {} in a video game.",
        "a sketch of a {}.",
        "a doodle of the {}.",
        "a origami {}.",
        "a low resolution photo of a {}.",
        "the toy {}.",
        "a rendition of the {}.",
        "a photo of the clean {}.",
        "a photo of a large {}.",
        "a rendition of a {}.",
        "a photo of a nice {}.",
        "a photo of a weird {}.",
        "a blurry photo of a {}.",
        "a cartoon {}.",
        "art of a {}.",
        "a sketch of the {}.",
        "a embroidered {}.",
        "a pixelated photo of a {}.",
        "itap of the {}.",
        "a jpeg corrupted photo of the {}.",
        "a good photo of a {}.",
        "a plushie {}.",
        "a photo of the nice {}.",
        "a photo of the small {}.",
        "a photo of the weird {}.",
        "the cartoon {}.",
        "art of the {}.",
        "a drawing of the {}.",
        "a photo of the large {}.",
        "a black and white photo of a {}.",
        "the plushie {}.",
        "a dark photo of a {}.",
        "itap of a {}.",
        "graffiti of the {}.",
        "a toy {}.",
        "itap of my {}.",
        "a photo of a cool {}.",
        "a photo of a small {}.",
        "a tattoo of the {}.",
    ]

    print(f"Using {len(templates)} prompt templates")
    print(f"Classifying into {len(imagenet_classes)} classes")

Image-Text Retrieval

PYTHON
class CLIPRetrieval:
    """
    Image-text retrieval using CLIP embeddings.
    """

    def __init__(self, model: CLIP, device: torch.device):
        self.model = model.to(device)
        self.device = device
        self.model.eval()

        # Index storage
        self.image_embeddings = None
        self.text_embeddings = None
        self.image_ids = None
        self.text_ids = None

    @torch.no_grad()
    def index_images(
        self,
        images: torch.Tensor,
        image_ids: Optional[List] = None
    ):
        """Build image index."""
        images = images.to(self.device)
        embeddings = self.model.encode_image(images)
        self.image_embeddings = F.normalize(embeddings, dim=-1)
        self.image_ids = image_ids or list(range(len(images)))

    @torch.no_grad()
    def index_texts(
        self,
        texts: torch.Tensor,
        text_ids: Optional[List] = None
    ):
        """Build text index."""
        texts = texts.to(self.device)
        embeddings = self.model.encode_text(texts)
        self.text_embeddings = F.normalize(embeddings, dim=-1)
        self.text_ids = text_ids or list(range(len(texts)))

    @torch.no_grad()
    def text_to_image_retrieval(
        self,
        query_text: torch.Tensor,
        top_k: int = 10
    ) -> Tuple[List, torch.Tensor]:
        """
        Retrieve images given text query.
        """
        query_text = query_text.to(self.device)
        query_embedding = self.model.encode_text(query_text)
        query_embedding = F.normalize(query_embedding, dim=-1)

        # Compute similarities
        similarities = query_embedding @ self.image_embeddings.t()

        # Get top-k
        scores, indices = similarities.topk(top_k, dim=-1)

        retrieved_ids = [self.image_ids[i] for i in indices[0].tolist()]
        return retrieved_ids, scores[0]

    @torch.no_grad()
    def image_to_text_retrieval(
        self,
        query_image: torch.Tensor,
        top_k: int = 10
    ) -> Tuple[List, torch.Tensor]:
        """
        Retrieve texts given image query.
        """
        query_image = query_image.to(self.device)
        query_embedding = self.model.encode_image(query_image)
        query_embedding = F.normalize(query_embedding, dim=-1)

        # Compute similarities
        similarities = query_embedding @ self.text_embeddings.t()

        # Get top-k
        scores, indices = similarities.topk(top_k, dim=-1)

        retrieved_ids = [self.text_ids[i] for i in indices[0].tolist()]
        return retrieved_ids, scores[0]

CLIP Variants

PYTHON
class SigLIP(nn.Module):
    """
    SigLIP: Sigmoid Loss for Language Image Pre-Training.

    Replaces softmax contrastive loss with sigmoid binary classification,
    enabling better scaling and batch-size independent training.
    """

    def __init__(self, vision_encoder, text_encoder, embed_dim=768):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder

        self.vision_proj = nn.Linear(embed_dim, embed_dim)
        self.text_proj = nn.Linear(embed_dim, embed_dim)

        # Bias for sigmoid loss
        self.logit_bias = nn.Parameter(torch.zeros(1))

    def forward(self, images, texts):
        image_embeds = F.normalize(self.vision_proj(self.vision_encoder(images)), dim=-1)
        text_embeds = F.normalize(self.text_proj(self.text_encoder(texts)), dim=-1)
        return image_embeds, text_embeds


class SigLIPLoss(nn.Module):
    """
    Sigmoid loss for vision-language alignment.

    Instead of softmax over all pairs, uses binary classification
    for each image-text pair independently.
    """

    def __init__(self, temperature: float = 10.0):
        super().__init__()
        self.temperature = temperature

    def forward(
        self,
        image_embeds: torch.Tensor,
        text_embeds: torch.Tensor,
        logit_bias: torch.Tensor
    ) -> torch.Tensor:
        # All pairwise similarities
        logits = image_embeds @ text_embeds.t() * self.temperature + logit_bias

        # Labels: 1 for matching pairs (diagonal), -1 for non-matching
        batch_size = image_embeds.size(0)
        labels = 2 * torch.eye(batch_size, device=logits.device) - 1

        # Binary cross-entropy with logits
        loss = -F.logsigmoid(labels * logits).mean()

        return loss


def compare_clip_variants():
    """Compare CLIP-style models."""
    variants = {
        'CLIP (OpenAI)': {
            'data': '400M image-text pairs (WIT)',
            'vision': 'ViT-L/14, ResNet',
            'loss': 'Softmax contrastive',
            'strengths': 'Original, well-studied'
        },
        'OpenCLIP': {
            'data': 'LAION-2B, LAION-400M',
            'vision': 'ViT-G/14, ConvNeXt',
            'loss': 'Softmax contrastive',
            'strengths': 'Open weights, larger scale'
        },
        'ALIGN': {
            'data': '1.8B noisy image-text pairs',
            'vision': 'EfficientNet',
            'loss': 'Softmax contrastive',
            'strengths': 'Noisy data robustness'
        },
        'SigLIP': {
            'data': 'WebLI (10B+)',
            'vision': 'ViT-So400m, ViT-L',
            'loss': 'Sigmoid binary',
            'strengths': 'Better scaling, batch-size independent'
        },
        'EVA-CLIP': {
            'data': 'Merged datasets',
            'vision': 'EVA ViT (MAE pretrained)',
            'loss': 'Softmax contrastive',
            'strengths': 'Strong SSL initialization'
        }
    }

    for name, info in variants.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")

compare_clip_variants()

Fine-tuning CLIP

PYTHON
class CLIPFineTuner:
    """
    Fine-tune CLIP for downstream tasks.
    """

    def __init__(
        self,
        model: CLIP,
        num_classes: int,
        device: torch.device,
        freeze_vision: bool = False,
        freeze_text: bool = True
    ):
        self.model = model.to(device)
        self.device = device

        # Freeze encoders as needed
        if freeze_vision:
            for param in self.model.visual.parameters():
                param.requires_grad = False

        if freeze_text:
            for param in self.model.transformer.parameters():
                param.requires_grad = False

        # Add classification head
        self.classifier = nn.Linear(
            model.visual.output_dim, num_classes
        ).to(device)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        features = self.model.encode_image(images)
        return self.classifier(features)


class LinearProbe:
    """Linear probe evaluation for CLIP."""

    def __init__(self, model: CLIP, device: torch.device):
        self.model = model.to(device)
        self.device = device

        # Freeze CLIP
        for param in self.model.parameters():
            param.requires_grad = False

    def extract_features(
        self,
        dataloader: torch.utils.data.DataLoader
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract features from dataset."""
        all_features = []
        all_labels = []

        self.model.eval()
        with torch.no_grad():
            for images, labels in dataloader:
                images = images.to(self.device)
                features = self.model.encode_image(images)
                features = F.normalize(features, dim=-1)
                all_features.append(features.cpu())
                all_labels.append(labels)

        return torch.cat(all_features), torch.cat(all_labels)

    def train_linear_classifier(
        self,
        train_features: torch.Tensor,
        train_labels: torch.Tensor,
        num_classes: int,
        epochs: int = 100
    ) -> nn.Linear:
        """Train linear classifier on frozen features."""
        classifier = nn.Linear(train_features.size(1), num_classes)

        optimizer = torch.optim.LBFGS(
            classifier.parameters(),
            lr=0.1,
            max_iter=epochs
        )

        def closure():
            optimizer.zero_grad()
            logits = classifier(train_features)
            loss = F.cross_entropy(logits, train_labels)
            loss.backward()
            return loss

        optimizer.step(closure)

        return classifier


class CoOp(nn.Module):
    """
    CoOp: Context Optimization for Vision-Language Models.

    Learns soft prompt vectors instead of using hand-crafted templates.
    """

    def __init__(
        self,
        clip_model: CLIP,
        num_classes: int,
        n_ctx: int = 16,  # Number of context tokens
        ctx_init: str = ""  # Optional initialization
    ):
        super().__init__()
        self.clip = clip_model
        self.n_ctx = n_ctx

        # Freeze CLIP
        for param in self.clip.parameters():
            param.requires_grad = False

        # Learnable context vectors
        ctx_dim = clip_model.transformer.token_embedding.embedding_dim
        self.ctx = nn.Parameter(torch.randn(n_ctx, ctx_dim))

        # Class name embeddings (frozen)
        self.register_buffer('class_token_ids', torch.zeros(num_classes, 77).long())

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        # Encode images
        image_features = self.clip.encode_image(images)
        image_features = F.normalize(image_features, dim=-1)

        # Build prompts with learned context
        # [ctx1, ctx2, ..., ctxN, class_name, ., EOS]
        prompts = self._build_prompts()

        # Encode text
        text_features = self._encode_prompts(prompts)
        text_features = F.normalize(text_features, dim=-1)

        # Compute logits
        logit_scale = self.clip.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

    def _build_prompts(self) -> torch.Tensor:
        """Build prompts with learned context."""
        # Implementation would combine self.ctx with class tokens
        pass

    def _encode_prompts(self, prompts: torch.Tensor) -> torch.Tensor:
        """Encode prompts through text encoder."""
        pass

Key Takeaways

CLIP demonstrated that contrastive learning on web-scale image-text pairs creates powerful, transferable vision-language representations. The key innovations include: (1) simple symmetric contrastive loss with learnable temperature, (2) separate vision and text encoders aligned in a shared space, (3) zero-shot transfer through natural language supervision, and (4) prompt engineering for downstream tasks. CLIP variants like SigLIP improve scaling through sigmoid loss, while fine-tuning methods like CoOp learn optimal prompts. CLIP embeddings have become foundational for many vision-language applications, from image search to multimodal generation models.

22.3 Image Captioning and Visual QA Advanced

Image Captioning and Visual QA

Image captioning and Visual Question Answering (VQA) are generative vision-language tasks that require models to produce natural language outputs conditioned on visual inputs. Unlike contrastive models that learn to match images and text, these tasks demand the ability to generate coherent, contextually appropriate text descriptions or answers. This section explores the architectures and techniques that enable machines to describe what they see and answer questions about visual content.

Image Captioning Fundamentals

Image captioning generates natural language descriptions of images using encoder-decoder architectures:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
import numpy as np

class ShowAndTell(nn.Module):
    """
    Show and Tell: Neural Image Caption Generator.

    Classic encoder-decoder architecture where CNN encodes the image
    and LSTM decodes the caption one word at a time.
    """

    def __init__(
        self,
        embed_dim: int = 512,
        hidden_dim: int = 512,
        vocab_size: int = 10000,
        num_layers: int = 1,
        max_length: int = 50,
        dropout: float = 0.5
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.max_length = max_length

        # Image encoder (using pretrained CNN features)
        self.image_projection = nn.Sequential(
            nn.Linear(2048, embed_dim),  # From ResNet features
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Word embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # LSTM decoder
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )

        # Output projection
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        image_features: torch.Tensor,
        captions: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Training forward pass with teacher forcing.

        Args:
            image_features: [B, 2048] CNN features
            captions: [B, max_len] tokenized captions
            lengths: [B] actual caption lengths
        """
        batch_size = image_features.size(0)

        # Project image to embedding space
        image_embed = self.image_projection(image_features)  # [B, embed_dim]

        # Embed caption words
        embeddings = self.embedding(captions)  # [B, max_len, embed_dim]

        # Prepend image embedding as first "word"
        embeddings = torch.cat([
            image_embed.unsqueeze(1),
            embeddings[:, :-1]  # Shift right
        ], dim=1)

        # Pack sequence for efficient LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            embeddings, lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        # Decode
        outputs, _ = self.lstm(packed)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)

        # Project to vocabulary
        outputs = self.fc(self.dropout(outputs))

        return outputs

    @torch.no_grad()
    def generate(
        self,
        image_features: torch.Tensor,
        start_token: int = 1,
        end_token: int = 2,
        max_length: Optional[int] = None
    ) -> torch.Tensor:
        """
        Generate captions using greedy decoding.
        """
        max_length = max_length or self.max_length
        batch_size = image_features.size(0)
        device = image_features.device

        # Initialize
        image_embed = self.image_projection(image_features)
        hidden = None

        # Start with image embedding
        input_embed = image_embed.unsqueeze(1)

        generated = []

        for _ in range(max_length):
            output, hidden = self.lstm(input_embed, hidden)
            logits = self.fc(output.squeeze(1))
            predicted = logits.argmax(dim=-1)

            generated.append(predicted)

            # Check for end token
            if (predicted == end_token).all():
                break

            # Next input
            input_embed = self.embedding(predicted).unsqueeze(1)

        return torch.stack(generated, dim=1)

Attention-based Captioning

Show, Attend and Tell introduced visual attention for image captioning:

PYTHON
class SoftAttention(nn.Module):
    """
    Soft attention mechanism for image captioning.

    Allows the decoder to focus on relevant image regions
    when generating each word.
    """

    def __init__(self, encoder_dim: int, decoder_dim: int, attention_dim: int):
        super().__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)

    def forward(
        self,
        encoder_out: torch.Tensor,
        decoder_hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute attention weights and context vector.

        Args:
            encoder_out: [B, num_pixels, encoder_dim]
            decoder_hidden: [B, decoder_dim]

        Returns:
            context: [B, encoder_dim] weighted sum of encoder outputs
            alpha: [B, num_pixels] attention weights
        """
        att1 = self.encoder_att(encoder_out)  # [B, num_pixels, attention_dim]
        att2 = self.decoder_att(decoder_hidden)  # [B, attention_dim]

        att = self.full_att(torch.tanh(att1 + att2.unsqueeze(1)))  # [B, num_pixels, 1]
        alpha = F.softmax(att.squeeze(2), dim=1)  # [B, num_pixels]

        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # [B, encoder_dim]

        return context, alpha


class ShowAttendTell(nn.Module):
    """
    Show, Attend and Tell: Visual Attention for Image Captioning.

    Extends Show and Tell with attention mechanism that learns
    to look at relevant image regions when generating each word.
    """

    def __init__(
        self,
        encoder_dim: int = 2048,
        attention_dim: int = 512,
        embed_dim: int = 512,
        decoder_dim: int = 512,
        vocab_size: int = 10000,
        dropout: float = 0.5
    ):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size

        # Attention mechanism
        self.attention = SoftAttention(encoder_dim, decoder_dim, attention_dim)

        # Word embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # LSTM cell (single step at a time for attention)
        self.lstm_cell = nn.LSTMCell(
            embed_dim + encoder_dim,  # Concatenate embedding and context
            decoder_dim
        )

        # Initialize hidden state from image features
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)

        # Output layers
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # Gate for context
        self.fc = nn.Linear(decoder_dim, vocab_size)

        self.dropout = nn.Dropout(dropout)

    def init_hidden_state(
        self,
        encoder_out: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize LSTM state from mean encoder features."""
        mean_encoder = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder)
        c = self.init_c(mean_encoder)
        return h, c

    def forward(
        self,
        encoder_out: torch.Tensor,
        captions: torch.Tensor,
        lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Training forward pass.

        Args:
            encoder_out: [B, num_pixels, encoder_dim] CNN feature maps
            captions: [B, max_len] tokenized captions
            lengths: [B] caption lengths
        """
        batch_size = encoder_out.size(0)
        num_pixels = encoder_out.size(1)
        device = encoder_out.device

        # Sort by length for packing
        lengths, sort_idx = lengths.sort(descending=True)
        encoder_out = encoder_out[sort_idx]
        captions = captions[sort_idx]

        # Embed captions
        embeddings = self.embedding(captions)  # [B, max_len, embed_dim]

        # Initialize hidden state
        h, c = self.init_hidden_state(encoder_out)

        # Remove <end> token for decoding
        decode_lengths = (lengths - 1).tolist()
        max_decode_length = max(decode_lengths)

        # Tensors to hold outputs
        predictions = torch.zeros(batch_size, max_decode_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, max_decode_length, num_pixels).to(device)

        # Decode word by word
        for t in range(max_decode_length):
            batch_size_t = sum([l > t for l in decode_lengths])

            # Attention
            context, alpha = self.attention(
                encoder_out[:batch_size_t],
                h[:batch_size_t]
            )

            # Gating scalar for context
            gate = torch.sigmoid(self.f_beta(h[:batch_size_t]))
            gated_context = gate * context

            # LSTM step
            h, c = self.lstm_cell(
                torch.cat([embeddings[:batch_size_t, t], gated_context], dim=1),
                (h[:batch_size_t], c[:batch_size_t])
            )

            # Predict
            preds = self.fc(self.dropout(h))

            predictions[:batch_size_t, t] = preds
            alphas[:batch_size_t, t] = alpha

        return predictions, alphas

    @torch.no_grad()
    def beam_search(
        self,
        encoder_out: torch.Tensor,
        beam_size: int = 5,
        max_length: int = 50,
        start_token: int = 1,
        end_token: int = 2
    ) -> List[List[int]]:
        """
        Generate caption using beam search.
        """
        device = encoder_out.device
        k = beam_size

        # Expand encoder output for beam
        encoder_out = encoder_out.expand(k, -1, -1)  # [k, num_pixels, encoder_dim]

        # Initialize
        h, c = self.init_hidden_state(encoder_out)

        # Start tokens
        prev_words = torch.full((k, 1), start_token, dtype=torch.long, device=device)

        # Lists to store completed sequences
        complete_seqs = []
        complete_seqs_scores = []

        # Cumulative scores
        top_k_scores = torch.zeros(k, 1, device=device)

        step = 1

        while step < max_length:
            embeddings = self.embedding(prev_words.squeeze(1))

            context, alpha = self.attention(encoder_out, h)
            gate = torch.sigmoid(self.f_beta(h))
            gated_context = gate * context

            h, c = self.lstm_cell(
                torch.cat([embeddings, gated_context], dim=1),
                (h, c)
            )

            scores = F.log_softmax(self.fc(h), dim=1)
            scores = top_k_scores.expand_as(scores) + scores

            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, dim=0)
            else:
                top_k_scores, top_k_words = scores.view(-1).topk(k, dim=0)

            prev_word_idxs = top_k_words // self.vocab_size
            next_word_idxs = top_k_words % self.vocab_size

            # Check for complete sequences
            incomplete_idxs = []
            for idx, word_idx in enumerate(next_word_idxs):
                if word_idx == end_token:
                    complete_seqs.append(prev_words[prev_word_idxs[idx]].tolist() + [word_idx.item()])
                    complete_seqs_scores.append(top_k_scores[idx].item())
                else:
                    incomplete_idxs.append(idx)

            if len(incomplete_idxs) == 0:
                break

            # Update state for incomplete sequences
            prev_words = torch.cat([
                prev_words[prev_word_idxs[incomplete_idxs]],
                next_word_idxs[incomplete_idxs].unsqueeze(1)
            ], dim=1)

            h = h[prev_word_idxs[incomplete_idxs]]
            c = c[prev_word_idxs[incomplete_idxs]]
            encoder_out = encoder_out[prev_word_idxs[incomplete_idxs]]
            top_k_scores = top_k_scores[incomplete_idxs].unsqueeze(1)

            k = len(incomplete_idxs)
            step += 1

        # Return best sequence
        if complete_seqs:
            best_idx = np.argmax(complete_seqs_scores)
            return complete_seqs[best_idx]
        else:
            return prev_words[0].tolist()

Visual Question Answering

VQA requires understanding both image content and natural language questions:

PYTHON
class VQAModel(nn.Module):
    """
    Visual Question Answering baseline model.

    Fuses image and question features to predict answers
    as multi-class classification.
    """

    def __init__(
        self,
        vocab_size: int = 10000,
        embed_dim: int = 300,
        hidden_dim: int = 1024,
        image_dim: int = 2048,
        num_answers: int = 3129,  # VQA v2 answer vocabulary
        dropout: float = 0.5
    ):
        super().__init__()

        # Question encoder
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.question_lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        # Image projection
        self.image_fc = nn.Linear(image_dim, hidden_dim * 2)

        # Fusion and classification
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim),  # Concatenate image + question
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_answers)
        )

    def forward(
        self,
        image_features: torch.Tensor,
        questions: torch.Tensor,
        question_lengths: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            image_features: [B, image_dim] CNN features
            questions: [B, max_len] tokenized questions
            question_lengths: [B] question lengths
        """
        batch_size = image_features.size(0)

        # Encode question
        question_embed = self.embedding(questions)
        packed = nn.utils.rnn.pack_padded_sequence(
            question_embed,
            question_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        _, (h, _) = self.question_lstm(packed)
        question_features = torch.cat([h[0], h[1]], dim=1)  # [B, hidden_dim*2]

        # Project image features
        image_projected = self.image_fc(image_features)  # [B, hidden_dim*2]

        # Fuse with element-wise multiplication and concatenation
        fused = torch.cat([
            image_projected * question_features,
            image_projected + question_features
        ], dim=1)

        # Classify
        logits = self.fusion(fused)

        return logits


class VQAWithAttention(nn.Module):
    """
    VQA model with visual attention.

    Uses question to attend over image regions for
    better answer prediction.
    """

    def __init__(
        self,
        vocab_size: int = 10000,
        embed_dim: int = 300,
        hidden_dim: int = 1024,
        image_dim: int = 2048,
        num_regions: int = 196,  # 14x14 feature map
        attention_dim: int = 512,
        num_answers: int = 3129,
        dropout: float = 0.5
    ):
        super().__init__()

        # Question encoder
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.question_lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )

        # Image projection
        self.image_proj = nn.Linear(image_dim, hidden_dim)

        # Attention
        self.question_att = nn.Linear(hidden_dim * 2, attention_dim)
        self.image_att = nn.Linear(hidden_dim, attention_dim)
        self.att_fc = nn.Linear(attention_dim, 1)

        # Answer classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_answers)
        )

    def forward(
        self,
        image_features: torch.Tensor,
        questions: torch.Tensor,
        question_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            image_features: [B, num_regions, image_dim] spatial CNN features
            questions: [B, max_len] tokenized questions
            question_lengths: [B] question lengths
        """
        batch_size = image_features.size(0)
        num_regions = image_features.size(1)

        # Encode question
        question_embed = self.embedding(questions)
        packed = nn.utils.rnn.pack_padded_sequence(
            question_embed,
            question_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        _, (h, _) = self.question_lstm(packed)
        question_features = torch.cat([h[0], h[1]], dim=1)  # [B, hidden_dim*2]

        # Project image regions
        image_projected = self.image_proj(image_features)  # [B, num_regions, hidden_dim]

        # Compute attention
        q_att = self.question_att(question_features).unsqueeze(1)  # [B, 1, attention_dim]
        i_att = self.image_att(image_projected)  # [B, num_regions, attention_dim]
        att_scores = self.att_fc(torch.tanh(q_att + i_att)).squeeze(2)  # [B, num_regions]
        att_weights = F.softmax(att_scores, dim=1)

        # Weighted image features
        attended_image = (image_projected * att_weights.unsqueeze(2)).sum(dim=1)  # [B, hidden_dim]

        # Fuse and classify
        fused = torch.cat([attended_image, question_features], dim=1)
        logits = self.classifier(fused)

        return logits, att_weights


class BottomUpTopDown(nn.Module):
    """
    Bottom-Up and Top-Down Attention for VQA.

    Uses object-level features from Faster R-CNN (bottom-up)
    with question-guided attention (top-down).
    """

    def __init__(
        self,
        vocab_size: int = 10000,
        embed_dim: int = 300,
        hidden_dim: int = 1024,
        object_dim: int = 2048,
        num_objects: int = 36,
        glimpses: int = 2,
        num_answers: int = 3129,
        dropout: float = 0.2
    ):
        super().__init__()
        self.glimpses = glimpses

        # Question encoder (GRU)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.question_gru = nn.GRU(
            embed_dim,
            hidden_dim,
            batch_first=True
        )

        # Object features projection
        self.object_proj = nn.Linear(object_dim, hidden_dim)

        # Multi-glimpse attention
        self.attention = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, 1)
            )
            for _ in range(glimpses)
        ])

        # Answer prediction
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * glimpses, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, num_answers)
        )

    def forward(
        self,
        object_features: torch.Tensor,
        questions: torch.Tensor,
        question_lengths: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            object_features: [B, num_objects, object_dim] bottom-up features
            questions: [B, max_len] tokenized questions
            question_lengths: [B] question lengths
        """
        batch_size = object_features.size(0)

        # Encode question
        question_embed = self.embedding(questions)
        packed = nn.utils.rnn.pack_padded_sequence(
            question_embed,
            question_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        _, h = self.question_gru(packed)
        question_features = h.squeeze(0)  # [B, hidden_dim]

        # Project object features
        object_projected = self.object_proj(object_features)  # [B, num_objects, hidden_dim]

        # Multi-glimpse attention
        attended_features = []

        for glimpse in range(self.glimpses):
            # Combine question and object features
            q_expanded = question_features.unsqueeze(1).expand(-1, object_projected.size(1), -1)
            combined = torch.cat([object_projected, q_expanded], dim=2)

            # Attention weights
            att_scores = self.attention[glimpse](combined).squeeze(2)
            att_weights = F.softmax(att_scores, dim=1)

            # Attended features
            attended = (object_projected * att_weights.unsqueeze(2)).sum(dim=1)
            attended_features.append(attended)

        # Concatenate glimpses
        fused = torch.cat(attended_features, dim=1)

        # Classify
        logits = self.classifier(fused)

        return logits

BLIP: Bootstrapped Language-Image Pre-training

BLIP unifies understanding and generation with a multimodal mixture of encoder-decoder:

PYTHON
class BLIPEncoder(nn.Module):
    """
    BLIP visual encoder with cross-attention to text.
    """

    def __init__(
        self,
        vision_dim: int = 768,
        text_dim: int = 768,
        num_heads: int = 12,
        num_layers: int = 12
    ):
        super().__init__()

        # Vision Transformer backbone
        self.vision_encoder = nn.ModuleList([
            TransformerBlock(vision_dim, num_heads)
            for _ in range(num_layers)
        ])

        # Cross-attention layers for image-text fusion
        self.cross_attention = nn.ModuleList([
            CrossAttentionBlock(vision_dim, text_dim, num_heads)
            for _ in range(num_layers // 2)
        ])

    def forward(
        self,
        image_embeds: torch.Tensor,
        text_embeds: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Encode image with optional text conditioning.
        """
        # Self-attention over image
        hidden = image_embeds
        for layer in self.vision_encoder:
            hidden = layer(hidden)

        # Cross-attention with text if provided
        if text_embeds is not None:
            for cross_layer in self.cross_attention:
                hidden = cross_layer(hidden, text_embeds)

        return hidden


class TransformerBlock(nn.Module):
    """Standard transformer block."""

    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x


class CrossAttentionBlock(nn.Module):
    """Cross-attention between two modalities."""

    def __init__(self, query_dim: int, kv_dim: int, num_heads: int):
        super().__init__()
        self.norm_q = nn.LayerNorm(query_dim)
        self.norm_kv = nn.LayerNorm(kv_dim)
        self.cross_attn = nn.MultiheadAttention(
            query_dim, num_heads, kdim=kv_dim, vdim=kv_dim, batch_first=True
        )
        self.norm_ffn = nn.LayerNorm(query_dim)
        self.ffn = nn.Sequential(
            nn.Linear(query_dim, query_dim * 4),
            nn.GELU(),
            nn.Linear(query_dim * 4, query_dim)
        )

    def forward(
        self,
        query: torch.Tensor,
        key_value: torch.Tensor
    ) -> torch.Tensor:
        # Cross-attention
        q = self.norm_q(query)
        kv = self.norm_kv(key_value)
        query = query + self.cross_attn(q, kv, kv)[0]

        # FFN
        query = query + self.ffn(self.norm_ffn(query))

        return query


class BLIP(nn.Module):
    """
    BLIP: Bootstrapped Language-Image Pre-training.

    Multimodal mixture of encoder-decoder supporting:
    - Image-Text Contrastive (ITC)
    - Image-Text Matching (ITM)
    - Image-Captioning (LM)
    """

    def __init__(
        self,
        vision_dim: int = 768,
        text_dim: int = 768,
        embed_dim: int = 256,
        vocab_size: int = 30522,
        max_length: int = 77
    ):
        super().__init__()

        # Vision encoder (ViT)
        self.visual_encoder = BLIPEncoder(vision_dim, text_dim)

        # Text encoder (BERT-like)
        self.text_encoder = TextEncoder(
            vocab_size=vocab_size,
            dim=text_dim,
            max_length=max_length
        )

        # Text decoder (for captioning)
        self.text_decoder = TextDecoder(
            vocab_size=vocab_size,
            dim=text_dim,
            max_length=max_length
        )

        # Projections for contrastive learning
        self.vision_proj = nn.Linear(vision_dim, embed_dim)
        self.text_proj = nn.Linear(text_dim, embed_dim)

        # ITM head
        self.itm_head = nn.Linear(text_dim, 2)

        # Temperature
        self.temp = nn.Parameter(torch.ones([]) * 0.07)

    def forward_itc(
        self,
        image: torch.Tensor,
        text: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Image-Text Contrastive loss forward."""
        # Encode
        image_embeds = self.visual_encoder(image)
        text_embeds = self.text_encoder(text)

        # Use CLS tokens
        image_feat = F.normalize(self.vision_proj(image_embeds[:, 0]), dim=-1)
        text_feat = F.normalize(self.text_proj(text_embeds[:, 0]), dim=-1)

        # Similarity
        sim_i2t = image_feat @ text_feat.t() / self.temp
        sim_t2i = text_feat @ image_feat.t() / self.temp

        return sim_i2t, sim_t2i

    def forward_itm(
        self,
        image: torch.Tensor,
        text: torch.Tensor,
        text_attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Image-Text Matching forward."""
        # Encode image
        image_embeds = self.visual_encoder(image)

        # Encode text with image cross-attention
        text_embeds = self.text_encoder(
            text,
            encoder_hidden_states=image_embeds,
            attention_mask=text_attention_mask
        )

        # Binary classification
        itm_output = self.itm_head(text_embeds[:, 0])

        return itm_output

    def forward_lm(
        self,
        image: torch.Tensor,
        text: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """Language Modeling (captioning) forward."""
        # Encode image
        image_embeds = self.visual_encoder(image)

        # Decode with causal attention
        lm_output = self.text_decoder(
            text,
            encoder_hidden_states=image_embeds,
            labels=labels
        )

        return lm_output

    @torch.no_grad()
    def generate_caption(
        self,
        image: torch.Tensor,
        max_length: int = 30,
        min_length: int = 10,
        num_beams: int = 3
    ) -> List[str]:
        """Generate captions for images."""
        image_embeds = self.visual_encoder(image)

        # Use text decoder for generation
        generated = self.text_decoder.generate(
            encoder_hidden_states=image_embeds,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams
        )

        return generated


class TextEncoder(nn.Module):
    """BERT-style text encoder."""

    def __init__(self, vocab_size: int, dim: int, max_length: int, num_layers: int = 6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.position_embedding = nn.Embedding(max_length, dim)
        self.layers = nn.ModuleList([
            TransformerBlock(dim, dim // 64) for _ in range(num_layers)
        ])
        self.cross_layers = nn.ModuleList([
            CrossAttentionBlock(dim, dim, dim // 64) for _ in range(num_layers // 2)
        ])

    def forward(
        self,
        input_ids: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device)

        x = self.embedding(input_ids) + self.position_embedding(positions)

        for i, layer in enumerate(self.layers):
            x = layer(x)
            if encoder_hidden_states is not None and i % 2 == 1:
                cross_idx = i // 2
                if cross_idx < len(self.cross_layers):
                    x = self.cross_layers[cross_idx](x, encoder_hidden_states)

        return x


class TextDecoder(nn.Module):
    """Causal text decoder for captioning."""

    def __init__(self, vocab_size: int, dim: int, max_length: int, num_layers: int = 6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.position_embedding = nn.Embedding(max_length, dim)
        self.layers = nn.ModuleList([
            CausalTransformerBlock(dim, dim // 64) for _ in range(num_layers)
        ])
        self.cross_layers = nn.ModuleList([
            CrossAttentionBlock(dim, dim, dim // 64) for _ in range(num_layers)
        ])
        self.lm_head = nn.Linear(dim, vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        labels: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device)

        x = self.embedding(input_ids) + self.position_embedding(positions)

        for self_layer, cross_layer in zip(self.layers, self.cross_layers):
            x = self_layer(x)
            x = cross_layer(x, encoder_hidden_states)

        logits = self.lm_head(x)

        if labels is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-100
            )
            return loss

        return logits

    def generate(self, encoder_hidden_states, max_length, min_length, num_beams):
        # Simplified generation - actual implementation would use beam search
        pass


class CausalTransformerBlock(nn.Module):
    """Transformer block with causal attention."""

    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=x.device), diagonal=1
        ).bool()

        x = x + self.attn(
            self.norm1(x), self.norm1(x), self.norm1(x),
            attn_mask=causal_mask
        )[0]
        x = x + self.mlp(self.norm2(x))
        return x

GIT: Generative Image-to-text Transformer

GIT simplifies vision-language generation with a single unified model:

PYTHON
class GIT(nn.Module):
    """
    GIT: Generative Image-to-text Transformer.

    Simple architecture: image encoder + single text decoder
    without complex fusion modules or auxiliary losses.
    """

    def __init__(
        self,
        image_dim: int = 768,
        text_dim: int = 768,
        vocab_size: int = 30522,
        max_length: int = 128,
        num_decoder_layers: int = 6,
        num_heads: int = 12
    ):
        super().__init__()

        # Image encoder (frozen or fine-tuned CLIP/ViT)
        self.image_encoder = VisionEncoder(image_dim)

        # Linear projection from image to text space
        self.image_proj = nn.Linear(image_dim, text_dim)

        # Text embeddings
        self.token_embedding = nn.Embedding(vocab_size, text_dim)
        self.position_embedding = nn.Embedding(max_length, text_dim)

        # Unified decoder that processes [image_tokens, text_tokens]
        self.decoder = nn.ModuleList([
            TransformerBlock(text_dim, num_heads)
            for _ in range(num_decoder_layers)
        ])

        self.ln_f = nn.LayerNorm(text_dim)
        self.lm_head = nn.Linear(text_dim, vocab_size, bias=False)

        self.max_length = max_length
        self.vocab_size = vocab_size

    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass for training.

        Args:
            images: [B, 3, H, W] input images
            input_ids: [B, seq_len] text tokens
            labels: [B, seq_len] target tokens (-100 for image positions)
        """
        batch_size = images.size(0)
        device = images.device

        # Encode and project images
        image_embeds = self.image_encoder(images)  # [B, num_patches, image_dim]
        image_embeds = self.image_proj(image_embeds)  # [B, num_patches, text_dim]
        num_image_tokens = image_embeds.size(1)

        # Embed text tokens
        text_embeds = self.token_embedding(input_ids)  # [B, seq_len, text_dim]

        # Add positional embeddings to text (after image tokens)
        text_positions = torch.arange(
            num_image_tokens,
            num_image_tokens + input_ids.size(1),
            device=device
        )
        text_embeds = text_embeds + self.position_embedding(text_positions)

        # Concatenate [image, text]
        hidden = torch.cat([image_embeds, text_embeds], dim=1)

        # Create causal mask (image tokens can attend to each other,
        # text tokens attend causally to all previous tokens)
        total_len = hidden.size(1)
        causal_mask = self._create_git_attention_mask(
            num_image_tokens, input_ids.size(1), device
        )

        # Process through decoder
        for layer in self.decoder:
            hidden = layer(hidden)

        hidden = self.ln_f(hidden)

        # Only compute logits for text positions
        text_hidden = hidden[:, num_image_tokens:]
        logits = self.lm_head(text_hidden)

        output = {'logits': logits}

        if labels is not None:
            # Shift for next-token prediction
            shift_logits = logits[:, :-1].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            loss = F.cross_entropy(
                shift_logits.view(-1, self.vocab_size),
                shift_labels.view(-1),
                ignore_index=-100
            )
            output['loss'] = loss

        return output

    def _create_git_attention_mask(
        self,
        num_image: int,
        num_text: int,
        device: torch.device
    ) -> torch.Tensor:
        """
        Create attention mask for GIT.

        Image tokens: bidirectional attention among themselves
        Text tokens: causal attention to image + previous text
        """
        total = num_image + num_text
        mask = torch.zeros(total, total, device=device)

        # Text tokens use causal mask
        text_mask = torch.triu(
            torch.ones(num_text, num_text, device=device),
            diagonal=1
        )
        mask[num_image:, num_image:] = text_mask

        return mask.bool()

    @torch.no_grad()
    def generate(
        self,
        images: torch.Tensor,
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        bos_token_id: int = 101,
        eos_token_id: int = 102
    ) -> torch.Tensor:
        """
        Generate text from images using sampling.
        """
        batch_size = images.size(0)
        device = images.device

        # Encode images
        image_embeds = self.image_encoder(images)
        image_embeds = self.image_proj(image_embeds)
        num_image_tokens = image_embeds.size(1)

        # Start with BOS token
        generated = torch.full(
            (batch_size, 1), bos_token_id, dtype=torch.long, device=device
        )

        for _ in range(max_new_tokens):
            # Embed current sequence
            text_embeds = self.token_embedding(generated)
            positions = torch.arange(
                num_image_tokens,
                num_image_tokens + generated.size(1),
                device=device
            )
            text_embeds = text_embeds + self.position_embedding(positions)

            # Concatenate with image
            hidden = torch.cat([image_embeds, text_embeds], dim=1)

            # Forward through decoder
            for layer in self.decoder:
                hidden = layer(hidden)

            hidden = self.ln_f(hidden)

            # Get logits for last position
            logits = self.lm_head(hidden[:, -1]) / temperature

            # Apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Apply top-p (nucleus) filtering
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove
                )
                logits[indices_to_remove] = float('-inf')

            # Sample
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            generated = torch.cat([generated, next_token], dim=1)

            # Stop if all sequences have EOS
            if (next_token == eos_token_id).all():
                break

        return generated


class VisionEncoder(nn.Module):
    """Simple vision encoder wrapper."""

    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        # Placeholder - actual implementation would use ViT/CLIP

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        # Returns [B, num_patches, dim]
        pass

VQA Evaluation and Metrics

PYTHON
class VQAEvaluator:
    """
    VQA evaluation following standard protocols.
    """

    def __init__(self):
        self.contractions = {
            "aint": "ain't", "arent": "aren't", "cant": "can't",
            "couldve": "could've", "couldnt": "couldn't",
            "didnt": "didn't", "doesnt": "doesn't", "dont": "don't",
            "hadnt": "hadn't", "hasnt": "hasn't", "havent": "haven't",
            "hed": "he'd", "hell": "he'll", "hes": "he's",
            "wouldnt": "wouldn't", "youre": "you're", "youve": "you've"
        }

    def normalize_answer(self, answer: str) -> str:
        """Normalize answer for comparison."""
        answer = answer.lower().strip()

        # Handle contractions
        for contraction, expanded in self.contractions.items():
            answer = answer.replace(expanded, contraction)

        # Remove punctuation
        import string
        answer = answer.translate(str.maketrans('', '', string.punctuation))

        # Remove articles
        articles = ['a', 'an', 'the']
        answer = ' '.join(w for w in answer.split() if w not in articles)

        return answer.strip()

    def compute_accuracy(
        self,
        prediction: str,
        ground_truths: List[str]
    ) -> float:
        """
        VQA accuracy: prediction matches if >=3 annotators agree.

        Score = min(1, #humans_that_provided_answer / 3)
        """
        prediction = self.normalize_answer(prediction)
        ground_truths = [self.normalize_answer(gt) for gt in ground_truths]

        # Count matches
        num_matches = sum(1 for gt in ground_truths if gt == prediction)

        return min(1.0, num_matches / 3.0)

    def evaluate_dataset(
        self,
        predictions: Dict[str, str],
        annotations: Dict[str, List[str]]
    ) -> Dict[str, float]:
        """
        Evaluate predictions on VQA dataset.

        Args:
            predictions: {question_id: predicted_answer}
            annotations: {question_id: [annotator_answers]}
        """
        accuracies = []

        for qid, pred in predictions.items():
            if qid in annotations:
                acc = self.compute_accuracy(pred, annotations[qid])
                accuracies.append(acc)

        return {
            'accuracy': np.mean(accuracies) * 100,
            'num_questions': len(accuracies)
        }


def captioning_metrics_example():
    """Common metrics for image captioning."""

    metrics = {
        'BLEU-4': {
            'description': 'Precision of n-gram overlaps (n=1,2,3,4)',
            'range': '0-100',
            'notes': 'Most common, but rewards exact matches'
        },
        'METEOR': {
            'description': 'Harmonic mean of precision and recall with synonyms',
            'range': '0-100',
            'notes': 'Better correlation with human judgment'
        },
        'CIDEr': {
            'description': 'Consensus-based metric using TF-IDF weighting',
            'range': '0-10+',
            'notes': 'Designed specifically for captioning'
        },
        'SPICE': {
            'description': 'Semantic similarity via scene graphs',
            'range': '0-100',
            'notes': 'Captures semantic content better than n-grams'
        },
        'CLIPScore': {
            'description': 'CLIP similarity between image and caption',
            'range': '0-100',
            'notes': 'Reference-free, correlates with human preference'
        }
    }

    print("Image Captioning Evaluation Metrics:")
    print("-" * 50)
    for name, info in metrics.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


captioning_metrics_example()

Key Takeaways

Image captioning and VQA represent foundational generative vision-language tasks that require both visual understanding and language generation capabilities. The evolution from CNN+LSTM architectures (Show and Tell) to attention-based models (Show, Attend and Tell) and finally to transformer-based approaches (BLIP, GIT) mirrors the broader trend in deep learning toward unified, scalable architectures. Attention mechanisms are crucial for grounding language in visual content, whether through soft attention over image regions or cross-attention in transformers. Modern models like BLIP combine multiple objectives (contrastive, matching, generation) for more robust representations, while GIT shows that simple architectures can achieve strong results when scaled with sufficient data. Evaluation remains challenging, with metrics like BLEU and CIDEr capturing surface-level similarity while CLIPScore better reflects semantic quality.

22.4 Multimodal Large Language Models Advanced

Multimodal Large Language Models

Multimodal Large Language Models (MLLMs) represent a paradigm shift in vision-language AI by extending the remarkable capabilities of Large Language Models to understand and reason about visual content. Rather than training vision-language models from scratch, MLLMs leverage pretrained LLMs as powerful reasoning engines and connect them with visual encoders through learned projection modules. This approach enables sophisticated visual understanding, instruction following, and multi-turn conversational abilities about images.

MLLM Architecture Overview

The typical MLLM architecture consists of three components: vision encoder, projection module, and language model:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Union
from dataclasses import dataclass

@dataclass
class MLLMConfig:
    """Configuration for Multimodal LLM."""
    vision_encoder: str = "clip-vit-large"
    llm_model: str = "llama-7b"
    vision_hidden_size: int = 1024
    llm_hidden_size: int = 4096
    projector_type: str = "mlp"  # "linear", "mlp", "qformer"
    num_query_tokens: int = 32  # For Q-Former
    freeze_vision: bool = True
    freeze_llm: bool = False
    image_token: str = "<image>"
    max_length: int = 2048


class MultimodalLLM(nn.Module):
    """
    Generic Multimodal Large Language Model architecture.

    Architecture: Vision Encoder -> Projector -> LLM

    The projector bridges the modality gap between vision features
    and the LLM's embedding space.
    """

    def __init__(self, config: MLLMConfig):
        super().__init__()
        self.config = config

        # Vision encoder (typically frozen)
        self.vision_encoder = self._build_vision_encoder()
        if config.freeze_vision:
            for param in self.vision_encoder.parameters():
                param.requires_grad = False

        # Vision-language projector
        self.projector = self._build_projector()

        # Large Language Model
        self.llm = self._build_llm()
        if config.freeze_llm:
            for param in self.llm.parameters():
                param.requires_grad = False

        # Special tokens
        self.image_token_id = None  # Set during tokenizer init

    def _build_vision_encoder(self) -> nn.Module:
        """Build vision encoder (CLIP ViT, SigLIP, etc.)."""
        # Placeholder - actual implementation loads pretrained model
        return VisionTransformer(
            img_size=224,
            patch_size=14,
            embed_dim=self.config.vision_hidden_size,
            depth=24,
            num_heads=16
        )

    def _build_projector(self) -> nn.Module:
        """Build vision-to-language projector."""
        if self.config.projector_type == "linear":
            return nn.Linear(
                self.config.vision_hidden_size,
                self.config.llm_hidden_size
            )
        elif self.config.projector_type == "mlp":
            return MLPProjector(
                self.config.vision_hidden_size,
                self.config.llm_hidden_size
            )
        elif self.config.projector_type == "qformer":
            return QFormerProjector(
                self.config.vision_hidden_size,
                self.config.llm_hidden_size,
                self.config.num_query_tokens
            )
        else:
            raise ValueError(f"Unknown projector type: {self.config.projector_type}")

    def _build_llm(self) -> nn.Module:
        """Build or load pretrained LLM."""
        # Placeholder - actual implementation loads pretrained model
        return LLMDecoder(
            vocab_size=32000,
            hidden_size=self.config.llm_hidden_size,
            num_layers=32,
            num_heads=32
        )

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        Encode images and project to LLM space.

        Args:
            images: [B, C, H, W] or [B, num_images, C, H, W]

        Returns:
            image_features: [B, num_tokens, llm_hidden_size]
        """
        # Handle multiple images per sample
        if images.dim() == 5:
            batch_size, num_images = images.shape[:2]
            images = images.view(-1, *images.shape[2:])
        else:
            batch_size = images.shape[0]
            num_images = 1

        # Extract vision features
        with torch.no_grad() if self.config.freeze_vision else torch.enable_grad():
            vision_features = self.vision_encoder(images)  # [B*N, num_patches, vision_dim]

        # Project to LLM space
        image_features = self.projector(vision_features)  # [B*N, num_tokens, llm_dim]

        # Reshape for multiple images
        if num_images > 1:
            image_features = image_features.view(
                batch_size, -1, self.config.llm_hidden_size
            )

        return image_features

    def forward(
        self,
        input_ids: torch.Tensor,
        images: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        image_positions: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass for training.

        Args:
            input_ids: [B, seq_len] text token ids
            images: [B, C, H, W] input images
            attention_mask: [B, seq_len] attention mask
            labels: [B, seq_len] targets for language modeling
            image_positions: [B] positions to insert image tokens
        """
        batch_size = input_ids.size(0)
        device = input_ids.device

        # Get text embeddings
        text_embeds = self.llm.get_input_embeddings()(input_ids)

        # Process images if provided
        if images is not None:
            image_features = self.encode_images(images)
            num_image_tokens = image_features.size(1)

            # Insert image features at specified positions
            if image_positions is not None:
                # Replace <image> tokens with actual image features
                inputs_embeds = self._merge_image_text_embeddings(
                    text_embeds, image_features, input_ids, self.image_token_id
                )
            else:
                # Prepend image features
                inputs_embeds = torch.cat([image_features, text_embeds], dim=1)
                if attention_mask is not None:
                    image_attention = torch.ones(
                        batch_size, num_image_tokens, device=device
                    )
                    attention_mask = torch.cat([image_attention, attention_mask], dim=1)
                if labels is not None:
                    # Don't compute loss on image tokens
                    image_labels = torch.full(
                        (batch_size, num_image_tokens), -100, device=device
                    )
                    labels = torch.cat([image_labels, labels], dim=1)
        else:
            inputs_embeds = text_embeds

        # Forward through LLM
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs

    def _merge_image_text_embeddings(
        self,
        text_embeds: torch.Tensor,
        image_features: torch.Tensor,
        input_ids: torch.Tensor,
        image_token_id: int
    ) -> torch.Tensor:
        """
        Replace image token placeholders with actual image features.

        Handles variable-length image tokens per position.
        """
        batch_size, seq_len, hidden_size = text_embeds.shape
        num_image_tokens = image_features.size(1)

        # Find image token positions
        image_mask = input_ids == image_token_id

        # Build new embeddings
        new_embeds = []
        for b in range(batch_size):
            positions = torch.where(image_mask[b])[0]
            if len(positions) == 0:
                new_embeds.append(text_embeds[b])
                continue

            # Split text at image positions and insert image features
            parts = []
            prev_pos = 0
            for i, pos in enumerate(positions):
                # Add text before image
                if pos > prev_pos:
                    parts.append(text_embeds[b, prev_pos:pos])
                # Add image features
                parts.append(image_features[b])
                prev_pos = pos + 1

            # Add remaining text
            if prev_pos < seq_len:
                parts.append(text_embeds[b, prev_pos:])

            new_embeds.append(torch.cat(parts, dim=0))

        # Pad to same length
        max_len = max(e.size(0) for e in new_embeds)
        padded = torch.zeros(batch_size, max_len, hidden_size, device=text_embeds.device)
        for b, emb in enumerate(new_embeds):
            padded[b, :emb.size(0)] = emb

        return padded

    @torch.no_grad()
    def generate(
        self,
        images: torch.Tensor,
        prompts: List[str],
        tokenizer,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> List[str]:
        """
        Generate responses for image-text prompts.
        """
        # Encode images
        image_features = self.encode_images(images)

        # Tokenize prompts
        inputs = tokenizer(prompts, return_tensors="pt", padding=True)
        input_ids = inputs.input_ids.to(images.device)
        attention_mask = inputs.attention_mask.to(images.device)

        # Get embeddings and prepend images
        text_embeds = self.llm.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([image_features, text_embeds], dim=1)

        # Update attention mask
        batch_size = images.size(0)
        num_image_tokens = image_features.size(1)
        image_attention = torch.ones(batch_size, num_image_tokens, device=images.device)
        attention_mask = torch.cat([image_attention, attention_mask], dim=1)

        # Generate
        outputs = self.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=temperature > 0
        )

        # Decode
        responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        return responses

Visual Projectors

Different projector designs trade off between simplicity, expressiveness, and computational cost:

PYTHON
class MLPProjector(nn.Module):
    """
    MLP Projector (LLaVA style).

    Simple but effective: two linear layers with GELU activation.
    Projects each vision token independently.
    """

    def __init__(
        self,
        vision_dim: int,
        llm_dim: int,
        hidden_dim: Optional[int] = None
    ):
        super().__init__()
        hidden_dim = hidden_dim or llm_dim

        self.proj = nn.Sequential(
            nn.Linear(vision_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, llm_dim)
        )

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: [B, num_patches, vision_dim]

        Returns:
            projected: [B, num_patches, llm_dim]
        """
        return self.proj(vision_features)


class QFormerProjector(nn.Module):
    """
    Q-Former Projector (BLIP-2 style).

    Uses learnable query tokens to extract relevant visual information.
    More parameter-efficient but adds computational overhead.
    """

    def __init__(
        self,
        vision_dim: int,
        llm_dim: int,
        num_query_tokens: int = 32,
        num_layers: int = 6,
        num_heads: int = 12
    ):
        super().__init__()
        self.num_query_tokens = num_query_tokens

        # Learnable query tokens
        self.query_tokens = nn.Parameter(
            torch.randn(1, num_query_tokens, vision_dim) * 0.02
        )

        # Cross-attention layers
        self.layers = nn.ModuleList([
            QFormerLayer(vision_dim, num_heads)
            for _ in range(num_layers)
        ])

        # Final projection to LLM dimension
        self.projection = nn.Linear(vision_dim, llm_dim)

        # Layer norm
        self.ln = nn.LayerNorm(vision_dim)

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: [B, num_patches, vision_dim]

        Returns:
            query_output: [B, num_query_tokens, llm_dim]
        """
        batch_size = vision_features.size(0)

        # Expand query tokens for batch
        queries = self.query_tokens.expand(batch_size, -1, -1)

        # Process through Q-Former layers
        for layer in self.layers:
            queries = layer(queries, vision_features)

        # Final projection
        queries = self.ln(queries)
        output = self.projection(queries)

        return output


class QFormerLayer(nn.Module):
    """Single Q-Former layer with self-attention and cross-attention."""

    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        # Self-attention among queries
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.self_ln = nn.LayerNorm(dim)

        # Cross-attention to vision features
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.cross_ln = nn.LayerNorm(dim)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        self.ffn_ln = nn.LayerNorm(dim)

    def forward(
        self,
        queries: torch.Tensor,
        vision_features: torch.Tensor
    ) -> torch.Tensor:
        # Self-attention
        q = self.self_ln(queries)
        queries = queries + self.self_attn(q, q, q)[0]

        # Cross-attention
        q = self.cross_ln(queries)
        queries = queries + self.cross_attn(q, vision_features, vision_features)[0]

        # FFN
        queries = queries + self.ffn(self.ffn_ln(queries))

        return queries


class ResamplerProjector(nn.Module):
    """
    Perceiver Resampler Projector (Flamingo style).

    Compresses variable-length vision features to fixed-length output
    using cross-attention with learnable latent queries.
    """

    def __init__(
        self,
        vision_dim: int,
        llm_dim: int,
        num_latents: int = 64,
        num_layers: int = 6,
        num_heads: int = 8
    ):
        super().__init__()

        # Learnable latent queries
        self.latents = nn.Parameter(torch.randn(num_latents, llm_dim))

        # Cross-attention layers with vision features
        self.layers = nn.ModuleList([
            PerceiverResamplerLayer(llm_dim, vision_dim, num_heads)
            for _ in range(num_layers)
        ])

        # Project vision features to latent dimension
        self.vision_proj = nn.Linear(vision_dim, llm_dim)

    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            vision_features: [B, num_patches, vision_dim]

        Returns:
            latent_output: [B, num_latents, llm_dim]
        """
        batch_size = vision_features.size(0)

        # Project vision features
        vision_features = self.vision_proj(vision_features)

        # Expand latents for batch
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)

        # Process through resampler layers
        for layer in self.layers:
            latents = layer(latents, vision_features)

        return latents


class PerceiverResamplerLayer(nn.Module):
    """Single Perceiver Resampler layer."""

    def __init__(self, latent_dim: int, kv_dim: int, num_heads: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(latent_dim)
        self.cross_attn = nn.MultiheadAttention(
            latent_dim, num_heads,
            kdim=kv_dim, vdim=kv_dim,
            batch_first=True
        )
        self.ln2 = nn.LayerNorm(latent_dim)
        self.ffn = nn.Sequential(
            nn.Linear(latent_dim, latent_dim * 4),
            nn.GELU(),
            nn.Linear(latent_dim * 4, latent_dim)
        )

    def forward(
        self,
        latents: torch.Tensor,
        kv: torch.Tensor
    ) -> torch.Tensor:
        latents = latents + self.cross_attn(
            self.ln1(latents), kv, kv
        )[0]
        latents = latents + self.ffn(self.ln2(latents))
        return latents


def compare_projectors():
    """Compare different visual projector approaches."""
    projectors = {
        'Linear': {
            'params': 'V × L',
            'tokens': 'Same as vision',
            'pros': 'Simplest, fastest',
            'cons': 'Limited capacity',
            'used_by': 'Early MLLMs'
        },
        'MLP (LLaVA)': {
            'params': 'V × H + H × L',
            'tokens': 'Same as vision (e.g., 576)',
            'pros': 'Simple, effective, maintains all info',
            'cons': 'Many tokens for LLM',
            'used_by': 'LLaVA, LLaVA-1.5'
        },
        'Q-Former (BLIP-2)': {
            'params': '~100M',
            'tokens': 'Fixed (e.g., 32)',
            'pros': 'Compresses to few tokens',
            'cons': 'Complex, information bottleneck',
            'used_by': 'BLIP-2, InstructBLIP'
        },
        'Resampler (Flamingo)': {
            'params': '~50M',
            'tokens': 'Fixed (e.g., 64)',
            'pros': 'Flexible compression',
            'cons': 'Moderate complexity',
            'used_by': 'Flamingo, IDEFICS'
        },
        'C-Abstractor': {
            'params': 'Small',
            'tokens': 'Adaptive',
            'pros': 'Content-aware compression',
            'cons': 'More complex training',
            'used_by': 'Honeybee'
        }
    }

    print("Visual Projector Comparison:")
    print("=" * 60)
    for name, info in projectors.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")

LLaVA Architecture

LLaVA (Large Language and Vision Assistant) introduced a simple and effective MLLM design:

PYTHON
class LLaVA(nn.Module):
    """
    LLaVA: Large Language and Vision Assistant.

    Simple design: CLIP ViT + MLP Projector + Vicuna/LLaMA LLM
    Two-stage training:
    1. Pre-training: Image-text pairs (freeze LLM)
    2. Fine-tuning: Visual instruction data (unfreeze LLM)
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        llm: nn.Module,
        vision_dim: int = 1024,
        llm_dim: int = 4096,
        mm_hidden_dim: int = 4096
    ):
        super().__init__()

        self.vision_encoder = vision_encoder
        self.llm = llm

        # Two-layer MLP projector
        self.mm_projector = nn.Sequential(
            nn.Linear(vision_dim, mm_hidden_dim),
            nn.GELU(),
            nn.Linear(mm_hidden_dim, llm_dim)
        )

        # Freeze vision encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        """Extract and project image features."""
        with torch.no_grad():
            # Get CLIP image features (excluding CLS token if present)
            image_features = self.vision_encoder(images)
            if image_features.size(1) == 257:  # Has CLS token
                image_features = image_features[:, 1:]  # Remove CLS

        # Project to LLM dimension
        image_features = self.mm_projector(image_features)

        return image_features

    def prepare_inputs_labels_for_multimodal(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
        images: torch.Tensor,
        image_token_index: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Prepare inputs by inserting image features at image token positions.
        """
        batch_size = input_ids.size(0)
        device = input_ids.device

        # Get image features
        image_features = self.encode_images(images)
        num_image_tokens = image_features.size(1)

        # Get text embeddings
        embed_tokens = self.llm.get_input_embeddings()
        text_embeds = embed_tokens(input_ids)

        # Process each sample
        new_input_embeds = []
        new_labels = []
        new_attention_mask = []

        for batch_idx in range(batch_size):
            cur_input_ids = input_ids[batch_idx]
            cur_labels = labels[batch_idx]
            cur_attention_mask = attention_mask[batch_idx]
            cur_text_embeds = text_embeds[batch_idx]

            # Find image token positions
            image_token_indices = torch.where(
                cur_input_ids == image_token_index
            )[0]

            if len(image_token_indices) == 0:
                # No image token, use text only
                new_input_embeds.append(cur_text_embeds)
                new_labels.append(cur_labels)
                new_attention_mask.append(cur_attention_mask)
                continue

            # Split and insert image features
            cur_image_features = image_features[batch_idx]

            # Build new sequence
            segments_embeds = []
            segments_labels = []
            segments_mask = []

            prev_idx = 0
            for i, image_idx in enumerate(image_token_indices):
                # Text before image
                if image_idx > prev_idx:
                    segments_embeds.append(cur_text_embeds[prev_idx:image_idx])
                    segments_labels.append(cur_labels[prev_idx:image_idx])
                    segments_mask.append(cur_attention_mask[prev_idx:image_idx])

                # Image features (labels = -100 to ignore in loss)
                segments_embeds.append(cur_image_features)
                segments_labels.append(
                    torch.full((num_image_tokens,), -100, device=device, dtype=labels.dtype)
                )
                segments_mask.append(
                    torch.ones(num_image_tokens, device=device, dtype=attention_mask.dtype)
                )

                prev_idx = image_idx + 1

            # Text after last image
            if prev_idx < len(cur_input_ids):
                segments_embeds.append(cur_text_embeds[prev_idx:])
                segments_labels.append(cur_labels[prev_idx:])
                segments_mask.append(cur_attention_mask[prev_idx:])

            # Concatenate segments
            new_input_embeds.append(torch.cat(segments_embeds, dim=0))
            new_labels.append(torch.cat(segments_labels, dim=0))
            new_attention_mask.append(torch.cat(segments_mask, dim=0))

        # Pad sequences
        max_len = max(e.size(0) for e in new_input_embeds)

        padded_embeds = torch.zeros(
            batch_size, max_len, new_input_embeds[0].size(-1),
            device=device, dtype=new_input_embeds[0].dtype
        )
        padded_labels = torch.full(
            (batch_size, max_len), -100, device=device, dtype=labels.dtype
        )
        padded_mask = torch.zeros(
            batch_size, max_len, device=device, dtype=attention_mask.dtype
        )

        for i, (emb, lab, msk) in enumerate(zip(
            new_input_embeds, new_labels, new_attention_mask
        )):
            length = emb.size(0)
            padded_embeds[i, :length] = emb
            padded_labels[i, :length] = lab
            padded_mask[i, :length] = msk

        return padded_embeds, padded_mask, padded_labels, None

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        images: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        image_token_index: int = -200
    ) -> Dict[str, torch.Tensor]:
        """Forward pass for training."""

        if images is not None:
            inputs_embeds, attention_mask, labels, _ = \
                self.prepare_inputs_labels_for_multimodal(
                    input_ids, attention_mask, labels, images, image_token_index
                )
            input_ids = None
        else:
            inputs_embeds = None

        outputs = self.llm(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )

        return outputs


class LLaVATrainer:
    """Training pipeline for LLaVA."""

    def __init__(
        self,
        model: LLaVA,
        optimizer: torch.optim.Optimizer,
        device: torch.device
    ):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device

    def pretrain_step(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor
    ) -> float:
        """
        Stage 1: Pre-training on image-caption pairs.

        Only trains the projector, LLM is frozen.
        """
        # Freeze LLM
        for param in self.model.llm.parameters():
            param.requires_grad = False

        self.model.train()

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            images=images,
            labels=labels
        )

        loss = outputs.loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def finetune_step(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor
    ) -> float:
        """
        Stage 2: Fine-tuning on visual instructions.

        Trains both projector and LLM.
        """
        # Unfreeze LLM
        for param in self.model.llm.parameters():
            param.requires_grad = True

        self.model.train()

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            images=images,
            labels=labels
        )

        loss = outputs.loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()


def llava_training_config():
    """LLaVA training hyperparameters."""
    config = {
        'stage1_pretrain': {
            'data': '558K image-caption pairs (CC3M subset)',
            'epochs': 1,
            'batch_size': 256,
            'lr': 1e-3,
            'trainable': 'projector only',
            'warmup_ratio': 0.03
        },
        'stage2_finetune': {
            'data': '665K visual instruction data',
            'epochs': 1,
            'batch_size': 128,
            'lr': 2e-5,
            'trainable': 'projector + LLM',
            'warmup_ratio': 0.03
        }
    }

    print("LLaVA Two-Stage Training:")
    for stage, params in config.items():
        print(f"\n{stage}:")
        for k, v in params.items():
            print(f"  {k}: {v}")

    return config

Visual Instruction Tuning

Visual instruction tuning enables MLLMs to follow diverse user instructions:

PYTHON
class VisualInstructionDataset:
    """
    Dataset for visual instruction tuning.

    Formats multimodal conversations as LLM training data.
    """

    def __init__(
        self,
        data_path: str,
        tokenizer,
        image_processor,
        max_length: int = 2048,
        image_token: str = "<image>"
    ):
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.image_token = image_token

        # Load data
        self.data = self._load_data(data_path)

    def _load_data(self, path: str) -> List[Dict]:
        """Load instruction data from JSON."""
        import json
        with open(path, 'r') as f:
            return json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.data[idx]

        # Load and process image
        image = self._load_image(sample['image'])
        image_tensor = self.image_processor(image)

        # Format conversation
        conversation = sample['conversations']
        text = self._format_conversation(conversation)

        # Tokenize
        tokens = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )

        # Create labels (mask user turns)
        labels = self._create_labels(tokens.input_ids, conversation)

        return {
            'input_ids': tokens.input_ids.squeeze(),
            'attention_mask': tokens.attention_mask.squeeze(),
            'labels': labels.squeeze(),
            'images': image_tensor
        }

    def _format_conversation(self, conversation: List[Dict]) -> str:
        """
        Format multi-turn conversation for training.

        Example format:
        <image>\nUSER: Describe this image.\nASSISTANT: This is...
        """
        formatted = ""

        for i, turn in enumerate(conversation):
            role = turn['from']
            content = turn['value']

            if i == 0 and self.image_token not in content:
                # Add image token at the start if not present
                content = f"{self.image_token}\n{content}"

            if role == 'human':
                formatted += f"USER: {content}\n"
            elif role == 'gpt':
                formatted += f"ASSISTANT: {content}\n"

        return formatted.strip()

    def _create_labels(
        self,
        input_ids: torch.Tensor,
        conversation: List[Dict]
    ) -> torch.Tensor:
        """
        Create labels with user turns masked (-100).

        Only compute loss on assistant responses.
        """
        labels = input_ids.clone()

        # Find assistant response boundaries and mask everything else
        # Simplified - actual implementation tracks token positions carefully
        # to mask user inputs and special tokens

        return labels

    def _load_image(self, image_path: str):
        """Load image from path or URL."""
        from PIL import Image
        return Image.open(image_path).convert('RGB')


class InstructionDataFormats:
    """Common instruction data formats for MLLMs."""

    @staticmethod
    def llava_format():
        """LLaVA instruction format."""
        example = {
            "id": "unique_id",
            "image": "path/to/image.jpg",
            "conversations": [
                {
                    "from": "human",
                    "value": "<image>\nWhat is shown in this image?"
                },
                {
                    "from": "gpt",
                    "value": "The image shows a cat sitting on a windowsill..."
                },
                {
                    "from": "human",
                    "value": "What color is the cat?"
                },
                {
                    "from": "gpt",
                    "value": "The cat appears to be orange and white..."
                }
            ]
        }
        return example

    @staticmethod
    def sharegpt_format():
        """ShareGPT-style format (multi-turn)."""
        example = {
            "id": "unique_id",
            "images": ["path/to/image1.jpg", "path/to/image2.jpg"],
            "conversations": [
                {"role": "user", "content": "<img>What's in the first image?"},
                {"role": "assistant", "content": "I can see..."},
                {"role": "user", "content": "<img>Compare it with this second image."},
                {"role": "assistant", "content": "Comparing the two images..."}
            ]
        }
        return example

    @staticmethod
    def instruction_types():
        """Categories of visual instructions."""
        types = {
            'detailed_description': 'Describe this image in detail.',
            'conversation': 'Multi-turn QA about image content',
            'complex_reasoning': 'What might happen next in this scene?',
            'grounding': 'Where is the red car in this image?',
            'ocr': 'What text can you read in this image?',
            'chart_understanding': 'Summarize the data shown in this chart.',
            'document_qa': 'Answer questions about document content',
            'creative': 'Write a story based on this image.'
        }
        return types


def create_instruction_data_pipeline():
    """Pipeline for creating visual instruction data."""

    pipeline_steps = {
        '1. Seed Data': {
            'sources': ['COCO', 'Visual Genome', 'TextCaps'],
            'description': 'Base image-caption pairs'
        },
        '2. GPT-4 Augmentation': {
            'method': 'Prompt GPT-4 with image captions/boxes',
            'output': 'Diverse instruction-response pairs',
            'description': 'Generate varied questions and detailed answers'
        },
        '3. Quality Filtering': {
            'checks': ['Response length', 'Relevance', 'Factuality'],
            'description': 'Remove low-quality or hallucinated responses'
        },
        '4. Diversity Balancing': {
            'method': 'Cluster by instruction type',
            'description': 'Ensure coverage of different capabilities'
        },
        '5. Format Conversion': {
            'output': 'Standardized conversation format',
            'description': 'Convert to training-ready JSON'
        }
    }

    print("Visual Instruction Data Pipeline:")
    for step, info in pipeline_steps.items():
        print(f"\n{step}:")
        for k, v in info.items():
            if isinstance(v, list):
                print(f"  {k}: {', '.join(v)}")
            else:
                print(f"  {k}: {v}")

BLIP-2 and InstructBLIP

BLIP-2 introduced efficient vision-language pre-training with frozen models:

PYTHON
class BLIP2(nn.Module):
    """
    BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and LLMs.

    Uses Q-Former to bridge frozen vision encoder and frozen LLM.
    """

    def __init__(
        self,
        vision_encoder: nn.Module,
        llm: nn.Module,
        vision_dim: int = 1408,  # EVA-CLIP ViT-G
        llm_dim: int = 4096,
        num_query_tokens: int = 32,
        cross_attention_freq: int = 2
    ):
        super().__init__()

        # Frozen vision encoder
        self.vision_encoder = vision_encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

        # Frozen LLM
        self.llm = llm
        for param in self.llm.parameters():
            param.requires_grad = False

        # Q-Former (trainable)
        self.qformer = QFormer(
            vision_dim=vision_dim,
            hidden_dim=768,
            num_query_tokens=num_query_tokens,
            cross_attention_freq=cross_attention_freq
        )

        # Project Q-Former output to LLM dimension
        self.llm_proj = nn.Linear(768, llm_dim)

    def forward_stage1(
        self,
        images: torch.Tensor,
        text_input_ids: torch.Tensor,
        text_attention_mask: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Stage 1: Vision-Language Representation Learning.

        Three losses:
        1. Image-Text Contrastive (ITC)
        2. Image-Text Matching (ITM)
        3. Image-grounded Text Generation (ITG)
        """
        # Get image features
        with torch.no_grad():
            image_features = self.vision_encoder(images)

        # Q-Former forward with all three objectives
        outputs = self.qformer.forward_stage1(
            image_features=image_features,
            text_input_ids=text_input_ids,
            text_attention_mask=text_attention_mask
        )

        return outputs

    def forward_stage2(
        self,
        images: torch.Tensor,
        text_input_ids: torch.Tensor,
        labels: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Stage 2: Vision-to-Language Generative Learning.

        Connect Q-Former to LLM for generation.
        """
        # Get image features
        with torch.no_grad():
            image_features = self.vision_encoder(images)

        # Get Q-Former output (visual tokens for LLM)
        query_output = self.qformer.extract_features(image_features)

        # Project to LLM dimension
        visual_tokens = self.llm_proj(query_output)

        # Get text embeddings
        with torch.no_grad():
            text_embeds = self.llm.get_input_embeddings()(text_input_ids)

        # Prepend visual tokens
        inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)

        # Prepare attention mask and labels
        batch_size = images.size(0)
        num_visual_tokens = visual_tokens.size(1)

        visual_attention = torch.ones(
            batch_size, num_visual_tokens, device=images.device
        )
        attention_mask = torch.cat([
            visual_attention,
            torch.ones_like(text_input_ids)
        ], dim=1)

        # Labels: -100 for visual tokens
        visual_labels = torch.full(
            (batch_size, num_visual_tokens), -100, device=labels.device
        )
        labels = torch.cat([visual_labels, labels], dim=1)

        # Forward through LLM
        with torch.no_grad():
            outputs = self.llm(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels
            )

        return outputs

    @torch.no_grad()
    def generate(
        self,
        images: torch.Tensor,
        prompt: str,
        tokenizer,
        max_new_tokens: int = 100
    ) -> str:
        """Generate text from image and prompt."""
        # Get visual tokens
        image_features = self.vision_encoder(images)
        query_output = self.qformer.extract_features(image_features)
        visual_tokens = self.llm_proj(query_output)

        # Tokenize prompt
        prompt_tokens = tokenizer(prompt, return_tensors='pt')
        prompt_embeds = self.llm.get_input_embeddings()(
            prompt_tokens.input_ids.to(images.device)
        )

        # Combine
        inputs_embeds = torch.cat([visual_tokens, prompt_embeds], dim=1)

        # Generate
        outputs = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=max_new_tokens
        )

        return tokenizer.decode(outputs[0], skip_special_tokens=True)


class QFormer(nn.Module):
    """
    Q-Former: Querying Transformer for BLIP-2.

    Uses learnable queries to extract visual features and align with text.
    """

    def __init__(
        self,
        vision_dim: int,
        hidden_dim: int,
        num_query_tokens: int,
        num_layers: int = 12,
        num_heads: int = 12,
        cross_attention_freq: int = 2
    ):
        super().__init__()
        self.num_query_tokens = num_query_tokens

        # Learnable query tokens
        self.query_tokens = nn.Parameter(
            torch.randn(1, num_query_tokens, hidden_dim) * 0.02
        )

        # BERT-like encoder with cross-attention
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            layer = QFormerBlock(
                hidden_dim=hidden_dim,
                vision_dim=vision_dim,
                num_heads=num_heads,
                has_cross_attention=(i % cross_attention_freq == 0)
            )
            self.layers.append(layer)

        # Projection heads
        self.vision_proj = nn.Linear(hidden_dim, 256)
        self.text_proj = nn.Linear(hidden_dim, 256)
        self.itm_head = nn.Linear(hidden_dim, 2)

        # Temperature for ITC
        self.temp = nn.Parameter(torch.ones([]) * 0.07)

    def extract_features(
        self,
        image_features: torch.Tensor
    ) -> torch.Tensor:
        """Extract visual features using queries."""
        batch_size = image_features.size(0)
        queries = self.query_tokens.expand(batch_size, -1, -1)

        for layer in self.layers:
            queries = layer(queries, image_features)

        return queries

    def forward_stage1(
        self,
        image_features: torch.Tensor,
        text_input_ids: torch.Tensor,
        text_attention_mask: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Stage 1 forward with ITC, ITM, and ITG losses.
        """
        batch_size = image_features.size(0)

        # Image-Text Contrastive
        image_query_output = self.extract_features(image_features)
        image_feats = F.normalize(
            self.vision_proj(image_query_output[:, 0]), dim=-1
        )

        # Text encoding (using queries as additional context)
        text_feats = F.normalize(
            self.text_proj(image_query_output[:, 0]), dim=-1  # Simplified
        )

        # ITC loss
        sim_i2t = image_feats @ text_feats.t() / self.temp
        sim_t2i = text_feats @ image_feats.t() / self.temp
        labels = torch.arange(batch_size, device=image_features.device)
        loss_itc = (
            F.cross_entropy(sim_i2t, labels) +
            F.cross_entropy(sim_t2i, labels)
        ) / 2

        # ITM loss (binary classification)
        itm_logits = self.itm_head(image_query_output[:, 0])
        # Need positive and negative pairs - simplified here
        itm_labels = torch.ones(batch_size, dtype=torch.long, device=image_features.device)
        loss_itm = F.cross_entropy(itm_logits, itm_labels)

        # ITG loss - causal language modeling
        # Simplified - actual implementation is more complex
        loss_itg = torch.tensor(0.0, device=image_features.device)

        return {
            'loss': loss_itc + loss_itm + loss_itg,
            'loss_itc': loss_itc,
            'loss_itm': loss_itm,
            'loss_itg': loss_itg
        }


class QFormerBlock(nn.Module):
    """Q-Former block with optional cross-attention."""

    def __init__(
        self,
        hidden_dim: int,
        vision_dim: int,
        num_heads: int,
        has_cross_attention: bool
    ):
        super().__init__()
        self.has_cross_attention = has_cross_attention

        # Self-attention
        self.self_attn = nn.MultiheadAttention(
            hidden_dim, num_heads, batch_first=True
        )
        self.self_ln = nn.LayerNorm(hidden_dim)

        # Cross-attention (optional)
        if has_cross_attention:
            self.cross_attn = nn.MultiheadAttention(
                hidden_dim, num_heads,
                kdim=vision_dim, vdim=vision_dim,
                batch_first=True
            )
            self.cross_ln = nn.LayerNorm(hidden_dim)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        self.ffn_ln = nn.LayerNorm(hidden_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Self-attention
        residual = hidden_states
        hidden_states = self.self_ln(hidden_states)
        hidden_states = self.self_attn(
            hidden_states, hidden_states, hidden_states
        )[0]
        hidden_states = residual + hidden_states

        # Cross-attention
        if self.has_cross_attention and encoder_hidden_states is not None:
            residual = hidden_states
            hidden_states = self.cross_ln(hidden_states)
            hidden_states = self.cross_attn(
                hidden_states, encoder_hidden_states, encoder_hidden_states
            )[0]
            hidden_states = residual + hidden_states

        # FFN
        residual = hidden_states
        hidden_states = self.ffn_ln(hidden_states)
        hidden_states = self.ffn(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

Modern MLLM Landscape

PYTHON
def mllm_comparison():
    """Compare modern Multimodal LLMs."""
    models = {
        'LLaVA-1.5': {
            'vision': 'CLIP ViT-L/14 (336px)',
            'llm': 'Vicuna-13B',
            'projector': '2-layer MLP',
            'training': '558K pretrain + 665K finetune',
            'strengths': 'Simple, strong baseline'
        },
        'BLIP-2': {
            'vision': 'EVA-CLIP ViT-G',
            'llm': 'FlanT5-XXL or OPT-6.7B',
            'projector': 'Q-Former (32 queries)',
            'training': '129M images, 2-stage',
            'strengths': 'Parameter efficient'
        },
        'InstructBLIP': {
            'vision': 'EVA-CLIP ViT-G',
            'llm': 'Vicuna-13B',
            'projector': 'Instruction-aware Q-Former',
            'training': 'BLIP-2 + instruction tuning',
            'strengths': 'Strong instruction following'
        },
        'Qwen-VL': {
            'vision': 'OpenCLIP ViT-G (448px)',
            'llm': 'Qwen-7B',
            'projector': 'Single-layer cross-attention',
            'training': '1.4B image-text + multi-task',
            'strengths': 'Multi-image, grounding, OCR'
        },
        'LLaVA-NeXT': {
            'vision': 'CLIP ViT-L (672px AnyRes)',
            'llm': 'Various (7B-110B)',
            'projector': '2-layer MLP',
            'training': 'Dynamic high-resolution',
            'strengths': 'High-res, efficient scaling'
        },
        'GPT-4V': {
            'vision': 'Unknown',
            'llm': 'GPT-4',
            'projector': 'Unknown',
            'training': 'Unknown',
            'strengths': 'Best overall performance (closed)'
        }
    }

    print("Multimodal LLM Comparison:")
    print("=" * 70)
    for name, info in models.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


def mllm_capabilities():
    """Capabilities enabled by modern MLLMs."""
    capabilities = {
        'Visual Understanding': [
            'Scene description',
            'Object recognition and counting',
            'Attribute identification',
            'Spatial relationship understanding'
        ],
        'Visual Reasoning': [
            'Multi-step reasoning about images',
            'Commonsense inference',
            'Cause-effect understanding',
            'Hypothetical reasoning'
        ],
        'Document/Chart Understanding': [
            'OCR and text extraction',
            'Chart and graph interpretation',
            'Table understanding',
            'Document QA'
        ],
        'Grounding': [
            'Object localization',
            'Region description',
            'Referring expression comprehension',
            'Visual pointing'
        ],
        'Multi-Image': [
            'Image comparison',
            'Temporal reasoning (before/after)',
            'Multi-view understanding',
            'Story understanding'
        ],
        'Creative Tasks': [
            'Image-based storytelling',
            'Advertisement copy generation',
            'Meme explanation',
            'Visual humor understanding'
        ]
    }

    print("\nMLLM Capabilities:")
    for category, items in capabilities.items():
        print(f"\n{category}:")
        for item in items:
            print(f"  - {item}")


mllm_comparison()
print("\n" + "=" * 70)
mllm_capabilities()

Key Takeaways

Multimodal Large Language Models represent a powerful paradigm for vision-language AI that leverages pretrained LLMs as the backbone for multimodal reasoning. The key architectural choices include: (1) vision encoder selection (CLIP, SigLIP, or custom), (2) projector design (MLP for simplicity vs. Q-Former for efficiency), and (3) LLM backbone (trading off capability vs. computational cost). The two-stage training approach—first aligning vision-language representations, then instruction tuning—has proven highly effective. Modern MLLMs like LLaVA, BLIP-2, and their successors demonstrate that combining frozen pretrained components with lightweight trainable modules enables strong performance while being computationally efficient. The field continues to advance rapidly, with improvements in high-resolution processing, multi-image understanding, and specialized capabilities like document analysis and visual grounding.

22.5 Vision-Language Applications Advanced

Vision-Language Applications

Vision-language models have enabled a wide range of practical applications that bridge visual understanding and natural language. From semantic image search to accessibility tools, these applications demonstrate how combining vision and language creates systems that are more intuitive, capable, and aligned with human communication. This section explores key application domains and provides implementation patterns for building vision-language powered systems.

CLIP-based embeddings enable searching images using natural language queries:

PYTHON
import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Union
from dataclasses import dataclass
from pathlib import Path
import pickle

@dataclass
class SearchResult:
    """Search result with image path and score."""
    image_path: str
    score: float
    metadata: Optional[Dict] = None


class SemanticImageSearch:
    """
    Semantic image search using CLIP embeddings.

    Enables natural language queries like:
    - "a sunset over mountains"
    - "person wearing a red hat"
    - "modern kitchen interior"
    """

    def __init__(
        self,
        clip_model,
        clip_processor,
        device: torch.device,
        embedding_dim: int = 512
    ):
        self.model = clip_model.to(device)
        self.processor = clip_processor
        self.device = device
        self.embedding_dim = embedding_dim

        # Index storage
        self.image_embeddings = None
        self.image_paths = []
        self.metadata = []

        self.model.eval()

    @torch.no_grad()
    def encode_images(
        self,
        images: List,
        batch_size: int = 32
    ) -> torch.Tensor:
        """Encode multiple images to embeddings."""
        all_embeddings = []

        for i in range(0, len(images), batch_size):
            batch = images[i:i + batch_size]
            inputs = self.processor(images=batch, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            embeddings = self.model.get_image_features(**inputs)
            embeddings = F.normalize(embeddings, dim=-1)
            all_embeddings.append(embeddings.cpu())

        return torch.cat(all_embeddings, dim=0)

    @torch.no_grad()
    def encode_text(self, queries: List[str]) -> torch.Tensor:
        """Encode text queries to embeddings."""
        inputs = self.processor(text=queries, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        embeddings = self.model.get_text_features(**inputs)
        embeddings = F.normalize(embeddings, dim=-1)

        return embeddings.cpu()

    def build_index(
        self,
        image_paths: List[str],
        metadata: Optional[List[Dict]] = None,
        batch_size: int = 32
    ):
        """
        Build search index from images.

        Args:
            image_paths: List of paths to images
            metadata: Optional metadata for each image
            batch_size: Batch size for encoding
        """
        from PIL import Image

        self.image_paths = image_paths
        self.metadata = metadata or [{}] * len(image_paths)

        # Load and encode images
        images = []
        valid_indices = []

        for i, path in enumerate(image_paths):
            try:
                img = Image.open(path).convert('RGB')
                images.append(img)
                valid_indices.append(i)
            except Exception as e:
                print(f"Failed to load {path}: {e}")

        # Update paths to only valid ones
        self.image_paths = [image_paths[i] for i in valid_indices]
        self.metadata = [self.metadata[i] for i in valid_indices]

        # Encode
        self.image_embeddings = self.encode_images(images, batch_size)

        print(f"Indexed {len(self.image_paths)} images")

    def search(
        self,
        query: str,
        top_k: int = 10,
        filter_fn: Optional[callable] = None
    ) -> List[SearchResult]:
        """
        Search images using natural language query.

        Args:
            query: Natural language search query
            top_k: Number of results to return
            filter_fn: Optional filter function on metadata

        Returns:
            List of SearchResult objects
        """
        # Encode query
        query_embedding = self.encode_text([query])[0]

        # Compute similarities
        similarities = query_embedding @ self.image_embeddings.t()

        # Apply filter if provided
        if filter_fn is not None:
            mask = torch.tensor([
                filter_fn(m) for m in self.metadata
            ], dtype=torch.bool)
            similarities = similarities.masked_fill(~mask, -float('inf'))

        # Get top-k
        scores, indices = similarities.topk(min(top_k, len(self.image_paths)))

        results = []
        for score, idx in zip(scores.tolist(), indices.tolist()):
            results.append(SearchResult(
                image_path=self.image_paths[idx],
                score=score,
                metadata=self.metadata[idx]
            ))

        return results

    def search_by_image(
        self,
        image,
        top_k: int = 10,
        exclude_self: bool = True
    ) -> List[SearchResult]:
        """
        Find similar images to a query image.
        """
        # Encode query image
        query_embedding = self.encode_images([image])[0]

        # Compute similarities
        similarities = query_embedding @ self.image_embeddings.t()

        # Get top-k
        k = top_k + 1 if exclude_self else top_k
        scores, indices = similarities.topk(min(k, len(self.image_paths)))

        results = []
        for score, idx in zip(scores.tolist(), indices.tolist()):
            # Skip near-perfect matches if excluding self
            if exclude_self and score > 0.99:
                continue
            results.append(SearchResult(
                image_path=self.image_paths[idx],
                score=score,
                metadata=self.metadata[idx]
            ))

        return results[:top_k]

    def hybrid_search(
        self,
        text_query: Optional[str] = None,
        image_query = None,
        text_weight: float = 0.5,
        top_k: int = 10
    ) -> List[SearchResult]:
        """
        Hybrid search combining text and image queries.
        """
        embeddings = []
        weights = []

        if text_query:
            text_emb = self.encode_text([text_query])[0]
            embeddings.append(text_emb)
            weights.append(text_weight)

        if image_query:
            image_emb = self.encode_images([image_query])[0]
            embeddings.append(image_emb)
            weights.append(1 - text_weight)

        # Weighted combination
        combined = sum(w * e for w, e in zip(weights, embeddings))
        combined = F.normalize(combined, dim=0)

        # Search
        similarities = combined @ self.image_embeddings.t()
        scores, indices = similarities.topk(min(top_k, len(self.image_paths)))

        return [
            SearchResult(
                image_path=self.image_paths[idx],
                score=score,
                metadata=self.metadata[idx]
            )
            for score, idx in zip(scores.tolist(), indices.tolist())
        ]

    def save_index(self, path: str):
        """Save index to disk."""
        data = {
            'embeddings': self.image_embeddings.numpy(),
            'paths': self.image_paths,
            'metadata': self.metadata
        }
        with open(path, 'wb') as f:
            pickle.dump(data, f)

    def load_index(self, path: str):
        """Load index from disk."""
        with open(path, 'rb') as f:
            data = pickle.load(f)
        self.image_embeddings = torch.from_numpy(data['embeddings'])
        self.image_paths = data['paths']
        self.metadata = data['metadata']

Content Moderation

Vision-language models enable flexible, natural language-defined content policies:

PYTHON
class ContentModerator:
    """
    Vision-language based content moderation.

    Uses CLIP to classify images against policy definitions
    specified in natural language.
    """

    def __init__(
        self,
        clip_model,
        clip_processor,
        device: torch.device
    ):
        self.model = clip_model.to(device)
        self.processor = clip_processor
        self.device = device
        self.model.eval()

        # Default policy categories
        self.policies = {}
        self.policy_embeddings = {}

    def add_policy(
        self,
        category: str,
        positive_descriptions: List[str],
        negative_descriptions: List[str],
        threshold: float = 0.3
    ):
        """
        Add a moderation policy defined in natural language.

        Args:
            category: Policy category name (e.g., "violence", "nsfw")
            positive_descriptions: Text describing violating content
            negative_descriptions: Text describing acceptable content
            threshold: Classification threshold
        """
        self.policies[category] = {
            'positive': positive_descriptions,
            'negative': negative_descriptions,
            'threshold': threshold
        }

        # Pre-compute text embeddings
        with torch.no_grad():
            pos_inputs = self.processor(
                text=positive_descriptions,
                return_tensors="pt",
                padding=True
            )
            pos_inputs = {k: v.to(self.device) for k, v in pos_inputs.items()}
            pos_embeds = self.model.get_text_features(**pos_inputs)
            pos_embeds = F.normalize(pos_embeds, dim=-1)

            neg_inputs = self.processor(
                text=negative_descriptions,
                return_tensors="pt",
                padding=True
            )
            neg_inputs = {k: v.to(self.device) for k, v in neg_inputs.items()}
            neg_embeds = self.model.get_text_features(**neg_inputs)
            neg_embeds = F.normalize(neg_embeds, dim=-1)

            self.policy_embeddings[category] = {
                'positive': pos_embeds,
                'negative': neg_embeds
            }

    def setup_default_policies(self):
        """Setup common moderation policies."""

        self.add_policy(
            category='violence',
            positive_descriptions=[
                'violent imagery',
                'graphic violence',
                'blood and gore',
                'weapons being used to harm',
                'physical assault'
            ],
            negative_descriptions=[
                'peaceful scene',
                'normal everyday activity',
                'people interacting normally',
                'safe content'
            ],
            threshold=0.25
        )

        self.add_policy(
            category='hate_symbols',
            positive_descriptions=[
                'hate symbols',
                'extremist imagery',
                'discriminatory symbols',
                'offensive gestures'
            ],
            negative_descriptions=[
                'normal symbols',
                'everyday objects',
                'neutral imagery'
            ],
            threshold=0.3
        )

        self.add_policy(
            category='dangerous_activities',
            positive_descriptions=[
                'dangerous stunts',
                'self-harm',
                'reckless behavior',
                'illegal drug use'
            ],
            negative_descriptions=[
                'safe activities',
                'normal behavior',
                'everyday activities'
            ],
            threshold=0.28
        )

    @torch.no_grad()
    def moderate(
        self,
        image,
        categories: Optional[List[str]] = None
    ) -> Dict[str, Dict]:
        """
        Check image against moderation policies.

        Args:
            image: PIL Image or path
            categories: Specific categories to check (default: all)

        Returns:
            Dictionary with results for each category
        """
        from PIL import Image

        if isinstance(image, str):
            image = Image.open(image).convert('RGB')

        # Encode image
        inputs = self.processor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        image_embed = self.model.get_image_features(**inputs)
        image_embed = F.normalize(image_embed, dim=-1)

        # Check against policies
        categories = categories or list(self.policies.keys())
        results = {}

        for category in categories:
            if category not in self.policies:
                continue

            policy = self.policies[category]
            embeddings = self.policy_embeddings[category]

            # Similarity to positive (violating) descriptions
            pos_sim = (image_embed @ embeddings['positive'].t()).max().item()

            # Similarity to negative (acceptable) descriptions
            neg_sim = (image_embed @ embeddings['negative'].t()).max().item()

            # Score: higher means more likely to violate
            score = pos_sim - neg_sim

            results[category] = {
                'score': score,
                'flagged': score > policy['threshold'],
                'confidence': min(abs(score) / policy['threshold'], 1.0),
                'positive_similarity': pos_sim,
                'negative_similarity': neg_sim
            }

        return results

    def batch_moderate(
        self,
        images: List,
        categories: Optional[List[str]] = None,
        batch_size: int = 32
    ) -> List[Dict]:
        """Moderate multiple images efficiently."""
        results = []

        for i in range(0, len(images), batch_size):
            batch = images[i:i + batch_size]
            batch_results = [
                self.moderate(img, categories)
                for img in batch
            ]
            results.extend(batch_results)

        return results


class AdaptiveContentFilter:
    """
    Adaptive content filter that learns from feedback.
    """

    def __init__(self, base_moderator: ContentModerator):
        self.moderator = base_moderator
        self.feedback_history = []

    def record_feedback(
        self,
        image_embedding: torch.Tensor,
        category: str,
        was_correct: bool,
        actual_label: bool
    ):
        """Record human feedback for improving thresholds."""
        self.feedback_history.append({
            'embedding': image_embedding,
            'category': category,
            'was_correct': was_correct,
            'actual_label': actual_label
        })

    def update_thresholds(self):
        """Update thresholds based on accumulated feedback."""
        from collections import defaultdict

        category_feedback = defaultdict(list)
        for fb in self.feedback_history:
            category_feedback[fb['category']].append(fb)

        for category, feedbacks in category_feedback.items():
            if category not in self.moderator.policies:
                continue

            # Analyze false positives and negatives
            false_positives = [f for f in feedbacks if not f['actual_label'] and not f['was_correct']]
            false_negatives = [f for f in feedbacks if f['actual_label'] and not f['was_correct']]

            # Adjust threshold
            current = self.moderator.policies[category]['threshold']
            if len(false_positives) > len(false_negatives) * 2:
                # Too many false positives, increase threshold
                self.moderator.policies[category]['threshold'] = current * 1.1
            elif len(false_negatives) > len(false_positives) * 2:
                # Too many false negatives, decrease threshold
                self.moderator.policies[category]['threshold'] = current * 0.9

Automated Alt-Text Generation

Vision-language models can generate accessibility descriptions:

PYTHON
class AltTextGenerator:
    """
    Generate alt-text descriptions for images.

    Optimized for accessibility: concise, informative, and
    focused on content relevant to understanding.
    """

    def __init__(self, mllm_model, tokenizer, device: torch.device):
        self.model = mllm_model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def generate_alt_text(
        self,
        image,
        context: Optional[str] = None,
        max_length: int = 100,
        style: str = "concise"
    ) -> str:
        """
        Generate alt-text for an image.

        Args:
            image: PIL Image
            context: Optional context about where the image appears
            max_length: Maximum length of description
            style: "concise", "detailed", or "functional"

        Returns:
            Alt-text description
        """
        if style == "concise":
            prompt = "Describe this image in one brief sentence for a visually impaired user. Focus on the main subject and action."
        elif style == "detailed":
            prompt = "Provide a detailed description of this image for someone who cannot see it. Include the main subjects, their positions, colors, and any text visible."
        elif style == "functional":
            prompt = "What is the purpose of this image? Describe what information it conveys."
        else:
            prompt = "Describe this image."

        if context:
            prompt += f" Context: This image appears in {context}."

        # Generate
        response = self.model.generate(
            image=image,
            prompt=prompt,
            max_new_tokens=max_length,
            temperature=0.3  # Lower temperature for factual descriptions
        )

        # Post-process
        alt_text = self._postprocess(response, max_length)

        return alt_text

    def _postprocess(self, text: str, max_length: int) -> str:
        """Clean and truncate alt-text."""
        # Remove common prefixes
        prefixes_to_remove = [
            "This image shows",
            "The image shows",
            "In this image,",
            "This is an image of",
            "I can see"
        ]
        for prefix in prefixes_to_remove:
            if text.lower().startswith(prefix.lower()):
                text = text[len(prefix):].strip()

        # Capitalize first letter
        if text:
            text = text[0].upper() + text[1:]

        # Ensure it ends with punctuation
        if text and text[-1] not in '.!?':
            text += '.'

        # Truncate if too long
        if len(text) > max_length:
            # Try to truncate at sentence boundary
            sentences = text.split('. ')
            truncated = ""
            for sent in sentences:
                if len(truncated) + len(sent) + 2 <= max_length:
                    truncated += sent + ". "
                else:
                    break
            text = truncated.strip() or text[:max_length-3] + "..."

        return text

    def batch_generate(
        self,
        images: List,
        contexts: Optional[List[str]] = None,
        style: str = "concise"
    ) -> List[str]:
        """Generate alt-text for multiple images."""
        contexts = contexts or [None] * len(images)

        return [
            self.generate_alt_text(img, ctx, style=style)
            for img, ctx in zip(images, contexts)
        ]

    def generate_structured_description(
        self,
        image
    ) -> Dict[str, str]:
        """
        Generate structured accessibility information.
        """
        prompts = {
            'brief': "In 10 words or less, what is this image?",
            'main_subject': "What is the main subject of this image?",
            'colors': "What are the dominant colors in this image?",
            'text_content': "Is there any text in this image? If so, what does it say?",
            'mood': "What mood or atmosphere does this image convey?"
        }

        results = {}
        for key, prompt in prompts.items():
            response = self.model.generate(
                image=image,
                prompt=prompt,
                max_new_tokens=50,
                temperature=0.3
            )
            results[key] = response.strip()

        return results


class AccessibilityChecker:
    """Check and improve image accessibility."""

    def __init__(self, alt_text_generator: AltTextGenerator):
        self.generator = alt_text_generator

    def analyze_alt_text(self, existing_alt: str, image) -> Dict:
        """
        Analyze quality of existing alt-text.
        """
        issues = []
        suggestions = []

        # Check length
        if len(existing_alt) < 10:
            issues.append("Alt-text is too short")
        elif len(existing_alt) > 200:
            issues.append("Alt-text may be too long")

        # Check for common bad patterns
        bad_patterns = [
            ("image", "Avoid starting with 'image'"),
            ("picture", "Avoid starting with 'picture'"),
            (".jpg", "Don't include file extensions"),
            (".png", "Don't include file extensions"),
            ("img_", "Don't include file naming conventions"),
            ("DSC", "Don't include camera file names")
        ]

        alt_lower = existing_alt.lower()
        for pattern, suggestion in bad_patterns:
            if pattern in alt_lower:
                issues.append(suggestion)

        # Generate improved version
        improved = self.generator.generate_alt_text(image)

        return {
            'existing': existing_alt,
            'issues': issues,
            'improved': improved,
            'quality_score': max(0, 100 - len(issues) * 20)
        }

E-commerce applications for visual shopping:

PYTHON
class VisualProductSearch:
    """
    Visual search for e-commerce products.

    Enables "shop the look" and similar product discovery.
    """

    def __init__(
        self,
        clip_model,
        clip_processor,
        device: torch.device
    ):
        self.model = clip_model.to(device)
        self.processor = clip_processor
        self.device = device
        self.model.eval()

        # Product index
        self.products = []
        self.product_embeddings = None

    def index_products(
        self,
        products: List[Dict],
        batch_size: int = 32
    ):
        """
        Index products with images and metadata.

        Args:
            products: List of dicts with 'image', 'name', 'category', 'price', etc.
        """
        self.products = products

        # Encode product images
        embeddings = []
        for i in range(0, len(products), batch_size):
            batch = products[i:i + batch_size]
            images = [p['image'] for p in batch]

            inputs = self.processor(images=images, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            with torch.no_grad():
                batch_embeds = self.model.get_image_features(**inputs)
                batch_embeds = F.normalize(batch_embeds, dim=-1)
                embeddings.append(batch_embeds.cpu())

        self.product_embeddings = torch.cat(embeddings, dim=0)

    def search_by_image(
        self,
        query_image,
        top_k: int = 20,
        category_filter: Optional[str] = None,
        price_range: Optional[Tuple[float, float]] = None
    ) -> List[Dict]:
        """
        Find similar products to query image.
        """
        # Encode query
        inputs = self.processor(images=query_image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            query_embed = self.model.get_image_features(**inputs)
            query_embed = F.normalize(query_embed, dim=-1).cpu()

        # Compute similarities
        similarities = (query_embed @ self.product_embeddings.t()).squeeze()

        # Apply filters
        valid_mask = torch.ones(len(self.products), dtype=torch.bool)

        if category_filter:
            for i, p in enumerate(self.products):
                if p.get('category') != category_filter:
                    valid_mask[i] = False

        if price_range:
            min_price, max_price = price_range
            for i, p in enumerate(self.products):
                price = p.get('price', 0)
                if price < min_price or price > max_price:
                    valid_mask[i] = False

        similarities = similarities.masked_fill(~valid_mask, -float('inf'))

        # Get top-k
        scores, indices = similarities.topk(min(top_k, valid_mask.sum()))

        results = []
        for score, idx in zip(scores.tolist(), indices.tolist()):
            product = self.products[idx].copy()
            product['similarity_score'] = score
            results.append(product)

        return results

    def search_by_text_and_image(
        self,
        query_image,
        text_query: str,
        image_weight: float = 0.7,
        top_k: int = 20
    ) -> List[Dict]:
        """
        Combined visual and text search.

        Example: Upload a dress photo and search for "similar but in blue"
        """
        # Encode image
        img_inputs = self.processor(images=query_image, return_tensors="pt")
        img_inputs = {k: v.to(self.device) for k, v in img_inputs.items()}

        # Encode text
        txt_inputs = self.processor(text=text_query, return_tensors="pt", padding=True)
        txt_inputs = {k: v.to(self.device) for k, v in txt_inputs.items()}

        with torch.no_grad():
            img_embed = self.model.get_image_features(**img_inputs)
            img_embed = F.normalize(img_embed, dim=-1).cpu()

            txt_embed = self.model.get_text_features(**txt_inputs)
            txt_embed = F.normalize(txt_embed, dim=-1).cpu()

        # Combine embeddings
        combined = image_weight * img_embed + (1 - image_weight) * txt_embed
        combined = F.normalize(combined, dim=-1)

        # Search
        similarities = (combined @ self.product_embeddings.t()).squeeze()
        scores, indices = similarities.topk(min(top_k, len(self.products)))

        results = []
        for score, idx in zip(scores.tolist(), indices.tolist()):
            product = self.products[idx].copy()
            product['similarity_score'] = score
            results.append(product)

        return results

    def complete_the_look(
        self,
        anchor_image,
        anchor_category: str,
        complement_categories: List[str],
        top_k_per_category: int = 5
    ) -> Dict[str, List[Dict]]:
        """
        Find complementary products to complete an outfit.
        """
        results = {}

        # Encode anchor
        inputs = self.processor(images=anchor_image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            anchor_embed = self.model.get_image_features(**inputs)
            anchor_embed = F.normalize(anchor_embed, dim=-1).cpu()

        # Find complementary items for each category
        for category in complement_categories:
            # Filter by category
            category_indices = [
                i for i, p in enumerate(self.products)
                if p.get('category') == category
            ]

            if not category_indices:
                continue

            category_embeddings = self.product_embeddings[category_indices]

            # Find matches
            similarities = (anchor_embed @ category_embeddings.t()).squeeze()
            k = min(top_k_per_category, len(category_indices))
            scores, indices = similarities.topk(k)

            category_results = []
            for score, idx in zip(scores.tolist(), indices.tolist()):
                product = self.products[category_indices[idx]].copy()
                product['compatibility_score'] = score
                category_results.append(product)

            results[category] = category_results

        return results

Document Understanding

Vision-language models for document analysis:

PYTHON
class DocumentAnalyzer:
    """
    Document understanding using vision-language models.

    Handles invoices, receipts, forms, and general documents.
    """

    def __init__(self, mllm_model, device: torch.device):
        self.model = mllm_model.to(device)
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def extract_information(
        self,
        document_image,
        schema: Dict[str, str]
    ) -> Dict[str, str]:
        """
        Extract structured information from document.

        Args:
            document_image: Image of document
            schema: Dict mapping field names to descriptions

        Example schema:
            {
                'vendor': 'Name of the vendor or store',
                'date': 'Date of transaction',
                'total': 'Total amount',
                'items': 'List of items purchased'
            }
        """
        results = {}

        for field, description in schema.items():
            prompt = f"Looking at this document, extract the {field} ({description}). Respond with only the extracted value, nothing else."

            response = self.model.generate(
                image=document_image,
                prompt=prompt,
                max_new_tokens=100,
                temperature=0.1
            )

            results[field] = response.strip()

        return results

    def analyze_receipt(self, receipt_image) -> Dict:
        """Extract information from receipt."""
        schema = {
            'store_name': 'Name of the store',
            'date': 'Date of purchase (format: YYYY-MM-DD)',
            'total': 'Total amount paid',
            'payment_method': 'How payment was made (cash/card/etc)',
            'items': 'List of items with prices'
        }

        return self.extract_information(receipt_image, schema)

    def analyze_invoice(self, invoice_image) -> Dict:
        """Extract information from invoice."""
        schema = {
            'invoice_number': 'Invoice number or ID',
            'date': 'Invoice date',
            'due_date': 'Payment due date',
            'vendor': 'Company name of the vendor',
            'amount_due': 'Total amount due',
            'line_items': 'Description and amount of each line item'
        }

        return self.extract_information(invoice_image, schema)

    def answer_document_question(
        self,
        document_image,
        question: str
    ) -> str:
        """
        Answer free-form questions about document.
        """
        prompt = f"Based on this document, answer the following question:\n{question}"

        response = self.model.generate(
            image=document_image,
            prompt=prompt,
            max_new_tokens=200,
            temperature=0.3
        )

        return response.strip()

    def summarize_document(self, document_image) -> str:
        """Generate summary of document content."""
        prompt = "Provide a brief summary of this document. What type of document is it and what are the key details?"

        response = self.model.generate(
            image=document_image,
            prompt=prompt,
            max_new_tokens=300,
            temperature=0.5
        )

        return response.strip()

    def verify_document(
        self,
        document_image,
        expected_values: Dict[str, str]
    ) -> Dict[str, Dict]:
        """
        Verify document contains expected values.

        Returns match status for each field.
        """
        extracted = self.extract_information(
            document_image,
            {k: f"the {k}" for k in expected_values.keys()}
        )

        results = {}
        for field, expected in expected_values.items():
            actual = extracted.get(field, "")
            match = expected.lower() in actual.lower()
            results[field] = {
                'expected': expected,
                'extracted': actual,
                'match': match
            }

        return results

Visual Grounding

Locating objects in images using natural language:

PYTHON
class VisualGrounder:
    """
    Visual grounding: locate objects in images using natural language.

    Maps referring expressions to bounding boxes.
    """

    def __init__(self, grounding_model, device: torch.device):
        self.model = grounding_model.to(device)
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def ground(
        self,
        image,
        expression: str
    ) -> List[Dict]:
        """
        Ground a referring expression to bounding boxes.

        Args:
            image: Input image
            expression: Natural language referring expression

        Returns:
            List of dicts with 'box' (x1, y1, x2, y2) and 'confidence'
        """
        # Forward through model
        outputs = self.model(image, expression)

        results = []
        for box, confidence in zip(outputs['boxes'], outputs['scores']):
            if confidence > 0.5:
                results.append({
                    'box': box.tolist(),  # [x1, y1, x2, y2]
                    'confidence': confidence.item()
                })

        return results

    def ground_multiple(
        self,
        image,
        expressions: List[str]
    ) -> Dict[str, List[Dict]]:
        """
        Ground multiple expressions in single image.
        """
        results = {}
        for expr in expressions:
            results[expr] = self.ground(image, expr)
        return results

    def interactive_segmentation(
        self,
        image,
        expression: str
    ) -> Dict:
        """
        Get segmentation mask for referred object.
        """
        # Ground expression
        boxes = self.ground(image, expression)

        if not boxes:
            return {'mask': None, 'expression': expression}

        # Get best box
        best_box = max(boxes, key=lambda x: x['confidence'])

        # Generate mask (would use SAM or similar)
        # Placeholder
        mask = self._generate_mask(image, best_box['box'])

        return {
            'mask': mask,
            'box': best_box['box'],
            'confidence': best_box['confidence'],
            'expression': expression
        }

    def _generate_mask(self, image, box):
        """Generate segmentation mask from box."""
        # Would integrate with SAM
        pass


class ReferringExpressionGenerator:
    """
    Generate referring expressions for objects in images.
    """

    def __init__(self, mllm_model, device: torch.device):
        self.model = mllm_model.to(device)
        self.device = device

    @torch.no_grad()
    def generate_expression(
        self,
        image,
        box: List[float],
        context: str = ""
    ) -> str:
        """
        Generate referring expression for boxed region.

        Args:
            image: Input image
            box: [x1, y1, x2, y2] normalized coordinates
            context: Additional context

        Returns:
            Referring expression that uniquely identifies the object
        """
        # Crop or highlight region
        prompt = f"Generate a short phrase that would uniquely identify the object in the highlighted region of this image. The region is at coordinates {box}."

        if context:
            prompt += f" Context: {context}"

        response = self.model.generate(
            image=image,
            prompt=prompt,
            max_new_tokens=50,
            temperature=0.5
        )

        return response.strip()

Application Integration Patterns

PYTHON
class VisionLanguageAPI:
    """
    Unified API for vision-language applications.
    """

    def __init__(self, config: Dict):
        self.config = config
        self._init_models()

    def _init_models(self):
        """Initialize required models based on config."""
        self.search_engine = None
        self.moderator = None
        self.alt_text_gen = None
        self.doc_analyzer = None

        # Initialize based on enabled features
        if self.config.get('search_enabled'):
            self.search_engine = SemanticImageSearch(...)

        if self.config.get('moderation_enabled'):
            self.moderator = ContentModerator(...)
            self.moderator.setup_default_policies()

        if self.config.get('accessibility_enabled'):
            self.alt_text_gen = AltTextGenerator(...)

        if self.config.get('document_enabled'):
            self.doc_analyzer = DocumentAnalyzer(...)

    async def process_image(
        self,
        image,
        tasks: List[str]
    ) -> Dict[str, any]:
        """
        Process image for multiple tasks in parallel.
        """
        import asyncio

        results = {}
        task_coroutines = []

        for task in tasks:
            if task == 'moderate' and self.moderator:
                task_coroutines.append(
                    self._async_moderate(image)
                )
            elif task == 'alt_text' and self.alt_text_gen:
                task_coroutines.append(
                    self._async_alt_text(image)
                )
            elif task == 'search_similar' and self.search_engine:
                task_coroutines.append(
                    self._async_search(image)
                )

        # Run in parallel
        task_results = await asyncio.gather(*task_coroutines)

        for task, result in zip(tasks, task_results):
            results[task] = result

        return results

    async def _async_moderate(self, image):
        """Async wrapper for moderation."""
        return self.moderator.moderate(image)

    async def _async_alt_text(self, image):
        """Async wrapper for alt-text generation."""
        return self.alt_text_gen.generate_alt_text(image)

    async def _async_search(self, image):
        """Async wrapper for similar image search."""
        return self.search_engine.search_by_image(image)


def deployment_best_practices():
    """Best practices for deploying vision-language applications."""

    practices = {
        'Model Serving': {
            'batch_requests': 'Batch multiple requests for GPU efficiency',
            'caching': 'Cache embeddings for repeated images',
            'quantization': 'Use INT8/FP16 for faster inference',
            'model_sharding': 'Distribute large models across GPUs'
        },
        'Scaling': {
            'horizontal': 'Multiple replicas behind load balancer',
            'async_processing': 'Queue long-running tasks',
            'cdn_integration': 'Cache static results at edge'
        },
        'Monitoring': {
            'latency_tracking': 'Monitor p50/p99 latencies',
            'quality_metrics': 'Track output quality over time',
            'drift_detection': 'Detect data distribution shifts',
            'feedback_loops': 'Collect user feedback for improvement'
        },
        'Safety': {
            'rate_limiting': 'Prevent abuse',
            'input_validation': 'Validate image formats and sizes',
            'output_filtering': 'Secondary check on generated content',
            'audit_logging': 'Log requests for compliance'
        }
    }

    print("Vision-Language Deployment Best Practices:")
    print("=" * 50)
    for category, items in practices.items():
        print(f"\n{category}:")
        for k, v in items.items():
            print(f"  {k}: {v}")

    return practices


deployment_best_practices()

Key Takeaways

Vision-language applications transform how we interact with visual content through natural language interfaces. Key application areas include: (1) semantic search enabling intuitive image discovery with text queries, (2) content moderation with flexible, language-defined policies, (3) accessibility tools generating alt-text and descriptions, (4) e-commerce visual search for product discovery, (5) document understanding for automated information extraction, and (6) visual grounding connecting language to image regions. Successful deployment requires attention to latency optimization through batching and caching, quality monitoring with human feedback loops, and safety measures including content filtering and rate limiting. The versatility of vision-language models allows a single foundation model to power multiple applications, making them increasingly central to modern AI systems that need to understand and communicate about the visual world.