Spaces:
Build error
Build error
danseith
commited on
Commit
·
c16370c
1
Parent(s):
2ce1788
Added edit slider and changed sampling back to multinomial.
Browse files
app.py
CHANGED
|
@@ -6,12 +6,16 @@ from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
|
|
| 6 |
from transformers import AutoModelForMaskedLM
|
| 7 |
|
| 8 |
# unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
def add_mask(text, size=1):
|
| 14 |
split_text = text.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
idx = np.random.randint(len(split_text), size=size)
|
| 16 |
for i in idx:
|
| 17 |
split_text[i] = '[MASK]'
|
|
@@ -114,32 +118,38 @@ PIPELINE_REGISTRY.register_pipeline(
|
|
| 114 |
)
|
| 115 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
textbox = gr.Textbox(label="Type language here", lines=5)
|
| 135 |
-
textbox2 = gr.Textbox(placeholder="
|
| 136 |
-
temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='
|
|
|
|
| 137 |
|
| 138 |
demo = gr.Interface(
|
| 139 |
fn=unmask,
|
| 140 |
-
inputs=[textbox, temp_slider],
|
| 141 |
-
outputs=[
|
| 142 |
-
examples=
|
| 143 |
)
|
| 144 |
|
| 145 |
demo.launch()
|
|
|
|
| 6 |
from transformers import AutoModelForMaskedLM
|
| 7 |
|
| 8 |
# unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
|
| 9 |
+
examples = [['A crustless [MASK] made from two slices of baked bread.', 1.2],
|
| 10 |
+
['The invention provides a method for altering or modifying [MASK] of one or more gene products.', 1.1],
|
| 11 |
+
['The graphite [MASK] is composed of a two-dimensional hexagonal lattice of carbon atoms.', 1.4]]
|
| 12 |
|
| 13 |
def add_mask(text, size=1):
|
| 14 |
split_text = text.split()
|
| 15 |
+
|
| 16 |
+
# If the user supplies a mask, don't add more
|
| 17 |
+
if '[MASK]' in split_text:
|
| 18 |
+
return text
|
| 19 |
idx = np.random.randint(len(split_text), size=size)
|
| 20 |
for i in idx:
|
| 21 |
split_text[i] = '[MASK]'
|
|
|
|
| 118 |
)
|
| 119 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
| 120 |
|
| 121 |
+
|
| 122 |
+
def unmask(text, temp, rounds):
|
| 123 |
+
sampling = 'multi'
|
| 124 |
+
|
| 125 |
+
for _ in range(rounds):
|
| 126 |
+
text = add_mask(text, size=1)
|
| 127 |
+
split_text = text.split()
|
| 128 |
+
res = scrambler(text, temp=temp, top_k=10)
|
| 129 |
+
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
| 130 |
+
out = {item["token_str"]: item["score"] for item in res}
|
| 131 |
+
score_to_str = {out[k]:k for k in out.keys()}
|
| 132 |
+
score_list = list(score_to_str.keys())
|
| 133 |
+
if sampling == 'multi':
|
| 134 |
+
idx = np.argmax(np.random.multinomial(1, score_list, 1))
|
| 135 |
+
else:
|
| 136 |
+
idx = np.random.randint(0, len(score_list))
|
| 137 |
+
score = score_list[idx]
|
| 138 |
+
new_token = score_to_str[score]
|
| 139 |
+
split_text[mask_pos] = new_token
|
| 140 |
+
text = ' '.join(split_text)
|
| 141 |
+
return text
|
| 142 |
|
| 143 |
textbox = gr.Textbox(label="Type language here", lines=5)
|
| 144 |
+
textbox2 = gr.Textbox(placeholder="", lines=4)
|
| 145 |
+
temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='Creativity')
|
| 146 |
+
edit_slider = gr.Slider(1, 50, step=1, value=1.0, label='Number of edits')
|
| 147 |
|
| 148 |
demo = gr.Interface(
|
| 149 |
fn=unmask,
|
| 150 |
+
inputs=[textbox, temp_slider, edit_slider],
|
| 151 |
+
outputs=[textbox2],
|
| 152 |
+
examples=examples,
|
| 153 |
)
|
| 154 |
|
| 155 |
demo.launch()
|