dungeon29 commited on
Commit
8a1f01d
·
verified ·
1 Parent(s): fffe967

Upload rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +231 -0
rag_engine.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
4
+ from langchain_qdrant import Qdrant
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from langchain_core.documents import Document
8
+ from qdrant_client import QdrantClient, models
9
+ from datasets import load_dataset
10
+
11
+ class RAGEngine:
12
+ def __init__(self, knowledge_base_dir="./knowledge_base"):
13
+ self.knowledge_base_dir = knowledge_base_dir
14
+
15
+ # Initialize Embeddings
16
+ self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
17
+
18
+ # Qdrant Cloud Configuration
19
+ # Prioritize Env Vars, fallback to Hardcoded (User provided)
20
+ env_qdrant_url = os.environ.get("QDRANT_URL")
21
+ print(f"DEBUG: QDRANT_URL from env: '{env_qdrant_url}'")
22
+
23
+ self.qdrant_url = env_qdrant_url or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io"
24
+ self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") # Don't default key if using local URL
25
+
26
+ if not self.qdrant_api_key and "qdrant.io" in self.qdrant_url:
27
+ self.qdrant_api_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM"
28
+
29
+ self.collection_name = "phishing_knowledge"
30
+
31
+ if not self.qdrant_url:
32
+ print("⚠️ QDRANT_URL not set. RAG will not function correctly.")
33
+ self.vector_store = None
34
+ return
35
+
36
+ print(f"☁️ Connecting to Qdrant: {self.qdrant_url}...")
37
+
38
+ # Initialize Qdrant Client
39
+ self.client = QdrantClient(
40
+ url=self.qdrant_url,
41
+ api_key=self.qdrant_api_key
42
+ )
43
+
44
+ # Initialize Vector Store Wrapper
45
+ self.vector_store = Qdrant(
46
+ client=self.client,
47
+ collection_name=self.collection_name,
48
+ embeddings=self.embedding_fn
49
+ )
50
+
51
+ # Check if collection exists/is empty and build if needed
52
+ try:
53
+ if not self.client.collection_exists(self.collection_name):
54
+ print(f"⚠️ Collection '{self.collection_name}' not found. Creating...")
55
+ self.client.create_collection(
56
+ collection_name=self.collection_name,
57
+ vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE)
58
+ )
59
+ print(f"✅ Collection '{self.collection_name}' created!")
60
+ self._build_index()
61
+ else:
62
+ # Check if dataset is already indexed
63
+ dataset_filter = models.Filter(
64
+ must=[
65
+ models.FieldCondition(
66
+ key="metadata.source",
67
+ match=models.MatchValue(value="hf_dataset")
68
+ )
69
+ ]
70
+ )
71
+ dataset_count = self.client.count(
72
+ collection_name=self.collection_name,
73
+ count_filter=dataset_filter
74
+ ).count
75
+
76
+ print(f"✅ Qdrant Collection '{self.collection_name}' ready with {dataset_count} vectors.")
77
+
78
+ if dataset_count == 0:
79
+ print("⚠️ Phishing dataset not found. Please run 'index_dataset_colab.ipynb' to populate.")
80
+ # self.load_from_huggingface() # Disabled to prevent timeout
81
+
82
+ except Exception as e:
83
+ print(f"⚠️ Collection check/creation failed: {e}")
84
+ # Try to build anyway, maybe wrapper handles it
85
+ self._build_index()
86
+
87
+ def _build_index(self):
88
+ """Load documents and build index"""
89
+ print("🔄 Building Knowledge Base Index on Qdrant Cloud...")
90
+
91
+ documents = self._load_documents()
92
+ if not documents:
93
+ print("⚠️ No documents found to index.")
94
+ return
95
+
96
+ # Split documents
97
+ text_splitter = RecursiveCharacterTextSplitter(
98
+ chunk_size=500,
99
+ chunk_overlap=50,
100
+ separators=["\n\n", "\n", " ", ""]
101
+ )
102
+ chunks = text_splitter.split_documents(documents)
103
+
104
+ if chunks:
105
+ # Add to vector store (Qdrant handles persistence automatically)
106
+ try:
107
+ self.vector_store.add_documents(chunks)
108
+ print(f"✅ Indexed {len(chunks)} chunks to Qdrant Cloud.")
109
+ except Exception as e:
110
+ print(f"❌ Error indexing to Qdrant: {e}")
111
+ else:
112
+ print("⚠️ No chunks created.")
113
+
114
+ def _load_documents(self):
115
+ """Load documents from directory or fallback file"""
116
+ documents = []
117
+
118
+ # Check for directory or fallback file
119
+ target_path = self.knowledge_base_dir
120
+ if not os.path.exists(target_path):
121
+ if os.path.exists("knowledge_base.txt"):
122
+ target_path = "knowledge_base.txt"
123
+ print("⚠️ Using fallback 'knowledge_base.txt' in root.")
124
+ else:
125
+ print(f"❌ Knowledge base not found at {target_path}")
126
+ return []
127
+
128
+ try:
129
+ if os.path.isfile(target_path):
130
+ # Load single file
131
+ if target_path.endswith(".pdf"):
132
+ loader = PyPDFLoader(target_path)
133
+ else:
134
+ loader = TextLoader(target_path, encoding="utf-8")
135
+ documents.extend(loader.load())
136
+ else:
137
+ # Load directory
138
+ loaders = [
139
+ DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
140
+ DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
141
+ DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
142
+ ]
143
+
144
+ for loader in loaders:
145
+ try:
146
+ docs = loader.load()
147
+ documents.extend(docs)
148
+ except Exception as e:
149
+ print(f"⚠️ Error loading with {loader}: {e}")
150
+
151
+ except Exception as e:
152
+ print(f"❌ Error loading documents: {e}")
153
+
154
+ return documents
155
+
156
+ def load_from_huggingface(self):
157
+ """Load and index dataset manually from Hugging Face JSON"""
158
+ dataset_url = "https://huggingface.co/datasets/ealvaradob/phishing-dataset/resolve/main/combined_reduced.json"
159
+ print(f"📥 Downloading dataset from {dataset_url}...")
160
+
161
+ try:
162
+ import requests
163
+ import json
164
+
165
+ response = requests.get(dataset_url)
166
+ if response.status_code != 200:
167
+ print(f"❌ Failed to download dataset: {response.status_code}")
168
+ return
169
+
170
+ data = response.json()
171
+ print(f"✅ Dataset downloaded. Processing {len(data)} rows...")
172
+
173
+ documents = []
174
+ for row in data:
175
+ # Structure: text, label
176
+ content = row.get('text', '')
177
+ label = row.get('label', -1)
178
+
179
+ if content:
180
+ doc = Document(
181
+ page_content=content,
182
+ metadata={"source": "hf_dataset", "label": label}
183
+ )
184
+ documents.append(doc)
185
+
186
+ if documents:
187
+ # Batch add to vector store
188
+ print(f"🔄 Indexing {len(documents)} documents to Qdrant...")
189
+
190
+ # Use a larger chunk size for efficiency since these are likely short texts
191
+ text_splitter = RecursiveCharacterTextSplitter(
192
+ chunk_size=1000,
193
+ chunk_overlap=100
194
+ )
195
+ chunks = text_splitter.split_documents(documents)
196
+
197
+ # Add in batches to avoid hitting API limits or timeouts
198
+ batch_size = 100
199
+ total_chunks = len(chunks)
200
+
201
+ for i in range(0, total_chunks, batch_size):
202
+ batch = chunks[i:i+batch_size]
203
+ try:
204
+ self.vector_store.add_documents(batch)
205
+ print(f" - Indexed batch {i//batch_size + 1}/{(total_chunks + batch_size - 1)//batch_size}")
206
+ except Exception as e:
207
+ print(f" ⚠️ Error indexing batch {i}: {e}")
208
+
209
+ print(f"✅ Successfully indexed {total_chunks} chunks from dataset!")
210
+ else:
211
+ print("⚠️ No valid documents found in dataset.")
212
+
213
+ except Exception as e:
214
+ print(f"❌ Error loading HF dataset: {e}")
215
+
216
+
217
+
218
+ def retrieve(self, query, n_results=3):
219
+ """Retrieve relevant context"""
220
+ if not self.vector_store:
221
+ return []
222
+
223
+ # Search
224
+ try:
225
+ results = self.vector_store.similarity_search(query, k=n_results)
226
+ if results:
227
+ return [doc.page_content for doc in results]
228
+ except Exception as e:
229
+ print(f"⚠️ Retrieval Error: {e}")
230
+
231
+ return []