Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
- chat_full.py +219 -115
chat_full.py
CHANGED
|
@@ -28,6 +28,8 @@ RESET_COLOR = "\033[0m"
|
|
| 28 |
|
| 29 |
# Add at the top with other constants
|
| 30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
|
|
|
|
|
|
| 31 |
|
| 32 |
class TokenPrinter:
|
| 33 |
"""Handles background printing of generated tokens."""
|
|
@@ -191,6 +193,89 @@ def load_model(path, function_name=None):
|
|
| 191 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
| 192 |
raise
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
def load_metadata(model,args):
|
| 195 |
# Extract metadata and config parameters
|
| 196 |
metadata = {}
|
|
@@ -246,18 +331,28 @@ def load_metadata(model,args):
|
|
| 246 |
else:
|
| 247 |
ctx_len = args.context_length
|
| 248 |
|
| 249 |
-
# Use defaults
|
| 250 |
metadata['context_length'] = ctx_len
|
| 251 |
metadata['state_length'] = ctx_len
|
| 252 |
-
|
|
|
|
| 253 |
metadata['lut_bits'] = 4
|
| 254 |
-
metadata['num_chunks'] = 4
|
| 255 |
-
print("\nUsing
|
| 256 |
print(f" Context Length: {metadata['context_length']}")
|
| 257 |
print(f" State Length: {metadata['state_length']}")
|
| 258 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
| 259 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
| 260 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
return metadata
|
| 262 |
|
| 263 |
def load_models(args,metadata):
|
|
@@ -379,7 +474,7 @@ def make_causal_mask(length, start):
|
|
| 379 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 380 |
return mask
|
| 381 |
|
| 382 |
-
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
|
| 383 |
"""Run prefill on the input sequence."""
|
| 384 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
| 385 |
|
|
@@ -404,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
| 404 |
# Generate position IDs for this batch
|
| 405 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
| 406 |
|
| 407 |
-
#
|
| 408 |
-
causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
|
| 409 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 410 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
| 411 |
|
| 412 |
# Run embeddings
|
|
@@ -430,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
| 430 |
|
| 431 |
return torch.tensor([current_pos], dtype=torch.int32)
|
| 432 |
|
| 433 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state
|
| 434 |
"""Generate the next token."""
|
| 435 |
# Get current token
|
| 436 |
current_token = input_ids[:, pos-1:pos]
|
|
@@ -445,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
| 445 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 446 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
| 447 |
|
| 448 |
-
#
|
| 449 |
-
|
| 450 |
-
single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
|
| 451 |
|
| 452 |
# Run through FFN chunks
|
| 453 |
for ffn_model in ffn_models:
|
|
@@ -496,23 +588,84 @@ def create_unified_state(ffn_models, context_length):
|
|
| 496 |
print("\nCreated unified transformer state")
|
| 497 |
return state
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
def get_user_input():
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
"""Interactive chat loop."""
|
|
|
|
| 509 |
context_length = metadata.get('context_length')
|
| 510 |
batch_size = metadata.get('batch_size', 64)
|
| 511 |
|
| 512 |
if not warmup:
|
| 513 |
print(f"\nUsing context length: {context_length}")
|
| 514 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
| 515 |
-
print("Type your message and press Enter to chat.")
|
|
|
|
| 516 |
|
| 517 |
# Keep track of conversation history
|
| 518 |
conversation = []
|
|
@@ -521,7 +674,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 521 |
while True:
|
| 522 |
try:
|
| 523 |
if not warmup:
|
| 524 |
-
print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
|
| 525 |
if auto_prompt is not None:
|
| 526 |
user_input = auto_prompt
|
| 527 |
if not warmup:
|
|
@@ -535,16 +688,31 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 535 |
|
| 536 |
if not user_input:
|
| 537 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
# Add user message to conversation
|
| 540 |
conversation.append({"role": "user", "content": user_input})
|
| 541 |
|
| 542 |
# Format using chat template with full history
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
|
| 549 |
# Check if we need to trim history
|
| 550 |
while base_input_ids.size(1) > context_length - 100: # Leave room for response
|
|
@@ -579,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 579 |
generation_start_time = time.time()
|
| 580 |
|
| 581 |
try:
|
| 582 |
-
# Create initial causal mask
|
| 583 |
-
causal_mask = make_causal_mask(context_length, 0)
|
| 584 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 585 |
-
|
| 586 |
# Run prefill on entire context
|
| 587 |
current_pos = run_prefill(
|
| 588 |
embed_model,
|
|
@@ -591,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 591 |
context_pos,
|
| 592 |
context_length,
|
| 593 |
batch_size,
|
| 594 |
-
state
|
|
|
|
| 595 |
)
|
| 596 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
| 597 |
|
|
@@ -625,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 625 |
new_size, # Prefill the entire shifted content
|
| 626 |
context_length,
|
| 627 |
batch_size,
|
| 628 |
-
state
|
|
|
|
| 629 |
)
|
| 630 |
|
| 631 |
# Start generating from the next position
|
|
@@ -644,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 644 |
input_ids,
|
| 645 |
pos,
|
| 646 |
context_length,
|
| 647 |
-
state
|
|
|
|
| 648 |
)
|
| 649 |
|
| 650 |
# Add token
|
|
@@ -697,76 +864,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 697 |
traceback.print_exc()
|
| 698 |
|
| 699 |
def main():
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
# Add meta.yaml option
|
| 703 |
-
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
| 704 |
-
|
| 705 |
-
# Add existing arguments
|
| 706 |
-
parser.add_argument('--d', '--dir', type=str, default='.',
|
| 707 |
-
help='Directory containing model files (default: current directory)')
|
| 708 |
-
parser.add_argument('--embed', type=str, required=False,
|
| 709 |
-
help='Path to embeddings model (relative to --dir)')
|
| 710 |
-
parser.add_argument('--ffn', type=str, required=False,
|
| 711 |
-
help='Path to FFN model (can be chunked, relative to --dir)')
|
| 712 |
-
parser.add_argument('--lmhead', type=str, required=False,
|
| 713 |
-
help='Path to LM head model (relative to --dir)')
|
| 714 |
-
parser.add_argument('--tokenizer', type=str, required=False,
|
| 715 |
-
help='Path to tokenizer')
|
| 716 |
-
|
| 717 |
-
# Add new argument for auto-generation
|
| 718 |
-
parser.add_argument('--prompt', type=str,
|
| 719 |
-
help='If specified, run once with this prompt and exit')
|
| 720 |
-
|
| 721 |
-
# Model configuration
|
| 722 |
-
parser.add_argument('--context-length', type=int,
|
| 723 |
-
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
| 724 |
-
|
| 725 |
-
args = parser.parse_args()
|
| 726 |
-
|
| 727 |
-
# If meta.yaml is provided, load parameters from it
|
| 728 |
-
if args.meta:
|
| 729 |
-
try:
|
| 730 |
-
with open(args.meta, 'r') as f:
|
| 731 |
-
meta = yaml.safe_load(f)
|
| 732 |
-
params = meta['model_info']['parameters']
|
| 733 |
-
|
| 734 |
-
# Set model directory to meta.yaml directory if not specified
|
| 735 |
-
if not args.d or args.d == '.':
|
| 736 |
-
args.d = str(Path(args.meta).parent)
|
| 737 |
-
|
| 738 |
-
# Build model paths based on parameters
|
| 739 |
-
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
| 740 |
-
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
| 741 |
-
num_chunks = int(params['num_chunks'])
|
| 742 |
-
|
| 743 |
-
# Set model paths if not specified
|
| 744 |
-
if not args.embed:
|
| 745 |
-
args.embed = 'llama_embeddings'
|
| 746 |
-
if not args.lmhead:
|
| 747 |
-
args.lmhead = f'llama_lm_head{lut_lmhead}'
|
| 748 |
-
if not args.ffn:
|
| 749 |
-
args.ffn = f'llama_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
| 750 |
-
if not args.tokenizer:
|
| 751 |
-
args.tokenizer = args.d
|
| 752 |
-
|
| 753 |
-
# Set other parameters
|
| 754 |
-
args.context_length = int(params['context_length'])
|
| 755 |
-
args.batch_size = int(params['batch_size'])
|
| 756 |
-
args.num_chunks = num_chunks
|
| 757 |
-
|
| 758 |
-
print(f"\nLoaded parameters from {args.meta}:")
|
| 759 |
-
print(f" Context Length: {args.context_length}")
|
| 760 |
-
print(f" Batch Size: {args.batch_size}")
|
| 761 |
-
print(f" Num Chunks: {args.num_chunks}")
|
| 762 |
-
print(f" Models Directory: {args.d}")
|
| 763 |
-
print(f" Embeddings: {args.embed}")
|
| 764 |
-
print(f" LM Head: {args.lmhead}")
|
| 765 |
-
print(f" FFN: {args.ffn}")
|
| 766 |
-
|
| 767 |
-
except Exception as e:
|
| 768 |
-
print(f"\nError loading meta.yaml: {str(e)}")
|
| 769 |
-
sys.exit(1)
|
| 770 |
|
| 771 |
# Convert directory to absolute path
|
| 772 |
model_dir = Path(args.d).resolve()
|
|
@@ -816,18 +914,23 @@ def main():
|
|
| 816 |
# Create unified state once
|
| 817 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 818 |
|
|
|
|
|
|
|
|
|
|
| 819 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
|
|
|
|
|
|
| 831 |
|
| 832 |
# Main run
|
| 833 |
chat_loop(
|
|
@@ -837,6 +940,7 @@ def main():
|
|
| 837 |
tokenizer=tokenizer,
|
| 838 |
metadata=metadata,
|
| 839 |
state=state, # Pass the state
|
|
|
|
| 840 |
warmup=False,
|
| 841 |
auto_prompt=args.prompt
|
| 842 |
)
|
|
|
|
| 28 |
|
| 29 |
# Add at the top with other constants
|
| 30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
| 31 |
+
THINKING_MODE = False
|
| 32 |
+
THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
|
| 33 |
|
| 34 |
class TokenPrinter:
|
| 35 |
"""Handles background printing of generated tokens."""
|
|
|
|
| 193 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
| 194 |
raise
|
| 195 |
|
| 196 |
+
def parse_args():
|
| 197 |
+
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
| 198 |
+
|
| 199 |
+
# Add meta.yaml option
|
| 200 |
+
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
| 201 |
+
|
| 202 |
+
# Add existing arguments
|
| 203 |
+
parser.add_argument('--d', '--dir', type=str, default='.',
|
| 204 |
+
help='Directory containing model files (default: current directory)')
|
| 205 |
+
parser.add_argument('--embed', type=str, required=False,
|
| 206 |
+
help='Path to embeddings model (relative to --dir)')
|
| 207 |
+
parser.add_argument('--ffn', type=str, required=False,
|
| 208 |
+
help='Path to FFN model (can be chunked, relative to --dir)')
|
| 209 |
+
parser.add_argument('--lmhead', type=str, required=False,
|
| 210 |
+
help='Path to LM head model (relative to --dir)')
|
| 211 |
+
parser.add_argument('--tokenizer', type=str, required=False,
|
| 212 |
+
help='Path to tokenizer')
|
| 213 |
+
|
| 214 |
+
# Add new argument for auto-generation
|
| 215 |
+
parser.add_argument('--prompt', type=str,
|
| 216 |
+
help='If specified, run once with this prompt and exit')
|
| 217 |
+
|
| 218 |
+
# Add no-warmup flag
|
| 219 |
+
parser.add_argument('--nw', action='store_true',
|
| 220 |
+
help='Skip warmup phase')
|
| 221 |
+
|
| 222 |
+
# Model configuration
|
| 223 |
+
parser.add_argument('--context-length', type=int,
|
| 224 |
+
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
| 225 |
+
parser.add_argument('--batch-size', type=int,
|
| 226 |
+
help='Batch size for prefill (default: 64)')
|
| 227 |
+
|
| 228 |
+
args = parser.parse_args()
|
| 229 |
+
|
| 230 |
+
# If meta.yaml is provided, load parameters from it
|
| 231 |
+
if args.meta:
|
| 232 |
+
try:
|
| 233 |
+
with open(args.meta, 'r') as f:
|
| 234 |
+
meta = yaml.safe_load(f)
|
| 235 |
+
params = meta['model_info']['parameters']
|
| 236 |
+
|
| 237 |
+
# Set model directory to meta.yaml directory if not specified
|
| 238 |
+
if not args.d or args.d == '.':
|
| 239 |
+
args.d = str(Path(args.meta).parent)
|
| 240 |
+
|
| 241 |
+
# Build model paths based on parameters
|
| 242 |
+
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
| 243 |
+
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
| 244 |
+
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
| 245 |
+
num_chunks = int(params['num_chunks'])
|
| 246 |
+
|
| 247 |
+
# Set model paths if not specified
|
| 248 |
+
if not args.embed:
|
| 249 |
+
args.embed = f'{prefix}_embeddings'
|
| 250 |
+
if not args.lmhead:
|
| 251 |
+
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
| 252 |
+
if not args.ffn:
|
| 253 |
+
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
| 254 |
+
if not args.tokenizer:
|
| 255 |
+
args.tokenizer = args.d
|
| 256 |
+
|
| 257 |
+
# Set other parameters if not overridden by command line
|
| 258 |
+
if args.context_length is None:
|
| 259 |
+
args.context_length = int(params['context_length'])
|
| 260 |
+
if args.batch_size is None:
|
| 261 |
+
args.batch_size = int(params['batch_size'])
|
| 262 |
+
args.num_chunks = num_chunks
|
| 263 |
+
|
| 264 |
+
print(f"\nLoaded parameters from {args.meta}:")
|
| 265 |
+
print(f" Context Length: {args.context_length}")
|
| 266 |
+
print(f" Batch Size: {args.batch_size}")
|
| 267 |
+
print(f" Num Chunks: {args.num_chunks}")
|
| 268 |
+
print(f" Models Directory: {args.d}")
|
| 269 |
+
print(f" Embeddings: {args.embed}")
|
| 270 |
+
print(f" LM Head: {args.lmhead}")
|
| 271 |
+
print(f" FFN: {args.ffn}")
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f"\nError loading meta.yaml: {str(e)}")
|
| 275 |
+
sys.exit(1)
|
| 276 |
+
|
| 277 |
+
return args
|
| 278 |
+
|
| 279 |
def load_metadata(model,args):
|
| 280 |
# Extract metadata and config parameters
|
| 281 |
metadata = {}
|
|
|
|
| 331 |
else:
|
| 332 |
ctx_len = args.context_length
|
| 333 |
|
| 334 |
+
# Use defaults or values from args
|
| 335 |
metadata['context_length'] = ctx_len
|
| 336 |
metadata['state_length'] = ctx_len
|
| 337 |
+
# Get batch size from args or use default
|
| 338 |
+
metadata['batch_size'] = getattr(args, 'batch_size', 64)
|
| 339 |
metadata['lut_bits'] = 4
|
| 340 |
+
metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
|
| 341 |
+
print("\nUsing parameters:")
|
| 342 |
print(f" Context Length: {metadata['context_length']}")
|
| 343 |
print(f" State Length: {metadata['state_length']}")
|
| 344 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
| 345 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
| 346 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
| 347 |
+
|
| 348 |
+
# Override with values from args if they exist
|
| 349 |
+
if hasattr(args, 'batch_size') and args.batch_size is not None:
|
| 350 |
+
metadata['batch_size'] = args.batch_size
|
| 351 |
+
print(f"\nOverriding batch size from args: {args.batch_size}")
|
| 352 |
+
if hasattr(args, 'num_chunks') and args.num_chunks is not None:
|
| 353 |
+
metadata['num_chunks'] = args.num_chunks
|
| 354 |
+
print(f"\nOverriding num chunks from args: {args.num_chunks}")
|
| 355 |
+
|
| 356 |
return metadata
|
| 357 |
|
| 358 |
def load_models(args,metadata):
|
|
|
|
| 474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 475 |
return mask
|
| 476 |
|
| 477 |
+
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
|
| 478 |
"""Run prefill on the input sequence."""
|
| 479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
| 480 |
|
|
|
|
| 499 |
# Generate position IDs for this batch
|
| 500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
| 501 |
|
| 502 |
+
# Use the pre-initialized causal mask and extract the batch portion
|
|
|
|
|
|
|
| 503 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
| 504 |
|
| 505 |
# Run embeddings
|
|
|
|
| 523 |
|
| 524 |
return torch.tensor([current_pos], dtype=torch.int32)
|
| 525 |
|
| 526 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
| 527 |
"""Generate the next token."""
|
| 528 |
# Get current token
|
| 529 |
current_token = input_ids[:, pos-1:pos]
|
|
|
|
| 538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 539 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
| 540 |
|
| 541 |
+
# Use the pre-initialized causal mask and extract the single position portion
|
| 542 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
|
|
|
| 543 |
|
| 544 |
# Run through FFN chunks
|
| 545 |
for ffn_model in ffn_models:
|
|
|
|
| 588 |
print("\nCreated unified transformer state")
|
| 589 |
return state
|
| 590 |
|
| 591 |
+
def initialize_causal_mask(context_length):
|
| 592 |
+
"""Initialize causal mask for transformer attention."""
|
| 593 |
+
causal_mask = make_causal_mask(context_length, 0)
|
| 594 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 595 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
| 596 |
+
return causal_mask
|
| 597 |
+
|
| 598 |
def get_user_input():
|
| 599 |
+
"""Get input from user, handling special key combinations."""
|
| 600 |
+
global THINKING_MODE
|
| 601 |
+
try:
|
| 602 |
+
import termios
|
| 603 |
+
import tty
|
| 604 |
+
import sys
|
| 605 |
+
|
| 606 |
+
def _getch():
|
| 607 |
+
fd = sys.stdin.fileno()
|
| 608 |
+
old_settings = termios.tcgetattr(fd)
|
| 609 |
+
try:
|
| 610 |
+
tty.setraw(sys.stdin.fileno())
|
| 611 |
+
ch = sys.stdin.read(1)
|
| 612 |
+
finally:
|
| 613 |
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
| 614 |
+
return ch
|
| 615 |
+
|
| 616 |
+
buffer = []
|
| 617 |
+
while True:
|
| 618 |
+
char = _getch()
|
| 619 |
+
|
| 620 |
+
# Debug: print the character code
|
| 621 |
+
print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
|
| 622 |
+
|
| 623 |
+
# Check for Enter key
|
| 624 |
+
if char == '\r' or char == '\n':
|
| 625 |
+
print() # Move to next line
|
| 626 |
+
input_text = ''.join(buffer)
|
| 627 |
+
# Check if the command is /t
|
| 628 |
+
if input_text == '/t':
|
| 629 |
+
THINKING_MODE = not THINKING_MODE
|
| 630 |
+
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
| 631 |
+
buffer = [] # Clear buffer
|
| 632 |
+
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
| 633 |
+
continue
|
| 634 |
+
return input_text
|
| 635 |
+
|
| 636 |
+
# Handle backspace
|
| 637 |
+
if char == '\x7f': # backspace
|
| 638 |
+
if buffer:
|
| 639 |
+
buffer.pop()
|
| 640 |
+
sys.stdout.write('\b \b') # Erase character
|
| 641 |
+
sys.stdout.flush()
|
| 642 |
+
continue
|
| 643 |
+
|
| 644 |
+
# Handle Ctrl-C
|
| 645 |
+
if char == '\x03': # Ctrl-C
|
| 646 |
+
print("^C")
|
| 647 |
+
raise KeyboardInterrupt
|
| 648 |
+
|
| 649 |
+
# Print character and add to buffer
|
| 650 |
+
sys.stdout.write(char)
|
| 651 |
+
sys.stdout.flush()
|
| 652 |
+
buffer.append(char)
|
| 653 |
+
|
| 654 |
+
except ImportError:
|
| 655 |
+
# Fallback for systems without termios
|
| 656 |
+
return input("> ")
|
| 657 |
+
|
| 658 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
| 659 |
"""Interactive chat loop."""
|
| 660 |
+
global THINKING_MODE
|
| 661 |
context_length = metadata.get('context_length')
|
| 662 |
batch_size = metadata.get('batch_size', 64)
|
| 663 |
|
| 664 |
if not warmup:
|
| 665 |
print(f"\nUsing context length: {context_length}")
|
| 666 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
| 667 |
+
print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
|
| 668 |
+
print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
|
| 669 |
|
| 670 |
# Keep track of conversation history
|
| 671 |
conversation = []
|
|
|
|
| 674 |
while True:
|
| 675 |
try:
|
| 676 |
if not warmup:
|
| 677 |
+
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
| 678 |
if auto_prompt is not None:
|
| 679 |
user_input = auto_prompt
|
| 680 |
if not warmup:
|
|
|
|
| 688 |
|
| 689 |
if not user_input:
|
| 690 |
continue
|
| 691 |
+
|
| 692 |
+
# Handle /t command
|
| 693 |
+
if user_input == "/t":
|
| 694 |
+
THINKING_MODE = not THINKING_MODE
|
| 695 |
+
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
| 696 |
+
continue
|
| 697 |
|
| 698 |
# Add user message to conversation
|
| 699 |
conversation.append({"role": "user", "content": user_input})
|
| 700 |
|
| 701 |
# Format using chat template with full history
|
| 702 |
+
if THINKING_MODE:
|
| 703 |
+
# Add thinking prompt to system message
|
| 704 |
+
conversation_with_thinking = [{"role": "system", "content": THINKING_PROMPT}] + conversation
|
| 705 |
+
base_input_ids = tokenizer.apply_chat_template(
|
| 706 |
+
conversation_with_thinking,
|
| 707 |
+
return_tensors="pt",
|
| 708 |
+
add_generation_prompt=True
|
| 709 |
+
).to(torch.int32)
|
| 710 |
+
else:
|
| 711 |
+
base_input_ids = tokenizer.apply_chat_template(
|
| 712 |
+
conversation,
|
| 713 |
+
return_tensors="pt",
|
| 714 |
+
add_generation_prompt=True
|
| 715 |
+
).to(torch.int32)
|
| 716 |
|
| 717 |
# Check if we need to trim history
|
| 718 |
while base_input_ids.size(1) > context_length - 100: # Leave room for response
|
|
|
|
| 747 |
generation_start_time = time.time()
|
| 748 |
|
| 749 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
# Run prefill on entire context
|
| 751 |
current_pos = run_prefill(
|
| 752 |
embed_model,
|
|
|
|
| 755 |
context_pos,
|
| 756 |
context_length,
|
| 757 |
batch_size,
|
| 758 |
+
state,
|
| 759 |
+
causal_mask
|
| 760 |
)
|
| 761 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
| 762 |
|
|
|
|
| 790 |
new_size, # Prefill the entire shifted content
|
| 791 |
context_length,
|
| 792 |
batch_size,
|
| 793 |
+
state,
|
| 794 |
+
causal_mask
|
| 795 |
)
|
| 796 |
|
| 797 |
# Start generating from the next position
|
|
|
|
| 810 |
input_ids,
|
| 811 |
pos,
|
| 812 |
context_length,
|
| 813 |
+
state,
|
| 814 |
+
causal_mask
|
| 815 |
)
|
| 816 |
|
| 817 |
# Add token
|
|
|
|
| 864 |
traceback.print_exc()
|
| 865 |
|
| 866 |
def main():
|
| 867 |
+
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
# Convert directory to absolute path
|
| 870 |
model_dir = Path(args.d).resolve()
|
|
|
|
| 914 |
# Create unified state once
|
| 915 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 916 |
|
| 917 |
+
# Initialize causal mask once
|
| 918 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
| 919 |
+
|
| 920 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 921 |
+
if not args.nw:
|
| 922 |
+
for i in range(2):
|
| 923 |
+
chat_loop(
|
| 924 |
+
embed_model=embed_model,
|
| 925 |
+
ffn_models=ffn_models,
|
| 926 |
+
lmhead_model=lmhead_model,
|
| 927 |
+
tokenizer=tokenizer,
|
| 928 |
+
metadata=metadata,
|
| 929 |
+
state=state, # Pass the state
|
| 930 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 931 |
+
warmup=True,
|
| 932 |
+
auto_prompt="who are you?"
|
| 933 |
+
)
|
| 934 |
|
| 935 |
# Main run
|
| 936 |
chat_loop(
|
|
|
|
| 940 |
tokenizer=tokenizer,
|
| 941 |
metadata=metadata,
|
| 942 |
state=state, # Pass the state
|
| 943 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 944 |
warmup=False,
|
| 945 |
auto_prompt=args.prompt
|
| 946 |
)
|