EagleOfEmpire's picture
Update app.py
6c24e56 verified
import gradio as gr
import tensorflow as tf
import numpy as np
import pickle
import torch
from transformers import AutoTokenizer, AutoModel
# добавляем нужные импорты
import re
import string
import emoji
import pymorphy2
import joblib
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
# ---------------------------
# ЗАГРУЗКА BERT
# ---------------------------
MODEL_NAME = 'sberbank-ai/ruBert-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = AutoModel.from_pretrained(MODEL_NAME)
# ---------------------------
# ЗАГРУЗКА SCALER и KERAS-МОДЕЛИ
# ---------------------------
with open("scaler.joblib", "rb") as f:
scaler = joblib.load(f)
keras_model = tf.keras.models.load_model("tf.keras", compile=False)
EMOTIONS = ["страх", "гнев", "грусть", "радость"]
# ---------------------------------------
# ФУНКЦИИ ДЛЯ ОБРАБОТКИ ЭМОДЗИ (добавь свои)
# ---------------------------------------
def remove_duplicate_emojis(text):
return text # заглушка — поставь свою реализацию
def is_emoji_spam(text):
return False # заглушка — поставь свою реализацию
def remove_all_emojis(text):
return text # заглушка — поставь свою реализацию
# ---------------------------
# ПРЕДОБРАБОТКА ТЕКСТА
# ---------------------------
def preprocess_text(text):
text = remove_duplicate_emojis(text)
if is_emoji_spam(text):
text = remove_all_emojis(text)
text = str(text).lower()
text = re.sub(r'http\S+|www\S+|https\S+', '', text)
text = re.sub(r'@\w+|#\w+', '', text)
text = text.translate(str.maketrans('', '', string.punctuation))
text = emoji.demojize(text)
text = re.sub(r'\d+', '', text)
try:
tokens = word_tokenize(text, language="russian")
except:
tokens = text.split()
try:
stop_words = set(stopwords.words('russian'))
except:
stop_words = set()
tokens = [
word for word in tokens
if (word.isalpha() or (word.startswith(':') and word.endswith(':')))
and word not in stop_words
and len(word) > 2
]
try:
lemmatizer = pymorphy2.MorphAnalyzer()
tokens = [lemmatizer.parse(word)[0].normal_form for word in tokens]
except:
pass
return ' '.join(tokens)