3280 lines
109 KiB
Markdown
3280 lines
109 KiB
Markdown
# RAG Engine MVP Tasks - AI Study Assistant
|
|
|
|
> ⚠️ **IMPORTANTE - DOCUMENTO DESATUALIZADO**: Este documento descreve uma arquitetura Python/FAISS que **NÃO FOI IMPLEMENTADA**.
|
|
>
|
|
> **Implementação Real:**
|
|
> - **Linguagem**: Dart (Flutter)
|
|
> - **Localização**: `lib/core/services/materials_rag_service.dart`, `lib/core/services/rag_ai_service.dart`
|
|
> - **Vector Store**: Firestore com embeddings mock (hash-based)
|
|
> - **PDF Processing**: `syncfusion_flutter_pdf` (não Python)
|
|
> - **Busca**: Keyword window search (não FAISS)
|
|
>
|
|
> **NÃO EXISTE:** Python, FAISS, Sentence Transformers, OpenAI, Anthropic
|
|
|
|
---
|
|
|
|
## 🧠 MVP RAG ENGINE ROADMAP (DOCUMENTAÇÃO ORIGINAL - NÃO IMPLEMENTADA)
|
|
|
|
---
|
|
|
|
## 📚 WEEK 1-2: FOUNDATION & SETUP (NOT IMPLEMENTED)
|
|
|
|
### Task 1.1: Vector Database Setup
|
|
**Status**: ❌ NOT IMPLEMENTED - FAISS não é utilizado
|
|
|
|
#### What Actually Exists:
|
|
```dart
|
|
// lib/core/services/vector_service.dart
|
|
class VectorService {
|
|
// Mock embedding generation using text hashing
|
|
static List<double> generateEmbedding(String text) {
|
|
final embedding = List<double>.filled(384, 0.0);
|
|
// Hash-based deterministic embeddings (not ML)
|
|
return embedding;
|
|
}
|
|
}
|
|
|
|
// lib/core/services/materials_rag_service.dart
|
|
class MaterialsRAGService {
|
|
// Keyword-based window search for PDFs
|
|
static Future<String> getContextForQuestion(...) async {
|
|
// 1. Extract PDF text with syncfusion_flutter_pdf
|
|
// 2. Find keyword matches
|
|
// 3. Return window of 1200 chars around match
|
|
// NO FAISS, NO VECTOR SEARCH, NO PYTHON
|
|
}
|
|
}
|
|
```
|
|
|
|
#### Technology Stack (Original - NOT USED):
|
|
```bash
|
|
# These dependencies DO NOT EXIST in the project:
|
|
# ❌ pip install faiss-cpu
|
|
# ❌ pip install sentence-transformers
|
|
# ❌ pip install numpy
|
|
# ❌ pip install nltk
|
|
# ❌ pip install spacy
|
|
```
|
|
|
|
#### Project Structure:
|
|
```
|
|
rag_engine/
|
|
├── src/
|
|
│ ├── __init__.py
|
|
│ ├── main.py
|
|
│ ├── config/
|
|
│ │ ├── __init__.py
|
|
│ │ ├── settings.py
|
|
│ │ └── constants.py
|
|
│ ├── core/
|
|
│ │ ├── __init__.py
|
|
│ │ ├── vector_store.py
|
|
│ │ ├── embeddings.py
|
|
│ │ ├── retriever.py
|
|
│ │ └── indexer.py
|
|
│ ├── preprocessing/
|
|
│ │ ├── __init__.py
|
|
│ │ ├── text_processor.py
|
|
│ │ ├── chunker.py
|
|
│ │ └── metadata_extractor.py
|
|
│ ├── retrieval/
|
|
│ │ ├── __init__.py
|
|
│ │ ├── keyword_search.py
|
|
│ │ ├── vector_search.py
|
|
│ │ ├── hybrid_search.py
|
|
│ │ └── ranker.py
|
|
│ ├── llm/
|
|
│ │ ├── __init__.py
|
|
│ │ ├── prompt_builder.py
|
|
│ │ ├── llm_client.py
|
|
│ │ └── response_processor.py
|
|
│ ├── utils/
|
|
│ │ ├── __init__.py
|
|
│ │ ├── logger.py
|
|
│ │ ├── validators.py
|
|
│ │ └── helpers.py
|
|
│ └── models/
|
|
│ ├── __init__.py
|
|
│ ├── document.py
|
|
│ ├── chunk.py
|
|
│ └── query.py
|
|
├── tests/
|
|
│ ├── __init__.py
|
|
│ ├── test_vector_store.py
|
|
│ ├── test_embeddings.py
|
|
│ ├── test_retriever.py
|
|
│ └── test_integration.py
|
|
├── data/
|
|
│ ├── models/ # Saved embedding models
|
|
│ ├── indices/ # FAISS index files
|
|
│ ├── chunks/ # Processed content chunks
|
|
│ └── temp/ # Temporary files
|
|
├── requirements.txt
|
|
├── setup.py
|
|
├── README.md
|
|
└── docker-compose.yml
|
|
```
|
|
|
|
#### Configuration:
|
|
|
|
**src/config/settings.py**
|
|
```python
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
import os
|
|
|
|
@dataclass
|
|
class VectorStoreConfig:
|
|
"""Configuration for vector storage"""
|
|
index_type: str = "IVF" # IVF, HNSW, Flat
|
|
dimension: int = 384 # all-MiniLM-L6-v2 dimension
|
|
nlist: int = 100 # Number of clusters for IVF
|
|
nprobe: int = 10 # Number of clusters to search
|
|
use_gpu: bool = False # GPU acceleration
|
|
metric: str = "INNER_PRODUCT" # Similarity metric
|
|
|
|
@dataclass
|
|
class EmbeddingConfig:
|
|
"""Configuration for text embeddings"""
|
|
model_name: str = "all-MiniLM-L6-v2"
|
|
batch_size: int = 32
|
|
max_length: int = 512
|
|
normalize_embeddings: bool = True
|
|
cache_embeddings: bool = True
|
|
model_cache_dir: str = "data/models"
|
|
|
|
@dataclass
|
|
class RetrievalConfig:
|
|
"""Configuration for retrieval pipeline"""
|
|
top_k: int = 10 # Number of results to retrieve
|
|
rerank: bool = True # Apply reranking
|
|
rerank_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
hybrid_alpha: float = 0.5 # Weight for hybrid search
|
|
min_similarity: float = 0.1 # Minimum similarity threshold
|
|
|
|
@dataclass
|
|
class ChunkingConfig:
|
|
"""Configuration for text chunking"""
|
|
chunk_size: int = 300 # Target chunk size in tokens
|
|
chunk_overlap: int = 50 # Overlap between chunks
|
|
min_chunk_size: int = 50 # Minimum chunk size
|
|
max_chunk_size: int = 800 # Maximum chunk size
|
|
respect_sentence_boundaries: bool = True
|
|
respect_paragraph_boundaries: bool = True
|
|
|
|
@dataclass
|
|
class RAGConfig:
|
|
"""Main RAG configuration"""
|
|
vector_store: VectorStoreConfig
|
|
embeddings: EmbeddingConfig
|
|
retrieval: RetrievalConfig
|
|
chunking: ChunkingConfig
|
|
|
|
# LLM settings
|
|
llm_provider: str = "anthropic" # anthropic, openai
|
|
llm_model: str = "claude-3-5-sonnet-20241022"
|
|
max_context_tokens: int = 4000
|
|
max_response_tokens: int = 500
|
|
temperature: float = 0.7
|
|
|
|
# Storage
|
|
firebase_project_id: str = os.getenv("FIREBASE_PROJECT_ID", "")
|
|
storage_bucket: str = os.getenv("STORAGE_BUCKET", "")
|
|
|
|
# Logging
|
|
log_level: str = "INFO"
|
|
log_file: str = "logs/rag_engine.log"
|
|
|
|
# Default configuration
|
|
DEFAULT_CONFIG = RAGConfig(
|
|
vector_store=VectorStoreConfig(),
|
|
embeddings=EmbeddingConfig(),
|
|
retrieval=RetrievalConfig(),
|
|
chunking=ChunkingConfig(),
|
|
)
|
|
```
|
|
|
|
---
|
|
|
|
### Task 1.2: Embedding Model Setup
|
|
**Priority**: Critical
|
|
**Estimated Time**: 6 hours
|
|
**Dependencies**: Task 1.1
|
|
|
|
#### Subtasks:
|
|
- [ ] Download and configure sentence-transformers model
|
|
- [ ] Create embedding service
|
|
- [ ] Implement batch processing
|
|
- [ ] Add embedding caching
|
|
- [ ] Create embedding quality checks
|
|
- [ ] Set up model versioning
|
|
|
|
#### Implementation:
|
|
|
|
**src/core/embeddings.py**
|
|
```python
|
|
import os
|
|
import pickle
|
|
import hashlib
|
|
from typing import List, Dict, Optional, Tuple
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
import torch
|
|
from pathlib import Path
|
|
from ..config.settings import EmbeddingConfig
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class EmbeddingService:
|
|
"""Service for generating and managing text embeddings"""
|
|
|
|
def __init__(self, config: EmbeddingConfig):
|
|
self.config = config
|
|
self.model = None
|
|
self.embedding_cache = {}
|
|
self.cache_file = Path("data/embeddings_cache.pkl")
|
|
self._load_model()
|
|
self._load_cache()
|
|
|
|
def _load_model(self):
|
|
"""Load the sentence transformer model"""
|
|
try:
|
|
logger.info(f"Loading embedding model: {self.config.model_name}")
|
|
|
|
# Create cache directory if it doesn't exist
|
|
os.makedirs(self.config.model_cache_dir, exist_ok=True)
|
|
|
|
# Load model with caching
|
|
self.model = SentenceTransformer(
|
|
self.config.model_name,
|
|
cache_folder=self.config.model_cache_dir
|
|
)
|
|
|
|
# Move to GPU if available and requested
|
|
if self.config.use_gpu and torch.cuda.is_available():
|
|
self.model = self.model.to('cuda')
|
|
logger.info("Model moved to GPU")
|
|
|
|
logger.info("Embedding model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load embedding model: {e}")
|
|
raise
|
|
|
|
def _load_cache(self):
|
|
"""Load embedding cache from disk"""
|
|
if self.config.cache_embeddings and self.cache_file.exists():
|
|
try:
|
|
with open(self.cache_file, 'rb') as f:
|
|
self.embedding_cache = pickle.load(f)
|
|
logger.info(f"Loaded {len(self.embedding_cache)} cached embeddings")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load cache: {e}")
|
|
self.embedding_cache = {}
|
|
|
|
def _save_cache(self):
|
|
"""Save embedding cache to disk"""
|
|
if self.config.cache_embeddings:
|
|
try:
|
|
os.makedirs(self.cache_file.parent, exist_ok=True)
|
|
with open(self.cache_file, 'wb') as f:
|
|
pickle.dump(self.embedding_cache, f)
|
|
logger.info("Embedding cache saved")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to save cache: {e}")
|
|
|
|
def _get_cache_key(self, text: str) -> str:
|
|
"""Generate cache key for text"""
|
|
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
|
|
|
def encode(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray:
|
|
"""
|
|
Encode texts to embeddings
|
|
|
|
Args:
|
|
texts: List of texts to encode
|
|
batch_size: Batch size for processing (overrides config)
|
|
|
|
Returns:
|
|
numpy array of embeddings
|
|
"""
|
|
if not texts:
|
|
return np.array([])
|
|
|
|
batch_size = batch_size or self.config.batch_size
|
|
embeddings = []
|
|
uncached_texts = []
|
|
uncached_indices = []
|
|
|
|
# Check cache first
|
|
if self.config.cache_embeddings:
|
|
for i, text in enumerate(texts):
|
|
cache_key = self._get_cache_key(text)
|
|
if cache_key in self.embedding_cache:
|
|
embeddings.append(self.embedding_cache[cache_key])
|
|
else:
|
|
uncached_texts.append(text)
|
|
uncached_indices.append(i)
|
|
|
|
# Encode uncached texts
|
|
if uncached_texts:
|
|
try:
|
|
# Process in batches
|
|
batch_embeddings = []
|
|
for i in range(0, len(uncached_texts), batch_size):
|
|
batch = uncached_texts[i:i + batch_size]
|
|
|
|
# Truncate if necessary
|
|
truncated_batch = [
|
|
text[:self.config.max_length]
|
|
for text in batch
|
|
]
|
|
|
|
# Generate embeddings
|
|
batch_emb = self.model.encode(
|
|
truncated_batch,
|
|
normalize_embeddings=self.config.normalize_embeddings,
|
|
convert_to_numpy=True,
|
|
show_progress_bar=False
|
|
)
|
|
batch_embeddings.append(batch_emb)
|
|
|
|
# Combine batch results
|
|
if batch_embeddings:
|
|
new_embeddings = np.vstack(batch_embeddings)
|
|
|
|
# Update cache
|
|
if self.config.cache_embeddings:
|
|
for text, emb in zip(uncached_texts, new_embeddings):
|
|
cache_key = self._get_cache_key(text)
|
|
self.embedding_cache[cache_key] = emb
|
|
|
|
# Insert new embeddings in correct positions
|
|
for idx, emb in zip(uncached_indices, new_embeddings):
|
|
embeddings.insert(idx, emb)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to encode texts: {e}")
|
|
raise
|
|
|
|
# Convert to numpy array
|
|
result = np.array(embeddings) if embeddings else np.array([])
|
|
|
|
# Ensure we have the right number of embeddings
|
|
if len(result) != len(texts):
|
|
logger.warning(f"Embedding count mismatch: {len(result)} vs {len(texts)}")
|
|
|
|
return result
|
|
|
|
def encode_single(self, text: str) -> np.ndarray:
|
|
"""Encode a single text"""
|
|
return self.encode([text])[0]
|
|
|
|
def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
|
"""Calculate cosine similarity between two embeddings"""
|
|
if self.config.normalize_embeddings:
|
|
# For normalized embeddings, dot product equals cosine similarity
|
|
return float(np.dot(embedding1, embedding2))
|
|
else:
|
|
# Manual cosine similarity calculation
|
|
dot_product = np.dot(embedding1, embedding2)
|
|
norm1 = np.linalg.norm(embedding1)
|
|
norm2 = np.linalg.norm(embedding2)
|
|
return float(dot_product / (norm1 * norm2))
|
|
|
|
def find_similar(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
candidate_embeddings: np.ndarray,
|
|
top_k: int = 10
|
|
) -> List[Tuple[int, float]]:
|
|
"""
|
|
Find most similar embeddings to query
|
|
|
|
Args:
|
|
query_embedding: Query embedding
|
|
candidate_embeddings: Array of candidate embeddings
|
|
top_k: Number of top results to return
|
|
|
|
Returns:
|
|
List of (index, similarity_score) tuples
|
|
"""
|
|
if len(candidate_embeddings) == 0:
|
|
return []
|
|
|
|
# Calculate similarities
|
|
if self.config.normalize_embeddings:
|
|
# For normalized embeddings, use matrix multiplication
|
|
similarities = np.dot(candidate_embeddings, query_embedding)
|
|
else:
|
|
# Manual cosine similarity
|
|
similarities = np.array([
|
|
self.similarity(query_embedding, emb)
|
|
for emb in candidate_embeddings
|
|
])
|
|
|
|
# Get top-k indices and scores
|
|
top_indices = np.argsort(similarities)[::-1][:top_k]
|
|
top_scores = similarities[top_indices]
|
|
|
|
return [(int(idx), float(score)) for idx, score in zip(top_indices, top_scores)]
|
|
|
|
def validate_embedding(self, embedding: np.ndarray) -> bool:
|
|
"""Validate embedding quality"""
|
|
if not isinstance(embedding, np.ndarray):
|
|
return False
|
|
|
|
if embedding.size == 0:
|
|
return False
|
|
|
|
if np.isnan(embedding).any() or np.isinf(embedding).any():
|
|
return False
|
|
|
|
expected_dim = self.model.get_sentence_embedding_dimension()
|
|
if embedding.shape[0] != expected_dim:
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_embedding_stats(self, embeddings: np.ndarray) -> Dict:
|
|
"""Get statistics about embeddings"""
|
|
if len(embeddings) == 0:
|
|
return {}
|
|
|
|
return {
|
|
"count": len(embeddings),
|
|
"dimension": embeddings.shape[1],
|
|
"mean_norm": np.mean(np.linalg.norm(embeddings, axis=1)),
|
|
"std_norm": np.std(np.linalg.norm(embeddings, axis=1)),
|
|
"has_nan": np.isnan(embeddings).any(),
|
|
"has_inf": np.isinf(embeddings).any(),
|
|
}
|
|
|
|
def clear_cache(self):
|
|
"""Clear embedding cache"""
|
|
self.embedding_cache.clear()
|
|
if self.cache_file.exists():
|
|
self.cache_file.unlink()
|
|
logger.info("Embedding cache cleared")
|
|
|
|
def save_cache(self):
|
|
"""Manually save cache to disk"""
|
|
self._save_cache()
|
|
|
|
def __del__(self):
|
|
"""Cleanup on deletion"""
|
|
try:
|
|
self._save_cache()
|
|
except:
|
|
pass
|
|
```
|
|
|
|
---
|
|
|
|
## 🔍 WEEK 3-4: CONTENT PROCESSING & CHUNKING
|
|
|
|
### Task 2.1: Text Preprocessing
|
|
**Priority**: High
|
|
**Estimated Time**: 10 hours
|
|
**Dependencies**: Task 1.2
|
|
|
|
#### Subtasks:
|
|
- [ ] Implement text cleaning
|
|
- [ ] Add language detection
|
|
- [ ] Create tokenization utilities
|
|
- [ ] Build text normalization
|
|
- [ ] Add format detection (PDF, DOCX, etc.)
|
|
- [ ] Implement quality checks
|
|
|
|
#### Implementation:
|
|
|
|
**src/preprocessing/text_processor.py**
|
|
```python
|
|
import re
|
|
import string
|
|
from typing import List, Dict, Optional, Tuple
|
|
import spacy
|
|
from langdetect import detect
|
|
from ..config.settings import ChunkingConfig
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class TextProcessor:
|
|
"""Service for preprocessing and cleaning text"""
|
|
|
|
def __init__(self, config: ChunkingConfig):
|
|
self.config = config
|
|
self.nlp = None
|
|
self._load_spacy()
|
|
|
|
def _load_spacy(self):
|
|
"""Load spaCy model for text processing"""
|
|
try:
|
|
self.nlp = spacy.load("en_core_web_sm")
|
|
logger.info("spaCy model loaded successfully")
|
|
except OSError:
|
|
logger.warning("spaCy model not found, some features will be limited")
|
|
self.nlp = None
|
|
|
|
def clean_text(self, text: str) -> str:
|
|
"""
|
|
Clean and normalize text
|
|
|
|
Args:
|
|
text: Raw text to clean
|
|
|
|
Returns:
|
|
Cleaned text
|
|
"""
|
|
if not text or not text.strip():
|
|
return ""
|
|
|
|
# Remove excessive whitespace
|
|
text = re.sub(r'\s+', ' ', text)
|
|
|
|
# Remove control characters
|
|
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]', '', text)
|
|
|
|
# Normalize quotes
|
|
text = re.sub(r'[""'']', '"', text)
|
|
text = re.sub(r[''''''], "'", text)
|
|
|
|
# Remove page numbers and headers/footers patterns
|
|
text = re.sub(r'\n\s*Page\s*\d+\s*\n', '\n', text, flags=re.IGNORECASE)
|
|
text = re.sub(r'\n\s*\d+\s*\n', '\n', text)
|
|
|
|
# Remove bullet points numbering
|
|
text = re.sub(r'^\s*[\d\w][\).\.\-]\s+', '', text, flags=re.MULTILINE)
|
|
|
|
# Clean up extra newlines
|
|
text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
|
|
|
|
return text.strip()
|
|
|
|
def detect_language(self, text: str) -> str:
|
|
"""Detect text language"""
|
|
try:
|
|
if len(text) < 50:
|
|
return "en" # Default for short texts
|
|
|
|
lang = detect(text)
|
|
return lang if lang in ['en', 'pt', 'es', 'fr', 'de'] else "en"
|
|
except:
|
|
return "en"
|
|
|
|
def tokenize_sentences(self, text: str) -> List[str]:
|
|
"""Split text into sentences"""
|
|
if self.nlp:
|
|
doc = self.nlp(text)
|
|
return [sent.text.strip() for sent in doc.sents if sent.text.strip()]
|
|
else:
|
|
# Fallback using regex
|
|
sentences = re.split(r'[.!?]+', text)
|
|
return [s.strip() for s in sentences if s.strip()]
|
|
|
|
def tokenize_paragraphs(self, text: str) -> List[str]:
|
|
"""Split text into paragraphs"""
|
|
paragraphs = re.split(r'\n\s*\n', text)
|
|
return [p.strip() for p in paragraphs if p.strip()]
|
|
|
|
def extract_keywords(self, text: str, max_keywords: int = 10) -> List[str]:
|
|
"""Extract keywords from text"""
|
|
if self.nlp:
|
|
doc = self.nlp(text)
|
|
|
|
# Extract noun phrases and proper nouns
|
|
keywords = []
|
|
|
|
# Add proper nouns
|
|
for ent in doc.ents:
|
|
if ent.label_ in ['PERSON', 'ORG', 'GPE', 'PRODUCT']:
|
|
keywords.append(ent.text)
|
|
|
|
# Add noun chunks
|
|
for chunk in doc.noun_chunks:
|
|
if len(chunk.text.split()) <= 3: # Keep short phrases
|
|
keywords.append(chunk.text)
|
|
|
|
# Remove duplicates and limit
|
|
unique_keywords = list(dict.fromkeys(keywords))
|
|
return unique_keywords[:max_keywords]
|
|
else:
|
|
# Fallback: extract important words
|
|
words = re.findall(r'\b[A-Z][a-z]+\b', text)
|
|
return list(dict.fromkeys(words))[:max_keywords]
|
|
|
|
def extract_math_expressions(self, text: str) -> List[str]:
|
|
"""Extract mathematical expressions from text"""
|
|
# Common math patterns
|
|
math_patterns = [
|
|
r'\$[^$]+\$', # LaTeX math mode
|
|
r'\\[a-zA-Z]+\{[^}]+\}', # LaTeX commands
|
|
r'\b[a-zA-Z]+\s*=\s*[^,;\n]+', # Equations
|
|
r'\b(?:sin|cos|tan|log|ln|sqrt)\([^)]+\)', # Functions
|
|
r'\b\d+\s*[+\-*/]\s*\d+', # Simple arithmetic
|
|
]
|
|
|
|
expressions = []
|
|
for pattern in math_patterns:
|
|
matches = re.findall(pattern, text)
|
|
expressions.extend(matches)
|
|
|
|
return expressions
|
|
|
|
def extract_code_blocks(self, text: str) -> List[str]:
|
|
"""Extract code blocks from text"""
|
|
code_patterns = [
|
|
r'```[\s\S]*?```', # Markdown code blocks
|
|
r'`[^`]+`', # Inline code
|
|
r'(?s)def\s+\w+\([^)]*\):.*?(?=\n\w|\Z)', # Python functions
|
|
]
|
|
|
|
code_blocks = []
|
|
for pattern in code_patterns:
|
|
matches = re.findall(pattern, text)
|
|
code_blocks.extend(matches)
|
|
|
|
return code_blocks
|
|
|
|
def assess_readability(self, text: str) -> Dict:
|
|
"""Assess text readability metrics"""
|
|
sentences = self.tokenize_sentences(text)
|
|
words = text.split()
|
|
syllables = sum(self._count_syllables(word) for word in words)
|
|
|
|
if len(sentences) == 0 or len(words) == 0:
|
|
return {"flesch_score": 0, "grade_level": 12}
|
|
|
|
# Flesch Reading Ease
|
|
avg_sentence_length = len(words) / len(sentences)
|
|
avg_syllables_per_word = syllables / len(words)
|
|
|
|
flesch_score = 206.835 - (1.015 * avg_sentence_length) - (84.6 * avg_syllables_per_word)
|
|
|
|
# Approximate grade level
|
|
if flesch_score >= 90:
|
|
grade_level = 5
|
|
elif flesch_score >= 80:
|
|
grade_level = 6
|
|
elif flesch_score >= 70:
|
|
grade_level = 7
|
|
elif flesch_score >= 60:
|
|
grade_level = 8
|
|
elif flesch_score >= 50:
|
|
grade_level = 9
|
|
elif flesch_score >= 40:
|
|
grade_level = 10
|
|
elif flesch_score >= 30:
|
|
grade_level = 11
|
|
else:
|
|
grade_level = 12
|
|
|
|
return {
|
|
"flesch_score": max(0, min(100, flesch_score)),
|
|
"grade_level": grade_level,
|
|
"avg_sentence_length": avg_sentence_length,
|
|
"avg_syllables_per_word": avg_syllables_per_word,
|
|
}
|
|
|
|
def _count_syllables(self, word: str) -> int:
|
|
"""Count syllables in a word (simplified)"""
|
|
word = word.lower()
|
|
vowels = "aeiouy"
|
|
syllable_count = 0
|
|
prev_char_was_vowel = False
|
|
|
|
for char in word:
|
|
is_vowel = char in vowels
|
|
if is_vowel and not prev_char_was_vowel:
|
|
syllable_count += 1
|
|
prev_char_was_vowel = is_vowel
|
|
|
|
# Adjust for silent 'e'
|
|
if word.endswith('e') and syllable_count > 1:
|
|
syllable_count -= 1
|
|
|
|
return max(1, syllable_count)
|
|
|
|
def extract_structure(self, text: str) -> Dict:
|
|
"""Extract document structure information"""
|
|
structure = {
|
|
"headings": [],
|
|
"lists": [],
|
|
"tables": [],
|
|
"sections": [],
|
|
}
|
|
|
|
# Extract headings (markdown-style)
|
|
heading_pattern = r'^(#{1,6})\s+(.+)$'
|
|
for match in re.finditer(heading_pattern, text, re.MULTILINE):
|
|
level = len(match.group(1))
|
|
title = match.group(2).strip()
|
|
structure["headings"].append({
|
|
"level": level,
|
|
"title": title,
|
|
"position": match.start(),
|
|
})
|
|
|
|
# Extract lists
|
|
list_patterns = [
|
|
r'^\s*[\-\*\+]\s+(.+)$', # Bullet lists
|
|
r'^\s*\d+\.\s+(.+)$', # Numbered lists
|
|
]
|
|
|
|
for pattern in list_patterns:
|
|
for match in re.finditer(pattern, text, re.MULTILINE):
|
|
structure["lists"].append({
|
|
"content": match.group(1).strip(),
|
|
"position": match.start(),
|
|
})
|
|
|
|
# Extract sections based on headings
|
|
if structure["headings"]:
|
|
for i, heading in enumerate(structure["headings"]):
|
|
start_pos = heading["position"]
|
|
end_pos = (structure["headings"][i + 1]["position"]
|
|
if i + 1 < len(structure["headings"])
|
|
else len(text))
|
|
|
|
section_text = text[start_pos:end_pos].strip()
|
|
structure["sections"].append({
|
|
"heading": heading,
|
|
"content": section_text,
|
|
"word_count": len(section_text.split()),
|
|
})
|
|
|
|
return structure
|
|
|
|
def validate_text_quality(self, text: str) -> Dict:
|
|
"""Validate text quality and return metrics"""
|
|
if not text or len(text.strip()) < 50:
|
|
return {
|
|
"is_valid": False,
|
|
"reason": "Text too short",
|
|
"score": 0.0,
|
|
}
|
|
|
|
# Quality metrics
|
|
word_count = len(text.split())
|
|
sentence_count = len(self.tokenize_sentences(text))
|
|
|
|
# Check for minimum requirements
|
|
if word_count < 10:
|
|
return {
|
|
"is_valid": False,
|
|
"reason": "Too few words",
|
|
"score": 0.1,
|
|
}
|
|
|
|
if sentence_count < 2:
|
|
return {
|
|
"is_valid": False,
|
|
"reason": "Too few sentences",
|
|
"score": 0.2,
|
|
}
|
|
|
|
# Calculate quality score
|
|
readability = self.assess_readability(text)
|
|
structure = self.extract_structure(text)
|
|
|
|
quality_score = 0.5 # Base score
|
|
|
|
# Readability bonus
|
|
if readability["flesch_score"] > 60:
|
|
quality_score += 0.2
|
|
elif readability["flesch_score"] > 40:
|
|
quality_score += 0.1
|
|
|
|
# Structure bonus
|
|
if structure["headings"]:
|
|
quality_score += 0.1
|
|
|
|
# Length bonus (appropriate length)
|
|
if 50 <= word_count <= 500:
|
|
quality_score += 0.1
|
|
elif word_count <= 1000:
|
|
quality_score += 0.05
|
|
|
|
# Content variety bonus
|
|
has_math = bool(self.extract_math_expressions(text))
|
|
has_code = bool(self.extract_code_blocks(text))
|
|
if has_math or has_code:
|
|
quality_score += 0.05
|
|
|
|
quality_score = min(1.0, quality_score)
|
|
|
|
return {
|
|
"is_valid": quality_score >= 0.3,
|
|
"reason": "Quality check passed" if quality_score >= 0.3 else "Low quality",
|
|
"score": quality_score,
|
|
"metrics": {
|
|
"word_count": word_count,
|
|
"sentence_count": sentence_count,
|
|
"readability": readability,
|
|
"has_structure": bool(structure["headings"]),
|
|
"has_math": has_math,
|
|
"has_code": has_code,
|
|
}
|
|
}
|
|
```
|
|
|
|
---
|
|
|
|
### Task 2.2: Intelligent Chunking
|
|
**Priority**: High
|
|
**Estimated Time**: 12 hours
|
|
**Dependencies**: Task 2.1
|
|
|
|
#### Subtasks:
|
|
- [ ] Implement semantic chunking
|
|
- [ ] Add respect for boundaries
|
|
- [ ] Create overlap management
|
|
- [ ] Build chunk validation
|
|
- [ ] Add metadata extraction
|
|
- [ ] Implement chunk quality scoring
|
|
|
|
#### Implementation:
|
|
|
|
**src/preprocessing/chunker.py**
|
|
```python
|
|
import re
|
|
from typing import List, Dict, Optional, Tuple
|
|
import numpy as np
|
|
from ..config.settings import ChunkingConfig
|
|
from ..models.chunk import Chunk, ChunkMetadata
|
|
from .text_processor import TextProcessor
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class IntelligentChunker:
|
|
"""Service for intelligent text chunking"""
|
|
|
|
def __init__(self, config: ChunkingConfig):
|
|
self.config = config
|
|
self.text_processor = TextProcessor(config)
|
|
|
|
def chunk_document(
|
|
self,
|
|
text: str,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict] = None
|
|
) -> List[Chunk]:
|
|
"""
|
|
Chunk a document into intelligent pieces
|
|
|
|
Args:
|
|
text: Document text
|
|
document_metadata: Document-level metadata
|
|
chunk_metadata: Default chunk metadata
|
|
|
|
Returns:
|
|
List of chunks
|
|
"""
|
|
# Clean and preprocess text
|
|
cleaned_text = self.text_processor.clean_text(text)
|
|
|
|
# Extract document structure
|
|
structure = self.text_processor.extract_structure(cleaned_text)
|
|
|
|
# Determine chunking strategy
|
|
chunks = []
|
|
|
|
if structure["sections"] and len(structure["sections"]) > 1:
|
|
# Use section-based chunking
|
|
chunks = self._chunk_by_sections(
|
|
cleaned_text,
|
|
structure,
|
|
document_metadata,
|
|
chunk_metadata
|
|
)
|
|
else:
|
|
# Use sliding window chunking
|
|
chunks = self._chunk_by_sliding_window(
|
|
cleaned_text,
|
|
document_metadata,
|
|
chunk_metadata
|
|
)
|
|
|
|
# Validate and filter chunks
|
|
valid_chunks = []
|
|
for chunk in chunks:
|
|
validation = self._validate_chunk(chunk)
|
|
if validation["is_valid"]:
|
|
chunk.quality_score = validation["score"]
|
|
valid_chunks.append(chunk)
|
|
else:
|
|
logger.warning(f"Invalid chunk: {validation['reason']}")
|
|
|
|
logger.info(f"Created {len(valid_chunks)} valid chunks from document")
|
|
return valid_chunks
|
|
|
|
def _chunk_by_sections(
|
|
self,
|
|
text: str,
|
|
structure: Dict,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict]
|
|
) -> List[Chunk]:
|
|
"""Chunk by document sections"""
|
|
chunks = []
|
|
|
|
for section in structure["sections"]:
|
|
section_text = section["content"]
|
|
section_heading = section["heading"]["title"]
|
|
|
|
# If section is too large, further chunk it
|
|
if self._is_too_large(section_text):
|
|
section_chunks = self._chunk_large_section(
|
|
section_text,
|
|
section_heading,
|
|
document_metadata,
|
|
chunk_metadata
|
|
)
|
|
chunks.extend(section_chunks)
|
|
else:
|
|
# Create single chunk for section
|
|
chunk = self._create_chunk(
|
|
section_text,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading=section_heading
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
def _chunk_by_sliding_window(
|
|
self,
|
|
text: str,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict]
|
|
) -> List[Chunk]:
|
|
"""Chunk using sliding window approach"""
|
|
chunks = []
|
|
|
|
# Get sentences for boundary awareness
|
|
sentences = self.text_processor.tokenize_sentences(text)
|
|
|
|
if not sentences:
|
|
return chunks
|
|
|
|
# Build chunks with overlap
|
|
current_chunk_sentences = []
|
|
current_chunk_length = 0
|
|
|
|
for i, sentence in enumerate(sentences):
|
|
sentence_length = len(sentence.split())
|
|
|
|
# Check if adding sentence exceeds chunk size
|
|
if current_chunk_length + sentence_length > self.config.chunk_size and current_chunk_sentences:
|
|
# Create chunk from accumulated sentences
|
|
chunk_text = " ".join(current_chunk_sentences)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
# Start new chunk with overlap
|
|
overlap_sentences = self._get_overlap_sentences(current_chunk_sentences)
|
|
current_chunk_sentences = overlap_sentences + [sentence]
|
|
current_chunk_length = sum(len(s.split()) for s in current_chunk_sentences)
|
|
else:
|
|
current_chunk_sentences.append(sentence)
|
|
current_chunk_length += sentence_length
|
|
|
|
# Add final chunk if it has content
|
|
if current_chunk_sentences:
|
|
chunk_text = " ".join(current_chunk_sentences)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
def _chunk_large_section(
|
|
self,
|
|
section_text: str,
|
|
section_heading: str,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict]
|
|
) -> List[Chunk]:
|
|
"""Chunk a large section into smaller pieces"""
|
|
chunks = []
|
|
|
|
# Split section into paragraphs
|
|
paragraphs = self.text_processor.tokenize_paragraphs(section_text)
|
|
|
|
current_chunk_paragraphs = []
|
|
current_chunk_length = 0
|
|
|
|
for paragraph in paragraphs:
|
|
paragraph_length = len(paragraph.split())
|
|
|
|
# Check if paragraph is too large for a single chunk
|
|
if paragraph_length > self.config.max_chunk_size:
|
|
# Process current chunk if it has content
|
|
if current_chunk_paragraphs:
|
|
chunk_text = "\n\n".join(current_chunk_paragraphs)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading=section_heading
|
|
)
|
|
chunks.append(chunk)
|
|
current_chunk_paragraphs = []
|
|
current_chunk_length = 0
|
|
|
|
# Chunk the large paragraph
|
|
paragraph_chunks = self._chunk_large_paragraph(
|
|
paragraph,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading
|
|
)
|
|
chunks.extend(paragraph_chunks)
|
|
else:
|
|
# Check if adding paragraph exceeds chunk size
|
|
if current_chunk_length + paragraph_length > self.config.chunk_size and current_chunk_paragraphs:
|
|
# Create chunk
|
|
chunk_text = "\n\n".join(current_chunk_paragraphs)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading=section_heading
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
# Start new chunk with overlap
|
|
overlap_paragraphs = self._get_overlap_paragraphs(current_chunk_paragraphs)
|
|
current_chunk_paragraphs = overlap_paragraphs + [paragraph]
|
|
current_chunk_length = sum(len(p.split()) for p in current_chunk_paragraphs)
|
|
else:
|
|
current_chunk_paragraphs.append(paragraph)
|
|
current_chunk_length += paragraph_length
|
|
|
|
# Add final chunk
|
|
if current_chunk_paragraphs:
|
|
chunk_text = "\n\n".join(current_chunk_paragraphs)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading=section_heading
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
def _chunk_large_paragraph(
|
|
self,
|
|
paragraph: str,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict],
|
|
section_heading: Optional[str] = None
|
|
) -> List[Chunk]:
|
|
"""Chunk a very large paragraph"""
|
|
chunks = []
|
|
|
|
# Split by sentences
|
|
sentences = self.text_processor.tokenize_sentences(paragraph)
|
|
|
|
current_chunk_sentences = []
|
|
current_chunk_length = 0
|
|
|
|
for sentence in sentences:
|
|
sentence_length = len(sentence.split())
|
|
|
|
if current_chunk_length + sentence_length > self.config.chunk_size and current_chunk_sentences:
|
|
# Create chunk
|
|
chunk_text = " ".join(current_chunk_sentences)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading=section_heading
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
# Start new chunk with overlap
|
|
overlap_sentences = self._get_overlap_sentences(current_chunk_sentences)
|
|
current_chunk_sentences = overlap_sentences + [sentence]
|
|
current_chunk_length = sum(len(s.split()) for s in current_chunk_sentences)
|
|
else:
|
|
current_chunk_sentences.append(sentence)
|
|
current_chunk_length += sentence_length
|
|
|
|
# Add final chunk
|
|
if current_chunk_sentences:
|
|
chunk_text = " ".join(current_chunk_sentences)
|
|
chunk = self._create_chunk(
|
|
chunk_text,
|
|
document_metadata,
|
|
chunk_metadata,
|
|
section_heading=section_heading
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
def _get_overlap_sentences(self, sentences: List[str]) -> List[str]:
|
|
"""Get overlap sentences for next chunk"""
|
|
if not sentences:
|
|
return []
|
|
|
|
# Calculate overlap based on word count
|
|
total_words = sum(len(s.split()) for s in sentences)
|
|
overlap_words = min(self.config.chunk_overlap, total_words // 2)
|
|
|
|
# Get sentences from the end that contain the overlap words
|
|
overlap_sentences = []
|
|
word_count = 0
|
|
|
|
for sentence in reversed(sentences):
|
|
sentence_words = len(sentence.split())
|
|
if word_count + sentence_words <= overlap_words:
|
|
overlap_sentences.insert(0, sentence)
|
|
word_count += sentence_words
|
|
else:
|
|
break
|
|
|
|
return overlap_sentences
|
|
|
|
def _get_overlap_paragraphs(self, paragraphs: List[str]) -> List[str]:
|
|
"""Get overlap paragraphs for next chunk"""
|
|
if not paragraphs:
|
|
return []
|
|
|
|
# Take last 1-2 paragraphs for overlap
|
|
overlap_count = min(2, len(paragraphs) // 2)
|
|
return paragraphs[-overlap_count:] if overlap_count > 0 else []
|
|
|
|
def _is_too_large(self, text: str) -> bool:
|
|
"""Check if text is too large for a single chunk"""
|
|
word_count = len(text.split())
|
|
return word_count > self.config.chunk_size
|
|
|
|
def _create_chunk(
|
|
self,
|
|
text: str,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict],
|
|
section_heading: Optional[str] = None
|
|
) -> Chunk:
|
|
"""Create a chunk with metadata"""
|
|
|
|
# Generate chunk ID
|
|
chunk_id = f"chunk_{hash(text) % 1000000:06d}"
|
|
|
|
# Extract metadata from text
|
|
keywords = self.text_processor.extract_keywords(text)
|
|
math_expressions = self.text_processor.extract_math_expressions(text)
|
|
code_blocks = self.text_processor.extract_code_blocks(text)
|
|
readability = self.text_processor.assess_readability(text)
|
|
|
|
# Create chunk metadata
|
|
metadata = ChunkMetadata(
|
|
word_count=len(text.split()),
|
|
character_count=len(text),
|
|
sentence_count=len(self.text_processor.tokenize_sentences(text)),
|
|
paragraph_count=len(self.text_processor.tokenize_paragraphs(text)),
|
|
keywords=keywords,
|
|
math_expressions=math_expressions,
|
|
code_blocks=code_blocks,
|
|
readability_score=readability["flesch_score"],
|
|
grade_level=readability["grade_level"],
|
|
language=self.text_processor.detect_language(text),
|
|
has_structure=bool(section_heading),
|
|
section_heading=section_heading,
|
|
**(chunk_metadata or {})
|
|
)
|
|
|
|
# Create chunk
|
|
chunk = Chunk(
|
|
id=chunk_id,
|
|
text=text,
|
|
document_metadata=document_metadata,
|
|
metadata=metadata
|
|
)
|
|
|
|
return chunk
|
|
|
|
def _validate_chunk(self, chunk: Chunk) -> Dict:
|
|
"""Validate chunk quality"""
|
|
|
|
# Check minimum length
|
|
if chunk.metadata.word_count < self.config.min_chunk_size:
|
|
return {
|
|
"is_valid": False,
|
|
"reason": f"Chunk too short: {chunk.metadata.word_count} words",
|
|
"score": 0.1,
|
|
}
|
|
|
|
# Check maximum length
|
|
if chunk.metadata.word_count > self.config.max_chunk_size:
|
|
return {
|
|
"is_valid": False,
|
|
"reason": f"Chunk too long: {chunk.metadata.word_count} words",
|
|
"score": 0.1,
|
|
}
|
|
|
|
# Check for meaningful content
|
|
if not chunk.text.strip() or len(chunk.text.strip()) < 20:
|
|
return {
|
|
"is_valid": False,
|
|
"reason": "Chunk contains insufficient content",
|
|
"score": 0.0,
|
|
}
|
|
|
|
# Calculate quality score
|
|
quality_score = 0.5 # Base score
|
|
|
|
# Length appropriateness
|
|
optimal_length = self.config.chunk_size
|
|
length_diff = abs(chunk.metadata.word_count - optimal_length)
|
|
length_score = max(0, 1 - (length_diff / optimal_length))
|
|
quality_score += length_score * 0.3
|
|
|
|
# Readability
|
|
if chunk.metadata.readability_score > 60:
|
|
quality_score += 0.1
|
|
elif chunk.metadata.readability_score > 40:
|
|
quality_score += 0.05
|
|
|
|
# Content richness
|
|
if chunk.metadata.keywords:
|
|
quality_score += 0.05
|
|
|
|
if chunk.metadata.math_expressions or chunk.metadata.code_blocks:
|
|
quality_score += 0.05
|
|
|
|
# Structure
|
|
if chunk.metadata.section_heading:
|
|
quality_score += 0.05
|
|
|
|
# Sentence completeness
|
|
if chunk.metadata.sentence_count >= 2:
|
|
quality_score += 0.05
|
|
|
|
quality_score = min(1.0, quality_score)
|
|
|
|
return {
|
|
"is_valid": quality_score >= 0.3,
|
|
"reason": "Quality check passed" if quality_score >= 0.3 else "Low quality score",
|
|
"score": quality_score,
|
|
}
|
|
|
|
def get_chunking_stats(self, chunks: List[Chunk]) -> Dict:
|
|
"""Get statistics about chunking results"""
|
|
if not chunks:
|
|
return {}
|
|
|
|
word_counts = [chunk.metadata.word_count for chunk in chunks]
|
|
quality_scores = [chunk.quality_score for chunk in chunks]
|
|
|
|
return {
|
|
"total_chunks": len(chunks),
|
|
"avg_word_count": np.mean(word_counts),
|
|
"min_word_count": min(word_counts),
|
|
"max_word_count": max(word_counts),
|
|
"avg_quality_score": np.mean(quality_scores),
|
|
"min_quality_score": min(quality_scores),
|
|
"max_quality_score": max(quality_scores),
|
|
"total_words": sum(word_counts),
|
|
"chunks_with_headings": sum(1 for c in chunks if c.metadata.section_heading),
|
|
"chunks_with_math": sum(1 for c in chunks if c.metadata.math_expressions),
|
|
"chunks_with_code": sum(1 for c in chunks if c.metadata.code_blocks),
|
|
}
|
|
```
|
|
|
|
---
|
|
|
|
## 🔎 WEEK 5-6: RETRIEVAL SYSTEM
|
|
|
|
### Task 3.1: Vector Search Implementation
|
|
**Priority**: High
|
|
**Estimated Time**: 14 hours
|
|
**Dependencies**: Task 2.2
|
|
|
|
#### Subtasks:
|
|
- [ ] Implement FAISS vector store
|
|
- [ ] Create indexing pipeline
|
|
- [ ] Add similarity search
|
|
- [ ] Implement batch operations
|
|
- [ ] Add index optimization
|
|
- [ ] Create search performance monitoring
|
|
|
|
#### Implementation:
|
|
|
|
**src/core/vector_store.py**
|
|
```python
|
|
import os
|
|
import pickle
|
|
import numpy as np
|
|
import faiss
|
|
from typing import List, Dict, Tuple, Optional, Any
|
|
from pathlib import Path
|
|
from ..config.settings import VectorStoreConfig
|
|
from ..models.chunk import Chunk
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class VectorStore:
|
|
"""FAISS-based vector store for efficient similarity search"""
|
|
|
|
def __init__(self, config: VectorStoreConfig, index_path: str = "data/indices"):
|
|
self.config = config
|
|
self.index_path = Path(index_path)
|
|
self.index_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.index = None
|
|
self.chunk_mapping = {} # Maps index position to chunk ID
|
|
self.embedding_dim = config.dimension
|
|
|
|
self._initialize_index()
|
|
|
|
def _initialize_index(self):
|
|
"""Initialize FAISS index based on configuration"""
|
|
try:
|
|
if self.config.index_type == "IVF":
|
|
# Inverted File Index
|
|
quantizer = faiss.IndexFlatIP(self.embedding_dim)
|
|
self.index = faiss.IndexIVFFlat(
|
|
quantizer,
|
|
self.embedding_dim,
|
|
self.config.nlist,
|
|
faiss.METRIC_INNER_PRODUCT
|
|
)
|
|
logger.info(f"Created IVF index with {self.config.nlist} clusters")
|
|
|
|
elif self.config.index_type == "HNSW":
|
|
# Hierarchical Navigable Small World
|
|
self.index = faiss.IndexHNSWFlat(self.embedding_dim, 32)
|
|
logger.info("Created HNSW index")
|
|
|
|
else:
|
|
# Flat Index (exact search)
|
|
self.index = faiss.IndexFlatIP(self.embedding_dim)
|
|
logger.info("Created Flat index")
|
|
|
|
# Move to GPU if available and requested
|
|
if self.config.use_gpu and faiss.get_num_gpus() > 0:
|
|
try:
|
|
res = faiss.StandardGpuResources()
|
|
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
|
|
logger.info("Index moved to GPU")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to move index to GPU: {e}")
|
|
|
|
# Try to load existing index
|
|
self._load_index()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize index: {e}")
|
|
raise
|
|
|
|
def add_embeddings(self, embeddings: np.ndarray, chunk_ids: List[str]):
|
|
"""
|
|
Add embeddings to the index
|
|
|
|
Args:
|
|
embeddings: Array of embeddings to add
|
|
chunk_ids: Corresponding chunk IDs
|
|
"""
|
|
if len(embeddings) != len(chunk_ids):
|
|
raise ValueError("Number of embeddings must match number of chunk IDs")
|
|
|
|
if len(embeddings) == 0:
|
|
return
|
|
|
|
try:
|
|
# Ensure embeddings are the right shape and type
|
|
embeddings = embeddings.astype(np.float32)
|
|
|
|
# Normalize embeddings for inner product similarity
|
|
if self.config.metric == "INNER_PRODUCT":
|
|
faiss.normalize_L2(embeddings)
|
|
|
|
# Get current index size
|
|
current_size = self.index.ntotal
|
|
|
|
# Add embeddings to index
|
|
self.index.add(embeddings)
|
|
|
|
# Update chunk mapping
|
|
for i, chunk_id in enumerate(chunk_ids):
|
|
self.chunk_mapping[current_size + i] = chunk_id
|
|
|
|
logger.info(f"Added {len(embeddings)} embeddings to index (total: {self.index.ntotal})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add embeddings: {e}")
|
|
raise
|
|
|
|
def search(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
top_k: int = 10,
|
|
nprobe: Optional[int] = None
|
|
) -> List[Tuple[str, float]]:
|
|
"""
|
|
Search for similar embeddings
|
|
|
|
Args:
|
|
query_embedding: Query embedding
|
|
top_k: Number of results to return
|
|
nprobe: Number of clusters to probe (for IVF index)
|
|
|
|
Returns:
|
|
List of (chunk_id, similarity_score) tuples
|
|
"""
|
|
if self.index.ntotal == 0:
|
|
return []
|
|
|
|
try:
|
|
# Prepare query
|
|
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
|
|
|
|
# Normalize for inner product
|
|
if self.config.metric == "INNER_PRODUCT":
|
|
faiss.normalize_L2(query_embedding)
|
|
|
|
# Set nprobe for IVF index
|
|
if nprobe and hasattr(self.index, 'nprobe'):
|
|
self.index.nprobe = nprobe
|
|
elif hasattr(self.index, 'nprobe') and self.config.nprobe:
|
|
self.index.nprobe = self.config.nprobe
|
|
|
|
# Search
|
|
similarities, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
|
|
|
# Convert to chunk IDs and scores
|
|
results = []
|
|
for similarity, idx in zip(similarities[0], indices[0]):
|
|
if idx >= 0 and idx < len(self.chunk_mapping): # Valid index
|
|
chunk_id = self.chunk_mapping[idx]
|
|
results.append((chunk_id, float(similarity)))
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Search failed: {e}")
|
|
return []
|
|
|
|
def batch_search(
|
|
self,
|
|
query_embeddings: np.ndarray,
|
|
top_k: int = 10,
|
|
nprobe: Optional[int] = None
|
|
) -> List[List[Tuple[str, float]]]:
|
|
"""
|
|
Batch search for multiple queries
|
|
|
|
Args:
|
|
query_embeddings: Array of query embeddings
|
|
top_k: Number of results per query
|
|
nprobe: Number of clusters to probe
|
|
|
|
Returns:
|
|
List of result lists, one per query
|
|
"""
|
|
if self.index.ntotal == 0:
|
|
return [[] for _ in range(len(query_embeddings))]
|
|
|
|
try:
|
|
# Prepare queries
|
|
query_embeddings = query_embeddings.astype(np.float32)
|
|
|
|
# Normalize for inner product
|
|
if self.config.metric == "INNER_PRODUCT":
|
|
faiss.normalize_L2(query_embeddings)
|
|
|
|
# Set nprobe for IVF index
|
|
if nprobe and hasattr(self.index, 'nprobe'):
|
|
self.index.nprobe = nprobe
|
|
elif hasattr(self.index, 'nprobe') and self.config.nprobe:
|
|
self.index.nprobe = self.config.nprobe
|
|
|
|
# Search
|
|
similarities, indices = self.index.search(
|
|
query_embeddings,
|
|
min(top_k, self.index.ntotal)
|
|
)
|
|
|
|
# Convert results
|
|
all_results = []
|
|
for sim_row, idx_row in zip(similarities, indices):
|
|
query_results = []
|
|
for similarity, idx in zip(sim_row, idx_row):
|
|
if idx >= 0 and idx < len(self.chunk_mapping):
|
|
chunk_id = self.chunk_mapping[idx]
|
|
query_results.append((chunk_id, float(similarity)))
|
|
all_results.append(query_results)
|
|
|
|
return all_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Batch search failed: {e}")
|
|
return [[] for _ in range(len(query_embeddings))]
|
|
|
|
def remove_embeddings(self, chunk_ids: List[str]):
|
|
"""
|
|
Remove embeddings from index (rebuilds index)
|
|
|
|
Args:
|
|
chunk_ids: Chunk IDs to remove
|
|
"""
|
|
if not chunk_ids:
|
|
return
|
|
|
|
try:
|
|
# FAISS doesn't support removal, so we need to rebuild
|
|
logger.info("Rebuilding index without specified chunks...")
|
|
|
|
# Get all current embeddings
|
|
all_embeddings = []
|
|
remaining_chunk_ids = []
|
|
|
|
# This is a simplified approach - in production, you'd want to store
|
|
# all embeddings and rebuild more efficiently
|
|
for idx, chunk_id in self.chunk_mapping.items():
|
|
if chunk_id not in chunk_ids:
|
|
# In practice, you'd retrieve the actual embedding here
|
|
# For now, we'll just skip it
|
|
remaining_chunk_ids.append(chunk_id)
|
|
|
|
# Rebuild index with remaining chunks
|
|
self._rebuild_index(remaining_chunk_ids)
|
|
|
|
logger.info(f"Rebuilt index with {len(remaining_chunk_ids)} chunks")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove embeddings: {e}")
|
|
raise
|
|
|
|
def _rebuild_index(self, chunk_ids: List[str]):
|
|
"""Rebuild index with specified chunks"""
|
|
# This would require storing all embeddings externally
|
|
# For MVP, we'll just clear the index
|
|
self._initialize_index()
|
|
self.chunk_mapping = {}
|
|
|
|
def save_index(self, name: str = "default"):
|
|
"""Save index to disk"""
|
|
try:
|
|
index_file = self.index_path / f"{name}.index"
|
|
mapping_file = self.index_path / f"{name}_mapping.pkl"
|
|
|
|
# Save index
|
|
if hasattr(self.index, 'cpu'):
|
|
# Move to CPU before saving
|
|
cpu_index = faiss.index_gpu_to_cpu(self.index)
|
|
faiss.write_index(cpu_index, str(index_file))
|
|
else:
|
|
faiss.write_index(self.index, str(index_file))
|
|
|
|
# Save mapping
|
|
with open(mapping_file, 'wb') as f:
|
|
pickle.dump(self.chunk_mapping, f)
|
|
|
|
logger.info(f"Index saved to {index_file}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save index: {e}")
|
|
raise
|
|
|
|
def _load_index(self, name: str = "default"):
|
|
"""Load index from disk"""
|
|
try:
|
|
index_file = self.index_path / f"{name}.index"
|
|
mapping_file = self.index_path / f"{name}_mapping.pkl"
|
|
|
|
if index_file.exists() and mapping_file.exists():
|
|
# Load index
|
|
self.index = faiss.read_index(str(index_file))
|
|
|
|
# Move to GPU if needed
|
|
if self.config.use_gpu and faiss.get_num_gpus() > 0:
|
|
try:
|
|
res = faiss.StandardGpuResources()
|
|
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to move loaded index to GPU: {e}")
|
|
|
|
# Load mapping
|
|
with open(mapping_file, 'rb') as f:
|
|
self.chunk_mapping = pickle.load(f)
|
|
|
|
logger.info(f"Loaded index with {self.index.ntotal} embeddings")
|
|
return True
|
|
else:
|
|
logger.info("No existing index found, starting with empty index")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load index: {e}")
|
|
return False
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get index statistics"""
|
|
return {
|
|
"total_embeddings": self.index.ntotal,
|
|
"index_type": self.config.index_type,
|
|
"dimension": self.embedding_dim,
|
|
"metric": self.config.metric,
|
|
"is_trained": getattr(self.index, 'is_trained', True),
|
|
"nlist": getattr(self.index, 'nlist', None),
|
|
"use_gpu": self.config.use_gpu and faiss.get_num_gpus() > 0,
|
|
}
|
|
|
|
def train_index(self, embeddings: np.ndarray):
|
|
"""Train index (required for some index types)"""
|
|
if hasattr(self.index, 'train') and not getattr(self.index, 'is_trained', True):
|
|
try:
|
|
embeddings = embeddings.astype(np.float32)
|
|
if self.config.metric == "INNER_PRODUCT":
|
|
faiss.normalize_L2(embeddings)
|
|
|
|
self.index.train(embeddings)
|
|
logger.info(f"Index trained with {len(embeddings)} embeddings")
|
|
except Exception as e:
|
|
logger.error(f"Failed to train index: {e}")
|
|
raise
|
|
```
|
|
|
|
---
|
|
|
|
### Task 3.2: Hybrid Retrieval System
|
|
**Priority**: High
|
|
**Estimated Time**: 12 hours
|
|
**Dependencies**: Task 3.1
|
|
|
|
#### Subtasks:
|
|
- [ ] Implement keyword search (BM25)
|
|
- [ ] Create vector similarity search
|
|
- [ ] Build hybrid ranking algorithm
|
|
- [ ] Add metadata filtering
|
|
- [ ] Implement result fusion
|
|
- [ ] Create performance optimization
|
|
|
|
#### Implementation:
|
|
|
|
**src/retrieval/hybrid_search.py**
|
|
```python
|
|
import numpy as np
|
|
from typing import List, Dict, Tuple, Optional, Any
|
|
from collections import Counter, defaultdict
|
|
import math
|
|
from ..core.vector_store import VectorStore
|
|
from ..core.embeddings import EmbeddingService
|
|
from ..models.chunk import Chunk
|
|
from ..models.query import Query, QueryResult
|
|
from .keyword_search import KeywordSearcher
|
|
from .ranker import ResultRanker
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class HybridSearcher:
|
|
"""Hybrid search combining keyword and vector search"""
|
|
|
|
def __init__(
|
|
self,
|
|
vector_store: VectorStore,
|
|
embedding_service: EmbeddingService,
|
|
keyword_searcher: KeywordSearcher,
|
|
ranker: ResultRanker
|
|
):
|
|
self.vector_store = vector_store
|
|
self.embedding_service = embedding_service
|
|
self.keyword_searcher = keyword_searcher
|
|
self.ranker = ranker
|
|
|
|
def search(
|
|
self,
|
|
query: Query,
|
|
chunks: Dict[str, Chunk], # chunk_id -> Chunk mapping
|
|
top_k: int = 10,
|
|
alpha: float = 0.5, # Weight for hybrid combination
|
|
rerank: bool = True
|
|
) -> QueryResult:
|
|
"""
|
|
Perform hybrid search
|
|
|
|
Args:
|
|
query: Query object
|
|
chunks: Mapping of chunk IDs to Chunk objects
|
|
top_k: Number of results to return
|
|
alpha: Weight for combining results (0=keyword only, 1=vector only)
|
|
rerank: Whether to apply reranking
|
|
|
|
Returns:
|
|
QueryResult object
|
|
"""
|
|
try:
|
|
logger.info(f"Performing hybrid search for query: {query.text[:50]}...")
|
|
|
|
# Step 1: Keyword search
|
|
keyword_results = self._perform_keyword_search(query, chunks)
|
|
|
|
# Step 2: Vector search
|
|
vector_results = self._perform_vector_search(query, top_k * 2)
|
|
|
|
# Step 3: Combine results
|
|
combined_results = self._combine_results(
|
|
keyword_results,
|
|
vector_results,
|
|
chunks,
|
|
alpha
|
|
)
|
|
|
|
# Step 4: Apply metadata filtering
|
|
filtered_results = self._apply_metadata_filters(
|
|
combined_results,
|
|
chunks,
|
|
query.filters
|
|
)
|
|
|
|
# Step 5: Rerank if requested
|
|
if rerank and len(filtered_results) > 1:
|
|
filtered_results = self.ranker.rerank(
|
|
query,
|
|
filtered_results,
|
|
chunks
|
|
)
|
|
|
|
# Step 6: Limit to top_k
|
|
final_results = filtered_results[:top_k]
|
|
|
|
# Create result object
|
|
result = QueryResult(
|
|
query=query,
|
|
results=final_results,
|
|
total_found=len(combined_results),
|
|
keyword_results_count=len(keyword_results),
|
|
vector_results_count=len(vector_results),
|
|
alpha=alpha,
|
|
reranked=rerank
|
|
)
|
|
|
|
logger.info(f"Search completed: {len(final_results)} results")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Hybrid search failed: {e}")
|
|
raise
|
|
|
|
def _perform_keyword_search(self, query: Query, chunks: Dict[str, Chunk]) -> List[Tuple[str, float]]:
|
|
"""Perform keyword search using BM25"""
|
|
try:
|
|
# Extract keywords from query
|
|
query_keywords = self._extract_keywords(query.text)
|
|
|
|
if not query_keywords:
|
|
logger.info("No keywords found in query")
|
|
return []
|
|
|
|
# Perform BM25 search
|
|
keyword_results = self.keyword_searcher.search(
|
|
query_keywords,
|
|
chunks,
|
|
top_k=50 # Get more results for combination
|
|
)
|
|
|
|
logger.info(f"Keyword search found {len(keyword_results)} results")
|
|
return keyword_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Keyword search failed: {e}")
|
|
return []
|
|
|
|
def _perform_vector_search(self, query: Query, top_k: int) -> List[Tuple[str, float]]:
|
|
"""Perform vector similarity search"""
|
|
try:
|
|
# Generate query embedding
|
|
query_embedding = self.embedding_service.encode_single(query.text)
|
|
|
|
# Search vector store
|
|
vector_results = self.vector_store.search(query_embedding, top_k)
|
|
|
|
logger.info(f"Vector search found {len(vector_results)} results")
|
|
return vector_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Vector search failed: {e}")
|
|
return []
|
|
|
|
def _combine_results(
|
|
self,
|
|
keyword_results: List[Tuple[str, float]],
|
|
vector_results: List[Tuple[str, float]],
|
|
chunks: Dict[str, Chunk],
|
|
alpha: float
|
|
) -> List[Tuple[str, float]]:
|
|
"""Combine keyword and vector search results"""
|
|
|
|
# Create score dictionaries
|
|
keyword_scores = dict(keyword_results)
|
|
vector_scores = dict(vector_results)
|
|
|
|
# Get all unique chunk IDs
|
|
all_chunk_ids = set(keyword_scores.keys()) | set(vector_scores.keys())
|
|
|
|
combined_results = []
|
|
|
|
for chunk_id in all_chunk_ids:
|
|
keyword_score = keyword_scores.get(chunk_id, 0.0)
|
|
vector_score = vector_scores.get(chunk_id, 0.0)
|
|
|
|
# Normalize scores (simple min-max normalization)
|
|
normalized_keyword = self._normalize_score(keyword_score, keyword_results)
|
|
normalized_vector = self._normalize_score(vector_score, vector_results)
|
|
|
|
# Combine scores
|
|
combined_score = alpha * normalized_vector + (1 - alpha) * normalized_keyword
|
|
|
|
combined_results.append((chunk_id, combined_score))
|
|
|
|
# Sort by combined score
|
|
combined_results.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
logger.info(f"Combined {len(all_chunk_ids)} unique results")
|
|
return combined_results
|
|
|
|
def _normalize_score(self, score: float, all_scores: List[Tuple[str, float]]) -> float:
|
|
"""Normalize score to 0-1 range"""
|
|
if not all_scores:
|
|
return 0.0
|
|
|
|
scores = [s for _, s in all_scores]
|
|
min_score = min(scores)
|
|
max_score = max(scores)
|
|
|
|
if max_score == min_score:
|
|
return 0.5 if score == min_score else 1.0
|
|
|
|
return (score - min_score) / (max_score - min_score)
|
|
|
|
def _apply_metadata_filters(
|
|
self,
|
|
results: List[Tuple[str, float]],
|
|
chunks: Dict[str, Chunk],
|
|
filters: Optional[Dict[str, Any]]
|
|
) -> List[Tuple[str, float]]:
|
|
"""Apply metadata filters to results"""
|
|
if not filters:
|
|
return results
|
|
|
|
filtered_results = []
|
|
|
|
for chunk_id, score in results:
|
|
if chunk_id not in chunks:
|
|
continue
|
|
|
|
chunk = chunks[chunk_id]
|
|
|
|
if self._passes_filters(chunk, filters):
|
|
filtered_results.append((chunk_id, score))
|
|
|
|
logger.info(f"Filtered {len(results)} -> {len(filtered_results)} results")
|
|
return filtered_results
|
|
|
|
def _passes_filters(self, chunk: Chunk, filters: Dict[str, Any]) -> bool:
|
|
"""Check if chunk passes all filters"""
|
|
|
|
# Subject filter
|
|
if "subject" in filters:
|
|
if chunk.document_metadata.get("subject") != filters["subject"]:
|
|
return False
|
|
|
|
# Difficulty filter
|
|
if "max_difficulty" in filters:
|
|
chunk_difficulty = chunk.metadata.get("difficulty", 0.5)
|
|
if chunk_difficulty > filters["max_difficulty"]:
|
|
return False
|
|
|
|
# Bloom level filter
|
|
if "max_bloom_level" in filters:
|
|
chunk_bloom = chunk.metadata.get("bloom_level", 3)
|
|
if chunk_bloom > filters["max_bloom_level"]:
|
|
return False
|
|
|
|
# Language filter
|
|
if "language" in filters:
|
|
if chunk.metadata.get("language") != filters["language"]:
|
|
return False
|
|
|
|
# Keyword filter (must contain at least one)
|
|
if "required_keywords" in filters:
|
|
chunk_text = chunk.text.lower()
|
|
chunk_keywords = chunk.metadata.get("keywords", [])
|
|
|
|
has_keyword = any(
|
|
keyword.lower() in chunk_text or keyword in chunk_keywords
|
|
for keyword in filters["required_keywords"]
|
|
)
|
|
if not has_keyword:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _extract_keywords(self, text: str) -> List[str]:
|
|
"""Extract keywords from query text"""
|
|
# Simple keyword extraction - could be enhanced with NLP
|
|
import re
|
|
|
|
# Remove common stop words
|
|
stop_words = {
|
|
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
|
'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have',
|
|
'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should',
|
|
'what', 'how', 'when', 'where', 'why', 'who', 'which', 'that', 'this'
|
|
}
|
|
|
|
# Extract words and clean
|
|
words = re.findall(r'\b[a-zA-Z]+\b', text.lower())
|
|
keywords = [word for word in words if word not in stop_words and len(word) > 2]
|
|
|
|
# Return unique keywords
|
|
return list(set(keywords))
|
|
|
|
def batch_search(
|
|
self,
|
|
queries: List[Query],
|
|
chunks: Dict[str, Chunk],
|
|
top_k: int = 10,
|
|
alpha: float = 0.5,
|
|
rerank: bool = True
|
|
) -> List[QueryResult]:
|
|
"""Perform batch search for multiple queries"""
|
|
results = []
|
|
|
|
for query in queries:
|
|
try:
|
|
result = self.search(query, chunks, top_k, alpha, rerank)
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Batch search failed for query {query.id}: {e}")
|
|
# Add empty result for failed query
|
|
results.append(QueryResult(
|
|
query=query,
|
|
results=[],
|
|
total_found=0,
|
|
keyword_results_count=0,
|
|
vector_results_count=0,
|
|
alpha=alpha,
|
|
reranked=rerank
|
|
))
|
|
|
|
return results
|
|
|
|
def get_search_stats(self) -> Dict:
|
|
"""Get search system statistics"""
|
|
return {
|
|
"vector_store_stats": self.vector_store.get_stats(),
|
|
"embedding_service_stats": {
|
|
"model_name": self.embedding_service.config.model_name,
|
|
"dimension": self.embedding_service.model.get_sentence_embedding_dimension(),
|
|
},
|
|
"keyword_searcher_stats": self.keyword_searcher.get_stats(),
|
|
}
|
|
```
|
|
|
|
---
|
|
|
|
## 🤖 WEEK 7-8: LLM INTEGRATION
|
|
|
|
### Task 4.1: LLM Client Implementation
|
|
**Priority**: High
|
|
**Estimated Time**: 10 hours
|
|
**Dependencies**: Task 3.2
|
|
|
|
#### Subtasks:
|
|
- [ ] Set up OpenAI API client
|
|
- [ ] Set up Anthropic API client
|
|
- [ ] Create prompt templates
|
|
- [ ] Implement response generation
|
|
- [ ] Add token counting
|
|
- [ ] Create error handling
|
|
|
|
#### Implementation:
|
|
|
|
**src/llm/llm_client.py**
|
|
```python
|
|
import os
|
|
import json
|
|
from typing import Dict, List, Optional, Any
|
|
import openai
|
|
import anthropic
|
|
from ..config.settings import RAGConfig
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class LLMClient:
|
|
"""Client for interacting with LLM APIs"""
|
|
|
|
def __init__(self, config: RAGConfig):
|
|
self.config = config
|
|
self.openai_client = None
|
|
self.anthropic_client = None
|
|
|
|
self._initialize_clients()
|
|
|
|
def _initialize_clients(self):
|
|
"""Initialize API clients"""
|
|
try:
|
|
# Initialize OpenAI client
|
|
if self.config.llm_provider in ["openai", "both"]:
|
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
if openai_api_key:
|
|
self.openai_client = openai.OpenAI(api_key=openai_api_key)
|
|
logger.info("OpenAI client initialized")
|
|
else:
|
|
logger.warning("OpenAI API key not found")
|
|
|
|
# Initialize Anthropic client
|
|
if self.config.llm_provider in ["anthropic", "both"]:
|
|
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
if anthropic_api_key:
|
|
self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)
|
|
logger.info("Anthropic client initialized")
|
|
else:
|
|
logger.warning("Anthropic API key not found")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize LLM clients: {e}")
|
|
raise
|
|
|
|
def generate_response(
|
|
self,
|
|
prompt: str,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
provider: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generate response from LLM
|
|
|
|
Args:
|
|
prompt: Input prompt
|
|
max_tokens: Maximum tokens in response
|
|
temperature: Sampling temperature
|
|
provider: LLM provider to use
|
|
|
|
Returns:
|
|
Response dictionary with text and metadata
|
|
"""
|
|
provider = provider or self.config.llm_provider
|
|
max_tokens = max_tokens or self.config.max_response_tokens
|
|
temperature = temperature or self.config.temperature
|
|
|
|
try:
|
|
if provider == "anthropic" and self.anthropic_client:
|
|
return self._generate_anthropic_response(prompt, max_tokens, temperature)
|
|
elif provider == "openai" and self.openai_client:
|
|
return self._generate_openai_response(prompt, max_tokens, temperature)
|
|
else:
|
|
raise ValueError(f"Provider {provider} not available")
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM generation failed: {e}")
|
|
raise
|
|
|
|
def _generate_anthropic_response(
|
|
self,
|
|
prompt: str,
|
|
max_tokens: int,
|
|
temperature: float
|
|
) -> Dict[str, Any]:
|
|
"""Generate response using Anthropic Claude"""
|
|
try:
|
|
message = self.anthropic_client.messages.create(
|
|
model=self.config.llm_model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
]
|
|
)
|
|
|
|
response_text = message.content[0].text
|
|
|
|
# Calculate token usage (approximate)
|
|
prompt_tokens = self._estimate_tokens(prompt)
|
|
completion_tokens = self._estimate_tokens(response_text)
|
|
|
|
return {
|
|
"text": response_text,
|
|
"provider": "anthropic",
|
|
"model": self.config.llm_model,
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
"temperature": temperature,
|
|
"finish_reason": message.stop_reason,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Anthropic generation failed: {e}")
|
|
raise
|
|
|
|
def _generate_openai_response(
|
|
self,
|
|
prompt: str,
|
|
max_tokens: int,
|
|
temperature: float
|
|
) -> Dict[str, Any]:
|
|
"""Generate response using OpenAI GPT"""
|
|
try:
|
|
response = self.openai_client.chat.completions.create(
|
|
model=self.config.llm_model,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful AI assistant."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
],
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
)
|
|
|
|
response_text = response.choices[0].message.content
|
|
|
|
return {
|
|
"text": response_text,
|
|
"provider": "openai",
|
|
"model": self.config.llm_model,
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens,
|
|
"temperature": temperature,
|
|
"finish_reason": response.choices[0].finish_reason,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"OpenAI generation failed: {e}")
|
|
raise
|
|
|
|
def _estimate_tokens(self, text: str) -> int:
|
|
"""Estimate token count (rough approximation)"""
|
|
# Simple approximation: ~4 characters per token
|
|
return max(1, len(text) // 4)
|
|
|
|
def validate_response(self, response: str, context: str) -> Dict[str, Any]:
|
|
"""Validate response quality and safety"""
|
|
validation = {
|
|
"is_safe": True,
|
|
"is_relevant": True,
|
|
"is_appropriate": True,
|
|
"issues": [],
|
|
"confidence": 1.0,
|
|
}
|
|
|
|
# Check for unsafe content
|
|
unsafe_patterns = [
|
|
"hate", "violence", "self-harm", "explicit", "illegal"
|
|
]
|
|
|
|
response_lower = response.lower()
|
|
for pattern in unsafe_patterns:
|
|
if pattern in response_lower:
|
|
validation["is_safe"] = False
|
|
validation["issues"].append(f"Potentially unsafe content: {pattern}")
|
|
validation["confidence"] *= 0.5
|
|
|
|
# Check relevance to context
|
|
if context:
|
|
context_words = set(context.lower().split())
|
|
response_words = set(response.lower().split())
|
|
overlap = len(context_words & response_words)
|
|
relevance_score = overlap / len(response_words) if response_words else 0
|
|
|
|
if relevance_score < 0.1:
|
|
validation["is_relevant"] = False
|
|
validation["issues"].append("Low relevance to context")
|
|
validation["confidence"] *= 0.7
|
|
|
|
# Check for appropriate length
|
|
if len(response) < 10:
|
|
validation["is_appropriate"] = False
|
|
validation["issues"].append("Response too short")
|
|
validation["confidence"] *= 0.8
|
|
elif len(response) > 2000:
|
|
validation["issues"].append("Response very long")
|
|
validation["confidence"] *= 0.9
|
|
|
|
return validation
|
|
|
|
def get_model_info(self, provider: str) -> Dict[str, Any]:
|
|
"""Get information about a specific model"""
|
|
model_info = {
|
|
"anthropic": {
|
|
"claude-3-5-sonnet-20241022": {
|
|
"max_tokens": 4096,
|
|
"context_window": 200000,
|
|
"cost_per_1k_input": 0.003,
|
|
"cost_per_1k_output": 0.015,
|
|
}
|
|
},
|
|
"openai": {
|
|
"gpt-4": {
|
|
"max_tokens": 4096,
|
|
"context_window": 8192,
|
|
"cost_per_1k_input": 0.03,
|
|
"cost_per_1k_output": 0.06,
|
|
},
|
|
"gpt-3.5-turbo": {
|
|
"max_tokens": 4096,
|
|
"context_window": 16385,
|
|
"cost_per_1k_input": 0.0015,
|
|
"cost_per_1k_output": 0.002,
|
|
}
|
|
}
|
|
}
|
|
|
|
return model_info.get(provider, {})
|
|
|
|
def estimate_cost(self, prompt_tokens: int, completion_tokens: int, provider: str) -> float:
|
|
"""Estimate cost for API call"""
|
|
model_info = self.get_model_info(provider)
|
|
model_data = model_info.get(self.config.llm_model, {})
|
|
|
|
input_cost = (prompt_tokens / 1000) * model_data.get("cost_per_1k_input", 0)
|
|
output_cost = (completion_tokens / 1000) * model_data.get("cost_per_1k_output", 0)
|
|
|
|
return input_cost + output_cost
|
|
```
|
|
|
|
---
|
|
|
|
### Task 4.2: Prompt Engineering
|
|
**Priority**: High
|
|
**Estimated Time**: 8 hours
|
|
**Dependencies**: Task 4.1
|
|
|
|
#### Subtasks:
|
|
- [ ] Create prompt templates for different modes
|
|
- [ ] Implement context injection
|
|
- [ ] Add constraint enforcement
|
|
- [ ] Create safety prompts
|
|
- [ ] Build prompt optimization
|
|
- [ ] Add prompt testing
|
|
|
|
#### Implementation:
|
|
|
|
**src/llm/prompt_builder.py**
|
|
```python
|
|
from typing import Dict, List, Optional, Any
|
|
from ..models.query import Query
|
|
from ..models.chunk import Chunk
|
|
from ..config.settings import RAGConfig
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class PromptBuilder:
|
|
"""Builds prompts for different interaction modes"""
|
|
|
|
def __init__(self, config: RAGConfig):
|
|
self.config = config
|
|
|
|
def build_prompt(
|
|
self,
|
|
query: Query,
|
|
retrieved_chunks: List[Chunk],
|
|
mode: str = "EXPLANATION",
|
|
student_level: int = 2,
|
|
constraints: Optional[Dict] = None
|
|
) -> str:
|
|
"""
|
|
Build complete prompt for LLM
|
|
|
|
Args:
|
|
query: Student query
|
|
retrieved_chunks: Retrieved context chunks
|
|
mode: Interaction mode
|
|
student_level: Student's current level (1-6 Bloom's)
|
|
constraints: Additional constraints
|
|
|
|
Returns:
|
|
Complete prompt string
|
|
"""
|
|
try:
|
|
# Get mode-specific template
|
|
template = self._get_mode_template(mode)
|
|
|
|
# Build context section
|
|
context_section = self._build_context_section(retrieved_chunks)
|
|
|
|
# Build constraints section
|
|
constraints_section = self._build_constraints_section(
|
|
mode, student_level, constraints
|
|
)
|
|
|
|
# Build query section
|
|
query_section = self._build_query_section(query)
|
|
|
|
# Combine all sections
|
|
prompt = template.format(
|
|
constraints_section=constraints_section,
|
|
context_section=context_section,
|
|
query_section=query_section,
|
|
mode=mode,
|
|
student_level=student_level
|
|
)
|
|
|
|
logger.info(f"Built prompt for mode {mode}, {len(retrieved_chunks)} chunks")
|
|
return prompt
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to build prompt: {e}")
|
|
raise
|
|
|
|
def _get_mode_template(self, mode: str) -> str:
|
|
"""Get prompt template for specific mode"""
|
|
templates = {
|
|
"EXPLANATION": """
|
|
{constraints_section}
|
|
|
|
CONTEXT:
|
|
{context_section}
|
|
|
|
STUDENT QUESTION:
|
|
{query_section}
|
|
|
|
INSTRUCTIONS:
|
|
Generate a clear, educational explanation that:
|
|
1. Uses ONLY the provided context above
|
|
2. Is appropriate for a student at level {student_level} (Bloom's taxonomy)
|
|
3. Includes 1-2 concrete examples if helpful
|
|
4. Avoids complex proofs unless appropriate for the level
|
|
5. Ends with a guiding question or next step
|
|
6. Is concise but comprehensive (max 300 words)
|
|
|
|
Focus on understanding over memorization.
|
|
""",
|
|
|
|
"TUTOR": """
|
|
{constraints_section}
|
|
|
|
CONTEXT:
|
|
{context_section}
|
|
|
|
STUDENT QUESTION:
|
|
{query_section}
|
|
|
|
INSTRUCTIONS:
|
|
Act as a Socratic tutor. Guide the student to discover the answer themselves:
|
|
1. Start with a clarifying question
|
|
2. Provide hints progressively, not the full answer
|
|
3. Use the context to guide your questions
|
|
4. Check for understanding between steps
|
|
5. Encourage critical thinking
|
|
6. Adapt to the student's level {student_level}
|
|
|
|
Never give the complete answer immediately. Guide, don't tell.
|
|
""",
|
|
|
|
"EXAM": """
|
|
{constraints_section}
|
|
|
|
CONTEXT:
|
|
{context_section}
|
|
|
|
STUDENT QUESTION:
|
|
{query_section}
|
|
|
|
INSTRUCTIONS:
|
|
You are proctoring an exam. Provide minimal assistance:
|
|
1. Answer only if the question is about exam format or instructions
|
|
2. Do not provide content answers or hints
|
|
3. If the question is unclear, ask for clarification
|
|
4. Maintain formal, neutral tone
|
|
5. Do not use the provided context to answer content questions
|
|
|
|
If the question asks for content knowledge, respond: "I cannot provide answers to exam questions. Please focus on the question and use your knowledge."
|
|
""",
|
|
|
|
"QUIZ": """
|
|
{constraints_section}
|
|
|
|
CONTEXT:
|
|
{context_section}
|
|
|
|
STUDENT QUESTION:
|
|
{query_section}
|
|
|
|
INSTRUCTIONS:
|
|
Create an interactive quiz experience:
|
|
1. Turn the student's question into a quiz question if appropriate
|
|
2. Or provide a related quiz question based on the context
|
|
3. Include multiple choice options if suitable
|
|
4. Ask the student to attempt an answer
|
|
5. Provide immediate feedback on their response
|
|
6. Use the context to ensure accuracy
|
|
7. Keep it engaging and educational
|
|
|
|
Level: {student_level}
|
|
""",
|
|
|
|
"EXPLORATION": """
|
|
{constraints_section}
|
|
|
|
CONTEXT:
|
|
{context_section}
|
|
|
|
STUDENT QUESTION:
|
|
{query_section}
|
|
|
|
INSTRUCTIONS:
|
|
Encourage deeper exploration and curiosity:
|
|
1. Go beyond the basic answer
|
|
2. Connect to related concepts and real-world applications
|
|
3. Ask "what if" and "why" questions
|
|
4. Suggest extensions and further learning
|
|
5. Use the context as a starting point, not a limit
|
|
6. Inspire curiosity about the subject
|
|
7. Adapt to level {student_level} but challenge appropriately
|
|
|
|
Be engaging and thought-provoking.
|
|
""",
|
|
|
|
"REMEDIAL": """
|
|
{constraints_section}
|
|
|
|
CONTEXT:
|
|
{context_section}
|
|
|
|
STUDENT QUESTION:
|
|
{query_section}
|
|
|
|
INSTRUCTIONS:
|
|
Address identified misconceptions patiently:
|
|
1. Directly acknowledge the confusion
|
|
2. Explain why the misconception is incorrect
|
|
3. Provide the correct mental model
|
|
4. Use simple, clear language
|
|
5. Include concrete examples to illustrate
|
|
6. Build confidence with positive reinforcement
|
|
7. Check for understanding before moving on
|
|
|
|
Level: {student_level} - focus on building foundation.
|
|
""",
|
|
}
|
|
|
|
return templates.get(mode, templates["EXPLANATION"])
|
|
|
|
def _build_context_section(self, chunks: List[Chunk]) -> str:
|
|
"""Build context section from retrieved chunks"""
|
|
if not chunks:
|
|
return "No relevant context found."
|
|
|
|
context_parts = []
|
|
|
|
for i, chunk in enumerate(chunks, 1):
|
|
chunk_text = chunk.text.strip()
|
|
|
|
# Add chunk header
|
|
header = f"[Source {i}: {chunk.document_metadata.get('concept', 'Unknown')}]"
|
|
|
|
# Add chunk content
|
|
context_parts.append(header)
|
|
context_parts.append(chunk_text)
|
|
|
|
# Add separator
|
|
context_parts.append("")
|
|
|
|
return "\n".join(context_parts)
|
|
|
|
def _build_constraints_section(
|
|
self,
|
|
mode: str,
|
|
student_level: int,
|
|
constraints: Optional[Dict]
|
|
) -> str:
|
|
"""Build constraints section"""
|
|
constraint_parts = []
|
|
|
|
# Core constraints
|
|
constraint_parts.append("CORE RULES:")
|
|
constraint_parts.append("- Use ONLY the provided context for factual information")
|
|
constraint_parts.append("- Never use your general knowledge for core content")
|
|
constraint_parts.append("- Admit uncertainty if context is insufficient")
|
|
constraint_parts.append("- Maintain educational appropriateness")
|
|
|
|
# Mode-specific constraints
|
|
if mode == "EXPLANATION":
|
|
constraint_parts.append("\nEXPLANATION CONSTRAINTS:")
|
|
constraint_parts.append(f"- Target Bloom's level: {student_level}")
|
|
constraint_parts.append("- Include examples when helpful")
|
|
constraint_parts.append("- Avoid unnecessary complexity")
|
|
|
|
elif mode == "TUTOR":
|
|
constraint_parts.append("\nTUTOR CONSTRAINTS:")
|
|
constraint_parts.append("- Ask guiding questions first")
|
|
constraint_parts.append("- Reveal information progressively")
|
|
constraint_parts.append("- Check understanding between steps")
|
|
|
|
elif mode == "EXAM":
|
|
constraint_parts.append("\nEXAM CONSTRAINTS:")
|
|
constraint_parts.append("- No content assistance")
|
|
constraint_parts.append("- Formal tone only")
|
|
constraint_parts.append("- Focus on exam instructions")
|
|
|
|
# Student level constraints
|
|
constraint_parts.append(f"\nLEVEL CONSTRAINTS (Level {student_level}):")
|
|
|
|
if student_level <= 2:
|
|
constraint_parts.append("- Focus on basic understanding")
|
|
constraint_parts.append("- Use simple language")
|
|
constraint_parts.append("- Avoid abstract concepts")
|
|
elif student_level <= 4:
|
|
constraint_parts.append("- Include application examples")
|
|
constraint_parts.append("- Introduce some analysis")
|
|
constraint_parts.append("- Balance simplicity and depth")
|
|
else:
|
|
constraint_parts.append("- Encourage critical thinking")
|
|
constraint_parts.append("- Include complex applications")
|
|
constraint_parts.append("- Allow for abstract reasoning")
|
|
|
|
# Additional constraints
|
|
if constraints:
|
|
constraint_parts.append("\nADDITIONAL CONSTRAINTS:")
|
|
for key, value in constraints.items():
|
|
constraint_parts.append(f"- {key}: {value}")
|
|
|
|
return "\n".join(constraint_parts)
|
|
|
|
def _build_query_section(self, query: Query) -> str:
|
|
"""Build query section"""
|
|
query_parts = []
|
|
|
|
# Add original query
|
|
query_parts.append(f"Question: {query.text}")
|
|
|
|
# Add context if available
|
|
if query.context:
|
|
query_parts.append(f"Context: {query.context}")
|
|
|
|
# Add student info if available
|
|
if query.student_info:
|
|
info_parts = []
|
|
if "grade_level" in query.student_info:
|
|
info_parts.append(f"Grade: {query.student_info['grade_level']}")
|
|
if "subject" in query.student_info:
|
|
info_parts.append(f"Subject: {query.student_info['subject']}")
|
|
if "recent_topics" in query.student_info:
|
|
info_parts.append(f"Recent topics: {', '.join(query.student_info['recent_topics'])}")
|
|
|
|
if info_parts:
|
|
query_parts.append(f"Student Info: {', '.join(info_parts)}")
|
|
|
|
return "\n".join(query_parts)
|
|
|
|
def build_system_prompt(self, mode: str, student_level: int) -> str:
|
|
"""Build system prompt for LLM"""
|
|
system_prompts = {
|
|
"EXPLANATION": f"You are an educational AI tutor specializing in clear explanations for students at Bloom's level {student_level}. Your goal is to build understanding through structured, example-rich explanations.",
|
|
|
|
"TUTOR": f"You are a Socratic tutor guiding students at level {student_level}. Your role is to ask thoughtful questions that lead students to discover answers themselves, using the provided context as your knowledge base.",
|
|
|
|
"EXAM": "You are an exam proctor maintaining academic integrity. Provide only procedural assistance and never help with exam content.",
|
|
|
|
"QUIZ": f"You are an interactive quiz creator for students at level {student_level}. Transform questions into engaging learning opportunities with immediate feedback.",
|
|
|
|
"EXPLORATION": f"You are an educational guide encouraging deeper exploration for students at level {student_level}. Connect concepts to real-world applications and inspire curiosity.",
|
|
|
|
"REMEDIAL": f"You are a patient remedial tutor helping students at level {student_level} overcome misconceptions. Build confidence through clear, step-by-step explanations.",
|
|
}
|
|
|
|
return system_prompts.get(mode, system_prompts["EXPLANATION"])
|
|
|
|
def validate_prompt(self, prompt: str) -> Dict[str, Any]:
|
|
"""Validate prompt quality and safety"""
|
|
validation = {
|
|
"is_valid": True,
|
|
"issues": [],
|
|
"suggestions": [],
|
|
"token_estimate": len(prompt) // 4,
|
|
}
|
|
|
|
# Check for prompt injection attempts
|
|
injection_patterns = [
|
|
"ignore previous instructions",
|
|
"forget everything above",
|
|
"system prompt",
|
|
"you are now",
|
|
"pretend you are",
|
|
]
|
|
|
|
prompt_lower = prompt.lower()
|
|
for pattern in injection_patterns:
|
|
if pattern in prompt_lower:
|
|
validation["is_valid"] = False
|
|
validation["issues"].append(f"Potential injection: {pattern}")
|
|
|
|
# Check length
|
|
if validation["token_estimate"] > 8000:
|
|
validation["issues"].append("Prompt very long, may exceed context window")
|
|
validation["suggestions"].append("Consider reducing context or chunking")
|
|
|
|
# Check for required sections
|
|
required_sections = ["CONTEXT:", "STUDENT QUESTION:", "INSTRUCTIONS:"]
|
|
for section in required_sections:
|
|
if section not in prompt:
|
|
validation["issues"].append(f"Missing required section: {section}")
|
|
|
|
return validation
|
|
|
|
def optimize_prompt(self, prompt: str, max_tokens: int = 4000) -> str:
|
|
"""Optimize prompt to fit within token limit"""
|
|
current_tokens = len(prompt) // 4
|
|
|
|
if current_tokens <= max_tokens:
|
|
return prompt
|
|
|
|
# If too long, truncate context section first
|
|
context_start = prompt.find("CONTEXT:")
|
|
context_end = prompt.find("STUDENT QUESTION:")
|
|
|
|
if context_start != -1 and context_end != -1:
|
|
context_section = prompt[context_start:context_end]
|
|
other_sections = prompt[:context_start] + prompt[context_end:]
|
|
|
|
# Calculate allowed context length
|
|
other_tokens = len(other_sections) // 4
|
|
allowed_context_tokens = max_tokens - other_tokens
|
|
|
|
if allowed_context_tokens > 100:
|
|
# Truncate context proportionally
|
|
context_ratio = allowed_context_tokens / (len(context_section) // 4)
|
|
truncated_context = context_section[:int(len(context_section) * context_ratio)]
|
|
|
|
# Add truncation notice
|
|
truncated_context += "\n[Context truncated due to length limits]\n"
|
|
|
|
return other_sections[:context_start] + truncated_context + other_sections[context_start:]
|
|
|
|
# If still too long, truncate overall
|
|
return prompt[:max_tokens * 4] + "\n[Prompt truncated]"
|
|
```
|
|
|
|
---
|
|
|
|
## 📊 WEEK 9-10: INTEGRATION & OPTIMIZATION
|
|
|
|
### Task 5.1: End-to-End Pipeline
|
|
**Priority**: High
|
|
**Estimated Time**: 16 hours
|
|
**Dependencies**: Task 4.2
|
|
|
|
#### Subtasks:
|
|
- [ ] Integrate all components
|
|
- [ ] Create pipeline orchestration
|
|
- [ ] Add error handling and recovery
|
|
- [ ] Implement caching
|
|
- [ ] Add performance monitoring
|
|
- [ ] Create pipeline testing
|
|
|
|
#### Implementation:
|
|
|
|
**src/main.py**
|
|
```python
|
|
import asyncio
|
|
import time
|
|
from typing import Dict, List, Optional, Any
|
|
from pathlib import Path
|
|
import json
|
|
|
|
from .config.settings import RAGConfig, DEFAULT_CONFIG
|
|
from .core.embeddings import EmbeddingService
|
|
from .core.vector_store import VectorStore
|
|
from .core.indexer import Indexer
|
|
from .preprocessing.text_processor import TextProcessor
|
|
from .preprocessing.chunker import IntelligentChunker
|
|
from .retrieval.hybrid_search import HybridSearcher
|
|
from .retrieval.keyword_search import KeywordSearcher
|
|
from .retrieval.ranker import ResultRanker
|
|
from .llm.llm_client import LLMClient
|
|
from .llm.prompt_builder import PromptBuilder
|
|
from .models.document import Document
|
|
from .models.query import Query
|
|
from .utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class RAGEngine:
|
|
"""Main RAG engine orchestrating all components"""
|
|
|
|
def __init__(self, config: Optional[RAGConfig] = None):
|
|
self.config = config or DEFAULT_CONFIG
|
|
self.components = {}
|
|
self._initialize_components()
|
|
|
|
def _initialize_components(self):
|
|
"""Initialize all RAG components"""
|
|
try:
|
|
logger.info("Initializing RAG engine components...")
|
|
|
|
# Core components
|
|
self.components["embeddings"] = EmbeddingService(self.config.embeddings)
|
|
self.components["vector_store"] = VectorStore(self.config.vector_store)
|
|
self.components["text_processor"] = TextProcessor(self.config.chunking)
|
|
self.components["chunker"] = IntelligentChunker(self.config.chunking)
|
|
|
|
# Retrieval components
|
|
self.components["keyword_searcher"] = KeywordSearcher()
|
|
self.components["ranker"] = ResultRanker()
|
|
self.components["hybrid_searcher"] = HybridSearcher(
|
|
vector_store=self.components["vector_store"],
|
|
embedding_service=self.components["embeddings"],
|
|
keyword_searcher=self.components["keyword_searcher"],
|
|
ranker=self.components["ranker"]
|
|
)
|
|
|
|
# LLM components
|
|
self.components["llm_client"] = LLMClient(self.config)
|
|
self.components["prompt_builder"] = PromptBuilder(self.config)
|
|
|
|
# Indexer for document processing
|
|
self.components["indexer"] = Indexer(
|
|
embedding_service=self.components["embeddings"],
|
|
vector_store=self.components["vector_store"],
|
|
chunker=self.components["chunker"]
|
|
)
|
|
|
|
logger.info("All components initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize components: {e}")
|
|
raise
|
|
|
|
async def process_document(
|
|
self,
|
|
text: str,
|
|
document_metadata: Dict,
|
|
chunk_metadata: Optional[Dict] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Process a document and add it to the index
|
|
|
|
Args:
|
|
text: Document text
|
|
document_metadata: Document metadata
|
|
chunk_metadata: Default chunk metadata
|
|
|
|
Returns:
|
|
Processing results
|
|
"""
|
|
try:
|
|
logger.info(f"Processing document: {document_metadata.get('title', 'Unknown')}")
|
|
start_time = time.time()
|
|
|
|
# Use indexer to process document
|
|
result = await self.components["indexer"].process_document(
|
|
text, document_metadata, chunk_metadata
|
|
)
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
logger.info(f"Document processed in {processing_time:.2f}s: {result['chunk_count']} chunks")
|
|
|
|
return {
|
|
"success": True,
|
|
"chunk_count": result["chunk_count"],
|
|
"processing_time": processing_time,
|
|
"chunks": result["chunks"],
|
|
"stats": result["stats"]
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Document processing failed: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"processing_time": time.time() - start_time
|
|
}
|
|
|
|
async def query(
|
|
self,
|
|
query_text: str,
|
|
mode: str = "EXPLANATION",
|
|
top_k: int = 5,
|
|
filters: Optional[Dict] = None,
|
|
student_info: Optional[Dict] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Process a student query
|
|
|
|
Args:
|
|
query_text: Student's question
|
|
mode: Interaction mode
|
|
top_k: Number of results to retrieve
|
|
filters: Metadata filters
|
|
student_info: Student information
|
|
|
|
Returns:
|
|
Query response
|
|
"""
|
|
try:
|
|
logger.info(f"Processing query: {query_text[:50]}...")
|
|
start_time = time.time()
|
|
|
|
# Create query object
|
|
query = Query(
|
|
text=query_text,
|
|
mode=mode,
|
|
filters=filters or {},
|
|
student_info=student_info or {}
|
|
)
|
|
|
|
# Get all chunks (in practice, this would be loaded from database)
|
|
chunks = {} # This would be loaded from storage
|
|
|
|
# Perform hybrid search
|
|
search_result = self.components["hybrid_searcher"].search(
|
|
query=query,
|
|
chunks=chunks,
|
|
top_k=top_k,
|
|
alpha=self.config.retrieval.hybrid_alpha,
|
|
rerank=self.config.retrieval.rerank
|
|
)
|
|
|
|
# Get retrieved chunks
|
|
retrieved_chunks = [
|
|
chunks[chunk_id] for chunk_id, _ in search_result.results
|
|
if chunk_id in chunks
|
|
]
|
|
|
|
if not retrieved_chunks:
|
|
return {
|
|
"success": True,
|
|
"response": "I don't have relevant information to answer that question.",
|
|
"retrieved_chunks": [],
|
|
"processing_time": time.time() - start_time,
|
|
"no_context": True
|
|
}
|
|
|
|
# Build prompt
|
|
prompt = self.components["prompt_builder"].build_prompt(
|
|
query=query,
|
|
retrieved_chunks=retrieved_chunks,
|
|
mode=mode,
|
|
student_level=student_info.get("level", 2) if student_info else 2
|
|
)
|
|
|
|
# Generate response
|
|
llm_response = self.components["llm_client"].generate_response(
|
|
prompt=prompt,
|
|
max_tokens=self.config.max_response_tokens,
|
|
temperature=self.config.temperature
|
|
)
|
|
|
|
# Validate response
|
|
validation = self.components["llm_client"].validate_response(
|
|
llm_response["text"],
|
|
"\n".join(chunk.text for chunk in retrieved_chunks)
|
|
)
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
result = {
|
|
"success": True,
|
|
"response": llm_response["text"],
|
|
"retrieved_chunks": [
|
|
{
|
|
"id": chunk.id,
|
|
"text": chunk.text[:200] + "...", # Truncate for response
|
|
"score": score,
|
|
"metadata": chunk.metadata.__dict__
|
|
}
|
|
for chunk, (_, score) in zip(retrieved_chunks, search_result.results)
|
|
],
|
|
"processing_time": processing_time,
|
|
"search_stats": {
|
|
"total_found": search_result.total_found,
|
|
"keyword_results": search_result.keyword_results_count,
|
|
"vector_results": search_result.vector_results_count,
|
|
},
|
|
"llm_stats": {
|
|
"provider": llm_response["provider"],
|
|
"model": llm_response["model"],
|
|
"tokens": llm_response["total_tokens"],
|
|
"cost": self.components["llm_client"].estimate_cost(
|
|
llm_response["prompt_tokens"],
|
|
llm_response["completion_tokens"],
|
|
llm_response["provider"]
|
|
)
|
|
},
|
|
"validation": validation
|
|
}
|
|
|
|
logger.info(f"Query processed in {processing_time:.2f}s")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Query processing failed: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"processing_time": time.time() - start_time
|
|
}
|
|
|
|
async def batch_query(
|
|
self,
|
|
queries: List[Dict[str, Any]],
|
|
top_k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""Process multiple queries in batch"""
|
|
results = []
|
|
|
|
for query_data in queries:
|
|
try:
|
|
result = await self.query(
|
|
query_text=query_data["text"],
|
|
mode=query_data.get("mode", "EXPLANATION"),
|
|
top_k=top_k,
|
|
filters=query_data.get("filters"),
|
|
student_info=query_data.get("student_info")
|
|
)
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Batch query failed: {e}")
|
|
results.append({
|
|
"success": False,
|
|
"error": str(e),
|
|
"query": query_data["text"]
|
|
})
|
|
|
|
return results
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get comprehensive system statistics"""
|
|
return {
|
|
"config": {
|
|
"vector_store": self.components["vector_store"].get_stats(),
|
|
"embeddings": {
|
|
"model": self.config.embeddings.model_name,
|
|
"dimension": self.config.embeddings.dimension,
|
|
}
|
|
},
|
|
"search": self.components["hybrid_searcher"].get_search_stats(),
|
|
"timestamp": time.time()
|
|
}
|
|
|
|
def save_index(self, name: str = "default"):
|
|
"""Save the current index"""
|
|
self.components["vector_store"].save_index(name)
|
|
logger.info(f"Index saved as {name}")
|
|
|
|
def load_index(self, name: str = "default"):
|
|
"""Load a saved index"""
|
|
# This would require implementing load functionality
|
|
logger.info(f"Index loading not yet implemented for {name}")
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Perform health check on all components"""
|
|
health_status = {
|
|
"overall": "healthy",
|
|
"components": {},
|
|
"timestamp": time.time()
|
|
}
|
|
|
|
for name, component in self.components.items():
|
|
try:
|
|
# Simple health check - try to get basic stats
|
|
if hasattr(component, 'get_stats'):
|
|
stats = component.get_stats()
|
|
health_status["components"][name] = {
|
|
"status": "healthy",
|
|
"stats": stats
|
|
}
|
|
else:
|
|
health_status["components"][name] = {
|
|
"status": "healthy"
|
|
}
|
|
except Exception as e:
|
|
health_status["components"][name] = {
|
|
"status": "unhealthy",
|
|
"error": str(e)
|
|
}
|
|
health_status["overall"] = "degraded"
|
|
|
|
return health_status
|
|
|
|
async def shutdown(self):
|
|
"""Graceful shutdown"""
|
|
logger.info("Shutting down RAG engine...")
|
|
|
|
# Save index
|
|
try:
|
|
self.save_index()
|
|
except Exception as e:
|
|
logger.error(f"Failed to save index during shutdown: {e}")
|
|
|
|
# Cleanup components
|
|
for name, component in self.components.items():
|
|
try:
|
|
if hasattr(component, 'cleanup'):
|
|
component.cleanup()
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup {name}: {e}")
|
|
|
|
logger.info("RAG engine shutdown complete")
|
|
|
|
# Main entry point
|
|
async def main():
|
|
"""Main function for testing"""
|
|
engine = RAGEngine()
|
|
|
|
try:
|
|
# Health check
|
|
health = await engine.health_check()
|
|
print("Health Status:", json.dumps(health, indent=2))
|
|
|
|
# Example query
|
|
result = await engine.query(
|
|
query_text="What is a derivative?",
|
|
mode="EXPLANATION",
|
|
student_info={"level": 2}
|
|
)
|
|
|
|
print("Query Result:", json.dumps(result, indent=2))
|
|
|
|
finally:
|
|
await engine.shutdown()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
```
|
|
|
|
---
|
|
|
|
## 🧪 WEEK 11-12: TESTING & DEPLOYMENT
|
|
|
|
### Task 6.1: Comprehensive Testing
|
|
**Priority**: High
|
|
**Estimated Time**: 16 hours
|
|
**Dependencies**: Task 5.1
|
|
|
|
#### Subtasks:
|
|
- [ ] Write unit tests for all components
|
|
- [ ] Create integration tests
|
|
- [ ] Add performance benchmarks
|
|
- [ ] Test error scenarios
|
|
- [ ] Create load testing
|
|
- [ ] Add regression tests
|
|
|
|
#### Test Suite Structure:
|
|
```
|
|
tests/
|
|
├── unit/
|
|
│ ├── test_embeddings.py
|
|
│ ├── test_vector_store.py
|
|
│ ├── test_chunker.py
|
|
│ ├── test_keyword_search.py
|
|
│ ├── test_hybrid_search.py
|
|
│ ├── test_llm_client.py
|
|
│ └── test_prompt_builder.py
|
|
├── integration/
|
|
│ ├── test_end_to_end.py
|
|
│ ├── test_document_processing.py
|
|
│ ├── test_query_pipeline.py
|
|
│ └── test_error_handling.py
|
|
├── performance/
|
|
│ ├── benchmark_retrieval.py
|
|
│ ├── benchmark_llm_calls.py
|
|
│ └── stress_test.py
|
|
├── fixtures/
|
|
│ ├── sample_documents/
|
|
│ ├── test_queries.json
|
|
│ └── expected_results/
|
|
└── conftest.py
|
|
```
|
|
|
|
---
|
|
|
|
### Task 6.2: Production Deployment
|
|
**Priority**: High
|
|
**Estimated Time**: 8 hours
|
|
**Dependencies**: Task 6.1
|
|
|
|
#### Subtasks:
|
|
- [ ] Create Docker configuration
|
|
- [ ] Set up environment variables
|
|
- [ ] Configure monitoring
|
|
- [ ] Create deployment scripts
|
|
- [ ] Set up logging
|
|
- [ ] Create health checks
|
|
|
|
#### Docker Configuration:
|
|
|
|
**Dockerfile**
|
|
```dockerfile
|
|
FROM python:3.11-slim
|
|
|
|
WORKDIR /app
|
|
|
|
# Install system dependencies
|
|
RUN apt-get update && apt-get install -y \
|
|
build-essential \
|
|
curl \
|
|
&& rm -rf /var/lib/apt/lists/*
|
|
|
|
# Copy requirements and install Python dependencies
|
|
COPY requirements.txt .
|
|
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
|
# Copy application code
|
|
COPY src/ ./src/
|
|
COPY tests/ ./tests/
|
|
|
|
# Create data directories
|
|
RUN mkdir -p data/models data/indices data/chunks logs
|
|
|
|
# Set environment variables
|
|
ENV PYTHONPATH=/app
|
|
ENV RAG_CONFIG_ENV=production
|
|
|
|
# Health check
|
|
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
|
CMD python -c "import asyncio; from src.main import RAGEngine; asyncio.run(RAGEngine().health_check())"
|
|
|
|
# Expose port
|
|
EXPOSE 8000
|
|
|
|
# Default command
|
|
CMD ["python", "-m", "src.main"]
|
|
```
|
|
|
|
---
|
|
|
|
## 📋 DELIVERABLES
|
|
|
|
### Week 2 Deliverables
|
|
- [ ] Vector database setup with FAISS
|
|
- [ ] Embedding service with sentence-transformers
|
|
- [ ] Basic text processing pipeline
|
|
- [ ] Project structure and configuration
|
|
|
|
### Week 4 Deliverables
|
|
- [ ] Intelligent chunking system
|
|
- [ ] Text preprocessing and quality validation
|
|
- [ ] Metadata extraction
|
|
- [ ] Content quality scoring
|
|
|
|
### Week 6 Deliverables
|
|
- [ ] Vector search implementation
|
|
- [ ] Keyword search (BM25)
|
|
- [ ] Hybrid retrieval system
|
|
- [ ] Result ranking and filtering
|
|
|
|
### Week 8 Deliverables
|
|
- [ ] LLM client integration
|
|
- [ ] Prompt engineering system
|
|
- [ ] Response generation and validation
|
|
- [ ] Multi-provider support
|
|
|
|
### Week 10 Deliverables
|
|
- [ ] End-to-end RAG pipeline
|
|
- [ ] Performance optimization
|
|
- [ ] Caching and monitoring
|
|
- [ ] Error handling and recovery
|
|
|
|
### Week 12 Deliverables
|
|
- [ ] Comprehensive test suite
|
|
- [ ] Production deployment
|
|
- [ ] Documentation and monitoring
|
|
- [ ] Performance benchmarks
|
|
|
|
---
|
|
|
|
## 📈 PERFORMANCE TARGETS
|
|
|
|
### Retrieval Performance
|
|
- **Query latency**: < 500ms for top-10 results
|
|
- **Indexing speed**: > 1000 chunks/second
|
|
- **Memory usage**: < 2GB for 100k chunks
|
|
- **Accuracy**: > 80% relevance for top-5 results
|
|
|
|
### LLM Performance
|
|
- **Response generation**: < 3 seconds
|
|
- **Token efficiency**: < 1000 tokens per response
|
|
- **Cost optimization**: < $0.01 per query
|
|
- **Quality score**: > 0.8 validation score
|
|
|
|
### System Performance
|
|
- **Uptime**: > 99.5%
|
|
- **Error rate**: < 1%
|
|
- **Concurrent queries**: > 100 QPS
|
|
- **Memory efficiency**: < 4GB total
|
|
|
|
---
|
|
|
|
## 🔧 MONITORING & LOGGING
|
|
|
|
### Key Metrics
|
|
- Query response times
|
|
- Retrieval accuracy
|
|
- LLM token usage and costs
|
|
- System resource utilization
|
|
- Error rates and types
|
|
|
|
### Logging Strategy
|
|
- Structured JSON logging
|
|
- Different log levels for components
|
|
- Performance tracing
|
|
- Error correlation
|
|
|
|
### Alerting
|
|
- High latency alerts
|
|
- Error rate thresholds
|
|
- Resource utilization warnings
|
|
- Cost limit alerts
|
|
|
|
---
|
|
|
|
## 🛡️ SAFETY & RELIABILITY
|
|
|
|
### Input Validation
|
|
- Query length limits
|
|
- Content filtering
|
|
- Injection detection
|
|
- Rate limiting
|
|
|
|
### Output Validation
|
|
- Response quality checks
|
|
- Safety content detection
|
|
- Hallucination detection
|
|
- Context relevance validation
|
|
|
|
### Fallback Strategies
|
|
- Multiple LLM providers
|
|
- Graceful degradation
|
|
- Cache fallbacks
|
|
- Error recovery
|
|
|
|
---
|
|
|
|
*Last Updated: 2026-05-06*
|
|
*Version: 1.0.0*
|
|
*RAG Engine Lead: ML Engineering Team*
|