File size: 2,681 Bytes
3ae84a3
 
 
f7ed480
650e271
3ae84a3
 
 
45effd2
b8b9bf6
3ae84a3
 
45effd2
 
3ae84a3
 
45effd2
f7ed480
3ae84a3
 
0c9c8ed
b8b9bf6
 
f7ed480
0c9c8ed
e647581
44b4666
e647581
1fbc5c6
45effd2
 
 
650e271
45effd2
 
 
 
 
650e271
45effd2
 
 
 
650e271
3ae84a3
 
44b4666
45effd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44b4666
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia


# Only runs once when the script is first run.
with open("insectarium_768.pickle", "rb") as handle:
    index = pickle.load(handle)

# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/vit-base-butterflies")
model = AutoModel.from_pretrained("sasha/vit-base-butterflies")

# Candidate images.
dataset = load_dataset("sasha/insectarium-butterflies")
ds = dataset["train"]


def query(image, top_k=4):
    inputs = feature_extractor(image, return_tensors="pt")
    model_output = model(**inputs)
    embedding = model_output.pooler_output.detach()
    results = index.query(embedding, k=top_k)
    inx = results[0][0].tolist()
    logits = results[1][0].tolist()
    images = ds.select(inx)["image"]
    captions = ds.select(inx)["name"]
    images_with_captions = [(i, c) for i, c in zip(images, captions)]
    labels_with_probs = dict(zip(captions, logits))
    labels_with_probs = {k: 1 - v for k, v in labels_with_probs.items()}
    try:
        description = wikipedia.summary(captions[0], sentences=1)
        description = "### " + description
        url = wikipedia.page(captions[0]).url
        url = " You can learn more about your butterfly [here](" + str(url) + ")!"
        description = description + url
    except:
        description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
        url = "https://en.wikipedia.org/wiki/Butterfly"
        url = " You can learn more about butterflies [here](" + str(url) + ")!"
        description = description + url
    return images_with_captions, labels_with_probs, description


with gr.Blocks() as demo:
    gr.Markdown("# Find my Butterfly 🦋")
    gr.Markdown(
        "## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!"
    )
    with gr.Row():
        with gr.Column(scale=1):
            inputs = gr.Image()
            btn = gr.Button("Find my butterfly!")
            description = gr.Markdown()

        with gr.Column(scale=2):
            outputs = gr.Gallery(rows=2)
            labels = gr.Label()

    gr.Markdown("### Image Examples")
    gr.Examples(
        examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
        inputs=inputs,
        outputs=[outputs, labels],
        fn=query,
        cache_examples=True,
    )
    btn.click(query, inputs, [outputs, labels, description])

demo.launch()