Spaces:
Running
on
Zero
Running
on
Zero
Upload joycaption.py
Browse files- joycaption.py +4 -6
joycaption.py
CHANGED
|
@@ -11,12 +11,15 @@ import os
|
|
| 11 |
import gc
|
| 12 |
|
| 13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 14 |
|
| 15 |
llm_models = {
|
| 16 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
| 17 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
|
|
|
| 18 |
"mergekit-community/L3.1-Boshima-b-FIX": None,
|
| 19 |
-
"meta-llama/Meta-Llama-3.1-8B": None,
|
| 20 |
}
|
| 21 |
|
| 22 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
|
@@ -25,9 +28,6 @@ MODEL_PATH = list(llm_models.keys())[0]
|
|
| 25 |
CHECKPOINT_PATH = Path("wpkklhc6")
|
| 26 |
TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
|
| 27 |
|
| 28 |
-
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 29 |
-
use_inference_client = False
|
| 30 |
-
|
| 31 |
class ImageAdapter(nn.Module):
|
| 32 |
def __init__(self, input_features: int, output_features: int):
|
| 33 |
super().__init__()
|
|
@@ -200,8 +200,6 @@ def stream_chat_mod(input_image: Image.Image, max_new_tokens: int=300, top_k: in
|
|
| 200 |
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
|
| 201 |
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
|
| 202 |
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
|
| 203 |
-
|
| 204 |
-
print(prompt)
|
| 205 |
|
| 206 |
# Trim off the prompt
|
| 207 |
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
|
|
|
| 11 |
import gc
|
| 12 |
|
| 13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 15 |
+
use_inference_client = False
|
| 16 |
|
| 17 |
llm_models = {
|
| 18 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
| 19 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
| 20 |
+
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
| 21 |
"mergekit-community/L3.1-Boshima-b-FIX": None,
|
| 22 |
+
"meta-llama/Meta-Llama-3.1-8B": None, # gated
|
| 23 |
}
|
| 24 |
|
| 25 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
|
|
|
| 28 |
CHECKPOINT_PATH = Path("wpkklhc6")
|
| 29 |
TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
class ImageAdapter(nn.Module):
|
| 32 |
def __init__(self, input_features: int, output_features: int):
|
| 33 |
super().__init__()
|
|
|
|
| 200 |
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
|
| 201 |
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
|
| 202 |
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Trim off the prompt
|
| 205 |
generate_ids = generate_ids[:, input_ids.shape[1]:]
|