Spaces:
Runtime error
Runtime error
| import requests | |
| import gradio as gr | |
| import torch | |
| from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM | |
| from transformers.models.auto.configuration_auto import AutoConfig | |
| from src.vision_encoder_decoder import SmallCap, SmallCapConfig | |
| from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel | |
| from src.utils import prep_strings, postprocess_preds | |
| import json | |
| from src.retrieve_caps import * | |
| from PIL import Image | |
| from torchvision import transforms | |
| from src.opt import ThisOPTConfig, ThisOPTForCausalLM | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load feature extractor | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
| # load and configure tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") | |
| tokenizer.pad_token = '!' | |
| tokenizer.eos_token = '.' | |
| # load model | |
| # AutoConfig.register("this_gpt2", ThisGPT2Config) | |
| # AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel) | |
| # AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel) | |
| # AutoConfig.register("smallcap", SmallCapConfig) | |
| # AutoModel.register(SmallCapConfig, SmallCap) | |
| # model = AutoModel.from_pretrained("Yova/SmallCap7M") | |
| AutoConfig.register("this_opt", ThisOPTConfig) | |
| AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM) | |
| AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM) | |
| AutoConfig.register("smallcap", SmallCapConfig) | |
| AutoModel.register(SmallCapConfig, SmallCap) | |
| model = AutoModel.from_pretrained("Yova/SmallCapOPT7M") | |
| model= model.to(device) | |
| template = open('src/template.txt').read().strip() + ' ' | |
| # precompute captions for retrieval | |
| captions = json.load(open('coco_index_captions.json')) | |
| retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device) | |
| retrieval_index = faiss.read_index('coco_index') | |
| #res = faiss.StandardGpuResources() | |
| #retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index) | |
| # Download human-readable labels for ImageNet. | |
| response = requests.get("https://git.io/JJkYN") | |
| labels = response.text.split("\n") | |
| def retrieve_caps(image_embedding, index, k=4): | |
| xq = image_embedding.astype(np.float32) | |
| faiss.normalize_L2(xq) | |
| D, I = index.search(xq, k) | |
| return I | |
| def classify_image(image): | |
| inp = transforms.ToTensor()(image) | |
| pixel_values_retrieval = feature_extractor_retrieval(image).to(device) | |
| with torch.no_grad(): | |
| image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy() | |
| nns = retrieve_caps(image_embedding, retrieval_index)[0] | |
| caps = [captions[i] for i in nns][:4] | |
| # prepare prompt | |
| decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True) | |
| # generate caption | |
| pixel_values = feature_extractor(image, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| pred = model.generate(pixel_values.to(device), | |
| decoder_input_ids=torch.tensor([decoder_input_ids]).to(device), | |
| max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0, | |
| min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id) | |
| #inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) | |
| #prediction = inception_net.predict(inp).flatten() | |
| retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps) | |
| #return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) | |
| return str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) + "\n\n\n"+ retrieved_caps | |
| image = gr.Image(type="pil") | |
| textbox = gr.Textbox(placeholder="Generated caption and retrieved captions...", lines=4) | |
| title = "SmallCap Demo" | |
| gr.Interface( | |
| fn=classify_image, inputs=image, outputs=textbox, title=title | |
| ).launch() |