Spaces:
Build error
Build error
danseith
commited on
Commit
·
f023836
1
Parent(s):
0d6ff2f
Updated rules to ignore punctuation
Browse files- app.py +12 -7
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
|
|
|
| 4 |
from nltk.stem import PorterStemmer
|
| 5 |
from collections import defaultdict
|
| 6 |
from transformers import pipeline
|
|
@@ -32,7 +33,8 @@ tab_two_examples = [[ex_str1, ex_key1],
|
|
| 32 |
# ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
| 33 |
# ]
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
def add_mask(text, lower_bound=0, index=None):
|
|
@@ -49,7 +51,7 @@ def add_mask(text, lower_bound=0, index=None):
|
|
| 49 |
idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
|
| 50 |
# Don't mask certain words
|
| 51 |
num_iters = 0
|
| 52 |
-
while split_text[idx].lower() in
|
| 53 |
num_iters += 1
|
| 54 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
| 55 |
if num_iters > 10:
|
|
@@ -220,8 +222,9 @@ def extract_keywords(text, queries):
|
|
| 220 |
# Iterate through text and mask each token
|
| 221 |
ps = PorterStemmer()
|
| 222 |
top_scores = defaultdict(list)
|
| 223 |
-
top_k_range =
|
| 224 |
-
|
|
|
|
| 225 |
for i in indices:
|
| 226 |
masked_text, masked = add_mask(text, index=i)
|
| 227 |
res = scrambler(masked_text, temp=temp, top_k=top_k_range)
|
|
@@ -229,12 +232,14 @@ def extract_keywords(text, queries):
|
|
| 229 |
sorted_keys = sorted(out, key=out.get)
|
| 230 |
# If the key does not appear, floor its rank for that round
|
| 231 |
for rank, token_str in enumerate(sorted_keys):
|
|
|
|
|
|
|
| 232 |
stemmed = ps.stem(token_str)
|
| 233 |
-
if token_str not in top_scores.keys():
|
| 234 |
-
top_scores[stemmed].append(0)
|
| 235 |
norm_rank = rank / top_k_range
|
| 236 |
top_scores[stemmed].append(norm_rank)
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
# Calc mean
|
| 239 |
for key in top_scores.keys():
|
| 240 |
top_scores[key] = np.mean(top_scores[key])
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
+
import re
|
| 5 |
from nltk.stem import PorterStemmer
|
| 6 |
from collections import defaultdict
|
| 7 |
from transformers import pipeline
|
|
|
|
| 33 |
# ['The _ plane is composed of a two-dimensional hexagonal lattice of carbon atoms.']
|
| 34 |
# ]
|
| 35 |
|
| 36 |
+
ignore_str = ['a', 'an', 'the', 'is', 'and', 'or', '!', '(', ')', '-', '[', ']', '{', '}', ';', ':', "'", '"', '\\',
|
| 37 |
+
',', '<', '>', '.', '/', '?', '@', '#', '$', '%', '^', '&', '*', '_', '~']
|
| 38 |
|
| 39 |
|
| 40 |
def add_mask(text, lower_bound=0, index=None):
|
|
|
|
| 51 |
idx = np.random.randint(low=lower_bound, high=len(split_text), size=1).astype(int)[0]
|
| 52 |
# Don't mask certain words
|
| 53 |
num_iters = 0
|
| 54 |
+
while split_text[idx].lower() in ignore_str:
|
| 55 |
num_iters += 1
|
| 56 |
idx = np.random.randint(len(split_text), size=1).astype(int)[0]
|
| 57 |
if num_iters > 10:
|
|
|
|
| 222 |
# Iterate through text and mask each token
|
| 223 |
ps = PorterStemmer()
|
| 224 |
top_scores = defaultdict(list)
|
| 225 |
+
top_k_range = 30
|
| 226 |
+
text_no_punc = re.sub(r'[^\w\s]', '', text)
|
| 227 |
+
indices = [i for i, t in enumerate(text_no_punc.split()) if t.lower() == query.lower()]
|
| 228 |
for i in indices:
|
| 229 |
masked_text, masked = add_mask(text, index=i)
|
| 230 |
res = scrambler(masked_text, temp=temp, top_k=top_k_range)
|
|
|
|
| 232 |
sorted_keys = sorted(out, key=out.get)
|
| 233 |
# If the key does not appear, floor its rank for that round
|
| 234 |
for rank, token_str in enumerate(sorted_keys):
|
| 235 |
+
if token_str in ignore_str:
|
| 236 |
+
continue
|
| 237 |
stemmed = ps.stem(token_str)
|
|
|
|
|
|
|
| 238 |
norm_rank = rank / top_k_range
|
| 239 |
top_scores[stemmed].append(norm_rank)
|
| 240 |
+
for key in top_scores.keys():
|
| 241 |
+
if key not in out.keys():
|
| 242 |
+
top_scores[key].append(0)
|
| 243 |
# Calc mean
|
| 244 |
for key in top_scores.keys():
|
| 245 |
top_scores[key] = np.mean(top_scores[key])
|
requirements.txt
CHANGED
|
@@ -2,4 +2,5 @@ gradio
|
|
| 2 |
torch
|
| 3 |
transformers
|
| 4 |
numpy
|
| 5 |
-
nltk
|
|
|
|
|
|
| 2 |
torch
|
| 3 |
transformers
|
| 4 |
numpy
|
| 5 |
+
nltk
|
| 6 |
+
re
|