tuan3335 commited on
Commit
c5208b6
·
1 Parent(s): 8103d43

refactor: switch to huggingface_hub InferenceClient for Qwen, remove local transformers usage

Browse files
Files changed (1) hide show
  1. agent.py +16 -24
agent.py CHANGED
@@ -20,7 +20,7 @@ from typing_extensions import TypedDict
20
  from pydantic import BaseModel, Field
21
 
22
  # LangChain HuggingFace Integration
23
- from transformers import AutoModelForCausalLM, AutoTokenizer
24
 
25
  from utils import (
26
  process_question_with_tools,
@@ -55,35 +55,27 @@ class AIBrain:
55
  def __init__(self):
56
  self.model_name = "Qwen/Qwen3-8B"
57
 
58
- print("🧠 Initializing Qwen3-8B với transformers gốc...")
59
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
60
- self.model = AutoModelForCausalLM.from_pretrained(
61
- self.model_name,
62
- torch_dtype="auto",
63
- device_map="auto"
64
  )
65
- print("✅ Qwen3 AI Brain với transformers đã sẵn sàng")
66
 
67
  def _generate_with_qwen3(self, prompt: str, max_tokens: int = 2048) -> str:
68
- """Sinh text với Qwen3 bằng transformers gốc, thinking mode tắt"""
69
  try:
70
- messages = [{"role": "user", "content": prompt}]
71
- text = self.tokenizer.apply_chat_template(
72
- messages,
73
- tokenize=False,
74
- add_generation_prompt=True,
75
- enable_thinking=False
 
76
  )
77
- model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
78
- generated_ids = self.model.generate(
79
- **model_inputs,
80
- max_new_tokens=max_tokens
81
- )
82
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
83
- response = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
84
- return response
85
  except Exception as e:
86
- print(f"⚠️ Qwen3 generation error: {str(e)}")
87
  return f"AI generation failed: {str(e)}"
88
 
89
  def analyze_question(self, question: str, task_id: str = "") -> Dict[str, Any]:
 
20
  from pydantic import BaseModel, Field
21
 
22
  # LangChain HuggingFace Integration
23
+ from huggingface_hub import InferenceClient
24
 
25
  from utils import (
26
  process_question_with_tools,
 
55
  def __init__(self):
56
  self.model_name = "Qwen/Qwen3-8B"
57
 
58
+ print("🧠 Initializing Qwen3-8B với huggingface_hub InferenceClient...")
59
+ self.client = InferenceClient(
60
+ provider="auto",
61
+ api_key=os.environ["HF_TOKEN"],
 
 
62
  )
63
+ print("✅ Qwen3 AI Brain với huggingface_hub InferenceClient đã sẵn sàng")
64
 
65
  def _generate_with_qwen3(self, prompt: str, max_tokens: int = 2048) -> str:
66
+ """Sinh text với Qwen3 bằng huggingface_hub InferenceClient"""
67
  try:
68
+ messages = [
69
+ {"role": "user", "content": prompt}
70
+ ]
71
+ completion = self.client.chat.completions.create(
72
+ model=self.model_name,
73
+ messages=messages,
74
+ max_tokens=max_tokens
75
  )
76
+ return completion.choices[0].message.content
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ print(f"⚠️ Qwen3 InferenceClient error: {str(e)}")
79
  return f"AI generation failed: {str(e)}"
80
 
81
  def analyze_question(self, question: str, task_id: str = "") -> Dict[str, Any]: