File size: 2,778 Bytes
6495706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# app.py
# Minimal Gradio app to classify music genre using a Hugging Face pretrained model.
# Paste into a Hugging Face Space (Gradio) or run locally.

from transformers import pipeline
import gradio as gr
import os
import math

# Load a pretrained music genre classifier from the HF Hub.
# Model used here: ccmusic-database/music_genre (public). Might change over time.
MODEL_ID = "ccmusic-database/music_genre"

# Initialize pipeline (this may take a few seconds on first load)
classifier = pipeline("audio-classification", model=MODEL_ID)

def pretty_results(result_list, top_k=5):
    # result_list is usually a list of dicts like [{'label': 'rock', 'score': 0.72}, ...]
    # Return formatted string and a dict for Gradio's label or gallery output.
    out_lines = []
    for i, r in enumerate(result_list[:top_k]):
        label = r.get("label", "unknown")
        score = r.get("score", 0.0)
        out_lines.append(f"{i+1}. {label} β€” {score*100:.1f}%")
    return "\n".join(out_lines)

def classify_audio(audio_file):
    """
    Gradio audio component usually passes a filepath (when uploading).
    We pass that file path into the HF pipeline, get predictions and return them.
    """
    if audio_file is None:
        return "No audio provided.", None

    # audio_file is typically a string file path in Spaces
    try:
        # Run the classifier; many HF audio models accept a filepath directly
        res = classifier(audio_file, top_k=5)
    except Exception as e:
        return f"Model inference failed: {e}", None

    text = pretty_results(res, top_k=5)
    # Also return a simple dict of label->score for nicer UI (optional)
    scores = {r.get("label", "unknown"): float(r.get("score", 0.0)) for r in res}
    return text, scores

# Build the Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# 🎧 Genre Guessr β€” Upload a song, get a genre")
    gr.Markdown("Free hosted: Hugging Face Spaces + pretrained model. Expect ~10 common genres (e.g. rock, pop, jazz).")

    with gr.Row():
        audio_in = gr.Audio(label="Upload song (MP3/WAV/OGG) or record", type="filepath")
        classify_btn = gr.Button("Classify")

    output_text = gr.Textbox(label="Top predictions", interactive=False)
    output_scores = gr.Label(num_top_classes=5, label="Probabilities")

    classify_btn.click(fn=classify_audio, inputs=audio_in, outputs=[output_text, output_scores])

    gr.Markdown(
        """
        **Notes:**  
        - This uses a public pretrained model fine-tuned on common datasets (GTZAN-style 10-genre set). Expect mistakes on short clips, remixes or genre-blends.  
        - If you want more genres or better accuracy, we can swap to a bigger model or fine-tune later.
        """
    )

if __name__ == "__main__":
    demo.launch()