Spaces:
Paused
Paused
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from huggingface_hub import login | |
| from PyPDF2 import PdfReader | |
| from docx import Document | |
| import csv | |
| import json | |
| import os | |
| import torch | |
| huggingface_token = os.getenv('HUGGINGFACE_TOKEN') | |
| # Realizar el inicio de sesi贸n de Hugging Face solo si el token est谩 disponible | |
| if huggingface_token: | |
| login(token=huggingface_token) | |
| # Configuraci贸n del modelo | |
| def load_llm(): | |
| llm = HuggingFaceEndpoint( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
| task="text-generation" | |
| ) | |
| llm_engine_hf = ChatHuggingFace(llm=llm) | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") | |
| return llm_engine_hf, tokenizer | |
| llm_engine_hf, tokenizer = load_llm() | |
| # Configuraci贸n del modelo de clasificaci贸n | |
| def load_classification_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") | |
| model = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") | |
| return model, tokenizer | |
| classification_model, classification_tokenizer = load_classification_model() | |
| id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"} | |
| def classify_text(text): | |
| inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length") | |
| classification_model.eval() | |
| with torch.no_grad(): | |
| outputs = classification_model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_id = logits.argmax(dim=-1).item() | |
| predicted_label = id2label[predicted_class_id] | |
| return f"Clasificaci贸n: {predicted_label}\n\nDocumento:\n{text}" | |
| def translate(text, target_language): | |
| template = ''' | |
| Por favor, traduzca el siguiente documento al {LANGUAGE}: | |
| <document> | |
| {TEXT} | |
| </document> | |
| Aseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento. | |
| ''' | |
| formatted_prompt = template.replace("{TEXT}", text).replace("{LANGUAGE}", target_language) | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt") | |
| outputs = llm_engine_hf.invoke(formatted_prompt) | |
| translated_text = outputs.content | |
| return translated_text | |
| def summarize(text, length): | |
| template = f''' | |
| Por favor, haga un resumen {length} del siguiente documento: | |
| <document> | |
| {text} | |
| </document> | |
| Aseg煤rese de que el resumen sea conciso y conserve el significado original del documento. | |
| ''' | |
| inputs = tokenizer(template, return_tensors="pt") | |
| outputs = llm_engine_hf.invoke(template) | |
| summarized_text = outputs.content | |
| return summarized_text | |
| def handle_uploaded_file(uploaded_file): | |
| try: | |
| if uploaded_file.name.endswith(".txt"): | |
| text = uploaded_file.read().decode("utf-8") | |
| elif uploaded_file.name.endswith(".pdf"): | |
| reader = PdfReader(uploaded_file) | |
| text = "" | |
| for page in range(len(reader.pages)): | |
| text += reader.pages[page].extract_text() | |
| elif uploaded_file.name.endswith(".docx"): | |
| doc = Document(uploaded_file) | |
| text = "\n".join([para.text for para in doc.paragraphs]) | |
| elif uploaded_file.name.endswith(".csv"): | |
| text = "" | |
| content = uploaded_file.read().decode("utf-8").splitlines() | |
| reader = csv.reader(content) | |
| text = " ".join([" ".join(row) for row in reader]) | |
| elif uploaded_file.name.endswith(".json"): | |
| data = json.load(uploaded_file) | |
| text = json.dumps(data, indent=4) | |
| else: | |
| text = "Tipo de archivo no soportado." | |
| return text | |
| except Exception as e: | |
| return str(e) | |
| st.title("LexAIcon") | |
| st.write("Puedes conversar con este chatbot basado en Mistral7B-Instruct y subir archivos para que el chatbot los procese.") | |
| if "generated" not in st.session_state: | |
| st.session_state["generated"] = [] | |
| if "past" not in st.session_state: | |
| st.session_state["past"] = [] | |
| # Entrada del usuario | |
| user_input = st.text_input("T煤: ", "") | |
| # Opciones para la traducci贸n | |
| target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"]) | |
| # Opciones para el resumen | |
| summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"]) | |
| # Manejo de archivos subidos | |
| uploaded_files = st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"], accept_multiple_files=True) | |
| if st.button("Enviar"): | |
| if user_input: | |
| response = generate_response(user_input) | |
| st.session_state.generated.append({"user": user_input, "bot": response}) | |
| # Botones de Resumir, Traducir y Explicar | |
| operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"]) | |
| if st.button("Ejecutar"): | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| file_content = handle_uploaded_file(uploaded_file) | |
| if operation == "Resumir": | |
| if summary_length == "corto": | |
| length = "de aproximadamente 50 palabras" | |
| elif summary_length == "medio": | |
| length = "de aproximadamente 100 palabras" | |
| elif summary_length == "largo": | |
| length = "de aproximadamente 500 palabras" | |
| result = summarize(file_content, length) | |
| elif operation == "Traducir": | |
| result = translate(file_content, target_language) | |
| elif operation == "Explicar": | |
| result = classify_text(file_content) | |
| st.write(result) | |
| if st.session_state.get("generated"): | |
| for chat in st.session_state["generated"]: | |
| st.write(f"T煤: {chat['user']}") | |
| st.write(f"Chatbot: {chat['bot']}") |