Spaces:
Runtime error
Runtime error
File size: 2,681 Bytes
3ae84a3 f7ed480 650e271 3ae84a3 45effd2 b8b9bf6 3ae84a3 45effd2 3ae84a3 45effd2 f7ed480 3ae84a3 0c9c8ed b8b9bf6 f7ed480 0c9c8ed e647581 44b4666 e647581 1fbc5c6 45effd2 650e271 45effd2 650e271 45effd2 650e271 3ae84a3 44b4666 45effd2 44b4666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia
# Only runs once when the script is first run.
with open("insectarium_768.pickle", "rb") as handle:
index = pickle.load(handle)
# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/vit-base-butterflies")
model = AutoModel.from_pretrained("sasha/vit-base-butterflies")
# Candidate images.
dataset = load_dataset("sasha/insectarium-butterflies")
ds = dataset["train"]
def query(image, top_k=4):
inputs = feature_extractor(image, return_tensors="pt")
model_output = model(**inputs)
embedding = model_output.pooler_output.detach()
results = index.query(embedding, k=top_k)
inx = results[0][0].tolist()
logits = results[1][0].tolist()
images = ds.select(inx)["image"]
captions = ds.select(inx)["name"]
images_with_captions = [(i, c) for i, c in zip(images, captions)]
labels_with_probs = dict(zip(captions, logits))
labels_with_probs = {k: 1 - v for k, v in labels_with_probs.items()}
try:
description = wikipedia.summary(captions[0], sentences=1)
description = "### " + description
url = wikipedia.page(captions[0]).url
url = " You can learn more about your butterfly [here](" + str(url) + ")!"
description = description + url
except:
description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
url = "https://en.wikipedia.org/wiki/Butterfly"
url = " You can learn more about butterflies [here](" + str(url) + ")!"
description = description + url
return images_with_captions, labels_with_probs, description
with gr.Blocks() as demo:
gr.Markdown("# Find my Butterfly 🦋")
gr.Markdown(
"## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!"
)
with gr.Row():
with gr.Column(scale=1):
inputs = gr.Image()
btn = gr.Button("Find my butterfly!")
description = gr.Markdown()
with gr.Column(scale=2):
outputs = gr.Gallery(rows=2)
labels = gr.Label()
gr.Markdown("### Image Examples")
gr.Examples(
examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
inputs=inputs,
outputs=[outputs, labels],
fn=query,
cache_examples=True,
)
btn.click(query, inputs, [outputs, labels, description])
demo.launch()
|