Files
LearnIT/docs/RAG_ENGINE_MVP_TASKS.md
2026-05-25 21:41:41 +01:00

109 KiB

RAG Engine MVP Tasks - AI Study Assistant

⚠️ IMPORTANTE - DOCUMENTO DESATUALIZADO: Este documento descreve uma arquitetura Python/FAISS que NÃO FOI IMPLEMENTADA.

Implementação Real:

  • Linguagem: Dart (Flutter)
  • Localização: lib/core/services/materials_rag_service.dart, lib/core/services/rag_ai_service.dart
  • Vector Store: Firestore com embeddings mock (hash-based)
  • PDF Processing: syncfusion_flutter_pdf (não Python)
  • Busca: Keyword window search (não FAISS)

NÃO EXISTE: Python, FAISS, Sentence Transformers, OpenAI, Anthropic


🧠 MVP RAG ENGINE ROADMAP (DOCUMENTAÇÃO ORIGINAL - NÃO IMPLEMENTADA)


📚 WEEK 1-2: FOUNDATION & SETUP (NOT IMPLEMENTED)

Task 1.1: Vector Database Setup

Status: NOT IMPLEMENTED - FAISS não é utilizado

What Actually Exists:

// lib/core/services/vector_service.dart
class VectorService {
  // Mock embedding generation using text hashing
  static List<double> generateEmbedding(String text) {
    final embedding = List<double>.filled(384, 0.0);
    // Hash-based deterministic embeddings (not ML)
    return embedding;
  }
}

// lib/core/services/materials_rag_service.dart
class MaterialsRAGService {
  // Keyword-based window search for PDFs
  static Future<String> getContextForQuestion(...) async {
    // 1. Extract PDF text with syncfusion_flutter_pdf
    // 2. Find keyword matches
    // 3. Return window of 1200 chars around match
    // NO FAISS, NO VECTOR SEARCH, NO PYTHON
  }
}

Technology Stack (Original - NOT USED):

# These dependencies DO NOT EXIST in the project:
# ❌ pip install faiss-cpu
# ❌ pip install sentence-transformers
# ❌ pip install numpy
# ❌ pip install nltk
# ❌ pip install spacy

Project Structure:

rag_engine/
├── src/
│   ├── __init__.py
│   ├── main.py
│   ├── config/
│   │   ├── __init__.py
│   │   ├── settings.py
│   │   └── constants.py
│   ├── core/
│   │   ├── __init__.py
│   │   ├── vector_store.py
│   │   ├── embeddings.py
│   │   ├── retriever.py
│   │   └── indexer.py
│   ├── preprocessing/
│   │   ├── __init__.py
│   │   ├── text_processor.py
│   │   ├── chunker.py
│   │   └── metadata_extractor.py
│   ├── retrieval/
│   │   ├── __init__.py
│   │   ├── keyword_search.py
│   │   ├── vector_search.py
│   │   ├── hybrid_search.py
│   │   └── ranker.py
│   ├── llm/
│   │   ├── __init__.py
│   │   ├── prompt_builder.py
│   │   ├── llm_client.py
│   │   └── response_processor.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── logger.py
│   │   ├── validators.py
│   │   └── helpers.py
│   └── models/
│       ├── __init__.py
│       ├── document.py
│       ├── chunk.py
│       └── query.py
├── tests/
│   ├── __init__.py
│   ├── test_vector_store.py
│   ├── test_embeddings.py
│   ├── test_retriever.py
│   └── test_integration.py
├── data/
│   ├── models/          # Saved embedding models
│   ├── indices/         # FAISS index files
│   ├── chunks/          # Processed content chunks
│   └── temp/            # Temporary files
├── requirements.txt
├── setup.py
├── README.md
└── docker-compose.yml

Configuration:

src/config/settings.py

from dataclasses import dataclass
from typing import Optional
import os

@dataclass
class VectorStoreConfig:
    """Configuration for vector storage"""
    index_type: str = "IVF"  # IVF, HNSW, Flat
    dimension: int = 384     # all-MiniLM-L6-v2 dimension
    nlist: int = 100         # Number of clusters for IVF
    nprobe: int = 10         # Number of clusters to search
    use_gpu: bool = False    # GPU acceleration
    metric: str = "INNER_PRODUCT"  # Similarity metric

@dataclass
class EmbeddingConfig:
    """Configuration for text embeddings"""
    model_name: str = "all-MiniLM-L6-v2"
    batch_size: int = 32
    max_length: int = 512
    normalize_embeddings: bool = True
    cache_embeddings: bool = True
    model_cache_dir: str = "data/models"

@dataclass
class RetrievalConfig:
    """Configuration for retrieval pipeline"""
    top_k: int = 10          # Number of results to retrieve
    rerank: bool = True      # Apply reranking
    rerank_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    hybrid_alpha: float = 0.5  # Weight for hybrid search
    min_similarity: float = 0.1  # Minimum similarity threshold

@dataclass
class ChunkingConfig:
    """Configuration for text chunking"""
    chunk_size: int = 300    # Target chunk size in tokens
    chunk_overlap: int = 50  # Overlap between chunks
    min_chunk_size: int = 50  # Minimum chunk size
    max_chunk_size: int = 800  # Maximum chunk size
    respect_sentence_boundaries: bool = True
    respect_paragraph_boundaries: bool = True

@dataclass
class RAGConfig:
    """Main RAG configuration"""
    vector_store: VectorStoreConfig
    embeddings: EmbeddingConfig
    retrieval: RetrievalConfig
    chunking: ChunkingConfig
    
    # LLM settings
    llm_provider: str = "anthropic"  # anthropic, openai
    llm_model: str = "claude-3-5-sonnet-20241022"
    max_context_tokens: int = 4000
    max_response_tokens: int = 500
    temperature: float = 0.7
    
    # Storage
    firebase_project_id: str = os.getenv("FIREBASE_PROJECT_ID", "")
    storage_bucket: str = os.getenv("STORAGE_BUCKET", "")
    
    # Logging
    log_level: str = "INFO"
    log_file: str = "logs/rag_engine.log"

# Default configuration
DEFAULT_CONFIG = RAGConfig(
    vector_store=VectorStoreConfig(),
    embeddings=EmbeddingConfig(),
    retrieval=RetrievalConfig(),
    chunking=ChunkingConfig(),
)

Task 1.2: Embedding Model Setup

Priority: Critical
Estimated Time: 6 hours
Dependencies: Task 1.1

Subtasks:

  • Download and configure sentence-transformers model
  • Create embedding service
  • Implement batch processing
  • Add embedding caching
  • Create embedding quality checks
  • Set up model versioning

Implementation:

src/core/embeddings.py

import os
import pickle
import hashlib
from typing import List, Dict, Optional, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
from pathlib import Path
from ..config.settings import EmbeddingConfig
from ..utils.logger import get_logger

logger = get_logger(__name__)

class EmbeddingService:
    """Service for generating and managing text embeddings"""
    
    def __init__(self, config: EmbeddingConfig):
        self.config = config
        self.model = None
        self.embedding_cache = {}
        self.cache_file = Path("data/embeddings_cache.pkl")
        self._load_model()
        self._load_cache()
    
    def _load_model(self):
        """Load the sentence transformer model"""
        try:
            logger.info(f"Loading embedding model: {self.config.model_name}")
            
            # Create cache directory if it doesn't exist
            os.makedirs(self.config.model_cache_dir, exist_ok=True)
            
            # Load model with caching
            self.model = SentenceTransformer(
                self.config.model_name,
                cache_folder=self.config.model_cache_dir
            )
            
            # Move to GPU if available and requested
            if self.config.use_gpu and torch.cuda.is_available():
                self.model = self.model.to('cuda')
                logger.info("Model moved to GPU")
            
            logger.info("Embedding model loaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to load embedding model: {e}")
            raise
    
    def _load_cache(self):
        """Load embedding cache from disk"""
        if self.config.cache_embeddings and self.cache_file.exists():
            try:
                with open(self.cache_file, 'rb') as f:
                    self.embedding_cache = pickle.load(f)
                logger.info(f"Loaded {len(self.embedding_cache)} cached embeddings")
            except Exception as e:
                logger.warning(f"Failed to load cache: {e}")
                self.embedding_cache = {}
    
    def _save_cache(self):
        """Save embedding cache to disk"""
        if self.config.cache_embeddings:
            try:
                os.makedirs(self.cache_file.parent, exist_ok=True)
                with open(self.cache_file, 'wb') as f:
                    pickle.dump(self.embedding_cache, f)
                logger.info("Embedding cache saved")
            except Exception as e:
                logger.warning(f"Failed to save cache: {e}")
    
    def _get_cache_key(self, text: str) -> str:
        """Generate cache key for text"""
        return hashlib.md5(text.encode('utf-8')).hexdigest()
    
    def encode(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray:
        """
        Encode texts to embeddings
        
        Args:
            texts: List of texts to encode
            batch_size: Batch size for processing (overrides config)
            
        Returns:
            numpy array of embeddings
        """
        if not texts:
            return np.array([])
        
        batch_size = batch_size or self.config.batch_size
        embeddings = []
        uncached_texts = []
        uncached_indices = []
        
        # Check cache first
        if self.config.cache_embeddings:
            for i, text in enumerate(texts):
                cache_key = self._get_cache_key(text)
                if cache_key in self.embedding_cache:
                    embeddings.append(self.embedding_cache[cache_key])
                else:
                    uncached_texts.append(text)
                    uncached_indices.append(i)
        
        # Encode uncached texts
        if uncached_texts:
            try:
                # Process in batches
                batch_embeddings = []
                for i in range(0, len(uncached_texts), batch_size):
                    batch = uncached_texts[i:i + batch_size]
                    
                    # Truncate if necessary
                    truncated_batch = [
                        text[:self.config.max_length] 
                        for text in batch
                    ]
                    
                    # Generate embeddings
                    batch_emb = self.model.encode(
                        truncated_batch,
                        normalize_embeddings=self.config.normalize_embeddings,
                        convert_to_numpy=True,
                        show_progress_bar=False
                    )
                    batch_embeddings.append(batch_emb)
                
                # Combine batch results
                if batch_embeddings:
                    new_embeddings = np.vstack(batch_embeddings)
                    
                    # Update cache
                    if self.config.cache_embeddings:
                        for text, emb in zip(uncached_texts, new_embeddings):
                            cache_key = self._get_cache_key(text)
                            self.embedding_cache[cache_key] = emb
                    
                    # Insert new embeddings in correct positions
                    for idx, emb in zip(uncached_indices, new_embeddings):
                        embeddings.insert(idx, emb)
                
            except Exception as e:
                logger.error(f"Failed to encode texts: {e}")
                raise
        
        # Convert to numpy array
        result = np.array(embeddings) if embeddings else np.array([])
        
        # Ensure we have the right number of embeddings
        if len(result) != len(texts):
            logger.warning(f"Embedding count mismatch: {len(result)} vs {len(texts)}")
        
        return result
    
    def encode_single(self, text: str) -> np.ndarray:
        """Encode a single text"""
        return self.encode([text])[0]
    
    def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
        """Calculate cosine similarity between two embeddings"""
        if self.config.normalize_embeddings:
            # For normalized embeddings, dot product equals cosine similarity
            return float(np.dot(embedding1, embedding2))
        else:
            # Manual cosine similarity calculation
            dot_product = np.dot(embedding1, embedding2)
            norm1 = np.linalg.norm(embedding1)
            norm2 = np.linalg.norm(embedding2)
            return float(dot_product / (norm1 * norm2))
    
    def find_similar(
        self, 
        query_embedding: np.ndarray, 
        candidate_embeddings: np.ndarray,
        top_k: int = 10
    ) -> List[Tuple[int, float]]:
        """
        Find most similar embeddings to query
        
        Args:
            query_embedding: Query embedding
            candidate_embeddings: Array of candidate embeddings
            top_k: Number of top results to return
            
        Returns:
            List of (index, similarity_score) tuples
        """
        if len(candidate_embeddings) == 0:
            return []
        
        # Calculate similarities
        if self.config.normalize_embeddings:
            # For normalized embeddings, use matrix multiplication
            similarities = np.dot(candidate_embeddings, query_embedding)
        else:
            # Manual cosine similarity
            similarities = np.array([
                self.similarity(query_embedding, emb) 
                for emb in candidate_embeddings
            ])
        
        # Get top-k indices and scores
        top_indices = np.argsort(similarities)[::-1][:top_k]
        top_scores = similarities[top_indices]
        
        return [(int(idx), float(score)) for idx, score in zip(top_indices, top_scores)]
    
    def validate_embedding(self, embedding: np.ndarray) -> bool:
        """Validate embedding quality"""
        if not isinstance(embedding, np.ndarray):
            return False
        
        if embedding.size == 0:
            return False
        
        if np.isnan(embedding).any() or np.isinf(embedding).any():
            return False
        
        expected_dim = self.model.get_sentence_embedding_dimension()
        if embedding.shape[0] != expected_dim:
            return False
        
        return True
    
    def get_embedding_stats(self, embeddings: np.ndarray) -> Dict:
        """Get statistics about embeddings"""
        if len(embeddings) == 0:
            return {}
        
        return {
            "count": len(embeddings),
            "dimension": embeddings.shape[1],
            "mean_norm": np.mean(np.linalg.norm(embeddings, axis=1)),
            "std_norm": np.std(np.linalg.norm(embeddings, axis=1)),
            "has_nan": np.isnan(embeddings).any(),
            "has_inf": np.isinf(embeddings).any(),
        }
    
    def clear_cache(self):
        """Clear embedding cache"""
        self.embedding_cache.clear()
        if self.cache_file.exists():
            self.cache_file.unlink()
        logger.info("Embedding cache cleared")
    
    def save_cache(self):
        """Manually save cache to disk"""
        self._save_cache()
    
    def __del__(self):
        """Cleanup on deletion"""
        try:
            self._save_cache()
        except:
            pass

🔍 WEEK 3-4: CONTENT PROCESSING & CHUNKING

Task 2.1: Text Preprocessing

Priority: High
Estimated Time: 10 hours
Dependencies: Task 1.2

Subtasks:

  • Implement text cleaning
  • Add language detection
  • Create tokenization utilities
  • Build text normalization
  • Add format detection (PDF, DOCX, etc.)
  • Implement quality checks

Implementation:

src/preprocessing/text_processor.py

import re
import string
from typing import List, Dict, Optional, Tuple
import spacy
from langdetect import detect
from ..config.settings import ChunkingConfig
from ..utils.logger import get_logger

logger = get_logger(__name__)

class TextProcessor:
    """Service for preprocessing and cleaning text"""
    
    def __init__(self, config: ChunkingConfig):
        self.config = config
        self.nlp = None
        self._load_spacy()
    
    def _load_spacy(self):
        """Load spaCy model for text processing"""
        try:
            self.nlp = spacy.load("en_core_web_sm")
            logger.info("spaCy model loaded successfully")
        except OSError:
            logger.warning("spaCy model not found, some features will be limited")
            self.nlp = None
    
    def clean_text(self, text: str) -> str:
        """
        Clean and normalize text
        
        Args:
            text: Raw text to clean
            
        Returns:
            Cleaned text
        """
        if not text or not text.strip():
            return ""
        
        # Remove excessive whitespace
        text = re.sub(r'\s+', ' ', text)
        
        # Remove control characters
        text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]', '', text)
        
        # Normalize quotes
        text = re.sub(r'[""'']', '"', text)
        text = re.sub(r[''''''], "'", text)
        
        # Remove page numbers and headers/footers patterns
        text = re.sub(r'\n\s*Page\s*\d+\s*\n', '\n', text, flags=re.IGNORECASE)
        text = re.sub(r'\n\s*\d+\s*\n', '\n', text)
        
        # Remove bullet points numbering
        text = re.sub(r'^\s*[\d\w][\).\.\-]\s+', '', text, flags=re.MULTILINE)
        
        # Clean up extra newlines
        text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
        
        return text.strip()
    
    def detect_language(self, text: str) -> str:
        """Detect text language"""
        try:
            if len(text) < 50:
                return "en"  # Default for short texts
            
            lang = detect(text)
            return lang if lang in ['en', 'pt', 'es', 'fr', 'de'] else "en"
        except:
            return "en"
    
    def tokenize_sentences(self, text: str) -> List[str]:
        """Split text into sentences"""
        if self.nlp:
            doc = self.nlp(text)
            return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
        else:
            # Fallback using regex
            sentences = re.split(r'[.!?]+', text)
            return [s.strip() for s in sentences if s.strip()]
    
    def tokenize_paragraphs(self, text: str) -> List[str]:
        """Split text into paragraphs"""
        paragraphs = re.split(r'\n\s*\n', text)
        return [p.strip() for p in paragraphs if p.strip()]
    
    def extract_keywords(self, text: str, max_keywords: int = 10) -> List[str]:
        """Extract keywords from text"""
        if self.nlp:
            doc = self.nlp(text)
            
            # Extract noun phrases and proper nouns
            keywords = []
            
            # Add proper nouns
            for ent in doc.ents:
                if ent.label_ in ['PERSON', 'ORG', 'GPE', 'PRODUCT']:
                    keywords.append(ent.text)
            
            # Add noun chunks
            for chunk in doc.noun_chunks:
                if len(chunk.text.split()) <= 3:  # Keep short phrases
                    keywords.append(chunk.text)
            
            # Remove duplicates and limit
            unique_keywords = list(dict.fromkeys(keywords))
            return unique_keywords[:max_keywords]
        else:
            # Fallback: extract important words
            words = re.findall(r'\b[A-Z][a-z]+\b', text)
            return list(dict.fromkeys(words))[:max_keywords]
    
    def extract_math_expressions(self, text: str) -> List[str]:
        """Extract mathematical expressions from text"""
        # Common math patterns
        math_patterns = [
            r'\$[^$]+\$',  # LaTeX math mode
            r'\\[a-zA-Z]+\{[^}]+\}',  # LaTeX commands
            r'\b[a-zA-Z]+\s*=\s*[^,;\n]+',  # Equations
            r'\b(?:sin|cos|tan|log|ln|sqrt)\([^)]+\)',  # Functions
            r'\b\d+\s*[+\-*/]\s*\d+',  # Simple arithmetic
        ]
        
        expressions = []
        for pattern in math_patterns:
            matches = re.findall(pattern, text)
            expressions.extend(matches)
        
        return expressions
    
    def extract_code_blocks(self, text: str) -> List[str]:
        """Extract code blocks from text"""
        code_patterns = [
            r'```[\s\S]*?```',  # Markdown code blocks
            r'`[^`]+`',  # Inline code
            r'(?s)def\s+\w+\([^)]*\):.*?(?=\n\w|\Z)',  # Python functions
        ]
        
        code_blocks = []
        for pattern in code_patterns:
            matches = re.findall(pattern, text)
            code_blocks.extend(matches)
        
        return code_blocks
    
    def assess_readability(self, text: str) -> Dict:
        """Assess text readability metrics"""
        sentences = self.tokenize_sentences(text)
        words = text.split()
        syllables = sum(self._count_syllables(word) for word in words)
        
        if len(sentences) == 0 or len(words) == 0:
            return {"flesch_score": 0, "grade_level": 12}
        
        # Flesch Reading Ease
        avg_sentence_length = len(words) / len(sentences)
        avg_syllables_per_word = syllables / len(words)
        
        flesch_score = 206.835 - (1.015 * avg_sentence_length) - (84.6 * avg_syllables_per_word)
        
        # Approximate grade level
        if flesch_score >= 90:
            grade_level = 5
        elif flesch_score >= 80:
            grade_level = 6
        elif flesch_score >= 70:
            grade_level = 7
        elif flesch_score >= 60:
            grade_level = 8
        elif flesch_score >= 50:
            grade_level = 9
        elif flesch_score >= 40:
            grade_level = 10
        elif flesch_score >= 30:
            grade_level = 11
        else:
            grade_level = 12
        
        return {
            "flesch_score": max(0, min(100, flesch_score)),
            "grade_level": grade_level,
            "avg_sentence_length": avg_sentence_length,
            "avg_syllables_per_word": avg_syllables_per_word,
        }
    
    def _count_syllables(self, word: str) -> int:
        """Count syllables in a word (simplified)"""
        word = word.lower()
        vowels = "aeiouy"
        syllable_count = 0
        prev_char_was_vowel = False
        
        for char in word:
            is_vowel = char in vowels
            if is_vowel and not prev_char_was_vowel:
                syllable_count += 1
            prev_char_was_vowel = is_vowel
        
        # Adjust for silent 'e'
        if word.endswith('e') and syllable_count > 1:
            syllable_count -= 1
        
        return max(1, syllable_count)
    
    def extract_structure(self, text: str) -> Dict:
        """Extract document structure information"""
        structure = {
            "headings": [],
            "lists": [],
            "tables": [],
            "sections": [],
        }
        
        # Extract headings (markdown-style)
        heading_pattern = r'^(#{1,6})\s+(.+)$'
        for match in re.finditer(heading_pattern, text, re.MULTILINE):
            level = len(match.group(1))
            title = match.group(2).strip()
            structure["headings"].append({
                "level": level,
                "title": title,
                "position": match.start(),
            })
        
        # Extract lists
        list_patterns = [
            r'^\s*[\-\*\+]\s+(.+)$',  # Bullet lists
            r'^\s*\d+\.\s+(.+)$',     # Numbered lists
        ]
        
        for pattern in list_patterns:
            for match in re.finditer(pattern, text, re.MULTILINE):
                structure["lists"].append({
                    "content": match.group(1).strip(),
                    "position": match.start(),
                })
        
        # Extract sections based on headings
        if structure["headings"]:
            for i, heading in enumerate(structure["headings"]):
                start_pos = heading["position"]
                end_pos = (structure["headings"][i + 1]["position"] 
                          if i + 1 < len(structure["headings"]) 
                          else len(text))
                
                section_text = text[start_pos:end_pos].strip()
                structure["sections"].append({
                    "heading": heading,
                    "content": section_text,
                    "word_count": len(section_text.split()),
                })
        
        return structure
    
    def validate_text_quality(self, text: str) -> Dict:
        """Validate text quality and return metrics"""
        if not text or len(text.strip()) < 50:
            return {
                "is_valid": False,
                "reason": "Text too short",
                "score": 0.0,
            }
        
        # Quality metrics
        word_count = len(text.split())
        sentence_count = len(self.tokenize_sentences(text))
        
        # Check for minimum requirements
        if word_count < 10:
            return {
                "is_valid": False,
                "reason": "Too few words",
                "score": 0.1,
            }
        
        if sentence_count < 2:
            return {
                "is_valid": False,
                "reason": "Too few sentences",
                "score": 0.2,
            }
        
        # Calculate quality score
        readability = self.assess_readability(text)
        structure = self.extract_structure(text)
        
        quality_score = 0.5  # Base score
        
        # Readability bonus
        if readability["flesch_score"] > 60:
            quality_score += 0.2
        elif readability["flesch_score"] > 40:
            quality_score += 0.1
        
        # Structure bonus
        if structure["headings"]:
            quality_score += 0.1
        
        # Length bonus (appropriate length)
        if 50 <= word_count <= 500:
            quality_score += 0.1
        elif word_count <= 1000:
            quality_score += 0.05
        
        # Content variety bonus
        has_math = bool(self.extract_math_expressions(text))
        has_code = bool(self.extract_code_blocks(text))
        if has_math or has_code:
            quality_score += 0.05
        
        quality_score = min(1.0, quality_score)
        
        return {
            "is_valid": quality_score >= 0.3,
            "reason": "Quality check passed" if quality_score >= 0.3 else "Low quality",
            "score": quality_score,
            "metrics": {
                "word_count": word_count,
                "sentence_count": sentence_count,
                "readability": readability,
                "has_structure": bool(structure["headings"]),
                "has_math": has_math,
                "has_code": has_code,
            }
        }

Task 2.2: Intelligent Chunking

Priority: High
Estimated Time: 12 hours
Dependencies: Task 2.1

Subtasks:

  • Implement semantic chunking
  • Add respect for boundaries
  • Create overlap management
  • Build chunk validation
  • Add metadata extraction
  • Implement chunk quality scoring

Implementation:

src/preprocessing/chunker.py

import re
from typing import List, Dict, Optional, Tuple
import numpy as np
from ..config.settings import ChunkingConfig
from ..models.chunk import Chunk, ChunkMetadata
from .text_processor import TextProcessor
from ..utils.logger import get_logger

logger = get_logger(__name__)

class IntelligentChunker:
    """Service for intelligent text chunking"""
    
    def __init__(self, config: ChunkingConfig):
        self.config = config
        self.text_processor = TextProcessor(config)
    
    def chunk_document(
        self, 
        text: str, 
        document_metadata: Dict,
        chunk_metadata: Optional[Dict] = None
    ) -> List[Chunk]:
        """
        Chunk a document into intelligent pieces
        
        Args:
            text: Document text
            document_metadata: Document-level metadata
            chunk_metadata: Default chunk metadata
            
        Returns:
            List of chunks
        """
        # Clean and preprocess text
        cleaned_text = self.text_processor.clean_text(text)
        
        # Extract document structure
        structure = self.text_processor.extract_structure(cleaned_text)
        
        # Determine chunking strategy
        chunks = []
        
        if structure["sections"] and len(structure["sections"]) > 1:
            # Use section-based chunking
            chunks = self._chunk_by_sections(
                cleaned_text, 
                structure, 
                document_metadata,
                chunk_metadata
            )
        else:
            # Use sliding window chunking
            chunks = self._chunk_by_sliding_window(
                cleaned_text,
                document_metadata,
                chunk_metadata
            )
        
        # Validate and filter chunks
        valid_chunks = []
        for chunk in chunks:
            validation = self._validate_chunk(chunk)
            if validation["is_valid"]:
                chunk.quality_score = validation["score"]
                valid_chunks.append(chunk)
            else:
                logger.warning(f"Invalid chunk: {validation['reason']}")
        
        logger.info(f"Created {len(valid_chunks)} valid chunks from document")
        return valid_chunks
    
    def _chunk_by_sections(
        self,
        text: str,
        structure: Dict,
        document_metadata: Dict,
        chunk_metadata: Optional[Dict]
    ) -> List[Chunk]:
        """Chunk by document sections"""
        chunks = []
        
        for section in structure["sections"]:
            section_text = section["content"]
            section_heading = section["heading"]["title"]
            
            # If section is too large, further chunk it
            if self._is_too_large(section_text):
                section_chunks = self._chunk_large_section(
                    section_text,
                    section_heading,
                    document_metadata,
                    chunk_metadata
                )
                chunks.extend(section_chunks)
            else:
                # Create single chunk for section
                chunk = self._create_chunk(
                    section_text,
                    document_metadata,
                    chunk_metadata,
                    section_heading=section_heading
                )
                chunks.append(chunk)
        
        return chunks
    
    def _chunk_by_sliding_window(
        self,
        text: str,
        document_metadata: Dict,
        chunk_metadata: Optional[Dict]
    ) -> List[Chunk]:
        """Chunk using sliding window approach"""
        chunks = []
        
        # Get sentences for boundary awareness
        sentences = self.text_processor.tokenize_sentences(text)
        
        if not sentences:
            return chunks
        
        # Build chunks with overlap
        current_chunk_sentences = []
        current_chunk_length = 0
        
        for i, sentence in enumerate(sentences):
            sentence_length = len(sentence.split())
            
            # Check if adding sentence exceeds chunk size
            if current_chunk_length + sentence_length > self.config.chunk_size and current_chunk_sentences:
                # Create chunk from accumulated sentences
                chunk_text = " ".join(current_chunk_sentences)
                chunk = self._create_chunk(
                    chunk_text,
                    document_metadata,
                    chunk_metadata
                )
                chunks.append(chunk)
                
                # Start new chunk with overlap
                overlap_sentences = self._get_overlap_sentences(current_chunk_sentences)
                current_chunk_sentences = overlap_sentences + [sentence]
                current_chunk_length = sum(len(s.split()) for s in current_chunk_sentences)
            else:
                current_chunk_sentences.append(sentence)
                current_chunk_length += sentence_length
        
        # Add final chunk if it has content
        if current_chunk_sentences:
            chunk_text = " ".join(current_chunk_sentences)
            chunk = self._create_chunk(
                chunk_text,
                document_metadata,
                chunk_metadata
            )
            chunks.append(chunk)
        
        return chunks
    
    def _chunk_large_section(
        self,
        section_text: str,
        section_heading: str,
        document_metadata: Dict,
        chunk_metadata: Optional[Dict]
    ) -> List[Chunk]:
        """Chunk a large section into smaller pieces"""
        chunks = []
        
        # Split section into paragraphs
        paragraphs = self.text_processor.tokenize_paragraphs(section_text)
        
        current_chunk_paragraphs = []
        current_chunk_length = 0
        
        for paragraph in paragraphs:
            paragraph_length = len(paragraph.split())
            
            # Check if paragraph is too large for a single chunk
            if paragraph_length > self.config.max_chunk_size:
                # Process current chunk if it has content
                if current_chunk_paragraphs:
                    chunk_text = "\n\n".join(current_chunk_paragraphs)
                    chunk = self._create_chunk(
                        chunk_text,
                        document_metadata,
                        chunk_metadata,
                        section_heading=section_heading
                    )
                    chunks.append(chunk)
                    current_chunk_paragraphs = []
                    current_chunk_length = 0
                
                # Chunk the large paragraph
                paragraph_chunks = self._chunk_large_paragraph(
                    paragraph,
                    document_metadata,
                    chunk_metadata,
                    section_heading
                )
                chunks.extend(paragraph_chunks)
            else:
                # Check if adding paragraph exceeds chunk size
                if current_chunk_length + paragraph_length > self.config.chunk_size and current_chunk_paragraphs:
                    # Create chunk
                    chunk_text = "\n\n".join(current_chunk_paragraphs)
                    chunk = self._create_chunk(
                        chunk_text,
                        document_metadata,
                        chunk_metadata,
                        section_heading=section_heading
                    )
                    chunks.append(chunk)
                    
                    # Start new chunk with overlap
                    overlap_paragraphs = self._get_overlap_paragraphs(current_chunk_paragraphs)
                    current_chunk_paragraphs = overlap_paragraphs + [paragraph]
                    current_chunk_length = sum(len(p.split()) for p in current_chunk_paragraphs)
                else:
                    current_chunk_paragraphs.append(paragraph)
                    current_chunk_length += paragraph_length
        
        # Add final chunk
        if current_chunk_paragraphs:
            chunk_text = "\n\n".join(current_chunk_paragraphs)
            chunk = self._create_chunk(
                chunk_text,
                document_metadata,
                chunk_metadata,
                section_heading=section_heading
            )
            chunks.append(chunk)
        
        return chunks
    
    def _chunk_large_paragraph(
        self,
        paragraph: str,
        document_metadata: Dict,
        chunk_metadata: Optional[Dict],
        section_heading: Optional[str] = None
    ) -> List[Chunk]:
        """Chunk a very large paragraph"""
        chunks = []
        
        # Split by sentences
        sentences = self.text_processor.tokenize_sentences(paragraph)
        
        current_chunk_sentences = []
        current_chunk_length = 0
        
        for sentence in sentences:
            sentence_length = len(sentence.split())
            
            if current_chunk_length + sentence_length > self.config.chunk_size and current_chunk_sentences:
                # Create chunk
                chunk_text = " ".join(current_chunk_sentences)
                chunk = self._create_chunk(
                    chunk_text,
                    document_metadata,
                    chunk_metadata,
                    section_heading=section_heading
                )
                chunks.append(chunk)
                
                # Start new chunk with overlap
                overlap_sentences = self._get_overlap_sentences(current_chunk_sentences)
                current_chunk_sentences = overlap_sentences + [sentence]
                current_chunk_length = sum(len(s.split()) for s in current_chunk_sentences)
            else:
                current_chunk_sentences.append(sentence)
                current_chunk_length += sentence_length
        
        # Add final chunk
        if current_chunk_sentences:
            chunk_text = " ".join(current_chunk_sentences)
            chunk = self._create_chunk(
                chunk_text,
                document_metadata,
                chunk_metadata,
                section_heading=section_heading
            )
            chunks.append(chunk)
        
        return chunks
    
    def _get_overlap_sentences(self, sentences: List[str]) -> List[str]:
        """Get overlap sentences for next chunk"""
        if not sentences:
            return []
        
        # Calculate overlap based on word count
        total_words = sum(len(s.split()) for s in sentences)
        overlap_words = min(self.config.chunk_overlap, total_words // 2)
        
        # Get sentences from the end that contain the overlap words
        overlap_sentences = []
        word_count = 0
        
        for sentence in reversed(sentences):
            sentence_words = len(sentence.split())
            if word_count + sentence_words <= overlap_words:
                overlap_sentences.insert(0, sentence)
                word_count += sentence_words
            else:
                break
        
        return overlap_sentences
    
    def _get_overlap_paragraphs(self, paragraphs: List[str]) -> List[str]:
        """Get overlap paragraphs for next chunk"""
        if not paragraphs:
            return []
        
        # Take last 1-2 paragraphs for overlap
        overlap_count = min(2, len(paragraphs) // 2)
        return paragraphs[-overlap_count:] if overlap_count > 0 else []
    
    def _is_too_large(self, text: str) -> bool:
        """Check if text is too large for a single chunk"""
        word_count = len(text.split())
        return word_count > self.config.chunk_size
    
    def _create_chunk(
        self,
        text: str,
        document_metadata: Dict,
        chunk_metadata: Optional[Dict],
        section_heading: Optional[str] = None
    ) -> Chunk:
        """Create a chunk with metadata"""
        
        # Generate chunk ID
        chunk_id = f"chunk_{hash(text) % 1000000:06d}"
        
        # Extract metadata from text
        keywords = self.text_processor.extract_keywords(text)
        math_expressions = self.text_processor.extract_math_expressions(text)
        code_blocks = self.text_processor.extract_code_blocks(text)
        readability = self.text_processor.assess_readability(text)
        
        # Create chunk metadata
        metadata = ChunkMetadata(
            word_count=len(text.split()),
            character_count=len(text),
            sentence_count=len(self.text_processor.tokenize_sentences(text)),
            paragraph_count=len(self.text_processor.tokenize_paragraphs(text)),
            keywords=keywords,
            math_expressions=math_expressions,
            code_blocks=code_blocks,
            readability_score=readability["flesch_score"],
            grade_level=readability["grade_level"],
            language=self.text_processor.detect_language(text),
            has_structure=bool(section_heading),
            section_heading=section_heading,
            **(chunk_metadata or {})
        )
        
        # Create chunk
        chunk = Chunk(
            id=chunk_id,
            text=text,
            document_metadata=document_metadata,
            metadata=metadata
        )
        
        return chunk
    
    def _validate_chunk(self, chunk: Chunk) -> Dict:
        """Validate chunk quality"""
        
        # Check minimum length
        if chunk.metadata.word_count < self.config.min_chunk_size:
            return {
                "is_valid": False,
                "reason": f"Chunk too short: {chunk.metadata.word_count} words",
                "score": 0.1,
            }
        
        # Check maximum length
        if chunk.metadata.word_count > self.config.max_chunk_size:
            return {
                "is_valid": False,
                "reason": f"Chunk too long: {chunk.metadata.word_count} words",
                "score": 0.1,
            }
        
        # Check for meaningful content
        if not chunk.text.strip() or len(chunk.text.strip()) < 20:
            return {
                "is_valid": False,
                "reason": "Chunk contains insufficient content",
                "score": 0.0,
            }
        
        # Calculate quality score
        quality_score = 0.5  # Base score
        
        # Length appropriateness
        optimal_length = self.config.chunk_size
        length_diff = abs(chunk.metadata.word_count - optimal_length)
        length_score = max(0, 1 - (length_diff / optimal_length))
        quality_score += length_score * 0.3
        
        # Readability
        if chunk.metadata.readability_score > 60:
            quality_score += 0.1
        elif chunk.metadata.readability_score > 40:
            quality_score += 0.05
        
        # Content richness
        if chunk.metadata.keywords:
            quality_score += 0.05
        
        if chunk.metadata.math_expressions or chunk.metadata.code_blocks:
            quality_score += 0.05
        
        # Structure
        if chunk.metadata.section_heading:
            quality_score += 0.05
        
        # Sentence completeness
        if chunk.metadata.sentence_count >= 2:
            quality_score += 0.05
        
        quality_score = min(1.0, quality_score)
        
        return {
            "is_valid": quality_score >= 0.3,
            "reason": "Quality check passed" if quality_score >= 0.3 else "Low quality score",
            "score": quality_score,
        }
    
    def get_chunking_stats(self, chunks: List[Chunk]) -> Dict:
        """Get statistics about chunking results"""
        if not chunks:
            return {}
        
        word_counts = [chunk.metadata.word_count for chunk in chunks]
        quality_scores = [chunk.quality_score for chunk in chunks]
        
        return {
            "total_chunks": len(chunks),
            "avg_word_count": np.mean(word_counts),
            "min_word_count": min(word_counts),
            "max_word_count": max(word_counts),
            "avg_quality_score": np.mean(quality_scores),
            "min_quality_score": min(quality_scores),
            "max_quality_score": max(quality_scores),
            "total_words": sum(word_counts),
            "chunks_with_headings": sum(1 for c in chunks if c.metadata.section_heading),
            "chunks_with_math": sum(1 for c in chunks if c.metadata.math_expressions),
            "chunks_with_code": sum(1 for c in chunks if c.metadata.code_blocks),
        }

🔎 WEEK 5-6: RETRIEVAL SYSTEM

Task 3.1: Vector Search Implementation

Priority: High
Estimated Time: 14 hours
Dependencies: Task 2.2

Subtasks:

  • Implement FAISS vector store
  • Create indexing pipeline
  • Add similarity search
  • Implement batch operations
  • Add index optimization
  • Create search performance monitoring

Implementation:

src/core/vector_store.py

import os
import pickle
import numpy as np
import faiss
from typing import List, Dict, Tuple, Optional, Any
from pathlib import Path
from ..config.settings import VectorStoreConfig
from ..models.chunk import Chunk
from ..utils.logger import get_logger

logger = get_logger(__name__)

class VectorStore:
    """FAISS-based vector store for efficient similarity search"""
    
    def __init__(self, config: VectorStoreConfig, index_path: str = "data/indices"):
        self.config = config
        self.index_path = Path(index_path)
        self.index_path.mkdir(parents=True, exist_ok=True)
        
        self.index = None
        self.chunk_mapping = {}  # Maps index position to chunk ID
        self.embedding_dim = config.dimension
        
        self._initialize_index()
    
    def _initialize_index(self):
        """Initialize FAISS index based on configuration"""
        try:
            if self.config.index_type == "IVF":
                # Inverted File Index
                quantizer = faiss.IndexFlatIP(self.embedding_dim)
                self.index = faiss.IndexIVFFlat(
                    quantizer, 
                    self.embedding_dim, 
                    self.config.nlist, 
                    faiss.METRIC_INNER_PRODUCT
                )
                logger.info(f"Created IVF index with {self.config.nlist} clusters")
                
            elif self.config.index_type == "HNSW":
                # Hierarchical Navigable Small World
                self.index = faiss.IndexHNSWFlat(self.embedding_dim, 32)
                logger.info("Created HNSW index")
                
            else:
                # Flat Index (exact search)
                self.index = faiss.IndexFlatIP(self.embedding_dim)
                logger.info("Created Flat index")
            
            # Move to GPU if available and requested
            if self.config.use_gpu and faiss.get_num_gpus() > 0:
                try:
                    res = faiss.StandardGpuResources()
                    self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
                    logger.info("Index moved to GPU")
                except Exception as e:
                    logger.warning(f"Failed to move index to GPU: {e}")
            
            # Try to load existing index
            self._load_index()
            
        except Exception as e:
            logger.error(f"Failed to initialize index: {e}")
            raise
    
    def add_embeddings(self, embeddings: np.ndarray, chunk_ids: List[str]):
        """
        Add embeddings to the index
        
        Args:
            embeddings: Array of embeddings to add
            chunk_ids: Corresponding chunk IDs
        """
        if len(embeddings) != len(chunk_ids):
            raise ValueError("Number of embeddings must match number of chunk IDs")
        
        if len(embeddings) == 0:
            return
        
        try:
            # Ensure embeddings are the right shape and type
            embeddings = embeddings.astype(np.float32)
            
            # Normalize embeddings for inner product similarity
            if self.config.metric == "INNER_PRODUCT":
                faiss.normalize_L2(embeddings)
            
            # Get current index size
            current_size = self.index.ntotal
            
            # Add embeddings to index
            self.index.add(embeddings)
            
            # Update chunk mapping
            for i, chunk_id in enumerate(chunk_ids):
                self.chunk_mapping[current_size + i] = chunk_id
            
            logger.info(f"Added {len(embeddings)} embeddings to index (total: {self.index.ntotal})")
            
        except Exception as e:
            logger.error(f"Failed to add embeddings: {e}")
            raise
    
    def search(
        self, 
        query_embedding: np.ndarray, 
        top_k: int = 10,
        nprobe: Optional[int] = None
    ) -> List[Tuple[str, float]]:
        """
        Search for similar embeddings
        
        Args:
            query_embedding: Query embedding
            top_k: Number of results to return
            nprobe: Number of clusters to probe (for IVF index)
            
        Returns:
            List of (chunk_id, similarity_score) tuples
        """
        if self.index.ntotal == 0:
            return []
        
        try:
            # Prepare query
            query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
            
            # Normalize for inner product
            if self.config.metric == "INNER_PRODUCT":
                faiss.normalize_L2(query_embedding)
            
            # Set nprobe for IVF index
            if nprobe and hasattr(self.index, 'nprobe'):
                self.index.nprobe = nprobe
            elif hasattr(self.index, 'nprobe') and self.config.nprobe:
                self.index.nprobe = self.config.nprobe
            
            # Search
            similarities, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
            
            # Convert to chunk IDs and scores
            results = []
            for similarity, idx in zip(similarities[0], indices[0]):
                if idx >= 0 and idx < len(self.chunk_mapping):  # Valid index
                    chunk_id = self.chunk_mapping[idx]
                    results.append((chunk_id, float(similarity)))
            
            return results
            
        except Exception as e:
            logger.error(f"Search failed: {e}")
            return []
    
    def batch_search(
        self, 
        query_embeddings: np.ndarray, 
        top_k: int = 10,
        nprobe: Optional[int] = None
    ) -> List[List[Tuple[str, float]]]:
        """
        Batch search for multiple queries
        
        Args:
            query_embeddings: Array of query embeddings
            top_k: Number of results per query
            nprobe: Number of clusters to probe
            
        Returns:
            List of result lists, one per query
        """
        if self.index.ntotal == 0:
            return [[] for _ in range(len(query_embeddings))]
        
        try:
            # Prepare queries
            query_embeddings = query_embeddings.astype(np.float32)
            
            # Normalize for inner product
            if self.config.metric == "INNER_PRODUCT":
                faiss.normalize_L2(query_embeddings)
            
            # Set nprobe for IVF index
            if nprobe and hasattr(self.index, 'nprobe'):
                self.index.nprobe = nprobe
            elif hasattr(self.index, 'nprobe') and self.config.nprobe:
                self.index.nprobe = self.config.nprobe
            
            # Search
            similarities, indices = self.index.search(
                query_embeddings, 
                min(top_k, self.index.ntotal)
            )
            
            # Convert results
            all_results = []
            for sim_row, idx_row in zip(similarities, indices):
                query_results = []
                for similarity, idx in zip(sim_row, idx_row):
                    if idx >= 0 and idx < len(self.chunk_mapping):
                        chunk_id = self.chunk_mapping[idx]
                        query_results.append((chunk_id, float(similarity)))
                all_results.append(query_results)
            
            return all_results
            
        except Exception as e:
            logger.error(f"Batch search failed: {e}")
            return [[] for _ in range(len(query_embeddings))]
    
    def remove_embeddings(self, chunk_ids: List[str]):
        """
        Remove embeddings from index (rebuilds index)
        
        Args:
            chunk_ids: Chunk IDs to remove
        """
        if not chunk_ids:
            return
        
        try:
            # FAISS doesn't support removal, so we need to rebuild
            logger.info("Rebuilding index without specified chunks...")
            
            # Get all current embeddings
            all_embeddings = []
            remaining_chunk_ids = []
            
            # This is a simplified approach - in production, you'd want to store
            # all embeddings and rebuild more efficiently
            for idx, chunk_id in self.chunk_mapping.items():
                if chunk_id not in chunk_ids:
                    # In practice, you'd retrieve the actual embedding here
                    # For now, we'll just skip it
                    remaining_chunk_ids.append(chunk_id)
            
            # Rebuild index with remaining chunks
            self._rebuild_index(remaining_chunk_ids)
            
            logger.info(f"Rebuilt index with {len(remaining_chunk_ids)} chunks")
            
        except Exception as e:
            logger.error(f"Failed to remove embeddings: {e}")
            raise
    
    def _rebuild_index(self, chunk_ids: List[str]):
        """Rebuild index with specified chunks"""
        # This would require storing all embeddings externally
        # For MVP, we'll just clear the index
        self._initialize_index()
        self.chunk_mapping = {}
    
    def save_index(self, name: str = "default"):
        """Save index to disk"""
        try:
            index_file = self.index_path / f"{name}.index"
            mapping_file = self.index_path / f"{name}_mapping.pkl"
            
            # Save index
            if hasattr(self.index, 'cpu'):
                # Move to CPU before saving
                cpu_index = faiss.index_gpu_to_cpu(self.index)
                faiss.write_index(cpu_index, str(index_file))
            else:
                faiss.write_index(self.index, str(index_file))
            
            # Save mapping
            with open(mapping_file, 'wb') as f:
                pickle.dump(self.chunk_mapping, f)
            
            logger.info(f"Index saved to {index_file}")
            
        except Exception as e:
            logger.error(f"Failed to save index: {e}")
            raise
    
    def _load_index(self, name: str = "default"):
        """Load index from disk"""
        try:
            index_file = self.index_path / f"{name}.index"
            mapping_file = self.index_path / f"{name}_mapping.pkl"
            
            if index_file.exists() and mapping_file.exists():
                # Load index
                self.index = faiss.read_index(str(index_file))
                
                # Move to GPU if needed
                if self.config.use_gpu and faiss.get_num_gpus() > 0:
                    try:
                        res = faiss.StandardGpuResources()
                        self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
                    except Exception as e:
                        logger.warning(f"Failed to move loaded index to GPU: {e}")
                
                # Load mapping
                with open(mapping_file, 'rb') as f:
                    self.chunk_mapping = pickle.load(f)
                
                logger.info(f"Loaded index with {self.index.ntotal} embeddings")
                return True
            else:
                logger.info("No existing index found, starting with empty index")
                return False
                
        except Exception as e:
            logger.warning(f"Failed to load index: {e}")
            return False
    
    def get_stats(self) -> Dict:
        """Get index statistics"""
        return {
            "total_embeddings": self.index.ntotal,
            "index_type": self.config.index_type,
            "dimension": self.embedding_dim,
            "metric": self.config.metric,
            "is_trained": getattr(self.index, 'is_trained', True),
            "nlist": getattr(self.index, 'nlist', None),
            "use_gpu": self.config.use_gpu and faiss.get_num_gpus() > 0,
        }
    
    def train_index(self, embeddings: np.ndarray):
        """Train index (required for some index types)"""
        if hasattr(self.index, 'train') and not getattr(self.index, 'is_trained', True):
            try:
                embeddings = embeddings.astype(np.float32)
                if self.config.metric == "INNER_PRODUCT":
                    faiss.normalize_L2(embeddings)
                
                self.index.train(embeddings)
                logger.info(f"Index trained with {len(embeddings)} embeddings")
            except Exception as e:
                logger.error(f"Failed to train index: {e}")
                raise

Task 3.2: Hybrid Retrieval System

Priority: High
Estimated Time: 12 hours
Dependencies: Task 3.1

Subtasks:

  • Implement keyword search (BM25)
  • Create vector similarity search
  • Build hybrid ranking algorithm
  • Add metadata filtering
  • Implement result fusion
  • Create performance optimization

Implementation:

src/retrieval/hybrid_search.py

import numpy as np
from typing import List, Dict, Tuple, Optional, Any
from collections import Counter, defaultdict
import math
from ..core.vector_store import VectorStore
from ..core.embeddings import EmbeddingService
from ..models.chunk import Chunk
from ..models.query import Query, QueryResult
from .keyword_search import KeywordSearcher
from .ranker import ResultRanker
from ..utils.logger import get_logger

logger = get_logger(__name__)

class HybridSearcher:
    """Hybrid search combining keyword and vector search"""
    
    def __init__(
        self,
        vector_store: VectorStore,
        embedding_service: EmbeddingService,
        keyword_searcher: KeywordSearcher,
        ranker: ResultRanker
    ):
        self.vector_store = vector_store
        self.embedding_service = embedding_service
        self.keyword_searcher = keyword_searcher
        self.ranker = ranker
    
    def search(
        self,
        query: Query,
        chunks: Dict[str, Chunk],  # chunk_id -> Chunk mapping
        top_k: int = 10,
        alpha: float = 0.5,  # Weight for hybrid combination
        rerank: bool = True
    ) -> QueryResult:
        """
        Perform hybrid search
        
        Args:
            query: Query object
            chunks: Mapping of chunk IDs to Chunk objects
            top_k: Number of results to return
            alpha: Weight for combining results (0=keyword only, 1=vector only)
            rerank: Whether to apply reranking
            
        Returns:
            QueryResult object
        """
        try:
            logger.info(f"Performing hybrid search for query: {query.text[:50]}...")
            
            # Step 1: Keyword search
            keyword_results = self._perform_keyword_search(query, chunks)
            
            # Step 2: Vector search
            vector_results = self._perform_vector_search(query, top_k * 2)
            
            # Step 3: Combine results
            combined_results = self._combine_results(
                keyword_results,
                vector_results,
                chunks,
                alpha
            )
            
            # Step 4: Apply metadata filtering
            filtered_results = self._apply_metadata_filters(
                combined_results,
                chunks,
                query.filters
            )
            
            # Step 5: Rerank if requested
            if rerank and len(filtered_results) > 1:
                filtered_results = self.ranker.rerank(
                    query,
                    filtered_results,
                    chunks
                )
            
            # Step 6: Limit to top_k
            final_results = filtered_results[:top_k]
            
            # Create result object
            result = QueryResult(
                query=query,
                results=final_results,
                total_found=len(combined_results),
                keyword_results_count=len(keyword_results),
                vector_results_count=len(vector_results),
                alpha=alpha,
                reranked=rerank
            )
            
            logger.info(f"Search completed: {len(final_results)} results")
            return result
            
        except Exception as e:
            logger.error(f"Hybrid search failed: {e}")
            raise
    
    def _perform_keyword_search(self, query: Query, chunks: Dict[str, Chunk]) -> List[Tuple[str, float]]:
        """Perform keyword search using BM25"""
        try:
            # Extract keywords from query
            query_keywords = self._extract_keywords(query.text)
            
            if not query_keywords:
                logger.info("No keywords found in query")
                return []
            
            # Perform BM25 search
            keyword_results = self.keyword_searcher.search(
                query_keywords,
                chunks,
                top_k=50  # Get more results for combination
            )
            
            logger.info(f"Keyword search found {len(keyword_results)} results")
            return keyword_results
            
        except Exception as e:
            logger.error(f"Keyword search failed: {e}")
            return []
    
    def _perform_vector_search(self, query: Query, top_k: int) -> List[Tuple[str, float]]:
        """Perform vector similarity search"""
        try:
            # Generate query embedding
            query_embedding = self.embedding_service.encode_single(query.text)
            
            # Search vector store
            vector_results = self.vector_store.search(query_embedding, top_k)
            
            logger.info(f"Vector search found {len(vector_results)} results")
            return vector_results
            
        except Exception as e:
            logger.error(f"Vector search failed: {e}")
            return []
    
    def _combine_results(
        self,
        keyword_results: List[Tuple[str, float]],
        vector_results: List[Tuple[str, float]],
        chunks: Dict[str, Chunk],
        alpha: float
    ) -> List[Tuple[str, float]]:
        """Combine keyword and vector search results"""
        
        # Create score dictionaries
        keyword_scores = dict(keyword_results)
        vector_scores = dict(vector_results)
        
        # Get all unique chunk IDs
        all_chunk_ids = set(keyword_scores.keys()) | set(vector_scores.keys())
        
        combined_results = []
        
        for chunk_id in all_chunk_ids:
            keyword_score = keyword_scores.get(chunk_id, 0.0)
            vector_score = vector_scores.get(chunk_id, 0.0)
            
            # Normalize scores (simple min-max normalization)
            normalized_keyword = self._normalize_score(keyword_score, keyword_results)
            normalized_vector = self._normalize_score(vector_score, vector_results)
            
            # Combine scores
            combined_score = alpha * normalized_vector + (1 - alpha) * normalized_keyword
            
            combined_results.append((chunk_id, combined_score))
        
        # Sort by combined score
        combined_results.sort(key=lambda x: x[1], reverse=True)
        
        logger.info(f"Combined {len(all_chunk_ids)} unique results")
        return combined_results
    
    def _normalize_score(self, score: float, all_scores: List[Tuple[str, float]]) -> float:
        """Normalize score to 0-1 range"""
        if not all_scores:
            return 0.0
        
        scores = [s for _, s in all_scores]
        min_score = min(scores)
        max_score = max(scores)
        
        if max_score == min_score:
            return 0.5 if score == min_score else 1.0
        
        return (score - min_score) / (max_score - min_score)
    
    def _apply_metadata_filters(
        self,
        results: List[Tuple[str, float]],
        chunks: Dict[str, Chunk],
        filters: Optional[Dict[str, Any]]
    ) -> List[Tuple[str, float]]:
        """Apply metadata filters to results"""
        if not filters:
            return results
        
        filtered_results = []
        
        for chunk_id, score in results:
            if chunk_id not in chunks:
                continue
            
            chunk = chunks[chunk_id]
            
            if self._passes_filters(chunk, filters):
                filtered_results.append((chunk_id, score))
        
        logger.info(f"Filtered {len(results)} -> {len(filtered_results)} results")
        return filtered_results
    
    def _passes_filters(self, chunk: Chunk, filters: Dict[str, Any]) -> bool:
        """Check if chunk passes all filters"""
        
        # Subject filter
        if "subject" in filters:
            if chunk.document_metadata.get("subject") != filters["subject"]:
                return False
        
        # Difficulty filter
        if "max_difficulty" in filters:
            chunk_difficulty = chunk.metadata.get("difficulty", 0.5)
            if chunk_difficulty > filters["max_difficulty"]:
                return False
        
        # Bloom level filter
        if "max_bloom_level" in filters:
            chunk_bloom = chunk.metadata.get("bloom_level", 3)
            if chunk_bloom > filters["max_bloom_level"]:
                return False
        
        # Language filter
        if "language" in filters:
            if chunk.metadata.get("language") != filters["language"]:
                return False
        
        # Keyword filter (must contain at least one)
        if "required_keywords" in filters:
            chunk_text = chunk.text.lower()
            chunk_keywords = chunk.metadata.get("keywords", [])
            
            has_keyword = any(
                keyword.lower() in chunk_text or keyword in chunk_keywords
                for keyword in filters["required_keywords"]
            )
            if not has_keyword:
                return False
        
        return True
    
    def _extract_keywords(self, text: str) -> List[str]:
        """Extract keywords from query text"""
        # Simple keyword extraction - could be enhanced with NLP
        import re
        
        # Remove common stop words
        stop_words = {
            'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
            'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have',
            'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should',
            'what', 'how', 'when', 'where', 'why', 'who', 'which', 'that', 'this'
        }
        
        # Extract words and clean
        words = re.findall(r'\b[a-zA-Z]+\b', text.lower())
        keywords = [word for word in words if word not in stop_words and len(word) > 2]
        
        # Return unique keywords
        return list(set(keywords))
    
    def batch_search(
        self,
        queries: List[Query],
        chunks: Dict[str, Chunk],
        top_k: int = 10,
        alpha: float = 0.5,
        rerank: bool = True
    ) -> List[QueryResult]:
        """Perform batch search for multiple queries"""
        results = []
        
        for query in queries:
            try:
                result = self.search(query, chunks, top_k, alpha, rerank)
                results.append(result)
            except Exception as e:
                logger.error(f"Batch search failed for query {query.id}: {e}")
                # Add empty result for failed query
                results.append(QueryResult(
                    query=query,
                    results=[],
                    total_found=0,
                    keyword_results_count=0,
                    vector_results_count=0,
                    alpha=alpha,
                    reranked=rerank
                ))
        
        return results
    
    def get_search_stats(self) -> Dict:
        """Get search system statistics"""
        return {
            "vector_store_stats": self.vector_store.get_stats(),
            "embedding_service_stats": {
                "model_name": self.embedding_service.config.model_name,
                "dimension": self.embedding_service.model.get_sentence_embedding_dimension(),
            },
            "keyword_searcher_stats": self.keyword_searcher.get_stats(),
        }

🤖 WEEK 7-8: LLM INTEGRATION

Task 4.1: LLM Client Implementation

Priority: High
Estimated Time: 10 hours
Dependencies: Task 3.2

Subtasks:

  • Set up OpenAI API client
  • Set up Anthropic API client
  • Create prompt templates
  • Implement response generation
  • Add token counting
  • Create error handling

Implementation:

src/llm/llm_client.py

import os
import json
from typing import Dict, List, Optional, Any
import openai
import anthropic
from ..config.settings import RAGConfig
from ..utils.logger import get_logger

logger = get_logger(__name__)

class LLMClient:
    """Client for interacting with LLM APIs"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.openai_client = None
        self.anthropic_client = None
        
        self._initialize_clients()
    
    def _initialize_clients(self):
        """Initialize API clients"""
        try:
            # Initialize OpenAI client
            if self.config.llm_provider in ["openai", "both"]:
                openai_api_key = os.getenv("OPENAI_API_KEY")
                if openai_api_key:
                    self.openai_client = openai.OpenAI(api_key=openai_api_key)
                    logger.info("OpenAI client initialized")
                else:
                    logger.warning("OpenAI API key not found")
            
            # Initialize Anthropic client
            if self.config.llm_provider in ["anthropic", "both"]:
                anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
                if anthropic_api_key:
                    self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)
                    logger.info("Anthropic client initialized")
                else:
                    logger.warning("Anthropic API key not found")
                    
        except Exception as e:
            logger.error(f"Failed to initialize LLM clients: {e}")
            raise
    
    def generate_response(
        self,
        prompt: str,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        provider: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Generate response from LLM
        
        Args:
            prompt: Input prompt
            max_tokens: Maximum tokens in response
            temperature: Sampling temperature
            provider: LLM provider to use
            
        Returns:
            Response dictionary with text and metadata
        """
        provider = provider or self.config.llm_provider
        max_tokens = max_tokens or self.config.max_response_tokens
        temperature = temperature or self.config.temperature
        
        try:
            if provider == "anthropic" and self.anthropic_client:
                return self._generate_anthropic_response(prompt, max_tokens, temperature)
            elif provider == "openai" and self.openai_client:
                return self._generate_openai_response(prompt, max_tokens, temperature)
            else:
                raise ValueError(f"Provider {provider} not available")
                
        except Exception as e:
            logger.error(f"LLM generation failed: {e}")
            raise
    
    def _generate_anthropic_response(
        self,
        prompt: str,
        max_tokens: int,
        temperature: float
    ) -> Dict[str, Any]:
        """Generate response using Anthropic Claude"""
        try:
            message = self.anthropic_client.messages.create(
                model=self.config.llm_model,
                max_tokens=max_tokens,
                temperature=temperature,
                messages=[
                    {
                        "role": "user",
                        "content": prompt
                    }
                ]
            )
            
            response_text = message.content[0].text
            
            # Calculate token usage (approximate)
            prompt_tokens = self._estimate_tokens(prompt)
            completion_tokens = self._estimate_tokens(response_text)
            
            return {
                "text": response_text,
                "provider": "anthropic",
                "model": self.config.llm_model,
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": prompt_tokens + completion_tokens,
                "temperature": temperature,
                "finish_reason": message.stop_reason,
            }
            
        except Exception as e:
            logger.error(f"Anthropic generation failed: {e}")
            raise
    
    def _generate_openai_response(
        self,
        prompt: str,
        max_tokens: int,
        temperature: float
    ) -> Dict[str, Any]:
        """Generate response using OpenAI GPT"""
        try:
            response = self.openai_client.chat.completions.create(
                model=self.config.llm_model,
                messages=[
                    {
                        "role": "system",
                        "content": "You are a helpful AI assistant."
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                max_tokens=max_tokens,
                temperature=temperature,
            )
            
            response_text = response.choices[0].message.content
            
            return {
                "text": response_text,
                "provider": "openai",
                "model": self.config.llm_model,
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens,
                "temperature": temperature,
                "finish_reason": response.choices[0].finish_reason,
            }
            
        except Exception as e:
            logger.error(f"OpenAI generation failed: {e}")
            raise
    
    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count (rough approximation)"""
        # Simple approximation: ~4 characters per token
        return max(1, len(text) // 4)
    
    def validate_response(self, response: str, context: str) -> Dict[str, Any]:
        """Validate response quality and safety"""
        validation = {
            "is_safe": True,
            "is_relevant": True,
            "is_appropriate": True,
            "issues": [],
            "confidence": 1.0,
        }
        
        # Check for unsafe content
        unsafe_patterns = [
            "hate", "violence", "self-harm", "explicit", "illegal"
        ]
        
        response_lower = response.lower()
        for pattern in unsafe_patterns:
            if pattern in response_lower:
                validation["is_safe"] = False
                validation["issues"].append(f"Potentially unsafe content: {pattern}")
                validation["confidence"] *= 0.5
        
        # Check relevance to context
        if context:
            context_words = set(context.lower().split())
            response_words = set(response.lower().split())
            overlap = len(context_words & response_words)
            relevance_score = overlap / len(response_words) if response_words else 0
            
            if relevance_score < 0.1:
                validation["is_relevant"] = False
                validation["issues"].append("Low relevance to context")
                validation["confidence"] *= 0.7
        
        # Check for appropriate length
        if len(response) < 10:
            validation["is_appropriate"] = False
            validation["issues"].append("Response too short")
            validation["confidence"] *= 0.8
        elif len(response) > 2000:
            validation["issues"].append("Response very long")
            validation["confidence"] *= 0.9
        
        return validation
    
    def get_model_info(self, provider: str) -> Dict[str, Any]:
        """Get information about a specific model"""
        model_info = {
            "anthropic": {
                "claude-3-5-sonnet-20241022": {
                    "max_tokens": 4096,
                    "context_window": 200000,
                    "cost_per_1k_input": 0.003,
                    "cost_per_1k_output": 0.015,
                }
            },
            "openai": {
                "gpt-4": {
                    "max_tokens": 4096,
                    "context_window": 8192,
                    "cost_per_1k_input": 0.03,
                    "cost_per_1k_output": 0.06,
                },
                "gpt-3.5-turbo": {
                    "max_tokens": 4096,
                    "context_window": 16385,
                    "cost_per_1k_input": 0.0015,
                    "cost_per_1k_output": 0.002,
                }
            }
        }
        
        return model_info.get(provider, {})
    
    def estimate_cost(self, prompt_tokens: int, completion_tokens: int, provider: str) -> float:
        """Estimate cost for API call"""
        model_info = self.get_model_info(provider)
        model_data = model_info.get(self.config.llm_model, {})
        
        input_cost = (prompt_tokens / 1000) * model_data.get("cost_per_1k_input", 0)
        output_cost = (completion_tokens / 1000) * model_data.get("cost_per_1k_output", 0)
        
        return input_cost + output_cost

Task 4.2: Prompt Engineering

Priority: High
Estimated Time: 8 hours
Dependencies: Task 4.1

Subtasks:

  • Create prompt templates for different modes
  • Implement context injection
  • Add constraint enforcement
  • Create safety prompts
  • Build prompt optimization
  • Add prompt testing

Implementation:

src/llm/prompt_builder.py

from typing import Dict, List, Optional, Any
from ..models.query import Query
from ..models.chunk import Chunk
from ..config.settings import RAGConfig
from ..utils.logger import get_logger

logger = get_logger(__name__)

class PromptBuilder:
    """Builds prompts for different interaction modes"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
    
    def build_prompt(
        self,
        query: Query,
        retrieved_chunks: List[Chunk],
        mode: str = "EXPLANATION",
        student_level: int = 2,
        constraints: Optional[Dict] = None
    ) -> str:
        """
        Build complete prompt for LLM
        
        Args:
            query: Student query
            retrieved_chunks: Retrieved context chunks
            mode: Interaction mode
            student_level: Student's current level (1-6 Bloom's)
            constraints: Additional constraints
            
        Returns:
            Complete prompt string
        """
        try:
            # Get mode-specific template
            template = self._get_mode_template(mode)
            
            # Build context section
            context_section = self._build_context_section(retrieved_chunks)
            
            # Build constraints section
            constraints_section = self._build_constraints_section(
                mode, student_level, constraints
            )
            
            # Build query section
            query_section = self._build_query_section(query)
            
            # Combine all sections
            prompt = template.format(
                constraints_section=constraints_section,
                context_section=context_section,
                query_section=query_section,
                mode=mode,
                student_level=student_level
            )
            
            logger.info(f"Built prompt for mode {mode}, {len(retrieved_chunks)} chunks")
            return prompt
            
        except Exception as e:
            logger.error(f"Failed to build prompt: {e}")
            raise
    
    def _get_mode_template(self, mode: str) -> str:
        """Get prompt template for specific mode"""
        templates = {
            "EXPLANATION": """
{constraints_section}

CONTEXT:
{context_section}

STUDENT QUESTION:
{query_section}

INSTRUCTIONS:
Generate a clear, educational explanation that:
1. Uses ONLY the provided context above
2. Is appropriate for a student at level {student_level} (Bloom's taxonomy)
3. Includes 1-2 concrete examples if helpful
4. Avoids complex proofs unless appropriate for the level
5. Ends with a guiding question or next step
6. Is concise but comprehensive (max 300 words)

Focus on understanding over memorization.
""",
            
            "TUTOR": """
{constraints_section}

CONTEXT:
{context_section}

STUDENT QUESTION:
{query_section}

INSTRUCTIONS:
Act as a Socratic tutor. Guide the student to discover the answer themselves:
1. Start with a clarifying question
2. Provide hints progressively, not the full answer
3. Use the context to guide your questions
4. Check for understanding between steps
5. Encourage critical thinking
6. Adapt to the student's level {student_level}

Never give the complete answer immediately. Guide, don't tell.
""",
            
            "EXAM": """
{constraints_section}

CONTEXT:
{context_section}

STUDENT QUESTION:
{query_section}

INSTRUCTIONS:
You are proctoring an exam. Provide minimal assistance:
1. Answer only if the question is about exam format or instructions
2. Do not provide content answers or hints
3. If the question is unclear, ask for clarification
4. Maintain formal, neutral tone
5. Do not use the provided context to answer content questions

If the question asks for content knowledge, respond: "I cannot provide answers to exam questions. Please focus on the question and use your knowledge."
""",
            
            "QUIZ": """
{constraints_section}

CONTEXT:
{context_section}

STUDENT QUESTION:
{query_section}

INSTRUCTIONS:
Create an interactive quiz experience:
1. Turn the student's question into a quiz question if appropriate
2. Or provide a related quiz question based on the context
3. Include multiple choice options if suitable
4. Ask the student to attempt an answer
5. Provide immediate feedback on their response
6. Use the context to ensure accuracy
7. Keep it engaging and educational

Level: {student_level}
""",
            
            "EXPLORATION": """
{constraints_section}

CONTEXT:
{context_section}

STUDENT QUESTION:
{query_section}

INSTRUCTIONS:
Encourage deeper exploration and curiosity:
1. Go beyond the basic answer
2. Connect to related concepts and real-world applications
3. Ask "what if" and "why" questions
4. Suggest extensions and further learning
5. Use the context as a starting point, not a limit
6. Inspire curiosity about the subject
7. Adapt to level {student_level} but challenge appropriately

Be engaging and thought-provoking.
""",
            
            "REMEDIAL": """
{constraints_section}

CONTEXT:
{context_section}

STUDENT QUESTION:
{query_section}

INSTRUCTIONS:
Address identified misconceptions patiently:
1. Directly acknowledge the confusion
2. Explain why the misconception is incorrect
3. Provide the correct mental model
4. Use simple, clear language
5. Include concrete examples to illustrate
6. Build confidence with positive reinforcement
7. Check for understanding before moving on

Level: {student_level} - focus on building foundation.
""",
        }
        
        return templates.get(mode, templates["EXPLANATION"])
    
    def _build_context_section(self, chunks: List[Chunk]) -> str:
        """Build context section from retrieved chunks"""
        if not chunks:
            return "No relevant context found."
        
        context_parts = []
        
        for i, chunk in enumerate(chunks, 1):
            chunk_text = chunk.text.strip()
            
            # Add chunk header
            header = f"[Source {i}: {chunk.document_metadata.get('concept', 'Unknown')}]"
            
            # Add chunk content
            context_parts.append(header)
            context_parts.append(chunk_text)
            
            # Add separator
            context_parts.append("")
        
        return "\n".join(context_parts)
    
    def _build_constraints_section(
        self,
        mode: str,
        student_level: int,
        constraints: Optional[Dict]
    ) -> str:
        """Build constraints section"""
        constraint_parts = []
        
        # Core constraints
        constraint_parts.append("CORE RULES:")
        constraint_parts.append("- Use ONLY the provided context for factual information")
        constraint_parts.append("- Never use your general knowledge for core content")
        constraint_parts.append("- Admit uncertainty if context is insufficient")
        constraint_parts.append("- Maintain educational appropriateness")
        
        # Mode-specific constraints
        if mode == "EXPLANATION":
            constraint_parts.append("\nEXPLANATION CONSTRAINTS:")
            constraint_parts.append(f"- Target Bloom's level: {student_level}")
            constraint_parts.append("- Include examples when helpful")
            constraint_parts.append("- Avoid unnecessary complexity")
            
        elif mode == "TUTOR":
            constraint_parts.append("\nTUTOR CONSTRAINTS:")
            constraint_parts.append("- Ask guiding questions first")
            constraint_parts.append("- Reveal information progressively")
            constraint_parts.append("- Check understanding between steps")
            
        elif mode == "EXAM":
            constraint_parts.append("\nEXAM CONSTRAINTS:")
            constraint_parts.append("- No content assistance")
            constraint_parts.append("- Formal tone only")
            constraint_parts.append("- Focus on exam instructions")
        
        # Student level constraints
        constraint_parts.append(f"\nLEVEL CONSTRAINTS (Level {student_level}):")
        
        if student_level <= 2:
            constraint_parts.append("- Focus on basic understanding")
            constraint_parts.append("- Use simple language")
            constraint_parts.append("- Avoid abstract concepts")
        elif student_level <= 4:
            constraint_parts.append("- Include application examples")
            constraint_parts.append("- Introduce some analysis")
            constraint_parts.append("- Balance simplicity and depth")
        else:
            constraint_parts.append("- Encourage critical thinking")
            constraint_parts.append("- Include complex applications")
            constraint_parts.append("- Allow for abstract reasoning")
        
        # Additional constraints
        if constraints:
            constraint_parts.append("\nADDITIONAL CONSTRAINTS:")
            for key, value in constraints.items():
                constraint_parts.append(f"- {key}: {value}")
        
        return "\n".join(constraint_parts)
    
    def _build_query_section(self, query: Query) -> str:
        """Build query section"""
        query_parts = []
        
        # Add original query
        query_parts.append(f"Question: {query.text}")
        
        # Add context if available
        if query.context:
            query_parts.append(f"Context: {query.context}")
        
        # Add student info if available
        if query.student_info:
            info_parts = []
            if "grade_level" in query.student_info:
                info_parts.append(f"Grade: {query.student_info['grade_level']}")
            if "subject" in query.student_info:
                info_parts.append(f"Subject: {query.student_info['subject']}")
            if "recent_topics" in query.student_info:
                info_parts.append(f"Recent topics: {', '.join(query.student_info['recent_topics'])}")
            
            if info_parts:
                query_parts.append(f"Student Info: {', '.join(info_parts)}")
        
        return "\n".join(query_parts)
    
    def build_system_prompt(self, mode: str, student_level: int) -> str:
        """Build system prompt for LLM"""
        system_prompts = {
            "EXPLANATION": f"You are an educational AI tutor specializing in clear explanations for students at Bloom's level {student_level}. Your goal is to build understanding through structured, example-rich explanations.",
            
            "TUTOR": f"You are a Socratic tutor guiding students at level {student_level}. Your role is to ask thoughtful questions that lead students to discover answers themselves, using the provided context as your knowledge base.",
            
            "EXAM": "You are an exam proctor maintaining academic integrity. Provide only procedural assistance and never help with exam content.",
            
            "QUIZ": f"You are an interactive quiz creator for students at level {student_level}. Transform questions into engaging learning opportunities with immediate feedback.",
            
            "EXPLORATION": f"You are an educational guide encouraging deeper exploration for students at level {student_level}. Connect concepts to real-world applications and inspire curiosity.",
            
            "REMEDIAL": f"You are a patient remedial tutor helping students at level {student_level} overcome misconceptions. Build confidence through clear, step-by-step explanations.",
        }
        
        return system_prompts.get(mode, system_prompts["EXPLANATION"])
    
    def validate_prompt(self, prompt: str) -> Dict[str, Any]:
        """Validate prompt quality and safety"""
        validation = {
            "is_valid": True,
            "issues": [],
            "suggestions": [],
            "token_estimate": len(prompt) // 4,
        }
        
        # Check for prompt injection attempts
        injection_patterns = [
            "ignore previous instructions",
            "forget everything above",
            "system prompt",
            "you are now",
            "pretend you are",
        ]
        
        prompt_lower = prompt.lower()
        for pattern in injection_patterns:
            if pattern in prompt_lower:
                validation["is_valid"] = False
                validation["issues"].append(f"Potential injection: {pattern}")
        
        # Check length
        if validation["token_estimate"] > 8000:
            validation["issues"].append("Prompt very long, may exceed context window")
            validation["suggestions"].append("Consider reducing context or chunking")
        
        # Check for required sections
        required_sections = ["CONTEXT:", "STUDENT QUESTION:", "INSTRUCTIONS:"]
        for section in required_sections:
            if section not in prompt:
                validation["issues"].append(f"Missing required section: {section}")
        
        return validation
    
    def optimize_prompt(self, prompt: str, max_tokens: int = 4000) -> str:
        """Optimize prompt to fit within token limit"""
        current_tokens = len(prompt) // 4
        
        if current_tokens <= max_tokens:
            return prompt
        
        # If too long, truncate context section first
        context_start = prompt.find("CONTEXT:")
        context_end = prompt.find("STUDENT QUESTION:")
        
        if context_start != -1 and context_end != -1:
            context_section = prompt[context_start:context_end]
            other_sections = prompt[:context_start] + prompt[context_end:]
            
            # Calculate allowed context length
            other_tokens = len(other_sections) // 4
            allowed_context_tokens = max_tokens - other_tokens
            
            if allowed_context_tokens > 100:
                # Truncate context proportionally
                context_ratio = allowed_context_tokens / (len(context_section) // 4)
                truncated_context = context_section[:int(len(context_section) * context_ratio)]
                
                # Add truncation notice
                truncated_context += "\n[Context truncated due to length limits]\n"
                
                return other_sections[:context_start] + truncated_context + other_sections[context_start:]
        
        # If still too long, truncate overall
        return prompt[:max_tokens * 4] + "\n[Prompt truncated]"

📊 WEEK 9-10: INTEGRATION & OPTIMIZATION

Task 5.1: End-to-End Pipeline

Priority: High
Estimated Time: 16 hours
Dependencies: Task 4.2

Subtasks:

  • Integrate all components
  • Create pipeline orchestration
  • Add error handling and recovery
  • Implement caching
  • Add performance monitoring
  • Create pipeline testing

Implementation:

src/main.py

import asyncio
import time
from typing import Dict, List, Optional, Any
from pathlib import Path
import json

from .config.settings import RAGConfig, DEFAULT_CONFIG
from .core.embeddings import EmbeddingService
from .core.vector_store import VectorStore
from .core.indexer import Indexer
from .preprocessing.text_processor import TextProcessor
from .preprocessing.chunker import IntelligentChunker
from .retrieval.hybrid_search import HybridSearcher
from .retrieval.keyword_search import KeywordSearcher
from .retrieval.ranker import ResultRanker
from .llm.llm_client import LLMClient
from .llm.prompt_builder import PromptBuilder
from .models.document import Document
from .models.query import Query
from .utils.logger import get_logger

logger = get_logger(__name__)

class RAGEngine:
    """Main RAG engine orchestrating all components"""
    
    def __init__(self, config: Optional[RAGConfig] = None):
        self.config = config or DEFAULT_CONFIG
        self.components = {}
        self._initialize_components()
    
    def _initialize_components(self):
        """Initialize all RAG components"""
        try:
            logger.info("Initializing RAG engine components...")
            
            # Core components
            self.components["embeddings"] = EmbeddingService(self.config.embeddings)
            self.components["vector_store"] = VectorStore(self.config.vector_store)
            self.components["text_processor"] = TextProcessor(self.config.chunking)
            self.components["chunker"] = IntelligentChunker(self.config.chunking)
            
            # Retrieval components
            self.components["keyword_searcher"] = KeywordSearcher()
            self.components["ranker"] = ResultRanker()
            self.components["hybrid_searcher"] = HybridSearcher(
                vector_store=self.components["vector_store"],
                embedding_service=self.components["embeddings"],
                keyword_searcher=self.components["keyword_searcher"],
                ranker=self.components["ranker"]
            )
            
            # LLM components
            self.components["llm_client"] = LLMClient(self.config)
            self.components["prompt_builder"] = PromptBuilder(self.config)
            
            # Indexer for document processing
            self.components["indexer"] = Indexer(
                embedding_service=self.components["embeddings"],
                vector_store=self.components["vector_store"],
                chunker=self.components["chunker"]
            )
            
            logger.info("All components initialized successfully")
            
        except Exception as e:
            logger.error(f"Failed to initialize components: {e}")
            raise
    
    async def process_document(
        self,
        text: str,
        document_metadata: Dict,
        chunk_metadata: Optional[Dict] = None
    ) -> Dict[str, Any]:
        """
        Process a document and add it to the index
        
        Args:
            text: Document text
            document_metadata: Document metadata
            chunk_metadata: Default chunk metadata
            
        Returns:
            Processing results
        """
        try:
            logger.info(f"Processing document: {document_metadata.get('title', 'Unknown')}")
            start_time = time.time()
            
            # Use indexer to process document
            result = await self.components["indexer"].process_document(
                text, document_metadata, chunk_metadata
            )
            
            processing_time = time.time() - start_time
            
            logger.info(f"Document processed in {processing_time:.2f}s: {result['chunk_count']} chunks")
            
            return {
                "success": True,
                "chunk_count": result["chunk_count"],
                "processing_time": processing_time,
                "chunks": result["chunks"],
                "stats": result["stats"]
            }
            
        except Exception as e:
            logger.error(f"Document processing failed: {e}")
            return {
                "success": False,
                "error": str(e),
                "processing_time": time.time() - start_time
            }
    
    async def query(
        self,
        query_text: str,
        mode: str = "EXPLANATION",
        top_k: int = 5,
        filters: Optional[Dict] = None,
        student_info: Optional[Dict] = None
    ) -> Dict[str, Any]:
        """
        Process a student query
        
        Args:
            query_text: Student's question
            mode: Interaction mode
            top_k: Number of results to retrieve
            filters: Metadata filters
            student_info: Student information
            
        Returns:
            Query response
        """
        try:
            logger.info(f"Processing query: {query_text[:50]}...")
            start_time = time.time()
            
            # Create query object
            query = Query(
                text=query_text,
                mode=mode,
                filters=filters or {},
                student_info=student_info or {}
            )
            
            # Get all chunks (in practice, this would be loaded from database)
            chunks = {}  # This would be loaded from storage
            
            # Perform hybrid search
            search_result = self.components["hybrid_searcher"].search(
                query=query,
                chunks=chunks,
                top_k=top_k,
                alpha=self.config.retrieval.hybrid_alpha,
                rerank=self.config.retrieval.rerank
            )
            
            # Get retrieved chunks
            retrieved_chunks = [
                chunks[chunk_id] for chunk_id, _ in search_result.results
                if chunk_id in chunks
            ]
            
            if not retrieved_chunks:
                return {
                    "success": True,
                    "response": "I don't have relevant information to answer that question.",
                    "retrieved_chunks": [],
                    "processing_time": time.time() - start_time,
                    "no_context": True
                }
            
            # Build prompt
            prompt = self.components["prompt_builder"].build_prompt(
                query=query,
                retrieved_chunks=retrieved_chunks,
                mode=mode,
                student_level=student_info.get("level", 2) if student_info else 2
            )
            
            # Generate response
            llm_response = self.components["llm_client"].generate_response(
                prompt=prompt,
                max_tokens=self.config.max_response_tokens,
                temperature=self.config.temperature
            )
            
            # Validate response
            validation = self.components["llm_client"].validate_response(
                llm_response["text"],
                "\n".join(chunk.text for chunk in retrieved_chunks)
            )
            
            processing_time = time.time() - start_time
            
            result = {
                "success": True,
                "response": llm_response["text"],
                "retrieved_chunks": [
                    {
                        "id": chunk.id,
                        "text": chunk.text[:200] + "...",  # Truncate for response
                        "score": score,
                        "metadata": chunk.metadata.__dict__
                    }
                    for chunk, (_, score) in zip(retrieved_chunks, search_result.results)
                ],
                "processing_time": processing_time,
                "search_stats": {
                    "total_found": search_result.total_found,
                    "keyword_results": search_result.keyword_results_count,
                    "vector_results": search_result.vector_results_count,
                },
                "llm_stats": {
                    "provider": llm_response["provider"],
                    "model": llm_response["model"],
                    "tokens": llm_response["total_tokens"],
                    "cost": self.components["llm_client"].estimate_cost(
                        llm_response["prompt_tokens"],
                        llm_response["completion_tokens"],
                        llm_response["provider"]
                    )
                },
                "validation": validation
            }
            
            logger.info(f"Query processed in {processing_time:.2f}s")
            return result
            
        except Exception as e:
            logger.error(f"Query processing failed: {e}")
            return {
                "success": False,
                "error": str(e),
                "processing_time": time.time() - start_time
            }
    
    async def batch_query(
        self,
        queries: List[Dict[str, Any]],
        top_k: int = 5
    ) -> List[Dict[str, Any]]:
        """Process multiple queries in batch"""
        results = []
        
        for query_data in queries:
            try:
                result = await self.query(
                    query_text=query_data["text"],
                    mode=query_data.get("mode", "EXPLANATION"),
                    top_k=top_k,
                    filters=query_data.get("filters"),
                    student_info=query_data.get("student_info")
                )
                results.append(result)
            except Exception as e:
                logger.error(f"Batch query failed: {e}")
                results.append({
                    "success": False,
                    "error": str(e),
                    "query": query_data["text"]
                })
        
        return results
    
    def get_stats(self) -> Dict[str, Any]:
        """Get comprehensive system statistics"""
        return {
            "config": {
                "vector_store": self.components["vector_store"].get_stats(),
                "embeddings": {
                    "model": self.config.embeddings.model_name,
                    "dimension": self.config.embeddings.dimension,
                }
            },
            "search": self.components["hybrid_searcher"].get_search_stats(),
            "timestamp": time.time()
        }
    
    def save_index(self, name: str = "default"):
        """Save the current index"""
        self.components["vector_store"].save_index(name)
        logger.info(f"Index saved as {name}")
    
    def load_index(self, name: str = "default"):
        """Load a saved index"""
        # This would require implementing load functionality
        logger.info(f"Index loading not yet implemented for {name}")
    
    async def health_check(self) -> Dict[str, Any]:
        """Perform health check on all components"""
        health_status = {
            "overall": "healthy",
            "components": {},
            "timestamp": time.time()
        }
        
        for name, component in self.components.items():
            try:
                # Simple health check - try to get basic stats
                if hasattr(component, 'get_stats'):
                    stats = component.get_stats()
                    health_status["components"][name] = {
                        "status": "healthy",
                        "stats": stats
                    }
                else:
                    health_status["components"][name] = {
                        "status": "healthy"
                    }
            except Exception as e:
                health_status["components"][name] = {
                    "status": "unhealthy",
                    "error": str(e)
                }
                health_status["overall"] = "degraded"
        
        return health_status
    
    async def shutdown(self):
        """Graceful shutdown"""
        logger.info("Shutting down RAG engine...")
        
        # Save index
        try:
            self.save_index()
        except Exception as e:
            logger.error(f"Failed to save index during shutdown: {e}")
        
        # Cleanup components
        for name, component in self.components.items():
            try:
                if hasattr(component, 'cleanup'):
                    component.cleanup()
            except Exception as e:
                logger.error(f"Failed to cleanup {name}: {e}")
        
        logger.info("RAG engine shutdown complete")

# Main entry point
async def main():
    """Main function for testing"""
    engine = RAGEngine()
    
    try:
        # Health check
        health = await engine.health_check()
        print("Health Status:", json.dumps(health, indent=2))
        
        # Example query
        result = await engine.query(
            query_text="What is a derivative?",
            mode="EXPLANATION",
            student_info={"level": 2}
        )
        
        print("Query Result:", json.dumps(result, indent=2))
        
    finally:
        await engine.shutdown()

if __name__ == "__main__":
    asyncio.run(main())

🧪 WEEK 11-12: TESTING & DEPLOYMENT

Task 6.1: Comprehensive Testing

Priority: High
Estimated Time: 16 hours
Dependencies: Task 5.1

Subtasks:

  • Write unit tests for all components
  • Create integration tests
  • Add performance benchmarks
  • Test error scenarios
  • Create load testing
  • Add regression tests

Test Suite Structure:

tests/
├── unit/
│   ├── test_embeddings.py
│   ├── test_vector_store.py
│   ├── test_chunker.py
│   ├── test_keyword_search.py
│   ├── test_hybrid_search.py
│   ├── test_llm_client.py
│   └── test_prompt_builder.py
├── integration/
│   ├── test_end_to_end.py
│   ├── test_document_processing.py
│   ├── test_query_pipeline.py
│   └── test_error_handling.py
├── performance/
│   ├── benchmark_retrieval.py
│   ├── benchmark_llm_calls.py
│   └── stress_test.py
├── fixtures/
│   ├── sample_documents/
│   ├── test_queries.json
│   └── expected_results/
└── conftest.py

Task 6.2: Production Deployment

Priority: High
Estimated Time: 8 hours
Dependencies: Task 6.1

Subtasks:

  • Create Docker configuration
  • Set up environment variables
  • Configure monitoring
  • Create deployment scripts
  • Set up logging
  • Create health checks

Docker Configuration:

Dockerfile

FROM python:3.11-slim

WORKDIR /app

# Install system dependencies
RUN apt-get update && apt-get install -y \
    build-essential \
    curl \
    && rm -rf /var/lib/apt/lists/*

# Copy requirements and install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application code
COPY src/ ./src/
COPY tests/ ./tests/

# Create data directories
RUN mkdir -p data/models data/indices data/chunks logs

# Set environment variables
ENV PYTHONPATH=/app
ENV RAG_CONFIG_ENV=production

# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
    CMD python -c "import asyncio; from src.main import RAGEngine; asyncio.run(RAGEngine().health_check())"

# Expose port
EXPOSE 8000

# Default command
CMD ["python", "-m", "src.main"]

📋 DELIVERABLES

Week 2 Deliverables

  • Vector database setup with FAISS
  • Embedding service with sentence-transformers
  • Basic text processing pipeline
  • Project structure and configuration

Week 4 Deliverables

  • Intelligent chunking system
  • Text preprocessing and quality validation
  • Metadata extraction
  • Content quality scoring

Week 6 Deliverables

  • Vector search implementation
  • Keyword search (BM25)
  • Hybrid retrieval system
  • Result ranking and filtering

Week 8 Deliverables

  • LLM client integration
  • Prompt engineering system
  • Response generation and validation
  • Multi-provider support

Week 10 Deliverables

  • End-to-end RAG pipeline
  • Performance optimization
  • Caching and monitoring
  • Error handling and recovery

Week 12 Deliverables

  • Comprehensive test suite
  • Production deployment
  • Documentation and monitoring
  • Performance benchmarks

📈 PERFORMANCE TARGETS

Retrieval Performance

  • Query latency: < 500ms for top-10 results
  • Indexing speed: > 1000 chunks/second
  • Memory usage: < 2GB for 100k chunks
  • Accuracy: > 80% relevance for top-5 results

LLM Performance

  • Response generation: < 3 seconds
  • Token efficiency: < 1000 tokens per response
  • Cost optimization: < $0.01 per query
  • Quality score: > 0.8 validation score

System Performance

  • Uptime: > 99.5%
  • Error rate: < 1%
  • Concurrent queries: > 100 QPS
  • Memory efficiency: < 4GB total

🔧 MONITORING & LOGGING

Key Metrics

  • Query response times
  • Retrieval accuracy
  • LLM token usage and costs
  • System resource utilization
  • Error rates and types

Logging Strategy

  • Structured JSON logging
  • Different log levels for components
  • Performance tracing
  • Error correlation

Alerting

  • High latency alerts
  • Error rate thresholds
  • Resource utilization warnings
  • Cost limit alerts

🛡️ SAFETY & RELIABILITY

Input Validation

  • Query length limits
  • Content filtering
  • Injection detection
  • Rate limiting

Output Validation

  • Response quality checks
  • Safety content detection
  • Hallucination detection
  • Context relevance validation

Fallback Strategies

  • Multiple LLM providers
  • Graceful degradation
  • Cache fallbacks
  • Error recovery

Last Updated: 2026-05-06 Version: 1.0.0 RAG Engine Lead: ML Engineering Team