You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
6.3 KiB
178 lines
6.3 KiB
# src/scripts/embed_excel_to_qdrant.py
|
|
import hashlib
|
|
import uuid
|
|
import pandas as pd
|
|
from sqlalchemy import text
|
|
from agno.knowledge.document import Document
|
|
import os
|
|
from urllib.parse import quote_plus
|
|
from sqlalchemy import create_engine
|
|
from dotenv import load_dotenv
|
|
|
|
# Import your custom Qdrant and Embedding factories
|
|
from src.knowledge.embedding_factory import EmbeddingFactory
|
|
from src.knowledge.vector_store import get_qdrant_store
|
|
|
|
load_dotenv()
|
|
|
|
# --- Database Setup ---
|
|
db_user = quote_plus(os.getenv("DB_USER"))
|
|
db_pass = quote_plus(os.getenv("DB_PASSWORD"))
|
|
db_host = os.getenv("DB_HOST")
|
|
db_port = os.getenv("DB_PORT")
|
|
db_name = os.getenv("DB_NAME")
|
|
db_url = f"postgresql+psycopg://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}"
|
|
|
|
engine = create_engine(db_url)
|
|
|
|
# Read the file that the recovery script just finished
|
|
EXCEL_FILE = "wiki_lang69_cleaned_final.xlsx"
|
|
BATCH_SIZE = 100 # Number of vectors to send to Qdrant at once
|
|
|
|
def get_text_from_json(json_data, target_lang='fa'):
|
|
"""Helper to safely extract exactly the Persian text from the JSON arrays."""
|
|
if not json_data or not isinstance(json_data, list):
|
|
return "Unknown"
|
|
|
|
# 1. Strictly look for the requested language ('fa')
|
|
for entry in json_data:
|
|
if isinstance(entry, dict) and entry.get('language_code') == target_lang:
|
|
return entry.get('text', 'Unknown')
|
|
|
|
# 2. Fallback: If no Persian translation exists, grab the first available language
|
|
if len(json_data) > 0 and isinstance(json_data[0], dict):
|
|
return json_data[0].get('text', 'Unknown')
|
|
|
|
return "Unknown"
|
|
|
|
def run_qdrant_ingestion():
|
|
print(f"📂 Loading cleaned text from {EXCEL_FILE}...")
|
|
try:
|
|
df = pd.read_excel(EXCEL_FILE)
|
|
except FileNotFoundError:
|
|
print(f"❌ Could not find {EXCEL_FILE}.")
|
|
return
|
|
|
|
# 1. Filter out any rows that still have errors or empty text
|
|
valid_df = df[df['error'].isna() & df['clean_text'].notna() & (df['clean_text'] != '')]
|
|
valid_ids = valid_df['id'].tolist()
|
|
|
|
total_docs = len(valid_ids)
|
|
print(f"📊 Found {total_docs} valid texts ready for embedding.")
|
|
|
|
if total_docs == 0:
|
|
return
|
|
|
|
# 2. Fetch all relational metadata from PostgreSQL
|
|
print("📥 Fetching relational metadata (Titles, Authors, Categories) from Database...")
|
|
|
|
query = text("""
|
|
SELECT
|
|
wc.id as content_id,
|
|
wc.wiki_id,
|
|
wc.language as lang_code,
|
|
w.title as wiki_titles,
|
|
a.id as author_id,
|
|
a.name as author_names,
|
|
c.id as category_id,
|
|
c.name as cat_names
|
|
FROM wiki_wikicontent wc
|
|
JOIN wiki_wiki w ON wc.wiki_id = w.id
|
|
LEFT JOIN wiki_author a ON w.author_id = a.id
|
|
JOIN wiki_wikicategory c ON w.category_id = c.id
|
|
WHERE wc.id IN :ids
|
|
""")
|
|
|
|
with engine.connect() as conn:
|
|
metadata_rows = conn.execute(query, {"ids": tuple(valid_ids)}).fetchall()
|
|
|
|
metadata_lookup = {row.content_id: row for row in metadata_rows}
|
|
|
|
# 3. Initialize Qdrant
|
|
print("🤖 Initializing Embedding Model and Qdrant Connection...")
|
|
embed_factory = EmbeddingFactory()
|
|
embedder = embed_factory.get_embedder()
|
|
vector_db = get_qdrant_store(embedder=embedder)
|
|
active_collection = vector_db.collection
|
|
|
|
# We need to track the model name to update the database later
|
|
model_json_str = f'["{active_collection}"]'
|
|
|
|
documents_to_upsert = []
|
|
processed_count = 0
|
|
|
|
print(f"🚀 Starting batch embedding into collection: {active_collection}")
|
|
|
|
# 4. Process valid Excel rows and build Agno Documents
|
|
for _, row in valid_df.iterrows():
|
|
content_id = row['id']
|
|
clean_text = row['clean_text']
|
|
|
|
db_meta = metadata_lookup.get(content_id)
|
|
if not db_meta:
|
|
continue
|
|
|
|
# Extract strictly Persian strings
|
|
wiki_title = get_text_from_json(db_meta.wiki_titles, 'fa')
|
|
author_name = get_text_from_json(db_meta.author_names, 'fa')
|
|
category_name = get_text_from_json(db_meta.cat_names, 'fa')
|
|
|
|
# Part A: The Narrative String
|
|
narrative_text = (
|
|
f"CATEGORY: {category_name}\n"
|
|
f"WIKI TITLE: {wiki_title}\n"
|
|
f"AUTHOR: {author_name}\n"
|
|
f"CONTENT:\n{clean_text}"
|
|
)
|
|
|
|
# Part B: The Strict Payload
|
|
payload = {
|
|
"source_type": "WIKI",
|
|
"content_id": content_id,
|
|
"wiki_id": db_meta.wiki_id,
|
|
"wiki_title": wiki_title,
|
|
"category_id": db_meta.category_id,
|
|
"category_name": category_name,
|
|
"author_id": db_meta.author_id, # Will be None if missing, which is fine
|
|
"author_name": author_name,
|
|
"language": db_meta.lang_code
|
|
}
|
|
|
|
# Deterministic Hash ID
|
|
hash_id = hashlib.md5(f"WIKI_{content_id}_{active_collection}".encode()).hexdigest()
|
|
qdrant_id = str(uuid.UUID(hash_id))
|
|
|
|
doc = Document(
|
|
id=qdrant_id,
|
|
content=narrative_text,
|
|
meta_data=payload
|
|
)
|
|
documents_to_upsert.append(doc)
|
|
|
|
# 5. Batch Upsert to Qdrant
|
|
if len(documents_to_upsert) >= BATCH_SIZE:
|
|
vector_db.upsert(documents=documents_to_upsert)
|
|
processed_count += len(documents_to_upsert)
|
|
print(f"✅ Embedded {processed_count}/{total_docs} vectors...")
|
|
documents_to_upsert = []
|
|
|
|
# Flush final partial batch
|
|
if documents_to_upsert:
|
|
vector_db.upsert(documents=documents_to_upsert)
|
|
processed_count += len(documents_to_upsert)
|
|
print(f"✅ Embedded {processed_count}/{total_docs} vectors...")
|
|
|
|
# 6. Mark as synced in PostgreSQL
|
|
print("🔄 Updating PostgreSQL to mark items as embedded...")
|
|
with engine.begin() as conn: # Use .begin() for automatic transaction commit
|
|
update_sql = text("""
|
|
UPDATE wiki_wikicontent
|
|
SET embedded_in = CAST(COALESCE(embedded_in, '[]') AS jsonb) || CAST(:model_json AS jsonb)
|
|
WHERE id IN :ids AND NOT (COALESCE(embedded_in, '[]')::jsonb @> :model_json::jsonb)
|
|
""")
|
|
conn.execute(update_sql, {"model_json": model_json_str, "ids": tuple(valid_ids)})
|
|
|
|
print("🎉 All documents successfully embedded and database updated!")
|
|
|
|
if __name__ == "__main__":
|
|
run_qdrant_ingestion()
|