from typing import Dict from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, ) class NewsPipeline: def __init__(self) -> None: self.category_tokenizer = AutoTokenizer.from_pretrained("elozano/news-category") self.category_pipeline = TextClassificationPipeline( model=AutoModelForSequenceClassification.from_pretrained( "elozano/news-category" ), tokenizer=self.category_tokenizer, ) self.fake_tokenizer = AutoTokenizer.from_pretrained("elozano/news-fake") self.fake_pipeline = TextClassificationPipeline( model=AutoModelForSequenceClassification.from_pretrained( "elozano/news-fake" ), tokenizer=self.fake_tokenizer, ) self.clickbait_pipeline = TextClassificationPipeline( model=AutoModelForSequenceClassification.from_pretrained( "elozano/news-clickbait" ), tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"), ) def __call__(self, headline: str, content: str) -> Dict[str, str]: category_article_text = f" {self.category_tokenizer.sep_token} ".join( [headline, content] ) fake_article_text = f" {self.fake_tokenizer.sep_token} ".join( [headline, content] ) return { "category": self.category_pipeline(category_article_text)[0]["label"], "fake": self.fake_pipeline(fake_article_text)[0]["label"], "clickbait": self.clickbait_pipeline(headline)[0]["label"], }