From 215d7abb5fa4d5f60f212814b81064f1d19bebac Mon Sep 17 00:00:00 2001 From: Marigold Date: Mon, 9 Dec 2024 18:53:42 +0100 Subject: [PATCH] wip --- apps/wizard/app_pages/indicator_search/app.py | 1 + apps/wizard/app_pages/similar_charts/app.py | 207 ++++++++++++++++++ apps/wizard/app_pages/similar_charts/data.py | 60 +++++ .../app_pages/similar_charts/scoring.py | 90 ++++++++ 4 files changed, 358 insertions(+) create mode 100644 apps/wizard/app_pages/similar_charts/app.py create mode 100644 apps/wizard/app_pages/similar_charts/data.py create mode 100644 apps/wizard/app_pages/similar_charts/scoring.py diff --git a/apps/wizard/app_pages/indicator_search/app.py b/apps/wizard/app_pages/indicator_search/app.py index 2ae4bc0ff62..26dcd587bff 100644 --- a/apps/wizard/app_pages/indicator_search/app.py +++ b/apps/wizard/app_pages/indicator_search/app.py @@ -26,6 +26,7 @@ def st_display_indicators(indicators: list[dict]): df = pd.DataFrame(indicators) + # TODO: make the link dynamic df["link"] = df.apply(lambda x: f"http://staging-site-indicator-search/admin/variables/{x['variableId']}/", axis=1) df["catalogPath"] = df["catalogPath"].str.replace("grapher/", "") diff --git a/apps/wizard/app_pages/similar_charts/app.py b/apps/wizard/app_pages/similar_charts/app.py new file mode 100644 index 00000000000..b170beedd57 --- /dev/null +++ b/apps/wizard/app_pages/similar_charts/app.py @@ -0,0 +1,207 @@ +import random +import re +import time + +import pandas as pd +import streamlit as st +import torch +from sentence_transformers import SentenceTransformer, util +from structlog import get_logger + +from apps.wizard.app_pages.insight_search import embeddings as emb +from apps.wizard.app_pages.similar_charts import data, scoring +from apps.wizard.utils import cached, set_states, url_persist +from apps.wizard.utils.components import Pagination, st_horizontal, st_multiselect_wider, tag_in_md +from etl.config import OWID_ENV + +DEVICE = "cpu" + +# Initialize log. +log = get_logger() + +# PAGE CONFIG +st.set_page_config( + page_title="Wizard: Similar Charts", + page_icon="🪄", + layout="wide", +) + +######################################################################################################################## +# FUNCTIONS +######################################################################################################################## + + +# TODO: convert chart to dataclass +def st_display_chart(chart, show_score=True): + tags = chart["tags"].split(";") if chart["tags"] else [] + + # TODO: fix this URL + chart_url = OWID_ENV.chart_site(chart["slug"]) + url_admin = "xxx" + + with st.container(border=True): + col1, col2 = st.columns(2) + with col1: + st.markdown(f"#### [{chart['title']}]({chart_url})") + st.markdown(f"Subtitle: {chart["subtitle"]}") + st.markdown(f"Tags: **{', '.join(tags)}**") + if show_score: + with col2: + st.markdown(f"#### Similarity: {chart['similarity']:.0%}") + st.table( + pd.Series( + { + "Tags": 0.1, + "Semantic": 0.2, + } + ) + .to_frame("score") + .style.format("{:.0%}") + ) + # TODO: Add scoring information here + # st.write("Tags: +10%") + # st.write("Semantic: +20%") + + return + + +def split_input_string(input_string: str) -> tuple[str, list[str], list[str]]: + """Break input string into query, includes and excludes.""" + # Break input string into query, includes and excludes + query = [] + includes = [] + excludes = [] + for term in input_string.split(): + if term.startswith("+"): + includes.append(term[1:].lower()) + elif term.startswith("-"): + excludes.append(term[1:].lower()) + else: + query.append(term) + + return " ".join(query), includes, excludes + + +def indicator_query(indicator: dict) -> str: + return indicator["name"] + " " + indicator["description"] + " " + (indicator["catalogPath"] or "") + + +def chart_text(chart: dict) -> str: + return chart["title"] + + +######################################################################################################################## +# Get embedding model. +MODEL = emb.get_model() +# Fetch all data indicators. +charts = data.get_charts() + +scoring_model = scoring.ScoringModel(MODEL, weights={"title": 1.0, "subtitle": 1e-9}) + +scoring_model.fit(charts) + +######################################################################################################################## + + +######################################################################################################################## +# RENDER +######################################################################################################################## + +# Streamlit app layout. +st.title(":material/search: Similar charts") + +# Box for input text. +chart_slug_or_id = st.text_input( + label="Chart slug or ID", + placeholder="Type something...", + value="human-trafficking-victims-over-18-years-old-male-vs-female", + help="Keep it empty to get a random chart.", +) + +st_multiselect_wider() +with st_horizontal(): + # Filter indicators + pass + + +if chart_slug_or_id == "": + # pick random chart + chosen_chart = random.sample(charts, 1)[0] +else: + # find chart by slug or id + chosen_chart = next( + (chart for chart in charts if chart.slug == chart_slug_or_id or str(chart.chart_id) == chart_slug_or_id), + None, + ) + if not chosen_chart: + st.error(f"Chart with slug or ID '{chart_slug_or_id}' not found.") + st.stop() + +# Get the sorted indicators. +# sorted_inds = emb.get_sorted_documents_by_similarity(MODEL, query, docs=indicators, embeddings=embeddings) # type: ignore + +# Display chosen chart +st_display_chart(chosen_chart, show_score=False) + +# Advanced expander. +st.session_state.sim_charts_expander_advanced_options = st.session_state.get( + "sim_charts_expander_advanced_options", False +) + +# Scores. +# These are the default thresholds for the different scores. +st.session_state.w_title = st.session_state.get("w_title", 1.0) +st.session_state.w_subtitle = st.session_state.get("w_subtitle", 1e-9) +st.session_state.w_tags = st.session_state.get("w_tags", 1e-9) +st.session_state.w_pageviews = st.session_state.get("w_pageviews", 1e-9) + + +with st.expander("Advanced options", expanded=st.session_state.sim_charts_expander_advanced_options): + for score_name in ["title", "subtitle", "tags", "pageviews"]: + # For some reason, if the slider minimum value is zero, streamlit raises an error when the slider is + # dragged to the minimum. Set it to a small, non-zero number. + url_persist(st.slider)( + f"Weight for {score_name} score", + min_value=1e-9, + max_value=1.0, + # step=0.001, + key=f"w_{score_name}", + ) + +scoring_model.set_weights( + { + "title": st.session_state.w_title, + "subtitle": st.session_state.w_subtitle, + "tags": st.session_state.w_tags, + "pageviews": st.session_state.w_pageviews, + } +) + +# Horizontal divider +st.markdown("---") + +similarity_dict = scoring_model.similarity(chosen_chart) + +for chart in charts: + chart.similarity = similarity_dict[chart.chart_id] + +sorted_charts = sorted(charts, key=lambda x: x.similarity, reverse=True) + + +# Use pagination +items_per_page = 10 +pagination = Pagination( + items=sorted_charts, + items_per_page=items_per_page, + pagination_key=f"pagination-di-search-{chart_slug_or_id}", +) + +if len(charts) > items_per_page: + pagination.show_controls(mode="bar") + +# Show items (only current page) +for item in pagination.get_page_items(): + # Don't show the chosen chart + if item.slug == chosen_chart.slug: + continue + st_display_chart(item) diff --git a/apps/wizard/app_pages/similar_charts/data.py b/apps/wizard/app_pages/similar_charts/data.py new file mode 100644 index 00000000000..98be6ff31c2 --- /dev/null +++ b/apps/wizard/app_pages/similar_charts/data.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from typing import Any, Dict + +import pandas as pd +import streamlit as st + +from etl.db import read_sql + + +@dataclass +class Chart: + chart_id: int + title: str + subtitle: str + tags: str + slug: str + similarity: float + + +def get_raw_charts() -> pd.DataFrame: + """Get all charts that exist in the database.""" + # Get all data indicators from the database. + query = """ + with tags as ( + select + ct.chartId as chart_id, + -- t.name as tag_name, + -- t.slug as tag_slug, + group_concat(t.name separator ';') as tags + from chart_tags as ct + join tags as t on ct.tagId = t.id + group by 1 + ) + select + c.id as chartId, + cf.slug, + cf.full->>'$.title' as title, + cf.full->>'$.subtitle' as subtitle, + cf.full->>'$.note' as note, + t.tags + from charts as c + join chart_configs as cf on c.configId = cf.id + left join tags as t on c.id = t.chart_id + -- test it on charts with 'human' in the title + where lower(cf.full->>'$.title') like '%%human%%' + """ + df = read_sql(query) + + return df + + +@st.cache_data(show_spinner=False, persist="disk", max_entries=1) +def get_charts() -> list[Chart]: + with st.spinner("Loading charts..."): + # Get charts from the database. + df = get_raw_charts() + + charts = df.to_dict(orient="records") + + return [Chart(**c) for c in charts] # type: ignore diff --git a/apps/wizard/app_pages/similar_charts/scoring.py b/apps/wizard/app_pages/similar_charts/scoring.py new file mode 100644 index 00000000000..67a41293cb9 --- /dev/null +++ b/apps/wizard/app_pages/similar_charts/scoring.py @@ -0,0 +1,90 @@ +import time +from dataclasses import dataclass +from typing import Optional + +import streamlit as st +import torch +from sentence_transformers import SentenceTransformer, util +from structlog import get_logger + +from apps.wizard.app_pages.insight_search import embeddings as emb +from apps.wizard.app_pages.similar_charts.data import Chart + +DEVICE = "cpu" + +# Initialize log. +log = get_logger() + + +class ScoringModel: + model: SentenceTransformer + chart_ids: list[int] + embeddings: dict[str, torch.Tensor] + + def __init__(self, model: SentenceTransformer, weights: Optional[dict[str, float]]) -> None: + self.model = model + self.weights = weights + + def fit(self, charts: list[Chart]): + self.chart_ids = [c.chart_id for c in charts] + + # Create an embedding for each chart. + self.embeddings["title"] = get_chart_embeddings( + self.model, [i.title for i in charts], model_name="sim_charts_title" + ) + self.embeddings["subtitle"] = get_chart_embeddings( + self.model, [i.subtitle for i in charts], model_name="sim_charts_subtitle" + ) + + def set_weights(self, weights: dict[str, float]): + self.weights = weights + + def similarity(self, chart: Chart) -> dict[int, float]: + log.info("calculate_similarity.start", n_docs=len(self.chart_ids)) + t = time.time() + + title_embedding = self.model.encode(chart.title, convert_to_tensor=True, device=DEVICE) + # TODO: Missing subtitle should be treated differently + subtitle_embedding = self.model.encode(chart.subtitle or "", convert_to_tensor=True, device=DEVICE) + + title_scores = _get_score(title_embedding, self.embeddings["title"]) + subtitle_scores = _get_score(subtitle_embedding, self.embeddings["subtitle"]) + + # Attach the similarity scores to the documents. + ret = {} + for i, chart_id in enumerate(self.chart_ids): + total_w = st.session_state.w_title + st.session_state.w_subtitle + # TODO: Use vector computation + similarity = ( + st.session_state.w_title * title_scores[i] + st.session_state.w_subtitle * subtitle_scores[i] + ) / total_w + ret[chart_id] = similarity + + log.info("calculate_similarity.end", t=time.time() - t) + + return ret + + def similarity_components(self, chart: Chart): + pass + + +# Compute the cosine similarity between the input and each document. +def _get_score(input_embedding, embeddings, typ="cosine"): + if typ == "cosine": + score = util.pytorch_cos_sim(embeddings, input_embedding) + score = (score + 1) / 2 + elif typ == "euclidean": + # distance = torch.cdist(embeddings, input_embedding) + score = util.euclidean_sim(embeddings, input_embedding) + score = 1 / (1 - score) # Normalize to [0, 1] + else: + raise ValueError(f"Invalid similarity type: {typ}") + + return score.cpu().numpy()[:, 0] + + +# TODO: memoization would be very expensive +@st.cache_data(show_spinner=False, max_entries=1) +def get_chart_embeddings(_model, _indicators_texts: list[str], model_name: str) -> torch.Tensor: + with st.spinner("Generating embeddings..."): + return emb.get_embeddings(_model, _indicators_texts, model_name=model_name) # type: ignore