Spaces:
Runtime error
Runtime error
plotting perplexity vs text length with text hover
Browse files
app.py
CHANGED
|
@@ -1,32 +1,48 @@
|
|
| 1 |
import dash
|
| 2 |
-
|
| 3 |
from dash import dcc, html
|
| 4 |
from dash.dependencies import Input, Output
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
# Create dash app
|
| 9 |
app = dash.Dash(__name__)
|
| 10 |
|
| 11 |
-
# Set dog and cat images
|
| 12 |
-
dogImage = "https://www.iconexperience.com/_img/v_collection_png/256x256/shadow/dog.png"
|
| 13 |
-
catImage = "https://d2ph5fj80uercy.cloudfront.net/06/cat3602.jpg"
|
| 14 |
-
|
| 15 |
-
# Generate dataframe
|
| 16 |
-
df = pd.DataFrame(
|
| 17 |
-
dict(
|
| 18 |
-
x=[1, 2],
|
| 19 |
-
y=[2, 4],
|
| 20 |
-
images=[dogImage, catImage],
|
| 21 |
-
)
|
| 22 |
-
)
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Update layout and update traces
|
| 28 |
fig.update_layout(clickmode='event+select')
|
| 29 |
-
fig.update_traces(marker_size=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Create app layout to show dash graph
|
| 32 |
app.layout = html.Div(
|
|
@@ -35,14 +51,14 @@ app.layout = html.Div(
|
|
| 35 |
id="graph_interaction",
|
| 36 |
figure=fig,
|
| 37 |
),
|
| 38 |
-
html.
|
| 39 |
]
|
| 40 |
)
|
| 41 |
|
| 42 |
|
| 43 |
# html callback function to hover the data on specific coordinates
|
| 44 |
@app.callback(
|
| 45 |
-
Output('
|
| 46 |
Input('graph_interaction', 'hoverData'))
|
| 47 |
def open_url(hoverData):
|
| 48 |
if hoverData:
|
|
@@ -52,4 +68,4 @@ def open_url(hoverData):
|
|
| 52 |
|
| 53 |
|
| 54 |
if __name__ == '__main__':
|
| 55 |
-
app.run_server(port=7860, host="0.0.0.0", debug=True
|
|
|
|
| 1 |
import dash
|
| 2 |
+
import plotly.express as px
|
| 3 |
from dash import dcc, html
|
| 4 |
from dash.dependencies import Input, Output
|
| 5 |
+
from dash.exceptions import PreventUpdate
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
|
| 8 |
# Create dash app
|
| 9 |
app = dash.Dash(__name__)
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
def get_dataset(name, n_items=1000):
|
| 13 |
+
ola_path = f"ola13/small-{name}-dedup"
|
| 14 |
+
dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas()
|
| 15 |
+
dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1)
|
| 16 |
+
|
| 17 |
+
for column in dataset.columns:
|
| 18 |
+
if column not in ["text", "perplexity", "text_length"]:
|
| 19 |
+
dataset = dataset.drop(column, axis=1)
|
| 20 |
|
| 21 |
+
dataset = dataset.sort_values("perplexity")
|
| 22 |
+
|
| 23 |
+
max_perp = dataset["perplexity"].max()
|
| 24 |
+
return dataset, max_perp
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# names = ["oscar", "the_pile", "c4", "roots_en"]
|
| 28 |
+
name = "oscar"
|
| 29 |
+
df, max_perplexity = get_dataset(name)
|
| 30 |
+
|
| 31 |
+
# Create scatter plot with x and y coordinates
|
| 32 |
+
fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"])
|
| 33 |
# Update layout and update traces
|
| 34 |
fig.update_layout(clickmode='event+select')
|
| 35 |
+
fig.update_traces(marker_size=3)
|
| 36 |
+
fig.update_xaxes(title_text="Perplexity (log scale)", type="log")
|
| 37 |
+
fig.update_yaxes(title_text="Text Length (log scale)", type="log")
|
| 38 |
+
|
| 39 |
+
styles = {
|
| 40 |
+
'textbox': {
|
| 41 |
+
'border': 'thin lightgrey solid',
|
| 42 |
+
'overflowX': 'scroll',
|
| 43 |
+
"whiteSpace": "pre-wrap;"
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
|
| 47 |
# Create app layout to show dash graph
|
| 48 |
app.layout = html.Div(
|
|
|
|
| 51 |
id="graph_interaction",
|
| 52 |
figure=fig,
|
| 53 |
),
|
| 54 |
+
html.Div(id='text', style=styles['textbox'])
|
| 55 |
]
|
| 56 |
)
|
| 57 |
|
| 58 |
|
| 59 |
# html callback function to hover the data on specific coordinates
|
| 60 |
@app.callback(
|
| 61 |
+
Output('text', 'children'),
|
| 62 |
Input('graph_interaction', 'hoverData'))
|
| 63 |
def open_url(hoverData):
|
| 64 |
if hoverData:
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
if __name__ == '__main__':
|
| 71 |
+
app.run_server(port=7860, host="0.0.0.0", debug=True)
|