Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
·
9b9ffd0
1
Parent(s):
bdcc521
Fix prediction and evaluation arguments
Browse files- src/evaluate.py +3 -4
- src/predict.py +15 -7
- src/train.py +1 -3
src/evaluate.py
CHANGED
|
@@ -7,7 +7,7 @@ from transformers import (
|
|
| 7 |
from preprocess import DatasetArguments, ProcessedArguments, get_words
|
| 8 |
from model import get_classifier_vectorizer
|
| 9 |
from shared import device
|
| 10 |
-
from predict import ClassifierArguments,
|
| 11 |
from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
|
| 12 |
import pandas as pd
|
| 13 |
from dataclasses import dataclass, field
|
|
@@ -19,7 +19,7 @@ import random
|
|
| 19 |
|
| 20 |
|
| 21 |
@dataclass
|
| 22 |
-
class EvaluationArguments:
|
| 23 |
"""
|
| 24 |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
| 25 |
"""
|
|
@@ -29,8 +29,7 @@ class EvaluationArguments:
|
|
| 29 |
'help': 'The number of videos to test on'
|
| 30 |
}
|
| 31 |
)
|
| 32 |
-
|
| 33 |
-
'model_path']
|
| 34 |
data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
|
| 35 |
dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
|
| 36 |
'validation_file']
|
|
|
|
| 7 |
from preprocess import DatasetArguments, ProcessedArguments, get_words
|
| 8 |
from model import get_classifier_vectorizer
|
| 9 |
from shared import device
|
| 10 |
+
from predict import ClassifierArguments, predict, filter_predictions, TrainingOutputArguments
|
| 11 |
from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
|
| 12 |
import pandas as pd
|
| 13 |
from dataclasses import dataclass, field
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@dataclass
|
| 22 |
+
class EvaluationArguments(TrainingOutputArguments):
|
| 23 |
"""
|
| 24 |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
| 25 |
"""
|
|
|
|
| 29 |
'help': 'The number of videos to test on'
|
| 30 |
}
|
| 31 |
)
|
| 32 |
+
|
|
|
|
| 33 |
data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
|
| 34 |
dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
|
| 35 |
'validation_file']
|
src/predict.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from shared import OutputArguments
|
| 2 |
from typing import Optional
|
| 3 |
from segment import (
|
|
@@ -21,7 +22,6 @@ from dataclasses import dataclass, field
|
|
| 21 |
from transformers import HfArgumentParser
|
| 22 |
from shared import device
|
| 23 |
import logging
|
| 24 |
-
from transformers.trainer_utils import get_last_checkpoint
|
| 25 |
|
| 26 |
|
| 27 |
def seconds_to_time(seconds):
|
|
@@ -31,12 +31,7 @@ def seconds_to_time(seconds):
|
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass
|
| 34 |
-
class
|
| 35 |
-
|
| 36 |
-
video_id: str = field(
|
| 37 |
-
metadata={
|
| 38 |
-
'help': 'Video to predict sponsorship segments for'}
|
| 39 |
-
)
|
| 40 |
|
| 41 |
model_path: str = field(
|
| 42 |
default=None,
|
|
@@ -59,6 +54,15 @@ class PredictArguments:
|
|
| 59 |
'Unable to find model, explicitly set `--model_path`')
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
|
| 63 |
|
| 64 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
|
@@ -252,6 +256,10 @@ def main():
|
|
| 252 |
))
|
| 253 |
predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
|
| 256 |
model.to(device())
|
| 257 |
|
|
|
|
| 1 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 2 |
from shared import OutputArguments
|
| 3 |
from typing import Optional
|
| 4 |
from segment import (
|
|
|
|
| 22 |
from transformers import HfArgumentParser
|
| 23 |
from shared import device
|
| 24 |
import logging
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def seconds_to_time(seconds):
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass
|
| 34 |
+
class TrainingOutputArguments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
model_path: str = field(
|
| 37 |
default=None,
|
|
|
|
| 54 |
'Unable to find model, explicitly set `--model_path`')
|
| 55 |
|
| 56 |
|
| 57 |
+
@dataclass
|
| 58 |
+
class PredictArguments(TrainingOutputArguments):
|
| 59 |
+
video_id: str = field(
|
| 60 |
+
default=None,
|
| 61 |
+
metadata={
|
| 62 |
+
'help': 'Video to predict sponsorship segments for'}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
|
| 67 |
|
| 68 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
|
|
|
| 256 |
))
|
| 257 |
predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
| 258 |
|
| 259 |
+
if predict_args.video_id is None:
|
| 260 |
+
print('No video ID supplied. Use `--video_id`.')
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
|
| 264 |
model.to(device())
|
| 265 |
|
src/train.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
from preprocess import load_datasets, DatasetArguments
|
| 2 |
from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
|
| 3 |
-
from shared import device
|
| 4 |
-
from shared import GeneralArguments, OutputArguments
|
| 5 |
from model import ModelArguments
|
| 6 |
import transformers
|
| 7 |
-
import logging
|
| 8 |
from model import get_model, get_tokenizer
|
| 9 |
import logging
|
| 10 |
import os
|
|
|
|
| 1 |
from preprocess import load_datasets, DatasetArguments
|
| 2 |
from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
|
| 3 |
+
from shared import device, GeneralArguments, OutputArguments
|
|
|
|
| 4 |
from model import ModelArguments
|
| 5 |
import transformers
|
|
|
|
| 6 |
from model import get_model, get_tokenizer
|
| 7 |
import logging
|
| 8 |
import os
|