Spaces:
Runtime error
Runtime error
im
commited on
Commit
·
4bb4754
1
Parent(s):
348c0ea
add heatmap and vector search examples
Browse files- app.py +34 -3
- requirements.txt +4 -1
app.py
CHANGED
|
@@ -444,7 +444,6 @@ with st.expander("References:"):
|
|
| 444 |
|
| 445 |
divider()
|
| 446 |
st.header("Embeddings")
|
| 447 |
-
st.caption("TBD...")
|
| 448 |
|
| 449 |
st.write("""\
|
| 450 |
Following tokenization, each token is transformed into a vector of numeric characteristics, a process
|
|
@@ -473,9 +472,11 @@ st.write("""\
|
|
| 473 |
characteristics using numbers, not words.
|
| 474 |
""")
|
| 475 |
|
|
|
|
|
|
|
| 476 |
col1, col2 = st.columns(2)
|
| 477 |
-
token_king = col1.text_input("Choose
|
| 478 |
-
token_queen = col2.text_input("Choose
|
| 479 |
|
| 480 |
from torch import nn
|
| 481 |
from transformers import AutoConfig
|
|
@@ -516,8 +517,38 @@ fig.update_layout(legend=dict(orientation="h"))
|
|
| 516 |
st.plotly_chart(fig, use_container_width=True)
|
| 517 |
|
| 518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
|
|
|
| 521 |
|
| 522 |
with st.expander("References:"):
|
| 523 |
st.write("""\
|
|
|
|
| 444 |
|
| 445 |
divider()
|
| 446 |
st.header("Embeddings")
|
|
|
|
| 447 |
|
| 448 |
st.write("""\
|
| 449 |
Following tokenization, each token is transformed into a vector of numeric characteristics, a process
|
|
|
|
| 472 |
characteristics using numbers, not words.
|
| 473 |
""")
|
| 474 |
|
| 475 |
+
# TODO: cache
|
| 476 |
+
|
| 477 |
col1, col2 = st.columns(2)
|
| 478 |
+
token_king = col1.text_input("Choose a word to compare embeddings:", value="king")
|
| 479 |
+
token_queen = col2.text_input("Choose a word to compare embeddings:", value="queen")
|
| 480 |
|
| 481 |
from torch import nn
|
| 482 |
from transformers import AutoConfig
|
|
|
|
| 517 |
st.plotly_chart(fig, use_container_width=True)
|
| 518 |
|
| 519 |
|
| 520 |
+
import numpy as np
|
| 521 |
+
|
| 522 |
+
sentence = st.text_input(label="words to explore embeddings", value="a the king queen space sit eat from on")
|
| 523 |
+
sentence = sentence.split()
|
| 524 |
+
|
| 525 |
+
def get_embeddings(text):
|
| 526 |
+
return np.array(openai.Embedding.create(input=text, model=EMBEDDING_MODEL)["data"][0]["embedding"])
|
| 527 |
+
|
| 528 |
+
input = {word: get_embeddings(word) for word in sentence}
|
| 529 |
+
|
| 530 |
+
scores_matrix = np.zeros((len(sentence), len(sentence)))
|
| 531 |
+
for i, word_i in enumerate(sentence):
|
| 532 |
+
for j, word_j in enumerate(sentence):
|
| 533 |
+
scores_matrix[i, j] = np.dot(input[word_i], input[word_j])
|
| 534 |
+
|
| 535 |
+
fig = px.imshow(scores_matrix, x=sentence, y=sentence, color_continuous_scale="hot_r")
|
| 536 |
+
fig.update_layout(coloraxis_showscale=False)
|
| 537 |
+
fig.update_layout(width=6000, title_text='Similar words have similar embeddings')
|
| 538 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 539 |
+
|
| 540 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
| 541 |
+
from langchain.vectorstores import FAISS
|
| 542 |
+
from langchain.schema.document import Document
|
| 543 |
+
db = FAISS.from_documents([Document(page_content="king"), Document(page_content="queen")], OpenAIEmbeddings(model=EMBEDDING_MODEL))
|
| 544 |
|
| 545 |
+
embeddings_query = st.text_input(label="search term")
|
| 546 |
+
if embeddings_query is not None and embeddings_query != '':
|
| 547 |
+
embedding_vector = OpenAIEmbeddings(model=EMBEDDING_MODEL).embed_query(embeddings_query)
|
| 548 |
+
docs = db.similarity_search_by_vector(embedding_vector)
|
| 549 |
+
st.write(docs[0].page_content)
|
| 550 |
|
| 551 |
+
st.caption("PCA explanation (optional materials) TBD...")
|
| 552 |
|
| 553 |
with st.expander("References:"):
|
| 554 |
st.write("""\
|
requirements.txt
CHANGED
|
@@ -3,4 +3,7 @@ tokenizers~=0.13.3
|
|
| 3 |
transformers~=4.31.0
|
| 4 |
torch~=2.0.1
|
| 5 |
openai~=0.27.8
|
| 6 |
-
plotly~=5.15.0
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
transformers~=4.31.0
|
| 4 |
torch~=2.0.1
|
| 5 |
openai~=0.27.8
|
| 6 |
+
plotly~=5.15.0
|
| 7 |
+
langchain~=0.0.242
|
| 8 |
+
faiss-cpu~=1.7.4
|
| 9 |
+
tiktoken~=0.4.0
|