Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
chat.py
CHANGED
|
@@ -243,18 +243,28 @@ def load_metadata(model,args):
|
|
| 243 |
else:
|
| 244 |
ctx_len = args.context_length
|
| 245 |
|
| 246 |
-
# Use defaults
|
| 247 |
metadata['context_length'] = ctx_len
|
| 248 |
metadata['state_length'] = ctx_len
|
| 249 |
-
|
|
|
|
| 250 |
metadata['lut_bits'] = 4
|
| 251 |
-
metadata['num_chunks'] = 4
|
| 252 |
-
print("\nUsing
|
| 253 |
print(f" Context Length: {metadata['context_length']}")
|
| 254 |
print(f" State Length: {metadata['state_length']}")
|
| 255 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
| 256 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
| 257 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
return metadata
|
| 259 |
|
| 260 |
def load_models(args,metadata):
|
|
@@ -376,11 +386,19 @@ def make_causal_mask(length, start):
|
|
| 376 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 377 |
return mask
|
| 378 |
|
| 379 |
-
def
|
| 380 |
-
"""
|
| 381 |
-
# Create causal mask
|
| 382 |
causal_mask = make_causal_mask(context_length, 0)
|
| 383 |
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
# Process in batches
|
| 386 |
batch_pos = 0
|
|
@@ -423,7 +441,7 @@ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length,
|
|
| 423 |
|
| 424 |
return torch.tensor([context_pos], dtype=torch.int32)
|
| 425 |
|
| 426 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
|
| 427 |
"""Generate the next token."""
|
| 428 |
# Get current token
|
| 429 |
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
|
@@ -437,8 +455,13 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
| 437 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
| 438 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 439 |
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
# Run through FFN chunks with state
|
| 444 |
for ffn_model in ffn_models:
|
|
@@ -447,7 +470,7 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
| 447 |
'hidden_states': hidden_states.numpy(),
|
| 448 |
'update_mask': update_mask.numpy(),
|
| 449 |
'position_ids': position_ids.numpy(),
|
| 450 |
-
'causal_mask':
|
| 451 |
'current_pos': position_ids.numpy()
|
| 452 |
}
|
| 453 |
output = ffn_model['infer'].predict(inputs, state)
|
|
@@ -493,7 +516,7 @@ def create_unified_state(ffn_models, context_length):
|
|
| 493 |
print("\nCreated unified transformer state")
|
| 494 |
return state
|
| 495 |
|
| 496 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
|
| 497 |
"""Interactive chat loop."""
|
| 498 |
context_length = metadata.get('context_length')
|
| 499 |
batch_size = metadata.get('batch_size', 64)
|
|
@@ -567,7 +590,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 567 |
# Start prefill timing
|
| 568 |
prefill_start = time.time()
|
| 569 |
|
| 570 |
-
# Run prefill with state
|
| 571 |
current_pos = run_prefill(
|
| 572 |
embed_model,
|
| 573 |
ffn_models,
|
|
@@ -575,7 +598,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 575 |
context_pos,
|
| 576 |
context_length,
|
| 577 |
batch_size,
|
| 578 |
-
state
|
|
|
|
| 579 |
)
|
| 580 |
|
| 581 |
# Calculate prefill timing
|
|
@@ -590,7 +614,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 590 |
inference_tokens = 0
|
| 591 |
|
| 592 |
while pos < context_length - 1:
|
| 593 |
-
# Generate next token
|
| 594 |
next_token = generate_next_token(
|
| 595 |
embed_model,
|
| 596 |
ffn_models,
|
|
@@ -598,7 +622,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 598 |
input_ids,
|
| 599 |
pos,
|
| 600 |
context_length,
|
| 601 |
-
state
|
|
|
|
| 602 |
)
|
| 603 |
|
| 604 |
# Add token to sequence
|
|
@@ -657,7 +682,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 657 |
traceback.print_exc()
|
| 658 |
|
| 659 |
def parse_args():
|
| 660 |
-
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
|
| 661 |
|
| 662 |
# Add meta.yaml option
|
| 663 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
@@ -678,9 +703,15 @@ def parse_args():
|
|
| 678 |
parser.add_argument('--prompt', type=str,
|
| 679 |
help='If specified, run once with this prompt and exit')
|
| 680 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
# Model configuration
|
| 682 |
parser.add_argument('--context-length', type=int,
|
| 683 |
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
|
|
|
|
|
|
| 684 |
|
| 685 |
args = parser.parse_args()
|
| 686 |
|
|
@@ -711,9 +742,11 @@ def parse_args():
|
|
| 711 |
if not args.tokenizer:
|
| 712 |
args.tokenizer = args.d
|
| 713 |
|
| 714 |
-
# Set other parameters
|
| 715 |
-
args.context_length
|
| 716 |
-
|
|
|
|
|
|
|
| 717 |
args.num_chunks = num_chunks
|
| 718 |
|
| 719 |
print(f"\nLoaded parameters from {args.meta}:")
|
|
@@ -782,18 +815,23 @@ def main():
|
|
| 782 |
# Create unified state once
|
| 783 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 784 |
|
|
|
|
|
|
|
|
|
|
| 785 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
|
|
|
|
|
|
| 797 |
|
| 798 |
# Main run
|
| 799 |
chat_loop(
|
|
@@ -803,6 +841,7 @@ def main():
|
|
| 803 |
tokenizer=tokenizer,
|
| 804 |
metadata=metadata,
|
| 805 |
state=state,
|
|
|
|
| 806 |
warmup=False,
|
| 807 |
auto_prompt=args.prompt
|
| 808 |
)
|
|
|
|
| 243 |
else:
|
| 244 |
ctx_len = args.context_length
|
| 245 |
|
| 246 |
+
# Use defaults or values from args
|
| 247 |
metadata['context_length'] = ctx_len
|
| 248 |
metadata['state_length'] = ctx_len
|
| 249 |
+
# Get batch size from args or use default
|
| 250 |
+
metadata['batch_size'] = getattr(args, 'batch_size', 64)
|
| 251 |
metadata['lut_bits'] = 4
|
| 252 |
+
metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
|
| 253 |
+
print("\nUsing parameters:")
|
| 254 |
print(f" Context Length: {metadata['context_length']}")
|
| 255 |
print(f" State Length: {metadata['state_length']}")
|
| 256 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
| 257 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
| 258 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
| 259 |
+
|
| 260 |
+
# Override with values from args if they exist
|
| 261 |
+
if hasattr(args, 'batch_size') and args.batch_size is not None:
|
| 262 |
+
metadata['batch_size'] = args.batch_size
|
| 263 |
+
print(f"\nOverriding batch size from args: {args.batch_size}")
|
| 264 |
+
if hasattr(args, 'num_chunks') and args.num_chunks is not None:
|
| 265 |
+
metadata['num_chunks'] = args.num_chunks
|
| 266 |
+
print(f"\nOverriding num chunks from args: {args.num_chunks}")
|
| 267 |
+
|
| 268 |
return metadata
|
| 269 |
|
| 270 |
def load_models(args,metadata):
|
|
|
|
| 386 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 387 |
return mask
|
| 388 |
|
| 389 |
+
def initialize_causal_mask(context_length):
|
| 390 |
+
"""Initialize causal mask for transformer attention."""
|
|
|
|
| 391 |
causal_mask = make_causal_mask(context_length, 0)
|
| 392 |
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 393 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
| 394 |
+
return causal_mask
|
| 395 |
+
|
| 396 |
+
def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
|
| 397 |
+
"""Run prefill on the input sequence."""
|
| 398 |
+
# Use provided causal mask or create one if not provided
|
| 399 |
+
if causal_mask is None:
|
| 400 |
+
causal_mask = make_causal_mask(context_length, 0)
|
| 401 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 402 |
|
| 403 |
# Process in batches
|
| 404 |
batch_pos = 0
|
|
|
|
| 441 |
|
| 442 |
return torch.tensor([context_pos], dtype=torch.int32)
|
| 443 |
|
| 444 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
|
| 445 |
"""Generate the next token."""
|
| 446 |
# Get current token
|
| 447 |
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
|
|
|
| 455 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
| 456 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 457 |
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
| 458 |
+
|
| 459 |
+
# Use provided causal mask or create one if not provided
|
| 460 |
+
if causal_mask is None:
|
| 461 |
+
causal_mask_data = make_causal_mask(context_length, 0)
|
| 462 |
+
single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
|
| 463 |
+
else:
|
| 464 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
| 465 |
|
| 466 |
# Run through FFN chunks with state
|
| 467 |
for ffn_model in ffn_models:
|
|
|
|
| 470 |
'hidden_states': hidden_states.numpy(),
|
| 471 |
'update_mask': update_mask.numpy(),
|
| 472 |
'position_ids': position_ids.numpy(),
|
| 473 |
+
'causal_mask': single_causal_mask.numpy(),
|
| 474 |
'current_pos': position_ids.numpy()
|
| 475 |
}
|
| 476 |
output = ffn_model['infer'].predict(inputs, state)
|
|
|
|
| 516 |
print("\nCreated unified transformer state")
|
| 517 |
return state
|
| 518 |
|
| 519 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
|
| 520 |
"""Interactive chat loop."""
|
| 521 |
context_length = metadata.get('context_length')
|
| 522 |
batch_size = metadata.get('batch_size', 64)
|
|
|
|
| 590 |
# Start prefill timing
|
| 591 |
prefill_start = time.time()
|
| 592 |
|
| 593 |
+
# Run prefill with state and causal mask
|
| 594 |
current_pos = run_prefill(
|
| 595 |
embed_model,
|
| 596 |
ffn_models,
|
|
|
|
| 598 |
context_pos,
|
| 599 |
context_length,
|
| 600 |
batch_size,
|
| 601 |
+
state,
|
| 602 |
+
causal_mask
|
| 603 |
)
|
| 604 |
|
| 605 |
# Calculate prefill timing
|
|
|
|
| 614 |
inference_tokens = 0
|
| 615 |
|
| 616 |
while pos < context_length - 1:
|
| 617 |
+
# Generate next token with causal mask
|
| 618 |
next_token = generate_next_token(
|
| 619 |
embed_model,
|
| 620 |
ffn_models,
|
|
|
|
| 622 |
input_ids,
|
| 623 |
pos,
|
| 624 |
context_length,
|
| 625 |
+
state,
|
| 626 |
+
causal_mask
|
| 627 |
)
|
| 628 |
|
| 629 |
# Add token to sequence
|
|
|
|
| 682 |
traceback.print_exc()
|
| 683 |
|
| 684 |
def parse_args():
|
| 685 |
+
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
|
| 686 |
|
| 687 |
# Add meta.yaml option
|
| 688 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
|
|
| 703 |
parser.add_argument('--prompt', type=str,
|
| 704 |
help='If specified, run once with this prompt and exit')
|
| 705 |
|
| 706 |
+
# Add no-warmup flag
|
| 707 |
+
parser.add_argument('--nw', action='store_true',
|
| 708 |
+
help='Skip warmup phase')
|
| 709 |
+
|
| 710 |
# Model configuration
|
| 711 |
parser.add_argument('--context-length', type=int,
|
| 712 |
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
| 713 |
+
parser.add_argument('--batch-size', type=int,
|
| 714 |
+
help='Batch size for prefill (default: 64)')
|
| 715 |
|
| 716 |
args = parser.parse_args()
|
| 717 |
|
|
|
|
| 742 |
if not args.tokenizer:
|
| 743 |
args.tokenizer = args.d
|
| 744 |
|
| 745 |
+
# Set other parameters if not overridden by command line
|
| 746 |
+
if args.context_length is None:
|
| 747 |
+
args.context_length = int(params['context_length'])
|
| 748 |
+
if args.batch_size is None:
|
| 749 |
+
args.batch_size = int(params['batch_size'])
|
| 750 |
args.num_chunks = num_chunks
|
| 751 |
|
| 752 |
print(f"\nLoaded parameters from {args.meta}:")
|
|
|
|
| 815 |
# Create unified state once
|
| 816 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 817 |
|
| 818 |
+
# Initialize causal mask once
|
| 819 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
| 820 |
+
|
| 821 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 822 |
+
if not args.nw:
|
| 823 |
+
for i in range(2):
|
| 824 |
+
chat_loop(
|
| 825 |
+
embed_model=embed_model,
|
| 826 |
+
ffn_models=ffn_models,
|
| 827 |
+
lmhead_model=lmhead_model,
|
| 828 |
+
tokenizer=tokenizer,
|
| 829 |
+
metadata=metadata,
|
| 830 |
+
state=state,
|
| 831 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 832 |
+
warmup=True,
|
| 833 |
+
auto_prompt="who are you?"
|
| 834 |
+
)
|
| 835 |
|
| 836 |
# Main run
|
| 837 |
chat_loop(
|
|
|
|
| 841 |
tokenizer=tokenizer,
|
| 842 |
metadata=metadata,
|
| 843 |
state=state,
|
| 844 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 845 |
warmup=False,
|
| 846 |
auto_prompt=args.prompt
|
| 847 |
)
|