Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| from transformers import VisionEncoderDecoderModel,ViTFeatureExtractor,PreTrainedTokenizerFast,GPT2Tokenizer,AutoModelForCausalLM,AutoTokenizer | |
| import requests | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline | |
| import re | |
| description = "Just upload an image, and generate a short story for the image.\n PS: GPT-2 is not perfect but it's fun to play with.May take a minute for the output to generate. Enjoyy!!!" | |
| title = "Story generator from images using ViT and GPT2" | |
| model = VisionEncoderDecoderModel.from_pretrained("gagan3012/ViTGPT2_vizwiz").to('cpu') | |
| vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2") | |
| story_gpt = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt2-genre-story-generator") | |
| st_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator") | |
| inputs = [ | |
| gr.inputs.Image(type="pil", label="Original Image") | |
| ] | |
| outputs = [ | |
| gr.outputs.Textbox(label = 'Story') | |
| ] | |
| examples = [['img_1.jpg'],['img_2.jpg']] | |
| def get_output_senten(img): | |
| pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values.to('cpu') | |
| encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=7) | |
| generated_sentences = tokenizer.batch_decode(encoder_outputs) | |
| senten = generated_sentences[0][generated_sentences[0][2:].index('>')+1:] | |
| senten = senten.replace('>','') | |
| senten = senten.replace('|','') | |
| res = senten.split('.')[0][0:75] | |
| res = res[0:res.rindex(' ')] | |
| print(res) | |
| tokenized_text=st_tokenizer.encode(res) | |
| input_ids=torch.tensor(tokenized_text).view(-1,len(tokenized_text)) | |
| outputs=story_gpt.generate(input_ids,max_length=100,num_beams=5,no_repeat_ngram_size=2,early_stopping=True) | |
| generated_story = st_tokenizer.batch_decode(outputs) | |
| print(len(generated_story)) | |
| ans = generated_story[0] | |
| ans = str(ans) | |
| ind = ans.rindex('.') | |
| ans = ans[0:ind+1] | |
| return ans | |
| gr.Interface( | |
| get_output_senten, | |
| inputs, | |
| outputs, | |
| examples = examples, | |
| title=title, | |
| description=description, | |
| theme="huggingface", | |
| ).launch(enable_queue=True) |