burtenshaw HF Staff commited on
Commit
8f42a5a
·
verified ·
1 Parent(s): 9a8c7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -23,16 +23,23 @@ tokenizer, model = load_model()
23
 
24
  @spaces.GPU
25
  def generate(prompt, history):
26
- conversation = [
 
 
27
  {"role": "user", "content": prompt},
28
  ]
 
 
 
 
29
 
 
30
  inputs = tokenizer.apply_chat_template(
31
- conversation,
32
  add_generation_prompt=True,
33
  tokenize=True,
34
  return_tensors="pt",
35
- return_dict=True
36
  ).to(device)
37
 
38
  with torch.no_grad():
@@ -42,7 +49,9 @@ def generate(prompt, history):
42
  )
43
 
44
  generated_tokens = outputs[0, inputs.input_ids.shape[1]:]
45
- return tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
46
 
47
 
48
  demo = gr.ChatInterface(fn=generate, type="messages", examples=["hello", "hola", "merhaba"], title="NanoChat")
 
23
 
24
  @spaces.GPU
25
  def generate(prompt, history):
26
+
27
+ if len(history) > 0:
28
+ messages = history + [
29
  {"role": "user", "content": prompt},
30
  ]
31
+ else:
32
+ messages = [
33
+ {"role": "user", "content": prompt},
34
+ ]
35
 
36
+ print(history)
37
  inputs = tokenizer.apply_chat_template(
38
+ messages,
39
  add_generation_prompt=True,
40
  tokenize=True,
41
  return_tensors="pt",
42
+ return_dict=True,
43
  ).to(device)
44
 
45
  with torch.no_grad():
 
49
  )
50
 
51
  generated_tokens = outputs[0, inputs.input_ids.shape[1]:]
52
+ output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
53
+
54
+ return output
55
 
56
 
57
  demo = gr.ChatInterface(fn=generate, type="messages", examples=["hello", "hola", "merhaba"], title="NanoChat")