Spaces:
Sleeping
Sleeping
Commit
·
74688de
1
Parent(s):
9e81616
space created
Browse files- .streamlit/config.toml +6 -0
- app.py +94 -0
- assets/logo.svg +1 -0
- requirements.txt +1 -0
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
primaryColor="#FF8000"
|
| 3 |
+
#backgroundColor="#FFFFFF"
|
| 4 |
+
#secondaryBackgroundColor="#F0F2F6"
|
| 5 |
+
#textColor="#262730"
|
| 6 |
+
#font="sans serif"
|
app.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The GIRT Authors.
|
| 3 |
+
# Lint as: python3
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# This space is built based on AMR-KELEG/ALDi and cis-lmu/GlotLID space.
|
| 7 |
+
# GIRT Space
|
| 8 |
+
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 10 |
+
import streamlit as st
|
| 11 |
+
import base64
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@st.cache_data
|
| 15 |
+
def render_svg(svg):
|
| 16 |
+
"""Renders the given svg string."""
|
| 17 |
+
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
| 18 |
+
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/> </p>'
|
| 19 |
+
c = st.container()
|
| 20 |
+
c.write(html, unsafe_allow_html=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@st.cache_resource
|
| 24 |
+
def load_model(model_name):
|
| 25 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
@st.cache_resource
|
| 29 |
+
def load_tokenizer(model_name):
|
| 30 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 31 |
+
return tokenizer
|
| 32 |
+
|
| 33 |
+
with st.spinner(text="Please wait while the model is loading...."):
|
| 34 |
+
|
| 35 |
+
model = load_model('nafisehNik/girt-t5-base')
|
| 36 |
+
tokenizer = load_tokenizer('nafisehNik/girt-t5-base')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def compute(sample, num_beams, length_penalty, early_stopping, max_length, min_length):
|
| 40 |
+
|
| 41 |
+
inputs = tokenizer(sample, return_tensors="pt").to('cpu')
|
| 42 |
+
|
| 43 |
+
outputs = model.generate(
|
| 44 |
+
**inputs,
|
| 45 |
+
num_beams=num_beams,
|
| 46 |
+
num_return_sequences=1,
|
| 47 |
+
length_penalty=length_penalty,
|
| 48 |
+
no_repeat_ngram_size=2,
|
| 49 |
+
early_stopping=early_stopping,
|
| 50 |
+
max_length=max_length,
|
| 51 |
+
min_length=min_length).to('cpu')
|
| 52 |
+
|
| 53 |
+
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=False)
|
| 54 |
+
generated_text = generated_texts[0]
|
| 55 |
+
|
| 56 |
+
replace_dict = {
|
| 57 |
+
'\n ': '\n',
|
| 58 |
+
'</s>': '',
|
| 59 |
+
'<pad> ': '',
|
| 60 |
+
'<pad>': '',
|
| 61 |
+
'<unk>': ''
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
postprocess_text = generated_text
|
| 65 |
+
for key, value in replace_dict.items():
|
| 66 |
+
postprocess_text = postprocess_text.replace(key, value)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
return postprocess_text
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
st.markdown("[](https://huggingface.co/spaces/nafisehNik/girt-space?duplicate=true)")
|
| 73 |
+
|
| 74 |
+
render_svg(open("assets/logo.svg").read())
|
| 75 |
+
|
| 76 |
+
tab1, tab2 = st.tabs(["Design GitHub Issue Template", "Manual Prompt"])
|
| 77 |
+
|
| 78 |
+
with tab1:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
with tab2:
|
| 82 |
+
|
| 83 |
+
sent = st.text_input(
|
| 84 |
+
"Sentence:", placeholder="Enter a prompt.", on_change=None
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# TODO: Check if this is needed!
|
| 88 |
+
clicked = st.button("Submit")
|
| 89 |
+
|
| 90 |
+
if sent:
|
| 91 |
+
res = compute(sent, num_beams=2, length_penalty=1.0, early_stopping=True, max_length=300, min_length=20)
|
| 92 |
+
st.code(res, language="python")
|
| 93 |
+
|
| 94 |
+
|
assets/logo.svg
ADDED
|
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.35.0,<4.45.0
|