|
|
|
|
|
""" |
|
|
Example inference script for DeBERTa v3 Small Explicit Content Classifier v2.0 |
|
|
""" |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
import torch |
|
|
|
|
|
def load_classifier(model_path="."): |
|
|
"""Load the model and create classification pipeline""" |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
classifier = pipeline( |
|
|
"text-classification", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
return_all_scores=True, |
|
|
truncation=True |
|
|
) |
|
|
|
|
|
return classifier |
|
|
|
|
|
def classify_text(classifier, text, show_all_scores=True, threshold=None): |
|
|
"""Classify text and optionally show all class probabilities""" |
|
|
results = classifier(text) |
|
|
|
|
|
print(f"\nText: \"{text[:100]}{'...' if len(text) > 100 else ''}\"") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
top_prediction = results[0] |
|
|
print(f"π― Prediction: {top_prediction['label']} ({top_prediction['score']:.3f})") |
|
|
|
|
|
if show_all_scores: |
|
|
print("\nπ All Class Probabilities:") |
|
|
for result in results: |
|
|
confidence = "π₯" if result['score'] > 0.7 else "β
" if result['score'] > 0.5 else "βͺ" |
|
|
print(f" {confidence} {result['label']:<20}: {result['score']:.3f}") |
|
|
|
|
|
if threshold: |
|
|
print(f"\nβ οΈ Above threshold ({threshold}):") |
|
|
above_threshold = [r for r in results if r['score'] > threshold] |
|
|
for result in above_threshold: |
|
|
print(f" {result['label']}: {result['score']:.3f}") |
|
|
|
|
|
return results |
|
|
|
|
|
def main(): |
|
|
print("π DeBERTa v3 Small Explicit Content Classifier v2.0") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
classifier = load_classifier() |
|
|
|
|
|
|
|
|
test_examples = [ |
|
|
"The morning sun cast long shadows across the peaceful meadow where children played.", |
|
|
"His fingers traced gentle patterns on her skin as she whispered his name.", |
|
|
"Content warning: This story contains mature themes including violence and sexual content.", |
|
|
"She gasped as he pulled her close, their bodies pressed together in desperate passion.", |
|
|
"The detective found the victim's body in a pool of blood, throat slashed.", |
|
|
"'Damn it,' he muttered, frustration evident in his voice.", |
|
|
"They shared a tender kiss under the starlit sky, hearts beating as one." |
|
|
] |
|
|
|
|
|
for text in test_examples: |
|
|
classify_text(classifier, text, show_all_scores=False) |
|
|
print() |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Interactive Mode - Enter text to classify (or 'quit' to exit):") |
|
|
|
|
|
while True: |
|
|
user_text = input("\nπ Enter text: ").strip() |
|
|
|
|
|
if user_text.lower() in ['quit', 'exit', 'q']: |
|
|
break |
|
|
|
|
|
if user_text: |
|
|
classify_text(classifier, user_text, show_all_scores=True, threshold=0.3) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |