Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse files- dataset_config.json +4 -4
- run_transformers_training.py +116 -8
dataset_config.json
CHANGED
|
@@ -3,8 +3,7 @@
|
|
| 3 |
"name": "George-API/cognitive-data",
|
| 4 |
"split": "train",
|
| 5 |
"column_mapping": {
|
| 6 |
-
"
|
| 7 |
-
"id": "id"
|
| 8 |
},
|
| 9 |
"processing": {
|
| 10 |
"sort_by_id": true,
|
|
@@ -17,7 +16,8 @@
|
|
| 17 |
"roles": {
|
| 18 |
"system": "System: {content}\n\n",
|
| 19 |
"human": "Human: {content}\n\n",
|
| 20 |
-
"assistant": "Assistant: {content}\n\n"
|
|
|
|
| 21 |
},
|
| 22 |
"metadata_handling": {
|
| 23 |
"include_paper_id": true,
|
|
@@ -29,7 +29,7 @@
|
|
| 29 |
"batch_size": 24,
|
| 30 |
"shuffle": false,
|
| 31 |
"drop_last": false,
|
| 32 |
-
"num_workers":
|
| 33 |
"pin_memory": true,
|
| 34 |
"prefetch_factor": 4
|
| 35 |
},
|
|
|
|
| 3 |
"name": "George-API/cognitive-data",
|
| 4 |
"split": "train",
|
| 5 |
"column_mapping": {
|
| 6 |
+
"conversations": "text"
|
|
|
|
| 7 |
},
|
| 8 |
"processing": {
|
| 9 |
"sort_by_id": true,
|
|
|
|
| 16 |
"roles": {
|
| 17 |
"system": "System: {content}\n\n",
|
| 18 |
"human": "Human: {content}\n\n",
|
| 19 |
+
"assistant": "Assistant: {content}\n\n",
|
| 20 |
+
"user": "Human: {content}\n\n"
|
| 21 |
},
|
| 22 |
"metadata_handling": {
|
| 23 |
"include_paper_id": true,
|
|
|
|
| 29 |
"batch_size": 24,
|
| 30 |
"shuffle": false,
|
| 31 |
"drop_last": false,
|
| 32 |
+
"num_workers": 4,
|
| 33 |
"pin_memory": true,
|
| 34 |
"prefetch_factor": 4
|
| 35 |
},
|
run_transformers_training.py
CHANGED
|
@@ -208,15 +208,51 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 208 |
logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
|
| 209 |
dataset = load_dataset(dataset_name, split=dataset_split)
|
| 210 |
|
| 211 |
-
# Map columns if specified
|
| 212 |
column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
|
| 213 |
if column_mapping:
|
| 214 |
-
logger.info(f"
|
| 215 |
|
| 216 |
-
#
|
|
|
|
| 217 |
for target, source in column_mapping.items():
|
| 218 |
if source in dataset.column_names:
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
# Sort dataset if required
|
| 222 |
sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
|
|
@@ -227,8 +263,14 @@ def load_dataset_with_mapping(dataset_config):
|
|
| 227 |
# Log the first few IDs to verify sorting
|
| 228 |
sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
|
| 229 |
logger.info(f"First few IDs after sorting: {sample_ids}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
|
|
|
| 232 |
return dataset
|
| 233 |
|
| 234 |
except Exception as e:
|
|
@@ -243,11 +285,13 @@ def format_phi_chat(messages, dataset_config):
|
|
| 243 |
roles = dataset_config.get("data_formatting", {}).get("roles", {
|
| 244 |
"system": "System: {content}\n\n",
|
| 245 |
"human": "Human: {content}\n\n",
|
|
|
|
| 246 |
"assistant": "Assistant: {content}\n\n"
|
| 247 |
})
|
| 248 |
|
| 249 |
# Handle research introduction metadata first
|
| 250 |
-
metadata = next((msg for msg in messages if
|
|
|
|
| 251 |
if metadata:
|
| 252 |
system_template = roles.get("system", "System: {content}\n\n")
|
| 253 |
formatted_chat = system_template.format(content=metadata['content'])
|
|
@@ -255,20 +299,29 @@ def format_phi_chat(messages, dataset_config):
|
|
| 255 |
|
| 256 |
# Process remaining messages
|
| 257 |
for message in messages:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
role = message.get("role", "").lower()
|
| 259 |
content = message.get("content", "")
|
| 260 |
|
| 261 |
# Format based on role
|
| 262 |
if role == "human" or role == "user":
|
| 263 |
-
template = roles.get("human", "Human: {content}\n\n")
|
| 264 |
formatted_chat += template.format(content=content)
|
| 265 |
-
elif role == "assistant":
|
| 266 |
template = roles.get("assistant", "Assistant: {content}\n\n")
|
| 267 |
formatted_chat += template.format(content=content)
|
| 268 |
elif role == "system":
|
| 269 |
# For system messages, prepend them
|
| 270 |
template = roles.get("system", "System: {content}\n\n")
|
| 271 |
formatted_chat = template.format(content=content) + formatted_chat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
return formatted_chat.strip()
|
| 274 |
|
|
@@ -284,8 +337,56 @@ class SimpleDataCollator:
|
|
| 284 |
self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
|
| 285 |
self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
|
| 286 |
self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
|
|
|
|
| 287 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
def __call__(self, features):
|
| 290 |
batch = {"input_ids": [], "attention_mask": [], "labels": []}
|
| 291 |
|
|
@@ -293,7 +394,12 @@ class SimpleDataCollator:
|
|
| 293 |
try:
|
| 294 |
# Get ID and conversation fields
|
| 295 |
paper_id = example.get("id", "")
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
if not conversation:
|
| 299 |
self.stats["skipped"] += 1
|
|
@@ -346,10 +452,12 @@ class SimpleDataCollator:
|
|
| 346 |
logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
|
| 347 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
| 348 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
|
|
|
| 349 |
else:
|
| 350 |
self.stats["skipped"] += 1
|
| 351 |
except Exception as e:
|
| 352 |
logger.warning(f"Error processing example: {str(e)[:100]}...")
|
|
|
|
| 353 |
self.stats["skipped"] += 1
|
| 354 |
continue
|
| 355 |
|
|
|
|
| 208 |
logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
|
| 209 |
dataset = load_dataset(dataset_name, split=dataset_split)
|
| 210 |
|
| 211 |
+
# Map columns if specified - with checks to avoid conflicts
|
| 212 |
column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
|
| 213 |
if column_mapping:
|
| 214 |
+
logger.info(f"Checking column mapping: {column_mapping}")
|
| 215 |
|
| 216 |
+
# Only apply mappings for columns that need renaming and don't already exist
|
| 217 |
+
safe_mappings = {}
|
| 218 |
for target, source in column_mapping.items():
|
| 219 |
if source in dataset.column_names:
|
| 220 |
+
# Skip if target already exists and is not the same as source
|
| 221 |
+
if target in dataset.column_names and target != source:
|
| 222 |
+
logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists")
|
| 223 |
+
else:
|
| 224 |
+
safe_mappings[source] = target
|
| 225 |
+
|
| 226 |
+
# Apply safe renames
|
| 227 |
+
if safe_mappings:
|
| 228 |
+
logger.info(f"Applying safe column mapping: {safe_mappings}")
|
| 229 |
+
for source, target in safe_mappings.items():
|
| 230 |
+
if source != target: # Only rename if names are different
|
| 231 |
+
dataset = dataset.rename_column(source, target)
|
| 232 |
+
|
| 233 |
+
# Verify expected columns exist
|
| 234 |
+
expected_columns = {"id", "conversations"}
|
| 235 |
+
for col in expected_columns:
|
| 236 |
+
if col not in dataset.column_names:
|
| 237 |
+
# If "conversations" is missing but "text" exists, it might need conversion
|
| 238 |
+
if col == "conversations" and "text" in dataset.column_names:
|
| 239 |
+
logger.info("Converting 'text' field to 'conversations' format")
|
| 240 |
+
|
| 241 |
+
def convert_text_to_conversations(example):
|
| 242 |
+
# Check if text is already a list of conversation turns
|
| 243 |
+
if isinstance(example.get("text"), list):
|
| 244 |
+
return {"conversations": example["text"]}
|
| 245 |
+
# Otherwise, create a simple conversation with the text as user message
|
| 246 |
+
else:
|
| 247 |
+
return {
|
| 248 |
+
"conversations": [
|
| 249 |
+
{"role": "user", "content": str(example.get("text", ""))}
|
| 250 |
+
]
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
dataset = dataset.map(convert_text_to_conversations)
|
| 254 |
+
else:
|
| 255 |
+
logger.warning(f"Expected column '{col}' not found in dataset")
|
| 256 |
|
| 257 |
# Sort dataset if required
|
| 258 |
sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
|
|
|
|
| 263 |
# Log the first few IDs to verify sorting
|
| 264 |
sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
|
| 265 |
logger.info(f"First few IDs after sorting: {sample_ids}")
|
| 266 |
+
|
| 267 |
+
# Log example of conversations structure to verify format
|
| 268 |
+
if "conversations" in dataset.column_names:
|
| 269 |
+
sample_conv = dataset["conversations"][0] if len(dataset) > 0 else []
|
| 270 |
+
logger.info(f"Example conversation structure: {sample_conv}")
|
| 271 |
|
| 272 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
| 273 |
+
logger.info(f"Dataset columns: {dataset.column_names}")
|
| 274 |
return dataset
|
| 275 |
|
| 276 |
except Exception as e:
|
|
|
|
| 285 |
roles = dataset_config.get("data_formatting", {}).get("roles", {
|
| 286 |
"system": "System: {content}\n\n",
|
| 287 |
"human": "Human: {content}\n\n",
|
| 288 |
+
"user": "Human: {content}\n\n",
|
| 289 |
"assistant": "Assistant: {content}\n\n"
|
| 290 |
})
|
| 291 |
|
| 292 |
# Handle research introduction metadata first
|
| 293 |
+
metadata = next((msg for msg in messages if isinstance(msg, dict) and
|
| 294 |
+
"[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
|
| 295 |
if metadata:
|
| 296 |
system_template = roles.get("system", "System: {content}\n\n")
|
| 297 |
formatted_chat = system_template.format(content=metadata['content'])
|
|
|
|
| 299 |
|
| 300 |
# Process remaining messages
|
| 301 |
for message in messages:
|
| 302 |
+
if not isinstance(message, dict) or "content" not in message:
|
| 303 |
+
logger.warning(f"Skipping invalid message format: {message}")
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
role = message.get("role", "").lower()
|
| 307 |
content = message.get("content", "")
|
| 308 |
|
| 309 |
# Format based on role
|
| 310 |
if role == "human" or role == "user":
|
| 311 |
+
template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
|
| 312 |
formatted_chat += template.format(content=content)
|
| 313 |
+
elif role == "assistant" or role == "bot":
|
| 314 |
template = roles.get("assistant", "Assistant: {content}\n\n")
|
| 315 |
formatted_chat += template.format(content=content)
|
| 316 |
elif role == "system":
|
| 317 |
# For system messages, prepend them
|
| 318 |
template = roles.get("system", "System: {content}\n\n")
|
| 319 |
formatted_chat = template.format(content=content) + formatted_chat
|
| 320 |
+
else:
|
| 321 |
+
# Default to system for unknown roles
|
| 322 |
+
logger.warning(f"Unknown role '{role}' - treating as system message")
|
| 323 |
+
template = roles.get("system", "System: {content}\n\n")
|
| 324 |
+
formatted_chat += template.format(content=content)
|
| 325 |
|
| 326 |
return formatted_chat.strip()
|
| 327 |
|
|
|
|
| 337 |
self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
|
| 338 |
self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
|
| 339 |
self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
|
| 340 |
+
self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
|
| 341 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
| 342 |
|
| 343 |
+
def normalize_conversation(self, conversation):
|
| 344 |
+
"""Normalize conversation format to ensure consistent structure."""
|
| 345 |
+
normalized = []
|
| 346 |
+
|
| 347 |
+
# Handle non-list or empty inputs
|
| 348 |
+
if not isinstance(conversation, list):
|
| 349 |
+
logger.warning(f"Conversation is not a list: {type(conversation)}")
|
| 350 |
+
if hasattr(conversation, 'items'): # It's a dict-like object
|
| 351 |
+
conversation = [conversation]
|
| 352 |
+
else:
|
| 353 |
+
return []
|
| 354 |
+
|
| 355 |
+
for turn in conversation:
|
| 356 |
+
# Skip empty or None entries
|
| 357 |
+
if not turn:
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
# Handle string entries (convert to user message)
|
| 361 |
+
if isinstance(turn, str):
|
| 362 |
+
normalized.append({"role": "user", "content": turn})
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
# Handle dict-like entries
|
| 366 |
+
if not isinstance(turn, dict) and hasattr(turn, 'get'):
|
| 367 |
+
# Convert to dict
|
| 368 |
+
turn = {k: turn.get(k) for k in ['role', 'content'] if hasattr(turn, 'get') and turn.get(k) is not None}
|
| 369 |
+
|
| 370 |
+
# Ensure both role and content exist
|
| 371 |
+
if not isinstance(turn, dict) or 'role' not in turn or 'content' not in turn:
|
| 372 |
+
logger.warning(f"Skipping malformatted conversation turn: {turn}")
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
# Normalize role field
|
| 376 |
+
role = turn.get('role', '').lower()
|
| 377 |
+
if role == 'user' or role == 'human':
|
| 378 |
+
role = 'user'
|
| 379 |
+
elif role == 'assistant' or role == 'bot':
|
| 380 |
+
role = 'assistant'
|
| 381 |
+
|
| 382 |
+
# Add normalized turn
|
| 383 |
+
normalized.append({
|
| 384 |
+
"role": role,
|
| 385 |
+
"content": str(turn.get('content', ''))
|
| 386 |
+
})
|
| 387 |
+
|
| 388 |
+
return normalized
|
| 389 |
+
|
| 390 |
def __call__(self, features):
|
| 391 |
batch = {"input_ids": [], "attention_mask": [], "labels": []}
|
| 392 |
|
|
|
|
| 394 |
try:
|
| 395 |
# Get ID and conversation fields
|
| 396 |
paper_id = example.get("id", "")
|
| 397 |
+
|
| 398 |
+
# Handle conversation field - could be under 'conversations' or 'text'
|
| 399 |
+
conversation = example.get("conversations", example.get("text", []))
|
| 400 |
+
|
| 401 |
+
# Normalize conversation format
|
| 402 |
+
conversation = self.normalize_conversation(conversation)
|
| 403 |
|
| 404 |
if not conversation:
|
| 405 |
self.stats["skipped"] += 1
|
|
|
|
| 452 |
logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
|
| 453 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
| 454 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
| 455 |
+
logger.info(f"Conversation structure: {conversation[:2]}...")
|
| 456 |
else:
|
| 457 |
self.stats["skipped"] += 1
|
| 458 |
except Exception as e:
|
| 459 |
logger.warning(f"Error processing example: {str(e)[:100]}...")
|
| 460 |
+
logger.warning(f"Problematic example: {str(example)[:200]}...")
|
| 461 |
self.stats["skipped"] += 1
|
| 462 |
continue
|
| 463 |
|