Advanced Expert 150 min read

Chapter 24: Agentic AI Systems

Agent architecture, tool use, multi-agent systems, and safety.

Libraries covered: LangChain

Learning Objectives

["Build AI agents", "Implement tool use", "Design multi-agent systems"]


24.1 Introduction to RAG Systems Intermediate

Introduction to RAG Systems

Retrieval-Augmented Generation (RAG) combines the power of large language models with external knowledge retrieval, enabling AI systems to access and reason over vast document collections without requiring model retraining. RAG addresses fundamental LLM limitations—hallucination, knowledge cutoffs, and lack of domain specificity—by grounding responses in retrieved evidence. This section introduces RAG architecture, embedding models, vector databases, and the core components that form modern knowledge-augmented AI systems.

The RAG Paradigm

RAG bridges parametric and non-parametric knowledge:

PYTHON
import numpy as np
from typing import List, Dict, Optional, Tuple, Any
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import hashlib
import json

@dataclass
class Document:
    """Represents a document or chunk in the RAG system."""
    content: str
    metadata: Dict[str, Any] = field(default_factory=dict)
    doc_id: Optional[str] = None
    embedding: Optional[np.ndarray] = None

    def __post_init__(self):
        if self.doc_id is None:
            # Generate deterministic ID from content
            self.doc_id = hashlib.md5(self.content.encode()).hexdigest()[:16]


@dataclass
class RetrievalResult:
    """Result from retrieval operation."""
    document: Document
    score: float
    rank: int


class RAGSystem:
    """
    Retrieval-Augmented Generation System.

    Architecture:
    1. Indexing Pipeline: Documents → Chunks → Embeddings → Vector Store
    2. Retrieval Pipeline: Query → Embedding → Search → Rerank → Top-K
    3. Generation Pipeline: Query + Context → LLM → Response

    Key Benefits:
    - Access to current/private knowledge without retraining
    - Reduced hallucination through grounding
    - Transparent sourcing with citations
    - Cost-effective knowledge updates
    """

    def __init__(
        self,
        embedding_model: 'EmbeddingModel',
        vector_store: 'VectorStore',
        llm: 'LanguageModel',
        chunker: 'TextChunker',
        top_k: int = 5
    ):
        self.embedding_model = embedding_model
        self.vector_store = vector_store
        self.llm = llm
        self.chunker = chunker
        self.top_k = top_k

    def index_documents(self, documents: List[Document]) -> int:
        """
        Index documents into the vector store.

        Pipeline:
        1. Chunk documents into smaller pieces
        2. Generate embeddings for each chunk
        3. Store in vector database
        """
        all_chunks = []

        for doc in documents:
            # Chunk the document
            chunks = self.chunker.chunk(doc.content, doc.metadata)
            all_chunks.extend(chunks)

        # Generate embeddings
        contents = [chunk.content for chunk in all_chunks]
        embeddings = self.embedding_model.embed_documents(contents)

        # Attach embeddings to chunks
        for chunk, embedding in zip(all_chunks, embeddings):
            chunk.embedding = embedding

        # Store in vector database
        self.vector_store.add(all_chunks)

        return len(all_chunks)

    def query(
        self,
        question: str,
        filters: Optional[Dict[str, Any]] = None
    ) -> Tuple[str, List[RetrievalResult]]:
        """
        Answer a question using RAG.

        Pipeline:
        1. Embed the query
        2. Retrieve relevant documents
        3. Generate answer with context
        """
        # Retrieve relevant context
        results = self.retrieve(question, filters)

        # Build context string
        context = self._build_context(results)

        # Generate response
        response = self._generate(question, context)

        return response, results

    def retrieve(
        self,
        query: str,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Retrieve relevant documents for a query."""
        # Embed query
        query_embedding = self.embedding_model.embed_query(query)

        # Search vector store
        results = self.vector_store.search(
            query_embedding,
            top_k=self.top_k,
            filters=filters
        )

        return results

    def _build_context(self, results: List[RetrievalResult]) -> str:
        """Build context string from retrieval results."""
        context_parts = []

        for i, result in enumerate(results, 1):
            source = result.document.metadata.get('source', 'Unknown')
            context_parts.append(
                f"[Source {i}: {source}]\n{result.document.content}"
            )

        return "\n\n".join(context_parts)

    def _generate(self, question: str, context: str) -> str:
        """Generate answer using LLM with retrieved context."""
        prompt = f"""Answer the question based on the provided context.
If the context doesn't contain relevant information, say so.

Context:
{context}

Question: {question}

Answer:"""

        return self.llm.generate(prompt)


def rag_vs_fine_tuning():
    """Compare RAG with fine-tuning approaches."""
    comparison = {
        'RAG': {
            'knowledge_update': 'Instant (update documents)',
            'cost': 'Lower (no retraining)',
            'hallucination': 'Reduced (grounded)',
            'transparency': 'High (citations)',
            'knowledge_scope': 'Unlimited (external)',
            'latency': 'Higher (retrieval step)',
            'best_for': 'Dynamic knowledge, QA, search'
        },
        'Fine-tuning': {
            'knowledge_update': 'Requires retraining',
            'cost': 'Higher (compute intensive)',
            'hallucination': 'Can increase',
            'transparency': 'Low (black box)',
            'knowledge_scope': 'Limited to training data',
            'latency': 'Lower (single forward pass)',
            'best_for': 'Style, format, specialized tasks'
        },
        'RAG + Fine-tuning': {
            'knowledge_update': 'Hybrid approach',
            'cost': 'Highest',
            'hallucination': 'Lowest',
            'transparency': 'High',
            'knowledge_scope': 'Comprehensive',
            'latency': 'Medium',
            'best_for': 'Production systems, high accuracy'
        }
    }

    print("RAG vs Fine-tuning Comparison:")
    print("=" * 70)
    for approach, attrs in comparison.items():
        print(f"\n{approach}:")
        for k, v in attrs.items():
            print(f"  {k}: {v}")


rag_vs_fine_tuning()

Text Embeddings

Embeddings convert text to dense vectors for semantic similarity:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

class EmbeddingModel(ABC):
    """Abstract base class for embedding models."""

    @abstractmethod
    def embed_documents(self, texts: List[str]) -> np.ndarray:
        """Embed a list of documents."""
        pass

    @abstractmethod
    def embed_query(self, text: str) -> np.ndarray:
        """Embed a single query."""
        pass

    @property
    @abstractmethod
    def dimension(self) -> int:
        """Return embedding dimension."""
        pass


class SentenceTransformerEmbeddings(EmbeddingModel):
    """
    Sentence Transformer embeddings for RAG.

    Popular models:
    - all-MiniLM-L6-v2: Fast, good quality (384 dim)
    - all-mpnet-base-v2: Best quality (768 dim)
    - multi-qa-mpnet-base-dot-v1: Optimized for QA
    - e5-large-v2: State-of-the-art (1024 dim)
    """

    def __init__(
        self,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        normalize: bool = True,
        max_length: int = 512
    ):
        self.device = device
        self.normalize = normalize
        self.max_length = max_length

        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.model.eval()

        self._dimension = self.model.config.hidden_size

    @property
    def dimension(self) -> int:
        return self._dimension

    def _mean_pooling(
        self,
        model_output: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Mean pooling over token embeddings."""
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(
            token_embeddings.size()
        ).float()

        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)

        return sum_embeddings / sum_mask

    @torch.no_grad()
    def embed_documents(
        self,
        texts: List[str],
        batch_size: int = 32
    ) -> np.ndarray:
        """Embed multiple documents with batching."""
        all_embeddings = []

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]

            # Tokenize
            encoded = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(self.device)

            # Forward pass
            outputs = self.model(**encoded)

            # Pool
            embeddings = self._mean_pooling(outputs, encoded['attention_mask'])

            # Normalize
            if self.normalize:
                embeddings = F.normalize(embeddings, p=2, dim=1)

            all_embeddings.append(embeddings.cpu().numpy())

        return np.vstack(all_embeddings)

    def embed_query(self, text: str) -> np.ndarray:
        """Embed a single query."""
        return self.embed_documents([text])[0]


class E5Embeddings(EmbeddingModel):
    """
    E5 embeddings with query/passage prefixes.

    E5 models are trained with contrastive learning and
    require specific prefixes:
    - Query: "query: {text}"
    - Passage: "passage: {text}"
    """

    def __init__(
        self,
        model_name: str = "intfloat/e5-large-v2",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.model.eval()
        self._dimension = self.model.config.hidden_size

    @property
    def dimension(self) -> int:
        return self._dimension

    def _average_pool(
        self,
        last_hidden_states: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        last_hidden = last_hidden_states.masked_fill(
            ~attention_mask[..., None].bool(), 0.0
        )
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    @torch.no_grad()
    def embed_documents(self, texts: List[str]) -> np.ndarray:
        """Embed documents with passage prefix."""
        prefixed = [f"passage: {text}" for text in texts]

        encoded = self.tokenizer(
            prefixed,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(self.device)

        outputs = self.model(**encoded)
        embeddings = self._average_pool(outputs.last_hidden_state, encoded['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1)

        return embeddings.cpu().numpy()

    def embed_query(self, text: str) -> np.ndarray:
        """Embed query with query prefix."""
        prefixed = f"query: {text}"

        encoded = self.tokenizer(
            [prefixed],
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**encoded)
            embedding = self._average_pool(outputs.last_hidden_state, encoded['attention_mask'])
            embedding = F.normalize(embedding, p=2, dim=1)

        return embedding.cpu().numpy()[0]


class OpenAIEmbeddings(EmbeddingModel):
    """
    OpenAI embedding API wrapper.

    Models:
    - text-embedding-3-small: Fast, economical (1536 dim)
    - text-embedding-3-large: Best quality (3072 dim)
    - text-embedding-ada-002: Legacy (1536 dim)
    """

    def __init__(
        self,
        model: str = "text-embedding-3-small",
        api_key: Optional[str] = None,
        dimensions: Optional[int] = None
    ):
        self.model = model
        self.api_key = api_key
        self._dimensions = dimensions

        # Dimension lookup
        self._default_dims = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 3072,
            "text-embedding-ada-002": 1536
        }

    @property
    def dimension(self) -> int:
        if self._dimensions:
            return self._dimensions
        return self._default_dims.get(self.model, 1536)

    def embed_documents(self, texts: List[str]) -> np.ndarray:
        """Embed documents using OpenAI API."""
        # Simulated - actual implementation would call API
        # import openai
        # response = openai.embeddings.create(
        #     model=self.model,
        #     input=texts,
        #     dimensions=self._dimensions
        # )
        # return np.array([d.embedding for d in response.data])

        # Placeholder for demonstration
        return np.random.randn(len(texts), self.dimension).astype(np.float32)

    def embed_query(self, text: str) -> np.ndarray:
        return self.embed_documents([text])[0]


def embedding_model_comparison():
    """Compare popular embedding models."""
    models = {
        'all-MiniLM-L6-v2': {
            'dimension': 384,
            'max_tokens': 256,
            'speed': 'Very fast',
            'quality': 'Good',
            'use_case': 'General purpose, resource-constrained'
        },
        'all-mpnet-base-v2': {
            'dimension': 768,
            'max_tokens': 384,
            'speed': 'Fast',
            'quality': 'Very good',
            'use_case': 'General purpose'
        },
        'e5-large-v2': {
            'dimension': 1024,
            'max_tokens': 512,
            'speed': 'Medium',
            'quality': 'Excellent',
            'use_case': 'High-accuracy retrieval'
        },
        'bge-large-en-v1.5': {
            'dimension': 1024,
            'max_tokens': 512,
            'speed': 'Medium',
            'quality': 'Excellent',
            'use_case': 'Benchmark leader'
        },
        'text-embedding-3-small': {
            'dimension': 1536,
            'max_tokens': 8191,
            'speed': 'API latency',
            'quality': 'Very good',
            'use_case': 'Long documents, ease of use'
        },
        'text-embedding-3-large': {
            'dimension': 3072,
            'max_tokens': 8191,
            'speed': 'API latency',
            'quality': 'Excellent',
            'use_case': 'Maximum quality'
        }
    }

    print("Embedding Model Comparison:")
    print("=" * 60)
    for name, attrs in models.items():
        print(f"\n{name}:")
        for k, v in attrs.items():
            print(f"  {k}: {v}")


embedding_model_comparison()

Vector Stores

Vector databases enable efficient similarity search at scale:

PYTHON
class VectorStore(ABC):
    """Abstract base class for vector stores."""

    @abstractmethod
    def add(self, documents: List[Document]) -> None:
        """Add documents to the store."""
        pass

    @abstractmethod
    def search(
        self,
        query_embedding: np.ndarray,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Search for similar documents."""
        pass

    @abstractmethod
    def delete(self, doc_ids: List[str]) -> None:
        """Delete documents by ID."""
        pass


class InMemoryVectorStore(VectorStore):
    """
    Simple in-memory vector store using NumPy.

    Good for:
    - Development and testing
    - Small document collections (<100K)
    - Prototyping RAG systems
    """

    def __init__(self, dimension: int):
        self.dimension = dimension
        self.documents: List[Document] = []
        self.embeddings: Optional[np.ndarray] = None

    def add(self, documents: List[Document]) -> None:
        """Add documents to the store."""
        for doc in documents:
            if doc.embedding is None:
                raise ValueError(f"Document {doc.doc_id} has no embedding")
            if doc.embedding.shape[0] != self.dimension:
                raise ValueError(f"Embedding dimension mismatch")

        self.documents.extend(documents)

        # Rebuild embedding matrix
        embeddings = np.array([doc.embedding for doc in self.documents])
        self.embeddings = embeddings

    def search(
        self,
        query_embedding: np.ndarray,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Search using cosine similarity."""
        if self.embeddings is None or len(self.documents) == 0:
            return []

        # Normalize for cosine similarity
        query_norm = query_embedding / np.linalg.norm(query_embedding)
        embeddings_norm = self.embeddings / np.linalg.norm(
            self.embeddings, axis=1, keepdims=True
        )

        # Compute similarities
        similarities = embeddings_norm @ query_norm

        # Apply filters
        if filters:
            mask = self._apply_filters(filters)
            similarities = np.where(mask, similarities, -np.inf)

        # Get top-k
        top_indices = np.argsort(similarities)[-top_k:][::-1]

        results = []
        for rank, idx in enumerate(top_indices):
            if similarities[idx] > -np.inf:
                results.append(RetrievalResult(
                    document=self.documents[idx],
                    score=float(similarities[idx]),
                    rank=rank
                ))

        return results

    def _apply_filters(self, filters: Dict[str, Any]) -> np.ndarray:
        """Apply metadata filters."""
        mask = np.ones(len(self.documents), dtype=bool)

        for key, value in filters.items():
            for i, doc in enumerate(self.documents):
                if doc.metadata.get(key) != value:
                    mask[i] = False

        return mask

    def delete(self, doc_ids: List[str]) -> None:
        """Delete documents by ID."""
        doc_ids_set = set(doc_ids)
        self.documents = [d for d in self.documents if d.doc_id not in doc_ids_set]

        if self.documents:
            self.embeddings = np.array([doc.embedding for doc in self.documents])
        else:
            self.embeddings = None


class FAISSVectorStore(VectorStore):
    """
    FAISS-based vector store for efficient similarity search.

    FAISS (Facebook AI Similarity Search) supports:
    - Multiple index types (Flat, IVF, HNSW, PQ)
    - GPU acceleration
    - Billion-scale search
    """

    def __init__(
        self,
        dimension: int,
        index_type: str = "Flat",
        nlist: int = 100,
        nprobe: int = 10,
        use_gpu: bool = False
    ):
        import faiss

        self.dimension = dimension
        self.index_type = index_type
        self.documents: List[Document] = []

        # Create index
        if index_type == "Flat":
            # Exact search (brute force)
            self.index = faiss.IndexFlatIP(dimension)
        elif index_type == "IVF":
            # Inverted file index for approximate search
            quantizer = faiss.IndexFlatIP(dimension)
            self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
            self.index.nprobe = nprobe
        elif index_type == "HNSW":
            # Hierarchical Navigable Small World graph
            self.index = faiss.IndexHNSWFlat(dimension, 32)
        else:
            raise ValueError(f"Unknown index type: {index_type}")

        # Move to GPU if requested
        if use_gpu and faiss.get_num_gpus() > 0:
            self.index = faiss.index_cpu_to_gpu(
                faiss.StandardGpuResources(),
                0,
                self.index
            )

        self._needs_training = index_type == "IVF"
        self._is_trained = False

    def add(self, documents: List[Document]) -> None:
        """Add documents to FAISS index."""
        embeddings = np.array([doc.embedding for doc in documents]).astype('float32')

        # Normalize for inner product (cosine similarity)
        faiss.normalize_L2(embeddings)

        # Train if needed (IVF requires training)
        if self._needs_training and not self._is_trained:
            self.index.train(embeddings)
            self._is_trained = True

        # Add to index
        self.index.add(embeddings)
        self.documents.extend(documents)

    def search(
        self,
        query_embedding: np.ndarray,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Search FAISS index."""
        import faiss

        query = query_embedding.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query)

        # Search (may need to fetch more for filtering)
        fetch_k = top_k * 3 if filters else top_k
        scores, indices = self.index.search(query, fetch_k)

        results = []
        for rank, (idx, score) in enumerate(zip(indices[0], scores[0])):
            if idx == -1:  # FAISS returns -1 for missing results
                continue

            doc = self.documents[idx]

            # Apply filters
            if filters:
                if not self._matches_filters(doc, filters):
                    continue

            results.append(RetrievalResult(
                document=doc,
                score=float(score),
                rank=len(results)
            ))

            if len(results) >= top_k:
                break

        return results

    def _matches_filters(self, doc: Document, filters: Dict[str, Any]) -> bool:
        """Check if document matches filters."""
        for key, value in filters.items():
            if doc.metadata.get(key) != value:
                return False
        return True

    def delete(self, doc_ids: List[str]) -> None:
        """Delete is complex in FAISS - typically rebuild index."""
        # Mark for deletion and rebuild
        doc_ids_set = set(doc_ids)
        new_docs = [d for d in self.documents if d.doc_id not in doc_ids_set]

        # Rebuild index
        self.documents = []
        self.index.reset()
        if self._needs_training:
            self._is_trained = False

        if new_docs:
            self.add(new_docs)


class ChromaVectorStore(VectorStore):
    """
    ChromaDB vector store wrapper.

    Chroma provides:
    - Simple API for RAG applications
    - Metadata filtering
    - Persistent storage
    - Embedding function integration
    """

    def __init__(
        self,
        collection_name: str = "rag_collection",
        persist_directory: Optional[str] = None
    ):
        # import chromadb
        # if persist_directory:
        #     self.client = chromadb.PersistentClient(path=persist_directory)
        # else:
        #     self.client = chromadb.Client()
        # self.collection = self.client.get_or_create_collection(collection_name)

        # Simulated for demonstration
        self.collection_name = collection_name
        self.documents: Dict[str, Document] = {}

    def add(self, documents: List[Document]) -> None:
        """Add documents to Chroma."""
        for doc in documents:
            self.documents[doc.doc_id] = doc

        # Actual Chroma API:
        # self.collection.add(
        #     ids=[d.doc_id for d in documents],
        #     embeddings=[d.embedding.tolist() for d in documents],
        #     documents=[d.content for d in documents],
        #     metadatas=[d.metadata for d in documents]
        # )

    def search(
        self,
        query_embedding: np.ndarray,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Search Chroma collection."""
        # Actual Chroma API:
        # results = self.collection.query(
        #     query_embeddings=[query_embedding.tolist()],
        #     n_results=top_k,
        #     where=filters
        # )

        # Simulated search
        results = []
        for doc in self.documents.values():
            if doc.embedding is not None:
                score = np.dot(query_embedding, doc.embedding)
                results.append((doc, score))

        results.sort(key=lambda x: x[1], reverse=True)

        return [
            RetrievalResult(document=doc, score=score, rank=i)
            for i, (doc, score) in enumerate(results[:top_k])
        ]

    def delete(self, doc_ids: List[str]) -> None:
        """Delete documents from Chroma."""
        for doc_id in doc_ids:
            self.documents.pop(doc_id, None)

        # Actual Chroma API:
        # self.collection.delete(ids=doc_ids)


def vector_store_comparison():
    """Compare vector store options."""
    stores = {
        'In-Memory (NumPy)': {
            'scale': '<100K vectors',
            'search_type': 'Exact (brute force)',
            'persistence': 'None',
            'best_for': 'Development, small datasets'
        },
        'FAISS': {
            'scale': 'Billions of vectors',
            'search_type': 'Exact or approximate (IVF, HNSW)',
            'persistence': 'File-based',
            'best_for': 'High-performance, large scale'
        },
        'ChromaDB': {
            'scale': 'Millions of vectors',
            'search_type': 'Approximate (HNSW)',
            'persistence': 'SQLite or DuckDB',
            'best_for': 'Simple RAG applications'
        },
        'Pinecone': {
            'scale': 'Billions (managed)',
            'search_type': 'Approximate',
            'persistence': 'Cloud managed',
            'best_for': 'Production, serverless'
        },
        'Weaviate': {
            'scale': 'Billions',
            'search_type': 'HNSW + hybrid',
            'persistence': 'Docker/Cloud',
            'best_for': 'Hybrid search, GraphQL API'
        },
        'Qdrant': {
            'scale': 'Billions',
            'search_type': 'HNSW',
            'persistence': 'Docker/Cloud',
            'best_for': 'Filtering, payload support'
        },
        'Milvus': {
            'scale': 'Trillions',
            'search_type': 'Multiple indexes',
            'persistence': 'Distributed',
            'best_for': 'Enterprise, high availability'
        }
    }

    print("Vector Store Comparison:")
    print("=" * 60)
    for name, attrs in stores.items():
        print(f"\n{name}:")
        for k, v in attrs.items():
            print(f"  {k}: {v}")


vector_store_comparison()

Document Chunking

Effective chunking is critical for RAG quality:

PYTHON
class TextChunker(ABC):
    """Abstract base class for text chunkers."""

    @abstractmethod
    def chunk(
        self,
        text: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        """Split text into chunks."""
        pass


class FixedSizeChunker(TextChunker):
    """
    Fixed-size chunking with overlap.

    Simple but effective baseline approach.
    """

    def __init__(
        self,
        chunk_size: int = 512,
        chunk_overlap: int = 50,
        length_function: callable = len
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.length_function = length_function

    def chunk(
        self,
        text: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        """Split text into fixed-size chunks with overlap."""
        chunks = []
        start = 0

        while start < len(text):
            end = start + self.chunk_size

            # Find chunk text
            chunk_text = text[start:end]

            # Create document
            chunk_metadata = {**(metadata or {}), 'chunk_index': len(chunks)}
            chunks.append(Document(
                content=chunk_text.strip(),
                metadata=chunk_metadata
            ))

            # Move start with overlap
            start = end - self.chunk_overlap

        return chunks


class RecursiveCharacterChunker(TextChunker):
    """
    Recursive chunking that respects text structure.

    Tries to split on natural boundaries:
    paragraphs > sentences > words > characters
    """

    def __init__(
        self,
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        separators: Optional[List[str]] = None
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators or ["\n\n", "\n", ". ", " ", ""]

    def chunk(
        self,
        text: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        """Recursively split text on natural boundaries."""
        chunks = self._split_text(text, self.separators)

        return [
            Document(
                content=chunk.strip(),
                metadata={**(metadata or {}), 'chunk_index': i}
            )
            for i, chunk in enumerate(chunks)
            if chunk.strip()
        ]

    def _split_text(
        self,
        text: str,
        separators: List[str]
    ) -> List[str]:
        """Recursively split text."""
        final_chunks = []

        # Find the appropriate separator
        separator = separators[-1]
        for sep in separators:
            if sep in text:
                separator = sep
                break

        # Split
        if separator:
            splits = text.split(separator)
        else:
            splits = list(text)

        # Merge small chunks
        current_chunk = []
        current_length = 0

        for split in splits:
            split_length = len(split)

            if current_length + split_length > self.chunk_size:
                if current_chunk:
                    merged = separator.join(current_chunk)
                    final_chunks.append(merged)

                    # Keep overlap
                    overlap_chunks = []
                    overlap_length = 0
                    for chunk in reversed(current_chunk):
                        if overlap_length + len(chunk) <= self.chunk_overlap:
                            overlap_chunks.insert(0, chunk)
                            overlap_length += len(chunk) + len(separator)
                        else:
                            break
                    current_chunk = overlap_chunks

                current_chunk.append(split)
                current_length = sum(len(c) for c in current_chunk)
            else:
                current_chunk.append(split)
                current_length += split_length + len(separator)

        # Add remaining
        if current_chunk:
            final_chunks.append(separator.join(current_chunk))

        return final_chunks


class SemanticChunker(TextChunker):
    """
    Semantic chunking based on embedding similarity.

    Splits where semantic similarity between sentences drops,
    creating more coherent chunks.
    """

    def __init__(
        self,
        embedding_model: EmbeddingModel,
        breakpoint_threshold: float = 0.5,
        min_chunk_size: int = 100,
        max_chunk_size: int = 2000
    ):
        self.embedding_model = embedding_model
        self.breakpoint_threshold = breakpoint_threshold
        self.min_chunk_size = min_chunk_size
        self.max_chunk_size = max_chunk_size

    def chunk(
        self,
        text: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        """Split text based on semantic similarity."""
        # Split into sentences
        sentences = self._split_sentences(text)

        if len(sentences) <= 1:
            return [Document(content=text, metadata=metadata)]

        # Get embeddings for each sentence
        embeddings = self.embedding_model.embed_documents(sentences)

        # Calculate similarities between adjacent sentences
        similarities = []
        for i in range(len(embeddings) - 1):
            sim = np.dot(embeddings[i], embeddings[i + 1])
            similarities.append(sim)

        # Find breakpoints (low similarity)
        breakpoints = []
        for i, sim in enumerate(similarities):
            if sim < self.breakpoint_threshold:
                breakpoints.append(i + 1)

        # Create chunks
        chunks = []
        start = 0

        for bp in breakpoints:
            chunk_text = " ".join(sentences[start:bp])

            # Check size constraints
            if len(chunk_text) >= self.min_chunk_size:
                chunks.append(chunk_text)
                start = bp
            # If too small, continue to next breakpoint

        # Add remaining
        if start < len(sentences):
            chunk_text = " ".join(sentences[start:])
            if chunks and len(chunk_text) < self.min_chunk_size:
                chunks[-1] += " " + chunk_text
            else:
                chunks.append(chunk_text)

        return [
            Document(
                content=chunk,
                metadata={**(metadata or {}), 'chunk_index': i}
            )
            for i, chunk in enumerate(chunks)
        ]

    def _split_sentences(self, text: str) -> List[str]:
        """Simple sentence splitting."""
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        return [s.strip() for s in sentences if s.strip()]


class MarkdownChunker(TextChunker):
    """
    Markdown-aware chunking that respects document structure.

    Preserves:
    - Headers and hierarchy
    - Code blocks
    - Lists
    """

    def __init__(
        self,
        chunk_size: int = 1000,
        chunk_overlap: int = 100
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def chunk(
        self,
        text: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        """Split markdown respecting structure."""
        import re

        chunks = []
        current_headers = []

        # Split by headers
        header_pattern = r'^(#{1,6})\s+(.+)$'
        sections = re.split(r'(?=^#{1,6}\s)', text, flags=re.MULTILINE)

        for section in sections:
            if not section.strip():
                continue

            # Extract header if present
            header_match = re.match(header_pattern, section, re.MULTILINE)
            if header_match:
                level = len(header_match.group(1))
                header_text = header_match.group(2)

                # Update header hierarchy
                current_headers = current_headers[:level-1] + [header_text]

            # Create chunk with header context
            chunk_metadata = {
                **(metadata or {}),
                'headers': ' > '.join(current_headers),
                'chunk_index': len(chunks)
            }

            # Split large sections
            if len(section) > self.chunk_size:
                sub_chunks = self._split_large_section(section)
                for i, sub_chunk in enumerate(sub_chunks):
                    sub_metadata = {**chunk_metadata, 'sub_chunk': i}
                    chunks.append(Document(content=sub_chunk, metadata=sub_metadata))
            else:
                chunks.append(Document(content=section, metadata=chunk_metadata))

        return chunks

    def _split_large_section(self, section: str) -> List[str]:
        """Split large section while preserving code blocks."""
        import re

        # Protect code blocks
        code_blocks = re.findall(r'
[\s\S]*?``&#039;, section) placeholders = [f&quot;<strong>CODE<em>BLOCK</em>{i}</strong>&quot; for i in range(len(code<em>blocks))]</p> <p>for placeholder, block in zip(placeholders, code</em>blocks): section = section.replace(block, placeholder)</p> <h1 id="split-on-paragraphs">Split on paragraphs</h1> <p>chunks = [] paragraphs = section.split(&#039;\n\n&#039;) current<em>chunk = [] current</em>length = 0</p> <p>for para in paragraphs: para<em>length = len(para)</p> <p>if current</em>length + para<em>length &gt; self.chunk</em>size and current<em>chunk: chunks.append(&#039;\n\n&#039;.join(current</em>chunk)) current<em>chunk = [] current</em>length = 0</p> <p>current<em>chunk.append(para) current</em>length += para<em>length</p> <p>if current</em>chunk: chunks.append(&#039;\n\n&#039;.join(current<em>chunk))</p> <h1 id="restore-code-blocks">Restore code blocks</h1> <p>restored</em>chunks = [] for chunk in chunks: for placeholder, block in zip(placeholders, code<em>blocks): chunk = chunk.replace(placeholder, block) restored</em>chunks.append(chunk)</p> <p>return restored_chunks ``

Key Takeaways

RAG systems combine retrieval and generation to ground LLM responses in external knowledge. The core pipeline involves: (1) chunking documents into manageable pieces, (2) embedding chunks using models like E5 or OpenAI embeddings, (3) storing embeddings in vector databases like FAISS or ChromaDB, (4) retrieving relevant context for queries, and (5) generating responses conditioned on retrieved content. Key considerations include chunk size/overlap trade-offs, embedding model selection based on quality vs. speed, and vector store choice based on scale requirements. RAG offers significant advantages over fine-tuning for knowledge-intensive tasks—instant updates, reduced hallucination, and transparent sourcing—while fine-tuning remains better for style and format adaptation.

24.2 Advanced Retrieval Strategies Advanced

Advanced Retrieval Strategies

Basic semantic search often falls short for complex queries, missing relevant documents due to vocabulary mismatch, failing to capture multi-faceted information needs, or returning superficially similar but irrelevant results. Advanced retrieval strategies address these limitations through hybrid search, query transformation, reranking, and multi-hop retrieval. This section explores techniques that significantly improve retrieval quality and RAG system performance.

Combining semantic and lexical search captures different relevance signals:

PYTHON
import numpy as np
from typing import List, Dict, Optional, Tuple, Any, Set
from dataclasses import dataclass
from abc import ABC, abstractmethod
import re
from collections import Counter
import math

@dataclass
class Document:
    """Document representation."""
    content: str
    metadata: Dict[str, Any]
    doc_id: str
    embedding: Optional[np.ndarray] = None


@dataclass
class RetrievalResult:
    """Retrieval result with score."""
    document: Document
    score: float
    rank: int
    source: str = "hybrid"  # Track which retriever found this


class HybridRetriever:
    """
    Hybrid retriever combining dense and sparse search.

    Dense (semantic): Captures meaning, handles paraphrase
    Sparse (keyword): Exact matches, rare terms, names

    The combination is more robust than either alone.
    """

    def __init__(
        self,
        dense_retriever: 'DenseRetriever',
        sparse_retriever: 'SparseRetriever',
        dense_weight: float = 0.5,
        fusion_method: str = "rrf"  # rrf, linear, or dbsf
    ):
        self.dense_retriever = dense_retriever
        self.sparse_retriever = sparse_retriever
        self.dense_weight = dense_weight
        self.sparse_weight = 1 - dense_weight
        self.fusion_method = fusion_method

    def search(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """
        Perform hybrid search combining dense and sparse results.
        """
        # Get results from both retrievers
        dense_results = self.dense_retriever.search(query, top_k=top_k * 2, filters=filters)
        sparse_results = self.sparse_retriever.search(query, top_k=top_k * 2, filters=filters)

        # Fuse results
        if self.fusion_method == "rrf":
            fused = self._reciprocal_rank_fusion(dense_results, sparse_results)
        elif self.fusion_method == "linear":
            fused = self._linear_combination(dense_results, sparse_results)
        elif self.fusion_method == "dbsf":
            fused = self._distribution_based_score_fusion(dense_results, sparse_results)
        else:
            raise ValueError(f"Unknown fusion method: {self.fusion_method}")

        # Return top-k
        return fused[:top_k]

    def _reciprocal_rank_fusion(
        self,
        dense_results: List[RetrievalResult],
        sparse_results: List[RetrievalResult],
        k: int = 60  # RRF constant
    ) -> List[RetrievalResult]:
        """
        Reciprocal Rank Fusion (RRF).

        RRF score = sum(1 / (k + rank)) across all result lists.
        Robust fusion that doesn't require score normalization.
        """
        doc_scores: Dict[str, float] = {}
        doc_objects: Dict[str, Document] = {}

        # Process dense results
        for result in dense_results:
            doc_id = result.document.doc_id
            rrf_score = 1.0 / (k + result.rank + 1)
            doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score * self.dense_weight
            doc_objects[doc_id] = result.document

        # Process sparse results
        for result in sparse_results:
            doc_id = result.document.doc_id
            rrf_score = 1.0 / (k + result.rank + 1)
            doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score * self.sparse_weight
            doc_objects[doc_id] = result.document

        # Sort by fused score
        sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)

        return [
            RetrievalResult(
                document=doc_objects[doc_id],
                score=score,
                rank=rank,
                source="hybrid_rrf"
            )
            for rank, (doc_id, score) in enumerate(sorted_docs)
        ]

    def _linear_combination(
        self,
        dense_results: List[RetrievalResult],
        sparse_results: List[RetrievalResult]
    ) -> List[RetrievalResult]:
        """
        Linear combination of normalized scores.

        Requires score normalization to [0, 1] range.
        """
        doc_scores: Dict[str, float] = {}
        doc_objects: Dict[str, Document] = {}

        # Normalize dense scores
        if dense_results:
            dense_max = max(r.score for r in dense_results)
            dense_min = min(r.score for r in dense_results)
            dense_range = dense_max - dense_min if dense_max != dense_min else 1

            for result in dense_results:
                doc_id = result.document.doc_id
                norm_score = (result.score - dense_min) / dense_range
                doc_scores[doc_id] = norm_score * self.dense_weight
                doc_objects[doc_id] = result.document

        # Normalize sparse scores
        if sparse_results:
            sparse_max = max(r.score for r in sparse_results)
            sparse_min = min(r.score for r in sparse_results)
            sparse_range = sparse_max - sparse_min if sparse_max != sparse_min else 1

            for result in sparse_results:
                doc_id = result.document.doc_id
                norm_score = (result.score - sparse_min) / sparse_range
                doc_scores[doc_id] = doc_scores.get(doc_id, 0) + norm_score * self.sparse_weight
                doc_objects[doc_id] = result.document

        sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)

        return [
            RetrievalResult(
                document=doc_objects[doc_id],
                score=score,
                rank=rank,
                source="hybrid_linear"
            )
            for rank, (doc_id, score) in enumerate(sorted_docs)
        ]

    def _distribution_based_score_fusion(
        self,
        dense_results: List[RetrievalResult],
        sparse_results: List[RetrievalResult]
    ) -> List[RetrievalResult]:
        """
        Distribution-Based Score Fusion (DBSF).

        Normalizes scores based on their distribution (z-score).
        """
        doc_scores: Dict[str, float] = {}
        doc_objects: Dict[str, Document] = {}

        def z_normalize(results: List[RetrievalResult]) -> Dict[str, float]:
            if not results:
                return {}
            scores = [r.score for r in results]
            mean = np.mean(scores)
            std = np.std(scores) if len(scores) > 1 else 1
            std = std if std > 0 else 1

            return {
                r.document.doc_id: (r.score - mean) / std
                for r in results
            }

        dense_normed = z_normalize(dense_results)
        sparse_normed = z_normalize(sparse_results)

        for result in dense_results:
            doc_id = result.document.doc_id
            doc_scores[doc_id] = dense_normed[doc_id] * self.dense_weight
            doc_objects[doc_id] = result.document

        for result in sparse_results:
            doc_id = result.document.doc_id
            doc_scores[doc_id] = doc_scores.get(doc_id, 0) + sparse_normed[doc_id] * self.sparse_weight
            doc_objects[doc_id] = result.document

        sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)

        return [
            RetrievalResult(
                document=doc_objects[doc_id],
                score=score,
                rank=rank,
                source="hybrid_dbsf"
            )
            for rank, (doc_id, score) in enumerate(sorted_docs)
        ]


class BM25Retriever:
    """
    BM25 sparse retriever for keyword-based search.

    BM25 is a bag-of-words ranking function that considers:
    - Term frequency (TF) with saturation
    - Inverse document frequency (IDF)
    - Document length normalization
    """

    def __init__(
        self,
        k1: float = 1.5,
        b: float = 0.75,
        epsilon: float = 0.25
    ):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.documents: List[Document] = []
        self.doc_lengths: List[int] = []
        self.avg_doc_length: float = 0
        self.term_freqs: List[Dict[str, int]] = []
        self.doc_freqs: Dict[str, int] = {}
        self.idf: Dict[str, float] = {}

    def index(self, documents: List[Document]) -> None:
        """Index documents for BM25 search."""
        self.documents = documents
        self.doc_lengths = []
        self.term_freqs = []
        self.doc_freqs = Counter()

        for doc in documents:
            # Tokenize
            tokens = self._tokenize(doc.content)
            self.doc_lengths.append(len(tokens))

            # Term frequencies
            tf = Counter(tokens)
            self.term_freqs.append(tf)

            # Document frequencies
            for term in set(tokens):
                self.doc_freqs[term] += 1

        self.avg_doc_length = np.mean(self.doc_lengths) if self.doc_lengths else 0

        # Compute IDF
        n_docs = len(documents)
        for term, df in self.doc_freqs.items():
            # IDF with smoothing
            self.idf[term] = math.log((n_docs - df + 0.5) / (df + 0.5) + 1)

    def search(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Search using BM25 scoring."""
        query_tokens = self._tokenize(query)

        scores = []
        for i, doc in enumerate(self.documents):
            # Apply filters
            if filters and not self._matches_filters(doc, filters):
                scores.append(-float('inf'))
                continue

            score = self._score_document(query_tokens, i)
            scores.append(score)

        # Get top-k indices
        top_indices = np.argsort(scores)[-top_k:][::-1]

        results = []
        for rank, idx in enumerate(top_indices):
            if scores[idx] > -float('inf'):
                results.append(RetrievalResult(
                    document=self.documents[idx],
                    score=scores[idx],
                    rank=rank,
                    source="bm25"
                ))

        return results

    def _score_document(self, query_tokens: List[str], doc_idx: int) -> float:
        """Compute BM25 score for a document."""
        score = 0.0
        doc_len = self.doc_lengths[doc_idx]
        tf_dict = self.term_freqs[doc_idx]

        for token in query_tokens:
            if token not in self.idf:
                continue

            tf = tf_dict.get(token, 0)
            idf = self.idf[token]

            # BM25 formula
            numerator = tf * (self.k1 + 1)
            denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_length)

            score += idf * numerator / denominator

        return score

    def _tokenize(self, text: str) -> List[str]:
        """Simple tokenization."""
        text = text.lower()
        tokens = re.findall(r'\b\w+\b', text)
        return tokens

    def _matches_filters(self, doc: Document, filters: Dict[str, Any]) -> bool:
        for key, value in filters.items():
            if doc.metadata.get(key) != value:
                return False
        return True


class DenseRetriever:
    """Dense retriever using vector similarity."""

    def __init__(self, embedding_model, vector_store):
        self.embedding_model = embedding_model
        self.vector_store = vector_store

    def search(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        query_embedding = self.embedding_model.embed_query(query)
        return self.vector_store.search(query_embedding, top_k, filters)

Query Transformation

Transforming queries improves retrieval for complex information needs:

PYTHON
class QueryTransformer(ABC):
    """Abstract base class for query transformers."""

    @abstractmethod
    def transform(self, query: str) -> List[str]:
        """Transform a query into one or more queries."""
        pass


class HyDEQueryTransformer(QueryTransformer):
    """
    Hypothetical Document Embeddings (HyDE).

    Instead of embedding the query directly, generate a
    hypothetical answer and embed that. The hypothetical
    answer is closer in embedding space to real answers.

    Query: "What causes rain?"
    HyDE: "Rain is caused by water vapor condensing..."
    """

    def __init__(self, llm, n_hypotheses: int = 1):
        self.llm = llm
        self.n_hypotheses = n_hypotheses

    def transform(self, query: str) -> List[str]:
        """Generate hypothetical document(s) for the query."""
        prompt = f"""Generate a detailed passage that would answer the following question.
Write as if you are writing a paragraph from a textbook or encyclopedia.
Do not include phrases like "According to" or "The answer is".

Question: {query}

Passage:"""

        hypotheses = []
        for _ in range(self.n_hypotheses):
            hypothesis = self.llm.generate(prompt)
            hypotheses.append(hypothesis)

        return hypotheses


class MultiQueryTransformer(QueryTransformer):
    """
    Generate multiple query variations to improve recall.

    Different phrasings capture different relevant documents.
    """

    def __init__(self, llm, n_queries: int = 3):
        self.llm = llm
        self.n_queries = n_queries

    def transform(self, query: str) -> List[str]:
        """Generate multiple query variations."""
        prompt = f"""Generate {self.n_queries} different versions of the following question.
Each version should ask for the same information but use different words and phrasing.
Return only the questions, one per line.

Original question: {query}

Alternative questions:"""

        response = self.llm.generate(prompt)
        queries = [q.strip() for q in response.strip().split('\n') if q.strip()]

        # Include original query
        return [query] + queries[:self.n_queries]


class StepBackQueryTransformer(QueryTransformer):
    """
    Step-back prompting for complex queries.

    For specific questions, first ask a more general question
    to retrieve broader context, then answer the specific question.

    Query: "What was Einstein's GPA at ETH Zurich?"
    Step-back: "What was Einstein's educational background?"
    """

    def __init__(self, llm):
        self.llm = llm

    def transform(self, query: str) -> List[str]:
        """Generate step-back query."""
        prompt = f"""Given a specific question, generate a more general question that would
provide useful background context for answering the original question.

Specific question: {query}

General question:"""

        general_query = self.llm.generate(prompt).strip()

        return [general_query, query]


class QueryDecomposer(QueryTransformer):
    """
    Decompose complex queries into sub-queries.

    Useful for multi-hop questions that require information
    from multiple documents.

    Query: "Who is the CEO of the company that made the iPhone?"
    Sub-queries:
    1. "What company made the iPhone?"
    2. "Who is the CEO of Apple?"
    """

    def __init__(self, llm, max_sub_queries: int = 4):
        self.llm = llm
        self.max_sub_queries = max_sub_queries

    def transform(self, query: str) -> List[str]:
        """Decompose query into sub-queries."""
        prompt = f"""Break down the following complex question into simpler sub-questions
that can be answered independently. Each sub-question should retrieve
specific information needed to answer the original question.

Complex question: {query}

Sub-questions (one per line):"""

        response = self.llm.generate(prompt)
        sub_queries = [q.strip() for q in response.strip().split('\n') if q.strip()]

        # Clean up numbering
        sub_queries = [re.sub(r'^\d+[\.\)]\s*', '', q) for q in sub_queries]

        return sub_queries[:self.max_sub_queries]


class MultiQueryRetriever:
    """
    Retriever that uses multiple query variations.

    Combines results from all query variations for better recall.
    """

    def __init__(
        self,
        base_retriever,
        query_transformer: QueryTransformer,
        fusion_method: str = "rrf"
    ):
        self.base_retriever = base_retriever
        self.query_transformer = query_transformer
        self.fusion_method = fusion_method

    def search(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Search using multiple query variations."""
        # Transform query
        queries = self.query_transformer.transform(query)

        # Retrieve for each query
        all_results = []
        for q in queries:
            results = self.base_retriever.search(q, top_k=top_k, filters=filters)
            all_results.append(results)

        # Fuse results
        if self.fusion_method == "rrf":
            return self._rrf_fusion(all_results, top_k)
        else:
            return self._union_dedup(all_results, top_k)

    def _rrf_fusion(
        self,
        all_results: List[List[RetrievalResult]],
        top_k: int,
        k: int = 60
    ) -> List[RetrievalResult]:
        """RRF fusion across query variations."""
        doc_scores: Dict[str, float] = {}
        doc_objects: Dict[str, Document] = {}

        for results in all_results:
            for result in results:
                doc_id = result.document.doc_id
                rrf_score = 1.0 / (k + result.rank + 1)
                doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score
                doc_objects[doc_id] = result.document

        sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)

        return [
            RetrievalResult(
                document=doc_objects[doc_id],
                score=score,
                rank=rank,
                source="multi_query"
            )
            for rank, (doc_id, score) in enumerate(sorted_docs[:top_k])
        ]

    def _union_dedup(
        self,
        all_results: List[List[RetrievalResult]],
        top_k: int
    ) -> List[RetrievalResult]:
        """Union with deduplication, keeping highest score."""
        seen: Dict[str, RetrievalResult] = {}

        for results in all_results:
            for result in results:
                doc_id = result.document.doc_id
                if doc_id not in seen or result.score > seen[doc_id].score:
                    seen[doc_id] = result

        sorted_results = sorted(seen.values(), key=lambda x: x.score, reverse=True)

        return [
            RetrievalResult(
                document=r.document,
                score=r.score,
                rank=i,
                source="multi_query"
            )
            for i, r in enumerate(sorted_results[:top_k])
        ]

Reranking

Cross-encoder reranking improves precision on retrieved candidates:

PYTHON
class Reranker(ABC):
    """Abstract base class for rerankers."""

    @abstractmethod
    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: Optional[int] = None
    ) -> List[RetrievalResult]:
        """Rerank documents for a query."""
        pass


class CrossEncoderReranker(Reranker):
    """
    Cross-encoder reranker using transformer models.

    Cross-encoders jointly encode query and document,
    allowing rich interaction between them. More accurate
    than bi-encoders but slower (can't pre-compute embeddings).

    Popular models:
    - cross-encoder/ms-marco-MiniLM-L-6-v2
    - BAAI/bge-reranker-base
    - Cohere rerank
    """

    def __init__(
        self,
        model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
        device: str = "cuda",
        batch_size: int = 32
    ):
        import torch
        from transformers import AutoTokenizer, AutoModelForSequenceClassification

        self.device = device
        self.batch_size = batch_size

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.to(device)
        self.model.eval()

    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: Optional[int] = None
    ) -> List[RetrievalResult]:
        """Rerank documents using cross-encoder scores."""
        import torch

        if not documents:
            return []

        # Prepare pairs
        pairs = [[query, doc.content] for doc in documents]

        # Score in batches
        all_scores = []

        with torch.no_grad():
            for i in range(0, len(pairs), self.batch_size):
                batch = pairs[i:i + self.batch_size]

                inputs = self.tokenizer(
                    batch,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors='pt'
                ).to(self.device)

                outputs = self.model(**inputs)
                scores = outputs.logits.squeeze(-1)

                # Handle single output
                if scores.dim() == 0:
                    scores = scores.unsqueeze(0)

                all_scores.extend(scores.cpu().tolist())

        # Sort by score
        scored_docs = list(zip(documents, all_scores))
        scored_docs.sort(key=lambda x: x[1], reverse=True)

        if top_k:
            scored_docs = scored_docs[:top_k]

        return [
            RetrievalResult(
                document=doc,
                score=score,
                rank=rank,
                source="reranked"
            )
            for rank, (doc, score) in enumerate(scored_docs)
        ]


class LLMReranker(Reranker):
    """
    LLM-based reranker using relevance scoring.

    Uses an LLM to score document relevance to query.
    More flexible but slower and more expensive.
    """

    def __init__(self, llm, batch_size: int = 5):
        self.llm = llm
        self.batch_size = batch_size

    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: Optional[int] = None
    ) -> List[RetrievalResult]:
        """Rerank using LLM relevance scoring."""
        scores = []

        for doc in documents:
            score = self._score_relevance(query, doc.content)
            scores.append(score)

        # Sort by score
        scored_docs = list(zip(documents, scores))
        scored_docs.sort(key=lambda x: x[1], reverse=True)

        if top_k:
            scored_docs = scored_docs[:top_k]

        return [
            RetrievalResult(
                document=doc,
                score=score,
                rank=rank,
                source="llm_reranked"
            )
            for rank, (doc, score) in enumerate(scored_docs)
        ]

    def _score_relevance(self, query: str, content: str) -> float:
        """Score relevance of content to query."""
        prompt = f"""Rate the relevance of the following document to the query.
Score from 0 to 10, where 0 is completely irrelevant and 10 is highly relevant.
Return only the numeric score.

Query: {query}

Document: {content[:1000]}

Score:"""

        response = self.llm.generate(prompt)

        try:
            score = float(re.search(r'\d+\.?\d*', response).group())
            return min(10, max(0, score))
        except:
            return 0.0


class CohereReranker(Reranker):
    """
    Cohere Rerank API wrapper.

    High-quality commercial reranking service.
    """

    def __init__(self, api_key: str, model: str = "rerank-english-v3.0"):
        self.api_key = api_key
        self.model = model

    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: Optional[int] = None
    ) -> List[RetrievalResult]:
        """Rerank using Cohere API."""
        # Actual implementation would call Cohere API
        # import cohere
        # co = cohere.Client(self.api_key)
        # results = co.rerank(
        #     query=query,
        #     documents=[d.content for d in documents],
        #     top_n=top_k,
        #     model=self.model
        # )

        # Simulated response
        scores = np.random.rand(len(documents))
        scored_docs = list(zip(documents, scores))
        scored_docs.sort(key=lambda x: x[1], reverse=True)

        if top_k:
            scored_docs = scored_docs[:top_k]

        return [
            RetrievalResult(
                document=doc,
                score=score,
                rank=rank,
                source="cohere_reranked"
            )
            for rank, (doc, score) in enumerate(scored_docs)
        ]


class RetrieveAndRerank:
    """
    Two-stage retrieval: fast retrieval followed by accurate reranking.

    Stage 1: Retrieve many candidates quickly (bi-encoder)
    Stage 2: Rerank candidates accurately (cross-encoder)
    """

    def __init__(
        self,
        retriever,
        reranker: Reranker,
        retrieve_k: int = 100,
        rerank_k: int = 10
    ):
        self.retriever = retriever
        self.reranker = reranker
        self.retrieve_k = retrieve_k
        self.rerank_k = rerank_k

    def search(
        self,
        query: str,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[RetrievalResult]:
        """Two-stage retrieve and rerank."""
        # Stage 1: Fast retrieval
        candidates = self.retriever.search(
            query,
            top_k=self.retrieve_k,
            filters=filters
        )

        if not candidates:
            return []

        # Stage 2: Accurate reranking
        documents = [r.document for r in candidates]
        reranked = self.reranker.rerank(
            query,
            documents,
            top_k=self.rerank_k
        )

        return reranked

Multi-Hop Retrieval

Complex questions require information from multiple documents:

PYTHON
class MultiHopRetriever:
    """
    Multi-hop retriever for complex questions.

    Iteratively retrieves and reasons to gather information
    from multiple documents that need to be combined.
    """

    def __init__(
        self,
        retriever,
        llm,
        max_hops: int = 3
    ):
        self.retriever = retriever
        self.llm = llm
        self.max_hops = max_hops

    def search(
        self,
        query: str,
        top_k: int = 5
    ) -> Tuple[List[RetrievalResult], List[str]]:
        """
        Multi-hop retrieval with reasoning chain.

        Returns final results and the reasoning trace.
        """
        all_results = []
        seen_doc_ids: Set[str] = set()
        reasoning_trace = []

        current_query = query

        for hop in range(self.max_hops):
            # Retrieve for current query
            results = self.retriever.search(current_query, top_k=top_k)

            # Filter out already seen documents
            new_results = [
                r for r in results
                if r.document.doc_id not in seen_doc_ids
            ]

            if not new_results:
                reasoning_trace.append(f"Hop {hop + 1}: No new documents found")
                break

            all_results.extend(new_results)
            for r in new_results:
                seen_doc_ids.add(r.document.doc_id)

            # Build context from all retrieved documents
            context = self._build_context(all_results)

            # Check if we can answer
            can_answer, reasoning = self._can_answer(query, context)
            reasoning_trace.append(f"Hop {hop + 1}: {reasoning}")

            if can_answer:
                break

            # Generate follow-up query
            current_query = self._generate_followup(query, context)
            reasoning_trace.append(f"Follow-up query: {current_query}")

        return all_results, reasoning_trace

    def _build_context(self, results: List[RetrievalResult]) -> str:
        """Build context from retrieved documents."""
        context_parts = []
        for i, r in enumerate(results, 1):
            context_parts.append(f"[Document {i}]\n{r.document.content}")
        return "\n\n".join(context_parts)

    def _can_answer(self, query: str, context: str) -> Tuple[bool, str]:
        """Check if current context can answer the query."""
        prompt = f"""Given the following context and question, determine if the context
contains enough information to fully answer the question.

Context:
{context}

Question: {query}

Can this question be fully answered with the given context?
Respond with "YES" or "NO" followed by a brief explanation.

Response:"""

        response = self.llm.generate(prompt)

        can_answer = response.strip().upper().startswith("YES")
        return can_answer, response.strip()

    def _generate_followup(self, original_query: str, context: str) -> str:
        """Generate follow-up query for missing information."""
        prompt = f"""The following context does not contain enough information
to answer the question. Generate a follow-up search query to find
the missing information.

Original question: {original_query}

Current context:
{context[:2000]}

What additional information is needed? Generate a search query:"""

        followup = self.llm.generate(prompt)
        return followup.strip()


class SelfQueryRetriever:
    """
    Self-querying retriever that extracts metadata filters.

    Uses LLM to parse natural language queries into
    structured filters + semantic query.

    Query: "Find Python tutorials from 2023"
    Parsed: filters={'language': 'python', 'year': 2023}
            query="tutorials"
    """

    def __init__(
        self,
        retriever,
        llm,
        metadata_schema: Dict[str, Dict[str, Any]]
    ):
        """
        Args:
            retriever: Base retriever
            llm: Language model for query parsing
            metadata_schema: Schema of available metadata fields
                Example: {
                    'category': {'type': 'string', 'values': ['tech', 'science']},
                    'year': {'type': 'integer', 'min': 2000, 'max': 2024},
                    'author': {'type': 'string'}
                }
        """
        self.retriever = retriever
        self.llm = llm
        self.metadata_schema = metadata_schema

    def search(
        self,
        query: str,
        top_k: int = 10
    ) -> List[RetrievalResult]:
        """Parse query and search with extracted filters."""
        # Parse query into filters and semantic query
        parsed = self._parse_query(query)

        # Search with filters
        results = self.retriever.search(
            parsed['query'],
            top_k=top_k,
            filters=parsed['filters']
        )

        return results

    def _parse_query(self, query: str) -> Dict[str, Any]:
        """Parse natural language query into structured form."""
        schema_desc = self._format_schema()

        prompt = f"""Parse the following search query into a semantic search query
and metadata filters.

Available metadata fields:
{schema_desc}

Query: {query}

Return a JSON object with:
- "query": The semantic search portion (what to search for)
- "filters": A dictionary of metadata filters to apply

JSON response:"""

        response = self.llm.generate(prompt)

        try:
            # Extract JSON from response
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                import json
                parsed = json.loads(json_match.group())
                return {
                    'query': parsed.get('query', query),
                    'filters': parsed.get('filters', {})
                }
        except:
            pass

        # Fallback to original query
        return {'query': query, 'filters': {}}

    def _format_schema(self) -> str:
        """Format metadata schema for prompt."""
        lines = []
        for field, info in self.metadata_schema.items():
            field_type = info.get('type', 'string')
            if 'values' in info:
                lines.append(f"- {field} ({field_type}): one of {info['values']}")
            elif 'min' in info or 'max' in info:
                range_str = f"{info.get('min', '...')} to {info.get('max', '...')}"
                lines.append(f"- {field} ({field_type}): range {range_str}")
            else:
                lines.append(f"- {field} ({field_type})")
        return "\n".join(lines)

Key Takeaways

Advanced retrieval strategies significantly improve RAG system quality. Hybrid search combines semantic understanding with exact keyword matching, capturing relevance signals that either approach alone would miss. Query transformation techniques like HyDE, multi-query generation, and decomposition help retrieve documents for complex or ambiguous queries. Cross-encoder reranking improves precision by jointly modeling query-document relevance, typically boosting accuracy by 5-15%. Multi-hop retrieval enables answering complex questions that require synthesizing information from multiple documents. Self-querying automatically extracts metadata filters from natural language. The optimal strategy depends on your use case: hybrid search for general-purpose applications, aggressive reranking for precision-critical systems, and multi-hop for complex reasoning tasks.

24.3 RAG Generation and Prompting Advanced

RAG Generation and Prompting

The generation phase of RAG determines how effectively retrieved context is used to produce accurate, well-grounded responses. Poor prompting can lead to hallucination, ignored context, or unfaithful answers even with perfect retrieval. This section covers prompt engineering for RAG, context injection strategies, citation generation, and techniques for reducing hallucination while maintaining response quality.

Context Injection Strategies

How you present retrieved context significantly impacts generation quality:

PYTHON
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass
from abc import ABC, abstractmethod
import re

@dataclass
class Document:
    """Document with content and metadata."""
    content: str
    metadata: Dict[str, Any]
    doc_id: str
    score: Optional[float] = None


@dataclass
class GenerationResult:
    """Result from RAG generation."""
    answer: str
    citations: List[Dict[str, Any]]
    confidence: float
    reasoning: Optional[str] = None


class ContextInjector(ABC):
    """Abstract base class for context injection strategies."""

    @abstractmethod
    def format_context(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """Format retrieved documents into context string."""
        pass


class SimpleContextInjector(ContextInjector):
    """
    Simple context injection with numbered sources.

    Best for straightforward QA with clear source attribution.
    """

    def __init__(
        self,
        max_context_length: int = 4000,
        include_metadata: bool = True
    ):
        self.max_context_length = max_context_length
        self.include_metadata = include_metadata

    def format_context(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """Format documents with source numbers."""
        context_parts = []
        current_length = 0

        for i, doc in enumerate(documents, 1):
            # Build source header
            if self.include_metadata:
                source = doc.metadata.get('source', f'Document {i}')
                header = f"[Source {i}: {source}]"
            else:
                header = f"[Source {i}]"

            # Check length
            chunk = f"{header}\n{doc.content}\n"
            if current_length + len(chunk) > self.max_context_length:
                # Truncate if needed
                remaining = self.max_context_length - current_length - len(header) - 10
                if remaining > 100:
                    chunk = f"{header}\n{doc.content[:remaining]}...\n"
                else:
                    break

            context_parts.append(chunk)
            current_length += len(chunk)

        return "\n".join(context_parts)


class StructuredContextInjector(ContextInjector):
    """
    Structured context with XML-like tags.

    Helps models distinguish between different parts of the prompt.
    """

    def format_context(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """Format with structured tags."""
        context_parts = []

        for i, doc in enumerate(documents, 1):
            source = doc.metadata.get('source', 'unknown')
            score = f" relevance={doc.score:.2f}" if doc.score else ""

            context_parts.append(f"""<document id="{i}" source="{source}"{score}>
{doc.content}
</document>""")

        return "\n\n".join(context_parts)


class HierarchicalContextInjector(ContextInjector):
    """
    Hierarchical context organization by relevance.

    Places most relevant documents first, with clear relevance indicators.
    """

    def __init__(self, relevance_tiers: int = 3):
        self.relevance_tiers = relevance_tiers

    def format_context(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """Organize context by relevance tiers."""
        if not documents:
            return "No relevant documents found."

        # Sort by score
        sorted_docs = sorted(documents, key=lambda d: d.score or 0, reverse=True)

        # Divide into tiers
        tier_size = max(1, len(sorted_docs) // self.relevance_tiers)
        tiers = {
            "Highly Relevant": sorted_docs[:tier_size],
            "Relevant": sorted_docs[tier_size:tier_size*2],
            "Potentially Relevant": sorted_docs[tier_size*2:]
        }

        context_parts = []
        doc_num = 1

        for tier_name, tier_docs in tiers.items():
            if tier_docs:
                context_parts.append(f"=== {tier_name} Sources ===")
                for doc in tier_docs:
                    source = doc.metadata.get('source', 'Unknown')
                    context_parts.append(f"\n[{doc_num}] {source}\n{doc.content}")
                    doc_num += 1

        return "\n".join(context_parts)


class CompressedContextInjector(ContextInjector):
    """
    Compress context using LLM summarization.

    Useful when retrieved content exceeds context window.
    """

    def __init__(self, llm, target_length: int = 2000):
        self.llm = llm
        self.target_length = target_length

    def format_context(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """Compress documents into concise context."""
        # First, try simple concatenation
        full_context = "\n\n".join([doc.content for doc in documents])

        if len(full_context) <= self.target_length:
            return full_context

        # Compress using LLM
        prompt = f"""Summarize the following documents to answer this question: {query}

Keep all facts and details relevant to the question.
Remove redundant information and filler text.
Maintain source attribution where possible.

Documents:
{full_context[:8000]}

Concise summary (max {self.target_length} chars):"""

        compressed = self.llm.generate(prompt)
        return compressed[:self.target_length]

RAG Prompt Engineering

Effective prompts guide the model to use context faithfully:

PYTHON
class RAGPromptTemplate:
    """
    RAG prompt templates for different use cases.

    Key principles:
    1. Clear instruction to use provided context
    2. Explicit handling of insufficient information
    3. Citation format specification
    4. Hallucination prevention instructions
    """

    @staticmethod
    def basic_qa(context: str, question: str) -> str:
        """Basic QA prompt with context."""
        return f"""Answer the question based on the provided context.
If the context doesn't contain relevant information, say "I don't have enough information to answer this question."

Context:
{context}

Question: {question}

Answer:"""

    @staticmethod
    def cited_qa(context: str, question: str) -> str:
        """QA with inline citations."""
        return f"""Answer the question using ONLY information from the provided sources.
Include citations in [1], [2], etc. format after each fact.
If sources don't contain the answer, say so.

Sources:
{context}

Question: {question}

Answer with citations:"""

    @staticmethod
    def strict_grounding(context: str, question: str) -> str:
        """Strict grounding to prevent hallucination."""
        return f"""You are a helpful assistant that answers questions using ONLY the provided context.

IMPORTANT RULES:
1. ONLY use information explicitly stated in the context
2. Do NOT add information from your training data
3. If the context doesn't answer the question, say "The provided documents don't contain this information"
4. Quote relevant passages when possible
5. Cite sources using [Source N] format

Context:
{context}

Question: {question}

Grounded Answer:"""

    @staticmethod
    def analytical_qa(context: str, question: str) -> str:
        """Analytical response with reasoning."""
        return f"""Analyze the provided documents to answer the question.

First, identify relevant information from each source.
Then, synthesize the information into a comprehensive answer.
Finally, note any gaps or conflicting information.

Documents:
{context}

Question: {question}

Analysis:
1. Relevant Information:
2. Synthesized Answer:
3. Limitations/Gaps:"""

    @staticmethod
    def conversational_qa(
        context: str,
        question: str,
        chat_history: List[Dict[str, str]]
    ) -> str:
        """Conversational QA with history."""
        history_str = ""
        for turn in chat_history[-5:]:  # Last 5 turns
            history_str += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n\n"

        return f"""You are having a conversation with a user. Use the provided context to answer their question.
Maintain conversation flow while staying grounded in the sources.

Previous conversation:
{history_str}

Relevant context:
{context}

Current question: {question}

Response:"""

    @staticmethod
    def multi_document_synthesis(context: str, question: str) -> str:
        """Synthesize information across multiple documents."""
        return f"""You are given multiple documents that may contain complementary, overlapping, or contradictory information.

Your task:
1. Identify the key points from each document
2. Note agreements and disagreements between sources
3. Provide a balanced synthesis that acknowledges different perspectives
4. Clearly attribute claims to their sources

Documents:
{context}

Question: {question}

Synthesis:"""


class RAGGenerator:
    """
    RAG response generator with multiple strategies.
    """

    def __init__(
        self,
        llm,
        context_injector: ContextInjector,
        citation_extractor: Optional['CitationExtractor'] = None
    ):
        self.llm = llm
        self.context_injector = context_injector
        self.citation_extractor = citation_extractor or CitationExtractor()

    def generate(
        self,
        query: str,
        documents: List[Document],
        prompt_style: str = "cited_qa"
    ) -> GenerationResult:
        """Generate response with retrieved context."""
        # Format context
        context = self.context_injector.format_context(documents, query)

        # Select prompt template
        prompt_fn = getattr(RAGPromptTemplate, prompt_style, RAGPromptTemplate.basic_qa)
        prompt = prompt_fn(context, query)

        # Generate response
        response = self.llm.generate(prompt)

        # Extract citations
        citations = self.citation_extractor.extract(response, documents)

        # Estimate confidence based on citation coverage
        confidence = self._estimate_confidence(response, citations, documents)

        return GenerationResult(
            answer=response,
            citations=citations,
            confidence=confidence
        )

    def _estimate_confidence(
        self,
        response: str,
        citations: List[Dict[str, Any]],
        documents: List[Document]
    ) -> float:
        """Estimate confidence based on grounding."""
        if not response or "don't have enough information" in response.lower():
            return 0.3

        # Higher confidence with more citations
        citation_score = min(1.0, len(citations) / 3)

        # Check for hedging language
        hedging_phrases = [
            "might", "could", "possibly", "perhaps",
            "I'm not sure", "it seems", "appears to"
        ]
        hedging_count = sum(1 for p in hedging_phrases if p in response.lower())
        hedging_penalty = min(0.3, hedging_count * 0.1)

        confidence = 0.5 + (citation_score * 0.4) - hedging_penalty
        return max(0.1, min(1.0, confidence))


class CitationExtractor:
    """Extract and validate citations from generated text."""

    def extract(
        self,
        response: str,
        source_documents: List[Document]
    ) -> List[Dict[str, Any]]:
        """Extract citations from response."""
        citations = []

        # Find citation patterns: [1], [Source 1], (1), etc.
        patterns = [
            r'\[(\d+)\]',
            r'\[Source (\d+)\]',
            r'\((\d+)\)',
            r'Source (\d+)',
        ]

        cited_indices = set()
        for pattern in patterns:
            matches = re.findall(pattern, response)
            for match in matches:
                try:
                    idx = int(match) - 1  # Convert to 0-indexed
                    cited_indices.add(idx)
                except ValueError:
                    continue

        # Build citation objects
        for idx in sorted(cited_indices):
            if 0 <= idx < len(source_documents):
                doc = source_documents[idx]
                citations.append({
                    'index': idx + 1,
                    'source': doc.metadata.get('source', f'Document {idx + 1}'),
                    'doc_id': doc.doc_id,
                    'metadata': doc.metadata
                })

        return citations

    def validate_citations(
        self,
        response: str,
        citations: List[Dict[str, Any]],
        source_documents: List[Document]
    ) -> Dict[str, Any]:
        """Validate that citations support the claims."""
        validation_result = {
            'valid': True,
            'issues': [],
            'coverage': 0.0
        }

        # Check if cited sources exist
        for citation in citations:
            idx = citation['index'] - 1
            if idx >= len(source_documents):
                validation_result['issues'].append(
                    f"Citation [{citation['index']}] references non-existent source"
                )
                validation_result['valid'] = False

        # Estimate how much of response is cited
        sentences = re.split(r'[.!?]+', response)
        cited_sentences = sum(
            1 for s in sentences
            if any(f'[{c["index"]}]' in s for c in citations)
        )
        validation_result['coverage'] = cited_sentences / max(1, len(sentences))

        return validation_result

Hallucination Reduction

Techniques to minimize unfaithful generation:

PYTHON
class HallucinationReducer:
    """
    Techniques to reduce hallucination in RAG responses.

    Hallucination types:
    1. Intrinsic: Contradicts source documents
    2. Extrinsic: Adds unsupported information
    """

    def __init__(self, llm, embedding_model):
        self.llm = llm
        self.embedding_model = embedding_model

    def verify_response(
        self,
        response: str,
        source_documents: List[Document],
        query: str
    ) -> Dict[str, Any]:
        """Verify response against source documents."""
        verification = {
            'is_grounded': True,
            'unsupported_claims': [],
            'contradictions': [],
            'confidence': 1.0
        }

        # Extract claims from response
        claims = self._extract_claims(response)

        # Verify each claim
        for claim in claims:
            support = self._find_support(claim, source_documents)

            if support['status'] == 'unsupported':
                verification['unsupported_claims'].append({
                    'claim': claim,
                    'reason': support['reason']
                })
                verification['is_grounded'] = False
            elif support['status'] == 'contradicted':
                verification['contradictions'].append({
                    'claim': claim,
                    'contradicting_text': support['evidence']
                })
                verification['is_grounded'] = False

        # Calculate confidence
        n_claims = len(claims)
        n_issues = len(verification['unsupported_claims']) + len(verification['contradictions'])
        verification['confidence'] = 1 - (n_issues / max(1, n_claims))

        return verification

    def _extract_claims(self, response: str) -> List[str]:
        """Extract factual claims from response."""
        prompt = f"""Extract the factual claims from the following text.
List each distinct claim on a separate line.
Only include verifiable factual statements, not opinions or hedged statements.

Text: {response}

Claims:"""

        claims_text = self.llm.generate(prompt)
        claims = [c.strip() for c in claims_text.split('\n') if c.strip()]
        # Remove numbering
        claims = [re.sub(r'^\d+[\.\)]\s*', '', c) for c in claims]

        return claims

    def _find_support(
        self,
        claim: str,
        documents: List[Document]
    ) -> Dict[str, Any]:
        """Find support for a claim in documents."""
        # Embed claim
        claim_embedding = self.embedding_model.embed_query(claim)

        # Find most similar passages
        best_score = 0
        best_passage = ""

        for doc in documents:
            # Simple chunking for comparison
            passages = doc.content.split('\n\n')
            for passage in passages:
                if len(passage) < 20:
                    continue

                passage_embedding = self.embedding_model.embed_query(passage)
                score = np.dot(claim_embedding, passage_embedding)

                if score > best_score:
                    best_score = score
                    best_passage = passage

        # Check support level
        if best_score < 0.5:
            return {
                'status': 'unsupported',
                'reason': 'No similar passage found in documents',
                'score': best_score
            }

        # Use LLM to verify support
        verification_prompt = f"""Does the passage support, contradict, or neither support nor contradict the claim?

Claim: {claim}

Passage: {best_passage}

Answer with one of: SUPPORTS, CONTRADICTS, NEITHER
Then briefly explain why.

Answer:"""

        verification = self.llm.generate(verification_prompt)

        if 'CONTRADICTS' in verification.upper():
            return {
                'status': 'contradicted',
                'evidence': best_passage,
                'score': best_score
            }
        elif 'SUPPORTS' in verification.upper():
            return {
                'status': 'supported',
                'evidence': best_passage,
                'score': best_score
            }
        else:
            return {
                'status': 'unsupported',
                'reason': 'Passage does not clearly support claim',
                'score': best_score
            }


class SelfConsistencyChecker:
    """
    Self-consistency checking through multiple generations.

    Generate multiple responses and check for consistency.
    Inconsistent answers indicate uncertainty or hallucination.
    """

    def __init__(self, llm, n_samples: int = 3):
        self.llm = llm
        self.n_samples = n_samples

    def generate_with_consistency(
        self,
        prompt: str,
        temperature: float = 0.7
    ) -> Dict[str, Any]:
        """Generate multiple responses and check consistency."""
        responses = []

        for _ in range(self.n_samples):
            response = self.llm.generate(prompt, temperature=temperature)
            responses.append(response)

        # Check consistency
        consistency = self._check_consistency(responses)

        # Select best response
        if consistency['is_consistent']:
            best_response = responses[0]
        else:
            best_response = self._select_majority(responses)

        return {
            'response': best_response,
            'all_responses': responses,
            'consistency': consistency
        }

    def _check_consistency(self, responses: List[str]) -> Dict[str, Any]:
        """Check if responses are consistent with each other."""
        prompt = f"""Compare these responses and determine if they are consistent
(conveying the same core information) or inconsistent (contradicting each other).

Response 1: {responses[0][:500]}

Response 2: {responses[1][:500] if len(responses) > 1 else 'N/A'}

Response 3: {responses[2][:500] if len(responses) > 2 else 'N/A'}

Are these responses consistent? Answer YES or NO, then explain any inconsistencies.

Answer:"""

        analysis = self.llm.generate(prompt)
        is_consistent = analysis.strip().upper().startswith('YES')

        return {
            'is_consistent': is_consistent,
            'analysis': analysis
        }

    def _select_majority(self, responses: List[str]) -> str:
        """Select response that aligns with majority."""
        # Simple heuristic: choose shortest response (often most focused)
        return min(responses, key=len)


class ChainOfVerification:
    """
    Chain-of-Verification (CoVe) for reducing hallucination.

    1. Generate initial response
    2. Generate verification questions
    3. Answer questions independently
    4. Revise response based on verified facts
    """

    def __init__(self, llm):
        self.llm = llm

    def generate_with_verification(
        self,
        context: str,
        question: str
    ) -> Dict[str, Any]:
        """Generate response with chain of verification."""
        # Step 1: Generate initial response
        initial_prompt = f"""Answer the question based on the context.

Context: {context}

Question: {question}

Answer:"""

        initial_response = self.llm.generate(initial_prompt)

        # Step 2: Generate verification questions
        verification_questions = self._generate_verification_questions(
            initial_response, question
        )

        # Step 3: Answer verification questions
        verified_facts = []
        for vq in verification_questions:
            answer = self._answer_verification_question(vq, context)
            verified_facts.append({
                'question': vq,
                'answer': answer
            })

        # Step 4: Revise response
        final_response = self._revise_response(
            initial_response,
            verified_facts,
            context,
            question
        )

        return {
            'initial_response': initial_response,
            'verification_questions': verification_questions,
            'verified_facts': verified_facts,
            'final_response': final_response
        }

    def _generate_verification_questions(
        self,
        response: str,
        original_question: str
    ) -> List[str]:
        """Generate questions to verify claims in response."""
        prompt = f"""Given this response to a question, generate specific factual questions
that can verify the accuracy of the claims made.

Original question: {original_question}
Response: {response}

Generate 3-5 verification questions (one per line):"""

        questions_text = self.llm.generate(prompt)
        questions = [q.strip() for q in questions_text.split('\n') if q.strip()]
        questions = [re.sub(r'^\d+[\.\)]\s*', '', q) for q in questions]

        return questions[:5]

    def _answer_verification_question(
        self,
        question: str,
        context: str
    ) -> str:
        """Answer verification question using only context."""
        prompt = f"""Answer this question using ONLY the provided context.
If the context doesn't contain the answer, say "Not found in context."

Context: {context}

Question: {question}

Answer:"""

        return self.llm.generate(prompt)

    def _revise_response(
        self,
        initial_response: str,
        verified_facts: List[Dict[str, str]],
        context: str,
        original_question: str
    ) -> str:
        """Revise response based on verified facts."""
        facts_str = "\n".join([
            f"Q: {f['question']}\nA: {f['answer']}"
            for f in verified_facts
        ])

        prompt = f"""Revise the initial response based on the verified facts below.
Remove or correct any information that contradicts the verified facts.
Only include information that is supported by the context.

Original question: {original_question}

Initial response: {initial_response}

Verified facts:
{facts_str}

Context: {context[:2000]}

Revised response:"""

        return self.llm.generate(prompt)

Long Context Handling

Strategies for handling large amounts of retrieved content:

PYTHON
class LongContextHandler:
    """
    Handle long contexts that exceed model limits.

    Strategies:
    1. Truncation with relevance ranking
    2. Map-reduce summarization
    3. Iterative refinement
    4. Hierarchical processing
    """

    def __init__(self, llm, max_context_length: int = 4000):
        self.llm = llm
        self.max_context_length = max_context_length

    def process_long_context(
        self,
        documents: List[Document],
        query: str,
        strategy: str = "map_reduce"
    ) -> str:
        """Process documents that exceed context length."""
        total_length = sum(len(doc.content) for doc in documents)

        if total_length <= self.max_context_length:
            # Fits in context
            return "\n\n".join([doc.content for doc in documents])

        if strategy == "truncate":
            return self._truncate_strategy(documents, query)
        elif strategy == "map_reduce":
            return self._map_reduce_strategy(documents, query)
        elif strategy == "refine":
            return self._refine_strategy(documents, query)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

    def _truncate_strategy(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """Truncate to fit, prioritizing relevant content."""
        # Sort by score if available
        sorted_docs = sorted(
            documents,
            key=lambda d: d.score or 0,
            reverse=True
        )

        context_parts = []
        current_length = 0

        for doc in sorted_docs:
            if current_length + len(doc.content) <= self.max_context_length:
                context_parts.append(doc.content)
                current_length += len(doc.content)
            else:
                # Add partial content
                remaining = self.max_context_length - current_length
                if remaining > 200:
                    context_parts.append(doc.content[:remaining] + "...")
                break

        return "\n\n".join(context_parts)

    def _map_reduce_strategy(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """
        Map-Reduce: Summarize each document, then combine.

        Map: Extract relevant information from each document
        Reduce: Combine summaries into final context
        """
        # Map phase: summarize each document
        summaries = []
        for i, doc in enumerate(documents):
            summary = self._summarize_for_query(doc.content, query)
            summaries.append(f"[Source {i+1}] {summary}")

        # Check if summaries fit
        combined = "\n\n".join(summaries)
        if len(combined) <= self.max_context_length:
            return combined

        # Reduce phase: combine summaries
        return self._reduce_summaries(summaries, query)

    def _summarize_for_query(self, content: str, query: str) -> str:
        """Summarize document focused on query."""
        prompt = f"""Summarize the following text, focusing on information relevant to this question: {query}

Text: {content[:3000]}

Relevant summary:"""

        return self.llm.generate(prompt)

    def _reduce_summaries(self, summaries: List[str], query: str) -> str:
        """Combine summaries into final context."""
        combined = "\n\n".join(summaries)

        prompt = f"""Combine these summaries into a comprehensive context for answering the question.
Remove redundant information while keeping all unique relevant facts.

Question: {query}

Summaries:
{combined[:6000]}

Combined context:"""

        return self.llm.generate(prompt)

    def _refine_strategy(
        self,
        documents: List[Document],
        query: str
    ) -> str:
        """
        Iterative refinement: build answer incrementally.

        Process one document at a time, refining the answer.
        """
        current_answer = ""

        for i, doc in enumerate(documents):
            prompt = f"""Question: {query}

Previous answer: {current_answer if current_answer else "None yet"}

New document [{i+1}]:
{doc.content[:2000]}

Based on this new document, update and improve the answer.
Keep information from previous answer if still relevant.
Add new information from this document.

Updated answer:"""

            current_answer = self.llm.generate(prompt)

        return current_answer


class ContextWindowOptimizer:
    """Optimize context usage within model limits."""

    def __init__(
        self,
        max_tokens: int = 4096,
        reserved_for_response: int = 500,
        reserved_for_prompt: int = 500
    ):
        self.max_tokens = max_tokens
        self.available_for_context = max_tokens - reserved_for_response - reserved_for_prompt

    def optimize_context(
        self,
        documents: List[Document],
        query: str
    ) -> Tuple[List[Document], Dict[str, Any]]:
        """Select and order documents to maximize utility."""
        # Estimate tokens (rough: 1 token ≈ 4 chars)
        def estimate_tokens(text: str) -> int:
            return len(text) // 4

        selected = []
        total_tokens = 0
        stats = {'included': 0, 'excluded': 0, 'truncated': 0}

        # Sort by score
        sorted_docs = sorted(documents, key=lambda d: d.score or 0, reverse=True)

        for doc in sorted_docs:
            doc_tokens = estimate_tokens(doc.content)

            if total_tokens + doc_tokens <= self.available_for_context:
                selected.append(doc)
                total_tokens += doc_tokens
                stats['included'] += 1
            elif total_tokens < self.available_for_context * 0.9:
                # Truncate to fit
                available = (self.available_for_context - total_tokens) * 4
                truncated_doc = Document(
                    content=doc.content[:available] + "...",
                    metadata=doc.metadata,
                    doc_id=doc.doc_id,
                    score=doc.score
                )
                selected.append(truncated_doc)
                stats['truncated'] += 1
                break
            else:
                stats['excluded'] += 1

        stats['total_tokens'] = total_tokens
        stats['utilization'] = total_tokens / self.available_for_context

        return selected, stats

Key Takeaways

RAG generation quality depends heavily on prompt engineering and context handling. Effective strategies include: (1) structured context injection with clear source attribution, (2) explicit grounding instructions that prevent the model from using parametric knowledge, (3) citation requirements that enable verification, (4) hallucination reduction through chain-of-verification and self-consistency, and (5) long context handling via map-reduce or iterative refinement. The prompt should clearly instruct the model to acknowledge when context is insufficient rather than hallucinating. For production systems, implement verification pipelines that check generated claims against source documents, providing confidence scores and flagging potential hallucinations before returning responses to users.

24.4 RAG Evaluation and Optimization Advanced

RAG Evaluation and Optimization

Evaluating RAG systems requires measuring both retrieval quality and generation faithfulness—a retrieval failure means the LLM lacks necessary context, while generation failure means the model ignores or misuses retrieved content. This section covers comprehensive evaluation frameworks, metrics for each RAG component, optimization strategies, and debugging techniques for improving system performance.

Retrieval Evaluation Metrics

Measuring how well retrieval surfaces relevant documents:

PYTHON
import numpy as np
from typing import List, Dict, Set, Optional, Any, Tuple
from dataclasses import dataclass
from collections import defaultdict
import json

@dataclass
class RetrievalEvalResult:
    """Results from retrieval evaluation."""
    precision: float
    recall: float
    f1: float
    mrr: float  # Mean Reciprocal Rank
    ndcg: float  # Normalized Discounted Cumulative Gain
    hit_rate: float  # At least one relevant doc retrieved
    map_score: float  # Mean Average Precision


class RetrievalEvaluator:
    """
    Evaluate retrieval quality against ground truth.

    Metrics:
    - Precision@K: Fraction of retrieved docs that are relevant
    - Recall@K: Fraction of relevant docs that are retrieved
    - MRR: Average reciprocal rank of first relevant doc
    - NDCG: Ranking quality considering graded relevance
    - Hit Rate: Percentage of queries with at least one relevant result
    """

    def __init__(self, k: int = 10):
        self.k = k

    def evaluate(
        self,
        queries: List[str],
        retrieved_doc_ids: List[List[str]],
        relevant_doc_ids: List[Set[str]],
        relevance_scores: Optional[List[Dict[str, float]]] = None
    ) -> RetrievalEvalResult:
        """
        Evaluate retrieval performance.

        Args:
            queries: List of query strings
            retrieved_doc_ids: Retrieved doc IDs for each query (ranked)
            relevant_doc_ids: Ground truth relevant doc IDs for each query
            relevance_scores: Optional graded relevance (for NDCG)
        """
        precisions = []
        recalls = []
        mrrs = []
        ndcgs = []
        hits = []
        aps = []  # Average Precision per query

        for i, (retrieved, relevant) in enumerate(zip(retrieved_doc_ids, relevant_doc_ids)):
            retrieved_k = retrieved[:self.k]
            relevant_set = set(relevant)

            # Precision@K
            relevant_retrieved = sum(1 for d in retrieved_k if d in relevant_set)
            precision = relevant_retrieved / len(retrieved_k) if retrieved_k else 0
            precisions.append(precision)

            # Recall@K
            recall = relevant_retrieved / len(relevant_set) if relevant_set else 0
            recalls.append(recall)

            # MRR (reciprocal rank of first relevant)
            rr = 0
            for rank, doc_id in enumerate(retrieved_k, 1):
                if doc_id in relevant_set:
                    rr = 1 / rank
                    break
            mrrs.append(rr)

            # Hit Rate
            hits.append(1 if relevant_retrieved > 0 else 0)

            # Average Precision
            ap = self._average_precision(retrieved_k, relevant_set)
            aps.append(ap)

            # NDCG
            if relevance_scores and i < len(relevance_scores):
                scores = relevance_scores[i]
            else:
                # Binary relevance
                scores = {d: 1.0 for d in relevant_set}
            ndcg = self._ndcg(retrieved_k, scores)
            ndcgs.append(ndcg)

        precision = np.mean(precisions)
        recall = np.mean(recalls)
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        return RetrievalEvalResult(
            precision=precision,
            recall=recall,
            f1=f1,
            mrr=np.mean(mrrs),
            ndcg=np.mean(ndcgs),
            hit_rate=np.mean(hits),
            map_score=np.mean(aps)
        )

    def _average_precision(
        self,
        retrieved: List[str],
        relevant: Set[str]
    ) -> float:
        """Calculate Average Precision."""
        if not relevant:
            return 0.0

        hits = 0
        sum_precisions = 0

        for rank, doc_id in enumerate(retrieved, 1):
            if doc_id in relevant:
                hits += 1
                precision_at_rank = hits / rank
                sum_precisions += precision_at_rank

        return sum_precisions / len(relevant)

    def _ndcg(
        self,
        retrieved: List[str],
        relevance_scores: Dict[str, float]
    ) -> float:
        """Calculate Normalized Discounted Cumulative Gain."""
        def dcg(scores: List[float]) -> float:
            return sum(
                (2 ** score - 1) / np.log2(rank + 2)
                for rank, score in enumerate(scores)
            )

        # DCG of retrieved ranking
        retrieved_scores = [relevance_scores.get(d, 0) for d in retrieved]
        actual_dcg = dcg(retrieved_scores)

        # Ideal DCG (best possible ranking)
        ideal_scores = sorted(relevance_scores.values(), reverse=True)[:len(retrieved)]
        ideal_dcg = dcg(ideal_scores)

        return actual_dcg / ideal_dcg if ideal_dcg > 0 else 0


class ContextRelevanceEvaluator:
    """
    Evaluate relevance of retrieved context to the query.

    Uses LLM to judge whether retrieved content is useful
    for answering the query.
    """

    def __init__(self, llm):
        self.llm = llm

    def evaluate_relevance(
        self,
        query: str,
        retrieved_contexts: List[str]
    ) -> Dict[str, Any]:
        """Evaluate relevance of each retrieved context."""
        results = {
            'per_context_scores': [],
            'overall_relevance': 0.0,
            'relevant_count': 0
        }

        for i, context in enumerate(retrieved_contexts):
            score = self._score_relevance(query, context)
            results['per_context_scores'].append({
                'index': i,
                'score': score,
                'is_relevant': score >= 0.5
            })

            if score >= 0.5:
                results['relevant_count'] += 1

        if results['per_context_scores']:
            results['overall_relevance'] = np.mean([
                r['score'] for r in results['per_context_scores']
            ])

        return results

    def _score_relevance(self, query: str, context: str) -> float:
        """Score relevance of context to query."""
        prompt = f"""Rate how relevant the following context is for answering the question.
Score from 0 to 1:
- 0: Completely irrelevant
- 0.5: Somewhat relevant, contains related information
- 1: Highly relevant, directly addresses the question

Question: {query}

Context: {context[:1500]}

Score (0-1):"""

        response = self.llm.generate(prompt)

        try:
            import re
            match = re.search(r'([0-9]*\.?[0-9]+)', response)
            if match:
                score = float(match.group(1))
                return min(1.0, max(0.0, score))
        except:
            pass

        return 0.5  # Default if parsing fails

Generation Evaluation Metrics

Measuring answer quality, faithfulness, and completeness:

PYTHON
class GenerationEvaluator:
    """
    Evaluate RAG generation quality.

    Key dimensions:
    - Faithfulness: Does answer stick to retrieved context?
    - Relevance: Does answer address the question?
    - Completeness: Does answer cover all relevant aspects?
    - Coherence: Is the answer well-structured and clear?
    """

    def __init__(self, llm, embedding_model=None):
        self.llm = llm
        self.embedding_model = embedding_model

    def evaluate(
        self,
        query: str,
        answer: str,
        contexts: List[str],
        reference_answer: Optional[str] = None
    ) -> Dict[str, float]:
        """Comprehensive generation evaluation."""
        results = {
            'faithfulness': self._evaluate_faithfulness(answer, contexts),
            'relevance': self._evaluate_relevance(query, answer),
            'completeness': self._evaluate_completeness(query, answer, contexts),
            'coherence': self._evaluate_coherence(answer)
        }

        if reference_answer:
            results['correctness'] = self._evaluate_correctness(
                answer, reference_answer
            )

        # Overall score (weighted average)
        weights = {
            'faithfulness': 0.35,
            'relevance': 0.30,
            'completeness': 0.20,
            'coherence': 0.15
        }

        results['overall'] = sum(
            results[k] * weights.get(k, 0.2)
            for k in ['faithfulness', 'relevance', 'completeness', 'coherence']
        )

        return results

    def _evaluate_faithfulness(
        self,
        answer: str,
        contexts: List[str]
    ) -> float:
        """
        Evaluate if answer is faithful to retrieved context.

        Checks for:
        - Claims not supported by context (hallucination)
        - Contradictions with context
        """
        combined_context = "\n\n".join(contexts)

        prompt = f"""Evaluate whether the answer is faithful to the provided context.

The answer should ONLY contain information that can be inferred from the context.
Penalize for:
- Claims not supported by the context
- Contradictions with the context
- Made-up facts not in the context

Context:
{combined_context[:3000]}

Answer to evaluate:
{answer}

Score the faithfulness from 0 to 1:
- 0: Contains significant hallucinations or contradictions
- 0.5: Mostly faithful with minor unsupported claims
- 1: Completely faithful to the context

Score:"""

        response = self.llm.generate(prompt)
        return self._extract_score(response)

    def _evaluate_relevance(self, query: str, answer: str) -> float:
        """Evaluate if answer addresses the question."""
        prompt = f"""Evaluate whether the answer addresses the question.

Question: {query}

Answer: {answer}

Score the relevance from 0 to 1:
- 0: Answer doesn't address the question at all
- 0.5: Answer partially addresses the question
- 1: Answer directly and fully addresses the question

Score:"""

        response = self.llm.generate(prompt)
        return self._extract_score(response)

    def _evaluate_completeness(
        self,
        query: str,
        answer: str,
        contexts: List[str]
    ) -> float:
        """Evaluate if answer covers all relevant information."""
        combined_context = "\n\n".join(contexts)

        prompt = f"""Evaluate the completeness of the answer.

Question: {query}

Available context:
{combined_context[:2500]}

Answer: {answer}

Does the answer cover all relevant information from the context?
Score from 0 to 1:
- 0: Misses most relevant information
- 0.5: Covers some but not all relevant information
- 1: Comprehensively covers all relevant information

Score:"""

        response = self.llm.generate(prompt)
        return self._extract_score(response)

    def _evaluate_coherence(self, answer: str) -> float:
        """Evaluate answer coherence and clarity."""
        prompt = f"""Evaluate the coherence and clarity of this answer.

Answer: {answer}

Score from 0 to 1:
- 0: Incoherent, poorly structured, hard to understand
- 0.5: Somewhat clear but could be better organized
- 1: Well-structured, clear, and easy to understand

Score:"""

        response = self.llm.generate(prompt)
        return self._extract_score(response)

    def _evaluate_correctness(
        self,
        answer: str,
        reference: str
    ) -> float:
        """Compare answer to reference answer."""
        prompt = f"""Compare the generated answer to the reference answer.

Reference answer: {reference}

Generated answer: {answer}

Score the correctness from 0 to 1:
- 0: Completely incorrect or contradicts reference
- 0.5: Partially correct, some matching information
- 1: Fully correct, matches reference semantically

Score:"""

        response = self.llm.generate(prompt)
        return self._extract_score(response)

    def _extract_score(self, response: str) -> float:
        """Extract numeric score from LLM response."""
        import re
        try:
            match = re.search(r'([0-9]*\.?[0-9]+)', response)
            if match:
                score = float(match.group(1))
                return min(1.0, max(0.0, score))
        except:
            pass
        return 0.5


class RAGASEvaluator:
    """
    RAGAS-style evaluation framework.

    RAGAS metrics:
    - Faithfulness: Answer grounded in context
    - Answer Relevancy: Answer addresses question
    - Context Relevancy: Retrieved context is relevant
    - Context Recall: Important info retrieved
    """

    def __init__(self, llm, embedding_model):
        self.llm = llm
        self.embedding_model = embedding_model

    def evaluate(
        self,
        question: str,
        answer: str,
        contexts: List[str],
        ground_truth: Optional[str] = None
    ) -> Dict[str, float]:
        """Run RAGAS evaluation."""
        results = {}

        # Faithfulness
        results['faithfulness'] = self._faithfulness(answer, contexts)

        # Answer Relevancy
        results['answer_relevancy'] = self._answer_relevancy(question, answer)

        # Context Relevancy
        results['context_relevancy'] = self._context_relevancy(question, contexts)

        # Context Recall (requires ground truth)
        if ground_truth:
            results['context_recall'] = self._context_recall(
                ground_truth, contexts
            )

        # Aggregate score
        scores = list(results.values())
        results['ragas_score'] = np.mean(scores) if scores else 0.0

        return results

    def _faithfulness(self, answer: str, contexts: List[str]) -> float:
        """
        Faithfulness: All claims in answer supported by context.

        1. Extract claims from answer
        2. Check each claim against context
        3. Score = supported_claims / total_claims
        """
        # Extract claims
        claims = self._extract_claims(answer)

        if not claims:
            return 1.0  # No claims = nothing to verify

        # Verify each claim
        supported = 0
        combined_context = "\n\n".join(contexts)

        for claim in claims:
            if self._is_supported(claim, combined_context):
                supported += 1

        return supported / len(claims)

    def _extract_claims(self, answer: str) -> List[str]:
        """Extract factual claims from answer."""
        prompt = f"""Extract all factual claims from this text.
Return each claim on a separate line.
Only include verifiable factual statements.

Text: {answer}

Claims:"""

        response = self.llm.generate(prompt)
        claims = [c.strip() for c in response.split('\n') if c.strip()]
        return claims[:10]  # Limit to 10 claims

    def _is_supported(self, claim: str, context: str) -> bool:
        """Check if claim is supported by context."""
        prompt = f"""Is the following claim supported by the context?

Claim: {claim}

Context: {context[:2500]}

Answer YES if the claim is directly supported or can be inferred from the context.
Answer NO if the claim is not supported or contradicted.

Answer (YES/NO):"""

        response = self.llm.generate(prompt)
        return response.strip().upper().startswith('YES')

    def _answer_relevancy(self, question: str, answer: str) -> float:
        """
        Answer Relevancy: Does answer address the question?

        Generate questions from answer and compare to original.
        """
        # Generate questions that would be answered by this response
        prompt = f"""Generate 3 questions that this answer would address:

Answer: {answer}

Questions (one per line):"""

        response = self.llm.generate(prompt)
        generated_questions = [q.strip() for q in response.split('\n') if q.strip()][:3]

        if not generated_questions:
            return 0.5

        # Calculate similarity between original and generated questions
        original_embedding = self.embedding_model.embed_query(question)

        similarities = []
        for gq in generated_questions:
            gq_embedding = self.embedding_model.embed_query(gq)
            sim = np.dot(original_embedding, gq_embedding)
            similarities.append(sim)

        return np.mean(similarities)

    def _context_relevancy(
        self,
        question: str,
        contexts: List[str]
    ) -> float:
        """
        Context Relevancy: Retrieved context is relevant to question.

        Measure: relevant_sentences / total_sentences
        """
        # For each context, extract relevant sentences
        total_sentences = 0
        relevant_sentences = 0

        for context in contexts:
            sentences = [s.strip() for s in context.split('.') if s.strip()]
            total_sentences += len(sentences)

            for sentence in sentences:
                if len(sentence) < 10:
                    continue

                prompt = f"""Is this sentence relevant to answering the question?

Question: {question}
Sentence: {sentence}

Answer YES or NO:"""

                response = self.llm.generate(prompt)
                if response.strip().upper().startswith('YES'):
                    relevant_sentences += 1

        return relevant_sentences / total_sentences if total_sentences > 0 else 0

    def _context_recall(
        self,
        ground_truth: str,
        contexts: List[str]
    ) -> float:
        """
        Context Recall: Important info from ground truth is in context.

        Check if facts in ground truth can be found in contexts.
        """
        # Extract key facts from ground truth
        facts = self._extract_claims(ground_truth)

        if not facts:
            return 1.0

        combined_context = "\n\n".join(contexts)

        found = sum(
            1 for fact in facts
            if self._is_supported(fact, combined_context)
        )

        return found / len(facts)

End-to-End Evaluation

Comprehensive evaluation of the entire RAG pipeline:

PYTHON
class RAGPipelineEvaluator:
    """
    End-to-end RAG pipeline evaluation.

    Evaluates:
    1. Retrieval quality
    2. Context utilization
    3. Generation quality
    4. Overall system performance
    """

    def __init__(
        self,
        rag_system,
        retrieval_evaluator: RetrievalEvaluator,
        generation_evaluator: GenerationEvaluator
    ):
        self.rag_system = rag_system
        self.retrieval_evaluator = retrieval_evaluator
        self.generation_evaluator = generation_evaluator

    def evaluate_dataset(
        self,
        eval_dataset: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """
        Evaluate RAG system on a dataset.

        Dataset format:
        [
            {
                'question': 'What is...',
                'ground_truth_answer': 'The answer is...',
                'relevant_doc_ids': ['doc1', 'doc2']
            },
            ...
        ]
        """
        results = {
            'retrieval': defaultdict(list),
            'generation': defaultdict(list),
            'e2e': defaultdict(list),
            'examples': []
        }

        for item in eval_dataset:
            question = item['question']
            ground_truth = item.get('ground_truth_answer')
            relevant_docs = set(item.get('relevant_doc_ids', []))

            # Run RAG
            answer, retrieved_results = self.rag_system.query(question)

            # Retrieval evaluation
            retrieved_ids = [r.document.doc_id for r in retrieved_results]
            retrieved_contexts = [r.document.content for r in retrieved_results]

            if relevant_docs:
                ret_metrics = self._evaluate_retrieval(retrieved_ids, relevant_docs)
                for k, v in ret_metrics.items():
                    results['retrieval'][k].append(v)

            # Generation evaluation
            gen_metrics = self.generation_evaluator.evaluate(
                question, answer, retrieved_contexts, ground_truth
            )
            for k, v in gen_metrics.items():
                results['generation'][k].append(v)

            # End-to-end correctness
            if ground_truth:
                e2e_correct = self._evaluate_e2e_correctness(answer, ground_truth)
                results['e2e']['correctness'].append(e2e_correct)

            # Store example
            results['examples'].append({
                'question': question,
                'answer': answer,
                'ground_truth': ground_truth,
                'n_retrieved': len(retrieved_results),
                'gen_scores': gen_metrics
            })

        # Aggregate results
        summary = {
            'retrieval': {k: np.mean(v) for k, v in results['retrieval'].items()},
            'generation': {k: np.mean(v) for k, v in results['generation'].items()},
            'e2e': {k: np.mean(v) for k, v in results['e2e'].items()},
            'n_examples': len(eval_dataset)
        }

        return {'summary': summary, 'detailed': results}

    def _evaluate_retrieval(
        self,
        retrieved_ids: List[str],
        relevant_ids: Set[str]
    ) -> Dict[str, float]:
        """Evaluate retrieval for single query."""
        retrieved_set = set(retrieved_ids[:10])

        precision = len(retrieved_set & relevant_ids) / len(retrieved_set) if retrieved_set else 0
        recall = len(retrieved_set & relevant_ids) / len(relevant_ids) if relevant_ids else 0

        # MRR
        mrr = 0
        for rank, doc_id in enumerate(retrieved_ids, 1):
            if doc_id in relevant_ids:
                mrr = 1 / rank
                break

        return {
            'precision': precision,
            'recall': recall,
            'mrr': mrr,
            'hit': 1 if (retrieved_set & relevant_ids) else 0
        }

    def _evaluate_e2e_correctness(
        self,
        answer: str,
        ground_truth: str
    ) -> float:
        """Evaluate end-to-end answer correctness."""
        prompt = f"""Compare the generated answer to the ground truth.

Ground truth: {ground_truth}

Generated answer: {answer}

Is the generated answer correct and complete?
Score from 0 to 1:
- 0: Incorrect or missing key information
- 0.5: Partially correct
- 1: Fully correct

Score:"""

        response = self.rag_system.llm.generate(prompt)

        try:
            import re
            match = re.search(r'([0-9]*\.?[0-9]+)', response)
            if match:
                return min(1.0, max(0.0, float(match.group(1))))
        except:
            pass
        return 0.5


class ABTestEvaluator:
    """
    A/B testing framework for RAG configurations.

    Compare different:
    - Chunking strategies
    - Embedding models
    - Retrieval methods
    - Prompt templates
    """

    def __init__(self, evaluator: RAGPipelineEvaluator):
        self.evaluator = evaluator

    def compare_configs(
        self,
        config_a: Dict[str, Any],
        config_b: Dict[str, Any],
        eval_dataset: List[Dict[str, Any]],
        metrics: List[str] = None
    ) -> Dict[str, Any]:
        """
        Compare two RAG configurations.

        Returns statistical comparison of performance.
        """
        metrics = metrics or ['faithfulness', 'relevance', 'correctness']

        # Evaluate both configs
        # Note: In practice, you'd build RAG systems from configs
        results_a = self.evaluator.evaluate_dataset(eval_dataset)
        results_b = self.evaluator.evaluate_dataset(eval_dataset)

        comparison = {
            'config_a': config_a,
            'config_b': config_b,
            'metrics': {}
        }

        for metric in metrics:
            scores_a = results_a['detailed']['generation'].get(metric, [])
            scores_b = results_b['detailed']['generation'].get(metric, [])

            if scores_a and scores_b:
                comparison['metrics'][metric] = {
                    'mean_a': np.mean(scores_a),
                    'mean_b': np.mean(scores_b),
                    'std_a': np.std(scores_a),
                    'std_b': np.std(scores_b),
                    'improvement': np.mean(scores_b) - np.mean(scores_a),
                    'significant': self._significance_test(scores_a, scores_b)
                }

        return comparison

    def _significance_test(
        self,
        scores_a: List[float],
        scores_b: List[float],
        alpha: float = 0.05
    ) -> bool:
        """Paired t-test for significance."""
        from scipy import stats

        if len(scores_a) != len(scores_b) or len(scores_a) < 3:
            return False

        _, p_value = stats.ttest_rel(scores_a, scores_b)
        return p_value < alpha

Optimization Strategies

Techniques to improve RAG performance:

PYTHON
class ChunkingOptimizer:
    """
    Optimize chunking parameters for RAG quality.

    Parameters to optimize:
    - Chunk size
    - Chunk overlap
    - Chunking strategy (fixed, recursive, semantic)
    """

    def __init__(
        self,
        rag_system_builder,
        evaluator: RAGPipelineEvaluator
    ):
        self.build_system = rag_system_builder
        self.evaluator = evaluator

    def grid_search(
        self,
        documents: List[str],
        eval_dataset: List[Dict[str, Any]],
        chunk_sizes: List[int] = [256, 512, 1024, 2048],
        overlaps: List[int] = [0, 50, 100, 200]
    ) -> Dict[str, Any]:
        """Grid search over chunking parameters."""
        results = []

        for chunk_size in chunk_sizes:
            for overlap in overlaps:
                if overlap >= chunk_size:
                    continue

                # Build system with these parameters
                rag_system = self.build_system(
                    documents,
                    chunk_size=chunk_size,
                    chunk_overlap=overlap
                )

                # Evaluate
                eval_results = self.evaluator.evaluate_dataset(eval_dataset)

                results.append({
                    'chunk_size': chunk_size,
                    'overlap': overlap,
                    'retrieval_recall': eval_results['summary']['retrieval'].get('recall', 0),
                    'faithfulness': eval_results['summary']['generation'].get('faithfulness', 0),
                    'overall': eval_results['summary']['generation'].get('overall', 0)
                })

        # Find best configuration
        best = max(results, key=lambda x: x['overall'])

        return {
            'all_results': results,
            'best_config': best,
            'recommendation': f"chunk_size={best['chunk_size']}, overlap={best['overlap']}"
        }


class RetrievalOptimizer:
    """Optimize retrieval parameters."""

    def __init__(self, rag_system, evaluator):
        self.rag_system = rag_system
        self.evaluator = evaluator

    def optimize_top_k(
        self,
        eval_dataset: List[Dict[str, Any]],
        k_values: List[int] = [3, 5, 10, 15, 20]
    ) -> Dict[str, Any]:
        """Find optimal number of retrieved documents."""
        results = []

        for k in k_values:
            self.rag_system.top_k = k
            eval_results = self.evaluator.evaluate_dataset(eval_dataset)

            results.append({
                'top_k': k,
                'recall': eval_results['summary']['retrieval'].get('recall', 0),
                'precision': eval_results['summary']['retrieval'].get('precision', 0),
                'faithfulness': eval_results['summary']['generation'].get('faithfulness', 0),
                'relevance': eval_results['summary']['generation'].get('relevance', 0)
            })

        # Analyze trade-offs
        analysis = self._analyze_k_tradeoffs(results)

        return {
            'results': results,
            'analysis': analysis
        }

    def _analyze_k_tradeoffs(self, results: List[Dict]) -> Dict[str, Any]:
        """Analyze recall vs precision trade-off."""
        # Find knee point (diminishing returns)
        recalls = [r['recall'] for r in results]
        ks = [r['top_k'] for r in results]

        # Simple heuristic: biggest improvement per k
        improvements = []
        for i in range(1, len(recalls)):
            improvement = (recalls[i] - recalls[i-1]) / (ks[i] - ks[i-1])
            improvements.append(improvement)

        # Find where improvement drops significantly
        if improvements:
            threshold = max(improvements) * 0.3
            for i, imp in enumerate(improvements):
                if imp < threshold:
                    recommended_k = ks[i]
                    break
            else:
                recommended_k = ks[-1]
        else:
            recommended_k = ks[0]

        return {
            'recommended_k': recommended_k,
            'reasoning': f"Diminishing returns observed after k={recommended_k}"
        }


class RAGDebugger:
    """Debug RAG pipeline failures."""

    def __init__(self, rag_system, llm):
        self.rag_system = rag_system
        self.llm = llm

    def diagnose_failure(
        self,
        question: str,
        expected_answer: str,
        actual_answer: str,
        retrieved_docs: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Diagnose why RAG failed for a specific example."""
        diagnosis = {
            'question': question,
            'expected': expected_answer,
            'actual': actual_answer,
            'failure_type': None,
            'root_cause': None,
            'suggestions': []
        }

        # Check retrieval
        retrieval_quality = self._check_retrieval(question, expected_answer, retrieved_docs)

        if retrieval_quality['has_relevant'] == False:
            diagnosis['failure_type'] = 'retrieval_failure'
            diagnosis['root_cause'] = 'Relevant documents not retrieved'
            diagnosis['suggestions'] = [
                'Check if relevant documents exist in index',
                'Try different embedding model',
                'Consider query expansion',
                'Review chunking strategy'
            ]
        elif retrieval_quality['relevant_rank'] > 5:
            diagnosis['failure_type'] = 'ranking_failure'
            diagnosis['root_cause'] = f"Relevant doc at rank {retrieval_quality['relevant_rank']}"
            diagnosis['suggestions'] = [
                'Implement reranking',
                'Increase top_k',
                'Fine-tune embedding model'
            ]
        else:
            # Retrieval ok, check generation
            gen_failure = self._check_generation(
                question, expected_answer, actual_answer, retrieved_docs
            )
            diagnosis['failure_type'] = gen_failure['type']
            diagnosis['root_cause'] = gen_failure['cause']
            diagnosis['suggestions'] = gen_failure['suggestions']

        diagnosis['retrieval_analysis'] = retrieval_quality

        return diagnosis

    def _check_retrieval(
        self,
        question: str,
        expected_answer: str,
        retrieved_docs: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Check if retrieval found relevant documents."""
        relevant_rank = None

        for i, doc in enumerate(retrieved_docs):
            # Check if doc content could answer the question
            prompt = f"""Does this document contain information to answer the question?

Question: {question}
Expected answer should include: {expected_answer[:200]}

Document: {doc.get('content', '')[:500]}

Answer YES or NO:"""

            response = self.llm.generate(prompt)
            if response.strip().upper().startswith('YES'):
                relevant_rank = i + 1
                break

        return {
            'has_relevant': relevant_rank is not None,
            'relevant_rank': relevant_rank,
            'n_retrieved': len(retrieved_docs)
        }

    def _check_generation(
        self,
        question: str,
        expected: str,
        actual: str,
        docs: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Diagnose generation failure."""
        context = "\n".join([d.get('content', '')[:500] for d in docs[:3]])

        prompt = f"""Analyze why the generated answer differs from the expected answer.

Question: {question}
Expected answer: {expected}
Generated answer: {actual}
Context used: {context[:1000]}

What went wrong? Choose one:
1. HALLUCINATION - Answer includes information not in context
2. INCOMPLETE - Answer missing key information from context
3. MISUNDERSTOOD - Answer misinterpreted the question
4. WRONG_FOCUS - Answer focused on wrong aspect

Answer with the category and brief explanation:"""

        analysis = self.llm.generate(prompt)

        failure_types = {
            'HALLUCINATION': {
                'type': 'hallucination',
                'cause': 'Model added unsupported information',
                'suggestions': ['Strengthen grounding instructions', 'Add chain-of-verification']
            },
            'INCOMPLETE': {
                'type': 'incomplete',
                'cause': 'Model missed relevant context',
                'suggestions': ['Improve context formatting', 'Highlight key information']
            },
            'MISUNDERSTOOD': {
                'type': 'misunderstanding',
                'cause': 'Model misinterpreted the question',
                'suggestions': ['Rephrase question', 'Add clarification']
            },
            'WRONG_FOCUS': {
                'type': 'wrong_focus',
                'cause': 'Model focused on irrelevant aspect',
                'suggestions': ['Improve retrieval relevance', 'Add query decomposition']
            }
        }

        for key, value in failure_types.items():
            if key in analysis.upper():
                return value

        return {
            'type': 'unknown',
            'cause': analysis,
            'suggestions': ['Manual investigation needed']
        }

Key Takeaways

RAG evaluation requires measuring both retrieval and generation quality. Key retrieval metrics include precision, recall, MRR, and NDCG; generation metrics focus on faithfulness, relevance, and completeness. The RAGAS framework provides a comprehensive evaluation approach. Optimization involves: (1) grid search over chunking parameters to balance granularity vs. context, (2) tuning top-k for recall/precision trade-off, (3) comparing embedding models and retrieval strategies via A/B testing. Debugging RAG failures requires diagnosing whether issues stem from retrieval (relevant docs not found), ranking (relevant docs ranked low), or generation (model ignoring/misusing context). Systematic evaluation and iterative optimization are essential for building high-quality RAG systems.

24.5 Production RAG Systems Advanced

Production RAG Systems

Moving RAG from prototype to production requires addressing scalability, reliability, observability, and operational concerns. Production systems must handle concurrent users, maintain index freshness, provide consistent latency, and operate cost-effectively. This section covers architectural patterns, caching strategies, monitoring, and operational best practices for deploying RAG at scale.

Scalable Architecture Patterns

Designing RAG systems for production workloads:

PYTHON
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import asyncio
import hashlib
import json
import time
from datetime import datetime, timedelta
from collections import defaultdict
import logging

logger = logging.getLogger(__name__)


@dataclass
class RAGConfig:
    """Production RAG configuration."""
    # Retrieval settings
    embedding_model: str = "text-embedding-3-small"
    vector_store: str = "pinecone"
    top_k: int = 10
    rerank_top_k: int = 5

    # Generation settings
    llm_model: str = "gpt-4-turbo"
    max_tokens: int = 1000
    temperature: float = 0.1

    # Performance settings
    timeout_seconds: float = 30.0
    max_retries: int = 3
    cache_ttl_seconds: int = 3600

    # Index settings
    chunk_size: int = 512
    chunk_overlap: int = 50


class ProductionRAGSystem:
    """
    Production-ready RAG system with:
    - Async processing for high throughput
    - Caching for reduced latency/cost
    - Circuit breakers for fault tolerance
    - Comprehensive logging and metrics
    """

    def __init__(
        self,
        config: RAGConfig,
        embedding_service: 'EmbeddingService',
        vector_store: 'VectorStore',
        llm_service: 'LLMService',
        cache: 'CacheService',
        metrics: 'MetricsCollector'
    ):
        self.config = config
        self.embedding_service = embedding_service
        self.vector_store = vector_store
        self.llm_service = llm_service
        self.cache = cache
        self.metrics = metrics

        # Circuit breakers
        self.embedding_breaker = CircuitBreaker("embedding", failure_threshold=5)
        self.llm_breaker = CircuitBreaker("llm", failure_threshold=3)

    async def query(
        self,
        question: str,
        user_id: Optional[str] = None,
        filters: Optional[Dict[str, Any]] = None,
        use_cache: bool = True
    ) -> Dict[str, Any]:
        """
        Process RAG query with production safeguards.
        """
        request_id = self._generate_request_id(question)
        start_time = time.time()

        try:
            # Check cache first
            if use_cache:
                cached = await self._check_cache(question, filters)
                if cached:
                    self.metrics.increment("cache_hit")
                    return {**cached, 'request_id': request_id, 'cached': True}
                self.metrics.increment("cache_miss")

            # Retrieve context
            with self.metrics.timer("retrieval_latency"):
                contexts = await self._retrieve_with_fallback(question, filters)

            # Generate response
            with self.metrics.timer("generation_latency"):
                response = await self._generate_with_fallback(question, contexts)

            # Build result
            result = {
                'request_id': request_id,
                'answer': response['answer'],
                'sources': response['sources'],
                'confidence': response.get('confidence', 0.5),
                'latency_ms': (time.time() - start_time) * 1000,
                'cached': False
            }

            # Cache result
            if use_cache:
                await self._cache_result(question, filters, result)

            # Log success
            self.metrics.increment("query_success")
            logger.info(f"Query {request_id} completed in {result['latency_ms']:.0f}ms")

            return result

        except Exception as e:
            self.metrics.increment("query_error")
            logger.error(f"Query {request_id} failed: {str(e)}")

            return {
                'request_id': request_id,
                'answer': "I'm sorry, I encountered an error processing your request.",
                'error': str(e),
                'latency_ms': (time.time() - start_time) * 1000
            }

    async def _retrieve_with_fallback(
        self,
        question: str,
        filters: Optional[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """Retrieve with circuit breaker and fallback."""
        if not self.embedding_breaker.is_closed():
            logger.warning("Embedding circuit open, using fallback")
            return await self._keyword_fallback(question)

        try:
            async with asyncio.timeout(self.config.timeout_seconds / 2):
                # Embed query
                embedding = await self.embedding_service.embed_async(question)

                # Search vector store
                results = await self.vector_store.search_async(
                    embedding,
                    top_k=self.config.top_k,
                    filters=filters
                )

                self.embedding_breaker.record_success()
                return results

        except Exception as e:
            self.embedding_breaker.record_failure()
            logger.error(f"Retrieval failed: {e}")
            return await self._keyword_fallback(question)

    async def _generate_with_fallback(
        self,
        question: str,
        contexts: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Generate with circuit breaker and retry."""
        if not self.llm_breaker.is_closed():
            return self._generate_fallback_response(question)

        for attempt in range(self.config.max_retries):
            try:
                async with asyncio.timeout(self.config.timeout_seconds):
                    response = await self.llm_service.generate_async(
                        question=question,
                        contexts=contexts,
                        max_tokens=self.config.max_tokens,
                        temperature=self.config.temperature
                    )

                    self.llm_breaker.record_success()
                    return response

            except asyncio.TimeoutError:
                logger.warning(f"Generation timeout, attempt {attempt + 1}")
                self.metrics.increment("generation_timeout")
            except Exception as e:
                logger.error(f"Generation error: {e}")
                self.llm_breaker.record_failure()

        return self._generate_fallback_response(question)

    async def _keyword_fallback(self, question: str) -> List[Dict[str, Any]]:
        """Fallback to keyword search when embedding fails."""
        self.metrics.increment("keyword_fallback")
        # Implement BM25 or simple keyword search
        return []

    def _generate_fallback_response(self, question: str) -> Dict[str, Any]:
        """Fallback response when LLM unavailable."""
        self.metrics.increment("llm_fallback")
        return {
            'answer': "I'm temporarily unable to generate a detailed response. Please try again shortly.",
            'sources': [],
            'confidence': 0.0
        }

    async def _check_cache(
        self,
        question: str,
        filters: Optional[Dict[str, Any]]
    ) -> Optional[Dict[str, Any]]:
        """Check cache for existing response."""
        cache_key = self._generate_cache_key(question, filters)
        return await self.cache.get_async(cache_key)

    async def _cache_result(
        self,
        question: str,
        filters: Optional[Dict[str, Any]],
        result: Dict[str, Any]
    ) -> None:
        """Cache response for future queries."""
        cache_key = self._generate_cache_key(question, filters)
        await self.cache.set_async(
            cache_key,
            result,
            ttl=self.config.cache_ttl_seconds
        )

    def _generate_cache_key(
        self,
        question: str,
        filters: Optional[Dict[str, Any]]
    ) -> str:
        """Generate deterministic cache key."""
        key_data = {
            'q': question.lower().strip(),
            'f': json.dumps(filters, sort_keys=True) if filters else None,
            'v': 1  # Cache version for invalidation
        }
        key_string = json.dumps(key_data, sort_keys=True)
        return f"rag:{hashlib.sha256(key_string.encode()).hexdigest()[:16]}"

    def _generate_request_id(self, question: str) -> str:
        """Generate unique request ID."""
        timestamp = int(time.time() * 1000)
        content_hash = hashlib.md5(question.encode()).hexdigest()[:8]
        return f"req_{timestamp}_{content_hash}"


class CircuitBreaker:
    """Circuit breaker for fault tolerance."""

    def __init__(
        self,
        name: str,
        failure_threshold: int = 5,
        recovery_timeout: int = 60
    ):
        self.name = name
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout

        self.failure_count = 0
        self.last_failure_time = None
        self.state = "closed"  # closed, open, half-open

    def is_closed(self) -> bool:
        """Check if circuit is closed (allowing requests)."""
        if self.state == "closed":
            return True

        if self.state == "open":
            # Check if recovery timeout elapsed
            if self.last_failure_time:
                elapsed = time.time() - self.last_failure_time
                if elapsed >= self.recovery_timeout:
                    self.state = "half-open"
                    return True

        return self.state == "half-open"

    def record_success(self):
        """Record successful request."""
        self.failure_count = 0
        self.state = "closed"

    def record_failure(self):
        """Record failed request."""
        self.failure_count += 1
        self.last_failure_time = time.time()

        if self.failure_count >= self.failure_threshold:
            self.state = "open"
            logger.warning(f"Circuit {self.name} opened after {self.failure_count} failures")

Caching Strategies

Multi-level caching for performance and cost optimization:

PYTHON
class CacheService(ABC):
    """Abstract cache service interface."""

    @abstractmethod
    async def get_async(self, key: str) -> Optional[Any]:
        pass

    @abstractmethod
    async def set_async(self, key: str, value: Any, ttl: int) -> None:
        pass

    @abstractmethod
    async def delete_async(self, key: str) -> None:
        pass


class MultiLevelCache(CacheService):
    """
    Multi-level caching strategy:
    L1: In-memory (fastest, limited size)
    L2: Redis (fast, shared across instances)
    L3: Database (persistent, unlimited)
    """

    def __init__(
        self,
        l1_cache: 'InMemoryCache',
        l2_cache: 'RedisCache',
        l3_cache: Optional['DatabaseCache'] = None
    ):
        self.l1 = l1_cache
        self.l2 = l2_cache
        self.l3 = l3_cache

    async def get_async(self, key: str) -> Optional[Any]:
        """Get from cache, checking each level."""
        # L1: In-memory
        value = self.l1.get(key)
        if value is not None:
            return value

        # L2: Redis
        value = await self.l2.get_async(key)
        if value is not None:
            # Populate L1
            self.l1.set(key, value, ttl=300)
            return value

        # L3: Database (if configured)
        if self.l3:
            value = await self.l3.get_async(key)
            if value is not None:
                # Populate L1 and L2
                self.l1.set(key, value, ttl=300)
                await self.l2.set_async(key, value, ttl=3600)
                return value

        return None

    async def set_async(self, key: str, value: Any, ttl: int) -> None:
        """Set in all cache levels."""
        # L1: Short TTL
        self.l1.set(key, value, ttl=min(ttl, 300))

        # L2: Medium TTL
        await self.l2.set_async(key, value, ttl=ttl)

        # L3: Long TTL (if configured)
        if self.l3:
            await self.l3.set_async(key, value, ttl=ttl * 24)

    async def delete_async(self, key: str) -> None:
        """Delete from all levels."""
        self.l1.delete(key)
        await self.l2.delete_async(key)
        if self.l3:
            await self.l3.delete_async(key)


class InMemoryCache:
    """LRU in-memory cache."""

    def __init__(self, max_size: int = 1000):
        self.max_size = max_size
        self.cache: Dict[str, Tuple[Any, float]] = {}  # key -> (value, expiry)

    def get(self, key: str) -> Optional[Any]:
        if key in self.cache:
            value, expiry = self.cache[key]
            if time.time() < expiry:
                return value
            else:
                del self.cache[key]
        return None

    def set(self, key: str, value: Any, ttl: int) -> None:
        # Evict if at capacity
        if len(self.cache) >= self.max_size:
            self._evict_oldest()

        self.cache[key] = (value, time.time() + ttl)

    def delete(self, key: str) -> None:
        self.cache.pop(key, None)

    def _evict_oldest(self):
        """Evict oldest entry."""
        if self.cache:
            oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k][1])
            del self.cache[oldest_key]


class SemanticCache:
    """
    Semantic caching for similar queries.

    Instead of exact key match, finds semantically similar
    cached queries and returns their results.
    """

    def __init__(
        self,
        embedding_model,
        similarity_threshold: float = 0.95,
        max_entries: int = 10000
    ):
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
        self.max_entries = max_entries

        self.entries: List[Dict[str, Any]] = []
        self.embeddings: Optional[np.ndarray] = None

    async def get_async(self, query: str) -> Optional[Dict[str, Any]]:
        """Find semantically similar cached query."""
        if not self.entries:
            return None

        # Embed query
        query_embedding = await self.embedding_model.embed_async(query)

        # Find most similar
        similarities = self.embeddings @ query_embedding

        max_idx = np.argmax(similarities)
        max_sim = similarities[max_idx]

        if max_sim >= self.similarity_threshold:
            entry = self.entries[max_idx]

            # Check expiry
            if time.time() < entry['expiry']:
                logger.info(f"Semantic cache hit with similarity {max_sim:.3f}")
                return entry['value']

        return None

    async def set_async(
        self,
        query: str,
        value: Dict[str, Any],
        ttl: int
    ) -> None:
        """Cache query and response."""
        # Embed query
        query_embedding = await self.embedding_model.embed_async(query)

        entry = {
            'query': query,
            'value': value,
            'expiry': time.time() + ttl,
            'embedding': query_embedding
        }

        self.entries.append(entry)

        # Update embedding matrix
        if self.embeddings is None:
            self.embeddings = query_embedding.reshape(1, -1)
        else:
            self.embeddings = np.vstack([self.embeddings, query_embedding])

        # Evict if over capacity
        if len(self.entries) > self.max_entries:
            self._evict_expired_or_oldest()

    def _evict_expired_or_oldest(self):
        """Remove expired entries or oldest."""
        current_time = time.time()

        # Remove expired
        valid_indices = [
            i for i, e in enumerate(self.entries)
            if e['expiry'] > current_time
        ]

        if len(valid_indices) < len(self.entries):
            self.entries = [self.entries[i] for i in valid_indices]
            self.embeddings = self.embeddings[valid_indices]

        # Remove oldest if still over capacity
        while len(self.entries) > self.max_entries:
            self.entries.pop(0)
            self.embeddings = self.embeddings[1:]


class EmbeddingCache:
    """Cache embeddings to reduce API calls."""

    def __init__(self, cache_service: CacheService):
        self.cache = cache_service

    async def get_or_compute(
        self,
        text: str,
        embedding_fn: callable
    ) -> np.ndarray:
        """Get cached embedding or compute and cache."""
        cache_key = f"emb:{hashlib.md5(text.encode()).hexdigest()}"

        cached = await self.cache.get_async(cache_key)
        if cached is not None:
            return np.array(cached)

        embedding = await embedding_fn(text)

        # Cache for 24 hours (embeddings don't change)
        await self.cache.set_async(cache_key, embedding.tolist(), ttl=86400)

        return embedding

Monitoring and Observability

Comprehensive monitoring for production RAG:

PYTHON
class MetricsCollector:
    """Collect and report RAG metrics."""

    def __init__(self):
        self.counters: Dict[str, int] = defaultdict(int)
        self.timers: Dict[str, List[float]] = defaultdict(list)
        self.gauges: Dict[str, float] = {}

    def increment(self, name: str, value: int = 1) -> None:
        """Increment a counter."""
        self.counters[name] += value

    def gauge(self, name: str, value: float) -> None:
        """Set a gauge value."""
        self.gauges[name] = value

    def timer(self, name: str) -> 'TimerContext':
        """Context manager for timing operations."""
        return TimerContext(self, name)

    def record_time(self, name: str, duration: float) -> None:
        """Record a timing measurement."""
        self.timers[name].append(duration)

    def get_summary(self) -> Dict[str, Any]:
        """Get metrics summary."""
        summary = {
            'counters': dict(self.counters),
            'gauges': dict(self.gauges),
            'timers': {}
        }

        for name, times in self.timers.items():
            if times:
                summary['timers'][name] = {
                    'count': len(times),
                    'mean': np.mean(times),
                    'p50': np.percentile(times, 50),
                    'p95': np.percentile(times, 95),
                    'p99': np.percentile(times, 99),
                    'max': max(times)
                }

        return summary


class TimerContext:
    """Context manager for timing."""

    def __init__(self, metrics: MetricsCollector, name: str):
        self.metrics = metrics
        self.name = name
        self.start_time = None

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, *args):
        duration = (time.time() - self.start_time) * 1000  # ms
        self.metrics.record_time(self.name, duration)


class RAGMonitor:
    """
    Production monitoring for RAG systems.

    Tracks:
    - Query latency and throughput
    - Retrieval quality metrics
    - Generation quality metrics
    - Error rates and types
    - Resource utilization
    """

    def __init__(
        self,
        metrics: MetricsCollector,
        alert_service: Optional['AlertService'] = None
    ):
        self.metrics = metrics
        self.alert_service = alert_service

        # Thresholds for alerts
        self.latency_threshold_ms = 5000
        self.error_rate_threshold = 0.05
        self.cache_hit_threshold = 0.3

    def record_query(
        self,
        request_id: str,
        latency_ms: float,
        success: bool,
        cached: bool,
        retrieval_count: int,
        confidence: float
    ) -> None:
        """Record query metrics."""
        self.metrics.record_time("query_latency", latency_ms)
        self.metrics.increment("total_queries")

        if success:
            self.metrics.increment("successful_queries")
        else:
            self.metrics.increment("failed_queries")

        if cached:
            self.metrics.increment("cached_queries")

        self.metrics.gauge("last_retrieval_count", retrieval_count)
        self.metrics.gauge("last_confidence", confidence)

        # Check for alerts
        self._check_alerts(latency_ms, success)

    def record_retrieval(
        self,
        query: str,
        num_results: int,
        latency_ms: float,
        top_score: float
    ) -> None:
        """Record retrieval metrics."""
        self.metrics.record_time("retrieval_latency", latency_ms)
        self.metrics.gauge("retrieval_results", num_results)
        self.metrics.gauge("retrieval_top_score", top_score)

    def record_generation(
        self,
        input_tokens: int,
        output_tokens: int,
        latency_ms: float
    ) -> None:
        """Record generation metrics."""
        self.metrics.record_time("generation_latency", latency_ms)
        self.metrics.increment("total_input_tokens", input_tokens)
        self.metrics.increment("total_output_tokens", output_tokens)

    def _check_alerts(self, latency_ms: float, success: bool) -> None:
        """Check if alerts should be triggered."""
        if not self.alert_service:
            return

        # Latency alert
        if latency_ms > self.latency_threshold_ms:
            self.alert_service.send_alert(
                severity="warning",
                message=f"High query latency: {latency_ms:.0f}ms"
            )

        # Error rate alert
        total = self.metrics.counters["total_queries"]
        failed = self.metrics.counters["failed_queries"]

        if total > 100:
            error_rate = failed / total
            if error_rate > self.error_rate_threshold:
                self.alert_service.send_alert(
                    severity="critical",
                    message=f"High error rate: {error_rate:.1%}"
                )

    def get_health_status(self) -> Dict[str, Any]:
        """Get system health status."""
        summary = self.metrics.get_summary()

        total_queries = summary['counters'].get('total_queries', 0)
        failed_queries = summary['counters'].get('failed_queries', 0)
        cached_queries = summary['counters'].get('cached_queries', 0)

        error_rate = failed_queries / total_queries if total_queries > 0 else 0
        cache_hit_rate = cached_queries / total_queries if total_queries > 0 else 0

        latency_stats = summary['timers'].get('query_latency', {})

        return {
            'status': 'healthy' if error_rate < 0.1 else 'degraded',
            'total_queries': total_queries,
            'error_rate': error_rate,
            'cache_hit_rate': cache_hit_rate,
            'latency_p50_ms': latency_stats.get('p50', 0),
            'latency_p95_ms': latency_stats.get('p95', 0),
            'latency_p99_ms': latency_stats.get('p99', 0)
        }


class QueryLogger:
    """Structured logging for RAG queries."""

    def __init__(self, log_level: str = "INFO"):
        self.logger = logging.getLogger("rag_queries")
        self.logger.setLevel(getattr(logging, log_level))

    def log_query(
        self,
        request_id: str,
        question: str,
        answer: str,
        contexts: List[str],
        latency_ms: float,
        metadata: Optional[Dict[str, Any]] = None
    ) -> None:
        """Log query details for debugging and analysis."""
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'request_id': request_id,
            'question': question[:200],
            'answer_length': len(answer),
            'num_contexts': len(contexts),
            'latency_ms': latency_ms,
            'metadata': metadata or {}
        }

        self.logger.info(json.dumps(log_entry))

    def log_error(
        self,
        request_id: str,
        error: Exception,
        context: Dict[str, Any]
    ) -> None:
        """Log error details."""
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'request_id': request_id,
            'error_type': type(error).__name__,
            'error_message': str(error),
            'context': context
        }

        self.logger.error(json.dumps(log_entry))

Index Management

Managing document indexes in production:

PYTHON
class IndexManager:
    """
    Manage RAG index lifecycle:
    - Initial indexing
    - Incremental updates
    - Index refresh/rebuild
    - Version management
    """

    def __init__(
        self,
        vector_store,
        embedding_service,
        chunker,
        config: RAGConfig
    ):
        self.vector_store = vector_store
        self.embedding_service = embedding_service
        self.chunker = chunker
        self.config = config

        self.index_metadata: Dict[str, Any] = {}

    async def build_index(
        self,
        documents: List[Dict[str, Any]],
        index_name: str,
        batch_size: int = 100
    ) -> Dict[str, Any]:
        """Build new index from documents."""
        start_time = time.time()
        total_chunks = 0

        logger.info(f"Building index '{index_name}' from {len(documents)} documents")

        for i in range(0, len(documents), batch_size):
            batch = documents[i:i + batch_size]

            # Process batch
            chunks = []
            for doc in batch:
                doc_chunks = self.chunker.chunk(
                    doc['content'],
                    metadata={
                        'doc_id': doc['id'],
                        'source': doc.get('source', ''),
                        **doc.get('metadata', {})
                    }
                )
                chunks.extend(doc_chunks)

            # Embed chunks
            contents = [c.content for c in chunks]
            embeddings = await self.embedding_service.embed_batch_async(contents)

            for chunk, embedding in zip(chunks, embeddings):
                chunk.embedding = embedding

            # Add to vector store
            await self.vector_store.add_async(chunks)
            total_chunks += len(chunks)

            logger.info(f"Indexed batch {i // batch_size + 1}, total chunks: {total_chunks}")

        # Store metadata
        self.index_metadata[index_name] = {
            'created_at': datetime.utcnow().isoformat(),
            'num_documents': len(documents),
            'num_chunks': total_chunks,
            'build_time_seconds': time.time() - start_time,
            'config': {
                'chunk_size': self.config.chunk_size,
                'chunk_overlap': self.config.chunk_overlap,
                'embedding_model': self.config.embedding_model
            }
        }

        return self.index_metadata[index_name]

    async def update_documents(
        self,
        documents: List[Dict[str, Any]],
        operation: str = "upsert"  # upsert, add, delete
    ) -> Dict[str, Any]:
        """Incrementally update index."""
        start_time = time.time()
        updated_count = 0

        for doc in documents:
            doc_id = doc['id']

            if operation == "delete":
                # Delete all chunks for this document
                await self.vector_store.delete_by_metadata_async(
                    {'doc_id': doc_id}
                )
                updated_count += 1

            elif operation in ["upsert", "add"]:
                if operation == "upsert":
                    # Delete existing chunks first
                    await self.vector_store.delete_by_metadata_async(
                        {'doc_id': doc_id}
                    )

                # Chunk and embed
                chunks = self.chunker.chunk(
                    doc['content'],
                    metadata={'doc_id': doc_id, **doc.get('metadata', {})}
                )

                contents = [c.content for c in chunks]
                embeddings = await self.embedding_service.embed_batch_async(contents)

                for chunk, embedding in zip(chunks, embeddings):
                    chunk.embedding = embedding

                await self.vector_store.add_async(chunks)
                updated_count += 1

        return {
            'operation': operation,
            'documents_processed': updated_count,
            'time_seconds': time.time() - start_time
        }

    async def refresh_index(
        self,
        index_name: str,
        documents: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Refresh index with new documents while maintaining availability."""
        # Build new index with temporary name
        temp_name = f"{index_name}_temp_{int(time.time())}"

        build_result = await self.build_index(documents, temp_name)

        # Atomic swap
        await self.vector_store.swap_indexes_async(index_name, temp_name)

        # Delete old index
        await self.vector_store.delete_index_async(temp_name)

        return build_result


class DocumentSyncService:
    """
    Sync documents from source systems to RAG index.

    Supports:
    - Change detection
    - Incremental sync
    - Scheduled refresh
    """

    def __init__(
        self,
        index_manager: IndexManager,
        source_connector: 'SourceConnector'
    ):
        self.index_manager = index_manager
        self.source = source_connector
        self.last_sync: Optional[datetime] = None
        self.sync_state: Dict[str, str] = {}  # doc_id -> hash

    async def sync(self, full: bool = False) -> Dict[str, Any]:
        """Sync documents from source."""
        start_time = time.time()

        if full:
            # Full refresh
            documents = await self.source.get_all_documents()
            result = await self.index_manager.build_index(documents, "main")
            self._update_sync_state(documents)

        else:
            # Incremental sync
            documents = await self.source.get_documents_since(self.last_sync)
            changes = self._detect_changes(documents)

            if changes['added']:
                await self.index_manager.update_documents(
                    changes['added'], operation="add"
                )

            if changes['modified']:
                await self.index_manager.update_documents(
                    changes['modified'], operation="upsert"
                )

            if changes['deleted']:
                await self.index_manager.update_documents(
                    changes['deleted'], operation="delete"
                )

            result = {
                'added': len(changes['added']),
                'modified': len(changes['modified']),
                'deleted': len(changes['deleted'])
            }

            self._update_sync_state(documents)

        self.last_sync = datetime.utcnow()

        return {
            **result,
            'sync_time_seconds': time.time() - start_time,
            'last_sync': self.last_sync.isoformat()
        }

    def _detect_changes(
        self,
        documents: List[Dict[str, Any]]
    ) -> Dict[str, List[Dict[str, Any]]]:
        """Detect added, modified, and deleted documents."""
        changes = {'added': [], 'modified': [], 'deleted': []}

        current_ids = set()
        for doc in documents:
            doc_id = doc['id']
            doc_hash = self._hash_document(doc)
            current_ids.add(doc_id)

            if doc_id not in self.sync_state:
                changes['added'].append(doc)
            elif self.sync_state[doc_id] != doc_hash:
                changes['modified'].append(doc)

        # Find deleted
        for doc_id in self.sync_state:
            if doc_id not in current_ids:
                changes['deleted'].append({'id': doc_id})

        return changes

    def _hash_document(self, doc: Dict[str, Any]) -> str:
        """Generate hash of document content."""
        content = json.dumps(doc, sort_keys=True)
        return hashlib.md5(content.encode()).hexdigest()

    def _update_sync_state(self, documents: List[Dict[str, Any]]) -> None:
        """Update sync state with current documents."""
        self.sync_state = {
            doc['id']: self._hash_document(doc)
            for doc in documents
        }

Cost Optimization

Strategies to reduce RAG operational costs:

PYTHON
class CostOptimizer:
    """
    Optimize RAG costs while maintaining quality.

    Cost factors:
    - Embedding API calls
    - LLM API calls (tokens)
    - Vector store queries
    - Infrastructure
    """

    def __init__(
        self,
        embedding_cost_per_1k: float = 0.0001,
        llm_input_cost_per_1k: float = 0.01,
        llm_output_cost_per_1k: float = 0.03
    ):
        self.embedding_cost = embedding_cost_per_1k / 1000
        self.llm_input_cost = llm_input_cost_per_1k / 1000
        self.llm_output_cost = llm_output_cost_per_1k / 1000

        self.daily_costs: Dict[str, float] = defaultdict(float)

    def estimate_query_cost(
        self,
        query_tokens: int,
        context_tokens: int,
        output_tokens: int,
        cached: bool = False
    ) -> Dict[str, float]:
        """Estimate cost for a single query."""
        if cached:
            return {'total': 0, 'embedding': 0, 'llm': 0}

        embedding_cost = self.embedding_cost * query_tokens
        llm_cost = (
            self.llm_input_cost * (query_tokens + context_tokens) +
            self.llm_output_cost * output_tokens
        )

        return {
            'embedding': embedding_cost,
            'llm': llm_cost,
            'total': embedding_cost + llm_cost
        }

    def record_cost(self, cost: float, category: str) -> None:
        """Record cost for tracking."""
        today = datetime.utcnow().strftime("%Y-%m-%d")
        self.daily_costs[f"{today}_{category}"] += cost

    def get_optimization_recommendations(
        self,
        metrics: Dict[str, Any]
    ) -> List[Dict[str, str]]:
        """Generate cost optimization recommendations."""
        recommendations = []

        # Check cache hit rate
        cache_hit_rate = metrics.get('cache_hit_rate', 0)
        if cache_hit_rate < 0.3:
            recommendations.append({
                'category': 'caching',
                'impact': 'high',
                'recommendation': 'Implement semantic caching to increase hit rate',
                'estimated_savings': f'{(0.3 - cache_hit_rate) * 100:.0f}% of query costs'
            })

        # Check average context size
        avg_context_tokens = metrics.get('avg_context_tokens', 2000)
        if avg_context_tokens > 3000:
            recommendations.append({
                'category': 'context',
                'impact': 'medium',
                'recommendation': 'Reduce context size with better chunking or compression',
                'estimated_savings': f'{(avg_context_tokens - 2000) / avg_context_tokens * 100:.0f}% of LLM input costs'
            })

        # Check embedding calls
        embedding_calls = metrics.get('embedding_calls', 0)
        embedding_cache_hits = metrics.get('embedding_cache_hits', 0)

        if embedding_calls > 0 and embedding_cache_hits / embedding_calls < 0.5:
            recommendations.append({
                'category': 'embeddings',
                'impact': 'medium',
                'recommendation': 'Cache embeddings more aggressively',
                'estimated_savings': 'Up to 50% of embedding costs'
            })

        return recommendations


class AdaptiveRAG:
    """
    Adaptive RAG that adjusts based on query complexity.

    Simple queries: Fewer retrieved docs, smaller model
    Complex queries: More docs, larger model, reranking
    """

    def __init__(
        self,
        simple_config: RAGConfig,
        complex_config: RAGConfig,
        complexity_classifier
    ):
        self.simple_config = simple_config
        self.complex_config = complex_config
        self.classifier = complexity_classifier

    async def query(self, question: str) -> Dict[str, Any]:
        """Route query based on complexity."""
        complexity = await self._classify_complexity(question)

        if complexity == "simple":
            config = self.simple_config
            # Use smaller model, fewer docs
        else:
            config = self.complex_config
            # Use larger model, more docs, reranking

        # Execute with appropriate config
        # ...

        return {
            'complexity': complexity,
            'config_used': config
        }

    async def _classify_complexity(self, question: str) -> str:
        """Classify query complexity."""
        # Simple heuristics
        if len(question) < 50 and '?' in question:
            return "simple"

        # Complex indicators
        complex_indicators = [
            "compare", "explain", "analyze", "why",
            "how does", "what are the differences"
        ]

        if any(ind in question.lower() for ind in complex_indicators):
            return "complex"

        return "simple"

Key Takeaways

Production RAG systems require careful attention to reliability, performance, and cost. Key architectural patterns include: (1) async processing with circuit breakers for fault tolerance, (2) multi-level caching (L1 in-memory, L2 Redis, L3 database) for latency and cost reduction, (3) semantic caching for similar query reuse, and (4) comprehensive monitoring covering latency, error rates, and quality metrics. Index management should support incremental updates and zero-downtime refreshes. Cost optimization involves embedding caching, adaptive retrieval based on query complexity, and context compression. Production systems should implement structured logging for debugging, alerting for anomaly detection, and health endpoints for operational visibility. The goal is maintaining high availability and consistent performance while controlling costs at scale.