Audio Classification
Safetensors
wav2vec2-bert
5roop's picture
Update README.md
d1484b0 verified
|
raw
history blame
7.27 kB
metadata
license: apache-2.0
language:
  - sl
  - hr
  - sr
base_model:
  - facebook/w2v-bert-2.0
pipeline_tag: audio-classification
metrics:
  - f1

Frame classification for filled pauses

This model classifies individual 20ms frames of audio based on presence of filled pauses ("eee", "errm", ...).

It was trained on human-annotated Slovenian speech corpus ROG-Artur and achieves F1 of 0.95 for the positive class on te test split of the same dataset.

Evaluation

Although the output of the model is a series 0 or 1, describing their 20ms frames, the evaluation was done on event level; spans of consecutive outputs 1 were bundled together into one event. When the true and predicted events partially overlap, this is counted as a true positive. We report precisions, recalls, and f1-scores of the positive class.

We observed several failure modes of the automatic inferrence process and designed post-processing steps to mitigate them. False positives were observed to be caused by improper audio segmentation, which is why disabling predictions that start at the start of the audio or end at the end of the audio can be beneficial. Another failure mode is predicting very short events, which is why ignoring very short predictions can be safely discarded.

Evaluation on ROG corpus

postprocessing recall precision F1
raw 0.981 0.955 0.968
drop_short 0.981 0.957 0.969
drop_short_initial_and_final 0.964 0.966 0.965
drop_short_and_initial 0.964 0.966 0.965
drop_initial 0.964 0.963 0.963

Evaluation on ParlaSpeech corpora

For every language in the ParlaSpeech collection, 400 instances were sampled and annotated by human annotators.

Evaluation on human-annotated instances produced the following metrics:

lang postprocessing recall precision F1
CZ drop_short_initial_and_final 0.889 0.859 0.874
CZ drop_short_and_initial 0.889 0.859 0.874
CZ drop_short 0.905 0.833 0.868
CZ drop_initial 0.889 0.846 0.867
CZ raw 0.905 0.814 0.857
HR drop_short_initial_and_final 0.94 0.887 0.913
HR drop_short_and_initial 0.94 0.887 0.913
HR drop_short 0.94 0.884 0.911
HR drop_initial 0.94 0.875 0.906
HR raw 0.94 0.872 0.905
PL drop_short 0.906 0.947 0.926
PL drop_short_initial_and_final 0.903 0.947 0.924
PL drop_short_and_initial 0.903 0.947 0.924
PL raw 0.91 0.924 0.917
PL drop_initial 0.908 0.924 0.916
RS drop_short 0.966 0.915 0.94
RS drop_short_initial_and_final 0.966 0.915 0.94
RS drop_short_and_initial 0.966 0.915 0.94
RS drop_initial 0.974 0.9 0.936
RS raw 0.974 0.9 0.936

The metrics reported are on event level, which means that if true and predicted filled pauses at least partially overlap, we count them as a True Positive event.

Example use:


from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification
from datasets import Dataset, Audio
import torch
import numpy as np
from pathlib import Path

device = torch.device("cuda")
model_name = "classla/wav2vecbert2-filledPause"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device)

ds = Dataset.from_dict(
    {
        "audio": [
            "/cache/peterr/mezzanine_resources/filled_pauses/data/dev/Iriss-J-Gvecg-P500001-avd_2082.293_2112.194.wav"
        ],
    }
).cast_column("audio", Audio(sampling_rate=16_000, mono=True))


def frames_to_intervals(
    frames: list[int],
    drop_short=True,
    drop_initial=True,
    drop_final=False,
    short_cutoff_s=0.08,
) -> list[tuple[float]]:
    """Transforms a list of ones or zeros, corresponding to annotations on frame
    levels, to a list of intervals ([start second, end second]).

    Allows for additional filtering on duration (false positives are often
    short) and start times (false positives starting at 0.0 are often an
    artifact of poor segmentation).

    :param list[int] frames: Input frame labels
    :param bool drop_short: Drop everything shorter than short_cutoff_s,
        defaults to True
    :param bool drop_initial: Drop predictions starting at 0.0, defaults to True
    :param bool drop_final: Drop predictions ending at audio end, defaults to True
    :param float short_cutoff_s: Duration in seconds of shortest allowable
        prediction, defaults to 0.08

    :return list[tuple[float]]: List of intervals [start_s, end_s]
    """
    from itertools import pairwise
    import pandas as pd

    results = []
    ndf = pd.DataFrame(
        data={
            "time_s": [0.020 * i for i in range(len(frames))],
            "frames": frames,
        }
    )
    ndf = ndf.dropna()
    indices_of_change = ndf.frames.diff()[ndf.frames.diff() != 0].index.values
    for si, ei in pairwise(indices_of_change):
        if ndf.loc[si : ei - 1, "frames"].mode()[0] == 0:
            pass
        else:
            results.append(
                (
                    round(ndf.loc[si, "time_s"], 3),
                    round(ndf.loc[ei, "time_s"], 3),
                )
            )
    if drop_short and (len(results) > 0):
        results = [i for i in results if (i[1] - i[0] >= short_cutoff_s)]
    if drop_initial and (len(results) > 0):
        results = [i for i in results if i[0] != 0.0]
    if drop_final and (len(results) > 0):
        results = [i for i in results if i[1] != 0.02 * len(frames)]
    return results


def evaluator(chunks):
    sampling_rate = chunks["audio"][0]["sampling_rate"]
    with torch.no_grad():
        inputs = feature_extractor(
            [i["array"] for i in chunks["audio"]],
            return_tensors="pt",
            sampling_rate=sampling_rate,
        ).to(device)
        logits = model(**inputs).logits
    y_pred = np.array(logits.cpu()).argmax(axis=-1)
    intervals = [frames_to_intervals(i) for i in y_pred]
    return {"y_pred": y_pred.tolist(), "intervals": intervals}


ds = ds.map(evaluator, batched=True)
print(ds["y_pred"][0])
# Prints a list of 20ms frames: [0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,0....]
# with 0 indicating no filled pause detected in that frame

print(ds["intervals"][0])
# Prints the identified intervals as a list of [start_s, ends_s]:
# [[0.08, 0.28 ], ...]

Citation

Coming soon.