C98yhou079 commited on
Commit
d04f5e9
·
verified ·
1 Parent(s): 51a5a00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -98
app.py CHANGED
@@ -1,130 +1,93 @@
1
- # app.py
2
- # TinyLLaVA Gradio app tailored for CPU-limited Hugging Face Spaces.
3
- # This file delays heavy model load until the first inference request to reduce build-time memory use.
4
-
5
- import io
6
- import requests
7
- import traceback
8
- from typing import Optional
9
  from PIL import Image
10
  import gradio as gr
11
  import torch
12
 
13
- # Lazy imports for tinyllava and transformers; we import them only when needed
14
- _model_loaded = False
15
- model = None
16
- tokenizer = None
17
- image_processor = None
18
- context_len = None
19
- device = "cpu" # Spaces are CPU-only
20
 
21
- # Choose a TinyLLaVA model suitable for CPU (1.5B recommended)
22
- MODEL_PATH = "bczhou/TinyLLaVA-1.5B" # recommended for Spaces CPU
 
23
  DEFAULT_MAX_TOKENS = 128
24
- DEFAULT_CONV_MODE = "v1"
25
 
26
- def lazy_load_model():
27
- global _model_loaded, model, tokenizer, image_processor, context_len
28
- if _model_loaded:
29
- return
30
- try:
31
- # Import here after torch is installed by start.py (or already present)
32
- from tinyllava.model.builder import load_pretrained_model
33
- from tinyllava.mm_utils import get_model_name_from_path
34
- from transformers import logging as hf_logging
35
- hf_logging.set_verbosity_error()
36
- except Exception as e:
37
- raise RuntimeError(f"Failed to import TinyLLaVA or transformers: {e}")
38
 
 
 
 
 
39
  model_name = get_model_name_from_path(MODEL_PATH)
40
- tokenizer, model, image_processor, context_len = load_pretrained_model(
41
- model_path=MODEL_PATH, model_base=None, model_name=model_name
42
- )
43
- model.to(device)
44
- model.eval()
45
- _model_loaded = True
46
 
47
- def load_image_from_url(url: str) -> Image.Image:
48
  resp = requests.get(url, timeout=10)
49
  resp.raise_for_status()
50
  return Image.open(io.BytesIO(resp.content)).convert("RGB")
51
 
52
- def _prepare_inputs(prompt: str, image: Image.Image):
53
  prompt_text = f"USER: <image>\n{(prompt or '').strip()}\nASSISTANT:"
54
- text_inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
55
- img_inputs = image_processor(images=image, return_tensors="pt").to(device)
56
- inputs = {**text_inputs, **img_inputs}
57
- return inputs, prompt_text
 
58
 
59
- def generate_text(prompt: str, uploaded_image: Optional[Image.Image], image_url: str,
60
- max_new_tokens: int = DEFAULT_MAX_TOKENS, conv_mode: str = DEFAULT_CONV_MODE):
61
  try:
62
- # Ensure model is loaded on first call
63
- if not _model_loaded:
64
- lazy_load_model()
65
-
66
- if uploaded_image is None and image_url:
67
- image = load_image_from_url(image_url)
68
- elif uploaded_image is not None:
69
- image = uploaded_image
70
  else:
71
- return "No image provided. Upload an image or provide an image URL."
72
-
73
- inputs, prompt_text = _prepare_inputs(prompt, image)
74
-
75
- gen_kwargs = {
76
- "max_new_tokens": int(max_new_tokens),
77
- "num_beams": 1,
78
- "temperature": 0.0,
79
- }
80
-
81
- outputs = model.generate(**inputs, **gen_kwargs)
82
-
83
- # Decode outputs
84
- if isinstance(outputs, torch.Tensor):
85
- decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
86
- elif isinstance(outputs, (list, tuple)):
87
- decoded = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)[0]
88
  else:
89
- decoded = str(outputs)
90
-
91
  if "ASSISTANT:" in decoded:
92
- reply = decoded.split("ASSISTANT:")[-1].strip()
93
- else:
94
- reply = decoded.strip()
95
- return reply
96
  except Exception as e:
97
- tb = traceback.format_exc()
98
- return f"Inference error: {e}\n\nTraceback:\n{tb}"
99
 
100
- # Gradio UI
101
  with gr.Blocks() as demo:
102
- gr.Markdown("# TinyLLaVA (CPU) — Hugging Face Spaces Demo")
103
  with gr.Row():
104
  with gr.Column(scale=2):
105
- prompt_input = gr.Textbox(label="Prompt (optional)", placeholder="Ask about the image...")
106
- upload = gr.Image(label="Upload Image (preferred)", type="pil")
107
- url = gr.Textbox(label="Image URL (used if upload empty)", placeholder="https://...")
108
- max_tokens = gr.Slider(minimum=32, maximum=512, step=32, value=DEFAULT_MAX_TOKENS, label="Max new tokens")
109
- run_btn = gr.Button("Generate")
110
  with gr.Column(scale=1):
111
- preview = gr.Image(label="Image preview", type="pil")
112
- out = gr.Textbox(label="Generated Text", lines=8)
113
-
114
- def update_preview(uploaded, url_text):
115
- if uploaded is not None:
116
- return uploaded
117
- if url_text:
118
  try:
119
- return load_image_from_url(url_text)
120
- except Exception:
121
  return None
122
  return None
123
-
124
- upload.change(fn=update_preview, inputs=[upload, url], outputs=preview)
125
- url.change(fn=update_preview, inputs=[upload, url], outputs=preview)
126
-
127
- run_btn.click(fn=generate_text, inputs=[prompt_input, upload, url, max_tokens], outputs=out)
128
 
129
  if __name__ == "__main__":
130
  demo.launch()
 
1
+ # app.py - Gradio UI using vendored tinyllava
2
+ import io, requests, traceback
 
 
 
 
 
 
3
  from PIL import Image
4
  import gradio as gr
5
  import torch
6
 
7
+ from tinyllava.model import load_pretrained_model
8
+ from tinyllava.mm_utils import get_model_name_from_path
 
 
 
 
 
9
 
10
+ # Use CPU-friendly TinyLLaVA model recommended for Spaces
11
+ MODEL_PATH = "bczhou/TinyLLaVA-1.5B"
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  DEFAULT_MAX_TOKENS = 128
 
14
 
15
+ # Lazy load
16
+ _model = None
17
+ _tokenizer = None
18
+ _image_processor = None
19
+ _context_len = None
 
 
 
 
 
 
 
20
 
21
+ def lazy_load():
22
+ global _model, _tokenizer, _image_processor, _context_len
23
+ if _model is not None:
24
+ return
25
  model_name = get_model_name_from_path(MODEL_PATH)
26
+ _tokenizer, _model, _image_processor, _context_len = load_pretrained_model(MODEL_PATH, model_name=model_name)
27
+ _model.to(DEVICE)
28
+ _model.eval()
 
 
 
29
 
30
+ def load_image_from_url(url: str):
31
  resp = requests.get(url, timeout=10)
32
  resp.raise_for_status()
33
  return Image.open(io.BytesIO(resp.content)).convert("RGB")
34
 
35
+ def prepare_inputs(prompt: str, image: Image.Image):
36
  prompt_text = f"USER: <image>\n{(prompt or '').strip()}\nASSISTANT:"
37
+ inputs = _tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
38
+ if _image_processor is not None:
39
+ img_inputs = _image_processor(images=image, return_tensors="pt").to(DEVICE)
40
+ inputs.update(img_inputs)
41
+ return inputs
42
 
43
+ def generate_text(prompt, upload, url, max_new_tokens=DEFAULT_MAX_TOKENS):
 
44
  try:
45
+ lazy_load()
46
+ if upload is None and url:
47
+ image = load_image_from_url(url)
48
+ elif upload is not None:
49
+ image = upload
 
 
 
50
  else:
51
+ return "No image provided."
52
+
53
+ inputs = prepare_inputs(prompt, image)
54
+ gen = _model.generate(**inputs, max_new_tokens=int(max_new_tokens), num_beams=1, temperature=0.0)
55
+ if isinstance(gen, torch.Tensor):
56
+ decoded = _tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
57
+ elif isinstance(gen, (list, tuple)):
58
+ decoded = _tokenizer.batch_decode(gen[0], skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
59
  else:
60
+ decoded = str(gen)
 
61
  if "ASSISTANT:" in decoded:
62
+ return decoded.split("ASSISTANT:")[-1].strip()
63
+ return decoded.strip()
 
 
64
  except Exception as e:
65
+ return f"Inference error: {e}\n\n{traceback.format_exc()}"
 
66
 
 
67
  with gr.Blocks() as demo:
68
+ gr.Markdown("TinyLLaVA (vendored loader) — CPU-friendly")
69
  with gr.Row():
70
  with gr.Column(scale=2):
71
+ prompt = gr.Textbox(label="Prompt (optional)")
72
+ upload = gr.Image(label="Upload Image", type="pil")
73
+ url = gr.Textbox(label="Image URL")
74
+ max_tokens = gr.Slider(32, 512, value=DEFAULT_MAX_TOKENS, step=32, label="Max new tokens")
75
+ btn = gr.Button("Generate")
76
  with gr.Column(scale=1):
77
+ preview = gr.Image(label="Preview", type="pil")
78
+ out = gr.Textbox(label="Output", lines=8)
79
+ def update_preview(u, s):
80
+ if u is not None:
81
+ return u
82
+ if s:
 
83
  try:
84
+ return load_image_from_url(s)
85
+ except:
86
  return None
87
  return None
88
+ upload.change(update_preview, [upload, url], preview)
89
+ url.change(update_preview, [upload, url], preview)
90
+ btn.click(generate_text, [prompt, upload, url, max_tokens], out)
 
 
91
 
92
  if __name__ == "__main__":
93
  demo.launch()