basmala12 commited on
Commit
96b0750
·
verified ·
1 Parent(s): 83b358b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -35
app.py CHANGED
@@ -1,70 +1,80 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
3
 
4
  MODEL_NAME = "basmala12/smollm_finetuning5"
5
 
6
- # Load model & tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
9
 
10
- pipe = pipeline(
11
- "text-generation",
12
- model=model,
13
- tokenizer=tokenizer,
14
- )
15
 
16
  def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
17
 
18
- # Build chat messages correctly
19
- msgs = [{"role": "system", "content": system_message}]
20
 
21
- # History as proper chat
22
- for user_msg, bot_msg in history:
23
- msgs.append({"role": "user", "content": user_msg})
24
- msgs.append({"role": "assistant", "content": bot_msg})
25
 
26
- # Add the new user message
27
- msgs.append({"role": "user", "content": message})
28
 
29
- # Apply chat template
30
  prompt = tokenizer.apply_chat_template(
31
- msgs,
32
  tokenize=False,
33
  add_generation_prompt=True,
34
  )
35
 
36
- # Generate output
37
- out = pipe(
38
- prompt,
39
- max_new_tokens=max_tokens,
40
- temperature=temperature,
41
- top_p=top_p,
42
- do_sample=True,
43
- )[0]["generated_text"]
44
 
45
- # Extract only the assistant answer
46
- if "<|im_start|>assistant" in out:
47
- out = out.split("<|im_start|>assistant", 1)[-1]
48
- out = out.replace("<|im_end|>", "").strip()
 
 
 
 
 
49
 
50
- # Enforce short answer + brief reasoning
51
- # (additional safety)
52
- if len(out.split()) > 45:
53
- out = " ".join(out.split()[:45]) + " ..."
54
 
55
- return out
 
 
 
56
 
 
57
 
58
 
59
  chatbot = gr.ChatInterface(
60
  fn=respond,
61
  type="messages",
62
  additional_inputs=[
63
- gr.Textbox("Give short answers with brief logical reasoning.", label="System message"),
 
 
 
64
  gr.Slider(1, 512, value=256, step=1, label="Max new tokens"),
65
  gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
66
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
67
  ],
 
 
68
  )
69
 
70
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  MODEL_NAME = "basmala12/smollm_finetuning5"
6
 
7
+ # Load tokenizer & model once at startup (on CPU)
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
+ model.eval()
11
 
 
 
 
 
 
12
 
13
  def respond(message, history, system_message, max_tokens, temperature, top_p):
14
+ """
15
+ ChatInterface (type="messages") passes:
16
+ - message: current user message (str)
17
+ - history: list of dicts: [{"role": "...", "content": "..."}, ...]
18
+ - system_message, max_tokens, temperature, top_p: from additional_inputs
19
+ We return a single string: the assistant reply.
20
+ """
21
 
22
+ # Build full conversation for the chat template
23
+ messages = [{"role": "system", "content": system_message}]
24
 
25
+ # history is a list of {"role": "user"/"assistant", "content": str}
26
+ # We append it as-is to preserve previous turns
27
+ messages.extend(history)
 
28
 
29
+ # Add the new user question
30
+ messages.append({"role": "user", "content": message})
31
 
32
+ # Turn into model prompt using the tokenizer's chat template
33
  prompt = tokenizer.apply_chat_template(
34
+ messages,
35
  tokenize=False,
36
  add_generation_prompt=True,
37
  )
38
 
39
+ # Tokenize
40
+ inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
41
 
42
+ # Generate continuation (new assistant answer only)
43
+ with torch.no_grad():
44
+ outputs = model.generate(
45
+ **inputs,
46
+ max_new_tokens=max_tokens,
47
+ do_sample=True,
48
+ temperature=float(temperature),
49
+ top_p=float(top_p),
50
+ )
51
 
52
+ # Slice off the prompt tokens, keep only new tokens
53
+ generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
54
+ answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
 
55
 
56
+ # Optional: enforce "short answer + brief reasoning"
57
+ words = answer.split()
58
+ if len(words) > 60:
59
+ answer = " ".join(words[:60]) + " ..."
60
 
61
+ return answer
62
 
63
 
64
  chatbot = gr.ChatInterface(
65
  fn=respond,
66
  type="messages",
67
  additional_inputs=[
68
+ gr.Textbox(
69
+ value="Give short answers with brief logical reasoning.",
70
+ label="System message",
71
+ ),
72
  gr.Slider(1, 512, value=256, step=1, label="Max new tokens"),
73
  gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
74
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
75
  ],
76
+ title="SmolLM2 – Short Reasoning Chatbot",
77
+ description="Fine-tuned SmolLM2 (basmala12/smollm_finetuning5) that gives short answers with brief logical reasoning.",
78
  )
79
 
80
  if __name__ == "__main__":