From 78c84cb1fab0e3d9aca01292eba805761290e35c Mon Sep 17 00:00:00 2001 From: Marigold Date: Mon, 13 Jan 2025 11:45:46 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=89=20Admin=20for=20related=20charts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/cli/__init__.py | 6 + apps/housekeeper/charts.py | 4 +- apps/related_charts/__init__.py | 0 apps/related_charts/cli.py | 204 ++++++++ .../app_pages/insight_search/embeddings.py | 4 +- apps/wizard/app_pages/similar_charts/app.py | 492 +++++++++++++++--- apps/wizard/app_pages/similar_charts/data.py | 33 ++ .../app_pages/similar_charts/scoring.py | 48 +- etl/config.py | 5 +- etl/grapher/model.py | 63 +++ 10 files changed, 758 insertions(+), 101 deletions(-) create mode 100644 apps/related_charts/__init__.py create mode 100644 apps/related_charts/cli.py diff --git a/apps/cli/__init__.py b/apps/cli/__init__.py index ed459f65f33..c405a1bd6d5 100644 --- a/apps/cli/__init__.py +++ b/apps/cli/__init__.py @@ -198,6 +198,12 @@ def cli_back() -> None: "anomalist": "apps.anomalist.cli.cli", }, }, + { + "name": "Related Charts", + "commands": { + "related-charts": "apps.related_charts.cli.cli", + }, + }, ] # Add subgroups (don't modify) + subgroups diff --git a/apps/housekeeper/charts.py b/apps/housekeeper/charts.py index e8d2c12fe7f..edbf6f1d130 100644 --- a/apps/housekeeper/charts.py +++ b/apps/housekeeper/charts.py @@ -214,9 +214,7 @@ def _get_main_message_usage(chart, refs): def send_extra_messages(chart, refs, **kwargs): """Provide more context in the thread""" ## 1/ Similar charts - similar_messages = ( - f"πŸ•΅οΈ <{OWID_ENV.wizard_url}similar_charts?chart_search_text={chart['slug']}| β†’ Explore similar charts>" - ) + similar_messages = f"πŸ•΅οΈ <{OWID_ENV.wizard_url}similar_charts?slug={chart['slug']}| β†’ Explore similar charts>" ## 2/ AI: Chart description, chart edit timeline, suggestion log.info("Getting AI summary...") diff --git a/apps/related_charts/__init__.py b/apps/related_charts/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/apps/related_charts/cli.py b/apps/related_charts/cli.py new file mode 100644 index 00000000000..8998cd2a4b6 --- /dev/null +++ b/apps/related_charts/cli.py @@ -0,0 +1,204 @@ +import datetime as dt +from typing import Optional + +import click +import pandas as pd +import structlog +from rich_click.rich_command import RichCommand +from sqlalchemy import text +from tqdm.auto import tqdm + +from apps.wizard.app_pages.similar_charts import data, scoring +from etl import config +from etl.db import get_engine + +config.enable_bugsnag() +log = structlog.get_logger() + + +def load_data(chart_slug: Optional[str]) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Load chart data and coview sessions DataFrame. + + Returns: + charts: DataFrame indexed by slug, containing metadata (like chart_id, views_365d, etc.) + coviews_df: DataFrame with MultiIndex (slug1, slug2) + and columns ['coviews', 'pageviews']. + """ + log.info("Loading chart data...") + charts = data.get_raw_charts().set_index("slug", drop=False) + + # If chart_slug is provided, verify it's in charts + if chart_slug and chart_slug not in charts.index: + log.warning("Chart slug not found in data. Exiting.", chart_slug=chart_slug) + return pd.DataFrame(), pd.DataFrame() + + log.info("Loading coview sessions...") + coviews_df = data.get_coviews_sessions( + after_date=str(dt.date.today() - dt.timedelta(days=365)), min_sessions=3 + ).to_frame(name="coviews") + + # If a single chart slug is given, filter for that slug1 only + if chart_slug: + coviews_df = coviews_df[coviews_df.index.get_level_values("slug1") == chart_slug] + + # Filter out any coviews rows whose slug1 isn't in our charts + coviews_df = coviews_df[coviews_df.index.get_level_values("slug1").isin(charts.index)] + + # Add pageviews of slug2 to coviews dataframe + coviews_df["pageviews"] = charts["views_365d"].reindex(coviews_df.index.get_level_values("slug2")).values + + return charts, coviews_df + + +def compute_recommendations( + charts: pd.DataFrame, + coviews_df: pd.DataFrame, + chart_slug: Optional[str], + top: int, + regularization: float, +) -> pd.DataFrame: + """ + Given charts and coview data, compute a DataFrame of recommended pairs: + chosen_chart, related_chart, chartId, relatedChartId, etc. + + The 'score' is computed as: + score = coviews - regularization * pageviews + + Args: + charts: DataFrame of chart metadata. + coviews_df: DataFrame with columns ['coviews', 'pageviews']. + chart_slug: Optional single slug to process; otherwise compute for all slugs. + top: How many top-related charts to retrieve for each slug. + regularization: Factor to penalize high-view charts. + + Returns: + A DataFrame of recommended chart pairs (chosen_chart, related_chart, score, etc.). + """ + # If we failed to load data (e.g., an invalid slug), return empty + if charts.empty or coviews_df.empty: + return pd.DataFrame() + + # Compute the score + coviews_df["score"] = coviews_df["coviews"] - regularization * coviews_df["pageviews"] + + # If a single chart slug is requested, ensure we only keep those rows + if chart_slug: + if chart_slug not in coviews_df.index.get_level_values("slug1"): + log.info("No coview data for this chart slug.", chart_slug=chart_slug) + return pd.DataFrame() + coviews_df = coviews_df.loc[[chart_slug]] + + recommended_rows = [] + + # Group by 'slug1' so each group has all charts related to that slug1 + grouped = coviews_df.groupby(level="slug1", sort=False) + log.info("Calculating related charts...") + + for slug1, group in tqdm(grouped, desc="Calculating related charts"): + top_related = group.sort_values("score", ascending=False).head(top) + for related_slug, score in zip(top_related.index.get_level_values("slug2"), top_related["score"]): + recommended_rows.append({"chosen_chart": slug1, "related_chart": related_slug, "score": score}) + + if not recommended_rows: + return pd.DataFrame() + + # Build the recommendations DataFrame + recommended_df = pd.DataFrame(recommended_rows) + recommended_df["chartId"] = recommended_df["chosen_chart"].map(charts["chart_id"]) + recommended_df["relatedChartId"] = recommended_df["related_chart"].map(charts["chart_id"]) + recommended_df["label"] = "good" + recommended_df["reviewer"] = "production" + + # Warn if some related_chart slugs can't be mapped to chartIds + ix_missing = recommended_df["relatedChartId"].isnull() + if ix_missing.any(): + log.warning("Chart ID not found for some related chart slugs.", n_missing=ix_missing.sum()) + recommended_df = recommended_df[~ix_missing] + + return recommended_df + + +def write_recommendations( + engine, recommended_df: pd.DataFrame, charts: pd.DataFrame, chart_slug: Optional[str] +) -> None: + """ + Writes the recommended DataFrame to the 'related_charts' table in the database. + If 'chart_slug' is specified, only deletes existing rows for that slug before inserting. + Otherwise, clears all 'production' rows first. + """ + if recommended_df.empty: + log.info("No related charts found. Nothing to write.") + return + + with engine.begin() as conn: + if chart_slug: + log.info("Deleting existing 'production' reviews for this chart.", chart_slug=chart_slug) + conn.execute( + text(""" + DELETE FROM related_charts + WHERE reviewer = 'production' AND chartId = :chartId + """), + {"chartId": charts.loc[chart_slug, "chart_id"]}, + ) + else: + log.info("Deleting all existing 'production' reviews.") + conn.execute(text("DELETE FROM related_charts WHERE reviewer = 'production'")) + + log.info("Inserting new related chart records.", rows=len(recommended_df)) + recommended_df[["chartId", "relatedChartId", "label", "reviewer", "score"]].to_sql( + "related_charts", con=conn, if_exists="append", index=False + ) + + +@click.command(name="related-charts", cls=RichCommand, help=__doc__) +@click.option( + "--chart-slug", + type=str, + help="Get related charts only for the chart with this slug.", +) +@click.option( + "--top", + type=int, + default=6, + help="Pick the top N related charts.", +) +@click.option( + "--regularization", type=float, default=0.001, help="Factor by which to penalize charts with high pageviews." +) +@click.option( + "--dry-run/--no-dry-run", + default=False, + help="If set, no changes will be written to the database.", +) +def cli(chart_slug: Optional[str], top: int, regularization: float, dry_run: bool) -> None: + """ + Generates a table of related charts (by coviews) and optionally writes them + to the database. If a single chart slug is provided, only that chart’s + related charts will be generated. + """ + engine = get_engine() + + # 1. Load data (no score calculated here) + charts, coviews_df = load_data(chart_slug) + + # 2. Compute recommendations (score is applied here) + recommended_df = compute_recommendations(charts, coviews_df, chart_slug, top, regularization) + + if recommended_df.empty: + log.info("No recommendations generated. Exiting.") + return + + # 3. Dry-run check + if dry_run: + log.info("Dry run mode enabled. No changes will be written to the database.") + log.info("Recommended DataFrame preview:", data=recommended_df.head()) + return + + # 4. Otherwise, write to DB + write_recommendations(engine, recommended_df, charts, chart_slug) + log.info("Related charts updated successfully.") + + +if __name__ == "__main__": + cli() diff --git a/apps/wizard/app_pages/insight_search/embeddings.py b/apps/wizard/app_pages/insight_search/embeddings.py index ddd31562743..b288a23ce7f 100644 --- a/apps/wizard/app_pages/insight_search/embeddings.py +++ b/apps/wizard/app_pages/insight_search/embeddings.py @@ -5,14 +5,12 @@ import streamlit as st import torch -from joblib import Memory from sentence_transformers import SentenceTransformer, util from structlog import get_logger +from etl.config import memory from etl.paths import CACHE_DIR -memory = Memory(CACHE_DIR, verbose=0) - # Initialize log. log = get_logger() diff --git a/apps/wizard/app_pages/similar_charts/app.py b/apps/wizard/app_pages/similar_charts/app.py index fd2b4addf2f..351a8e3f2fc 100644 --- a/apps/wizard/app_pages/similar_charts/app.py +++ b/apps/wizard/app_pages/similar_charts/app.py @@ -1,17 +1,37 @@ +import datetime as dt import random +from typing import List, get_args import pandas as pd import streamlit as st +from sqlalchemy.orm import Session from structlog import get_logger from apps.wizard.app_pages.similar_charts import data, scoring from apps.wizard.utils import embeddings as emb +from apps.wizard.utils import start_profiler +from apps.wizard.utils.cached import get_grapher_user from apps.wizard.utils.components import Pagination, st_horizontal, st_multiselect_wider, url_persist from etl.config import OWID_ENV +from etl.db import get_engine +from etl.git_helpers import log_time +from etl.grapher import model as gm + +PROFILER = start_profiler() + +ITEMS_PER_PAGE = 20 # Initialize log. log = get_logger() +# Database engine. +engine = get_engine() + +# Get reviewer's name. +grapher_user = get_grapher_user(st.context.headers.get("X-Forwarded-For")) +assert grapher_user, "User not found" +reviewer = grapher_user.fullName + # PAGE CONFIG st.set_page_config( page_title="Wizard: Similar Charts", @@ -20,14 +40,33 @@ ) ######################################################################################################################## -# FUNCTIONS +# CONSTANTS & FUNCTIONS ######################################################################################################################## +DISPLAY_STATE_OPTIONS = { + "good": { + "label": "Good", + "color": "green", + "icon": "βœ…", + }, + "bad": { + "label": "Bad", + "color": "red", + "icon": "❌", + }, + "neutral": { + "label": "Neutral", + "color": "gray", + "icon": "⏳", + }, +} + +CHART_LABELS = get_args(gm.RELATED_CHART_LABEL) + @st.cache_data(show_spinner=False, ttl="1h") def get_charts() -> list[data.Chart]: with st.spinner("Loading charts..."): - # Get charts from the database.. df = data.get_raw_charts() if len(df) == 0: @@ -43,40 +82,40 @@ def get_charts() -> list[data.Chart]: return ret -def st_chart_info(chart: data.Chart) -> None: +@log_time +@st.cache_data(show_spinner=False) +def get_coviews() -> pd.Series: + # Load coviews for all charts for the past 365 days. + with st.spinner("Loading coviews..."): + return data.get_coviews_sessions(after_date=str(dt.date.today() - dt.timedelta(days=365)), min_sessions=3) + + +def st_chart_info(chart: data.Chart, show_coviews=True) -> None: + """Displays general info about a single chart.""" chart_url = OWID_ENV.chart_site(chart.slug) - title = f"#### [{chart.title}]({chart_url})" + # title = f"#### [{chart.title}]({chart_url})" + title = f"[{chart.title}]({chart_url})" if chart.gpt_reason: title += " πŸ€–" - st.markdown(title) + st.subheader(title, anchor=chart.slug) st.markdown(f"Slug: {chart.slug}") st.markdown(f"Subtitle: {chart.subtitle}") st.markdown(f"Tags: **{', '.join(chart.tags)}**") st.markdown(f"Pageviews: **{chart.views_365d}**") + if show_coviews: + st.markdown(f"Coviews: **{chart.coviews}**") def st_chart_scores(chart: data.Chart, sim_components: pd.DataFrame) -> None: - st.markdown(f"#### Similarity: {chart.similarity:.0%}") + """Displays scoring info (score, breakdown table) for a single chart.""" + st.markdown(f"#### Score: {chart.similarity:.0%}") st.table(sim_components.loc[chart.chart_id].to_frame("score").style.format("{:.0%}")) if chart.gpt_reason: st.markdown(f"**GPT Diversity Reason**:\n{chart.gpt_reason}") -def st_display_chart( - chart: data.Chart, - sim_components: pd.DataFrame = pd.DataFrame(), -) -> None: - with st.container(border=True): - col1, col2 = st.columns(2) - with col1: - st_chart_info(chart) - with col2: - st_chart_scores(chart, sim_components) - - def split_input_string(input_string: str) -> tuple[str, list[str], list[str]]: """Break input string into query, includes and excludes.""" - # Break input string into query, includes and excludes query = [] includes = [] excludes = [] @@ -91,28 +130,257 @@ def split_input_string(input_string: str) -> tuple[str, list[str], list[str]]: return " ".join(query), includes, excludes -@st.cache_data(show_spinner=False, max_entries=1) +@log_time +@st.cache_data( + show_spinner=False, + max_entries=1, + hash_funcs={list[data.Chart]: lambda charts: len(charts)}, +) def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: with st.spinner("Loading model..."): scoring_model = scoring.ScoringModel(emb.get_model()) - scoring_model.fit(charts) + with st.spinner("Fitting model..."): + scoring_model.fit(charts) return scoring_model ######################################################################################################################## -# Fetch all data indicators. +# NEW COMPONENTS +######################################################################################################################## + + +class RelatedChartDisplayer: + """ + Encapsulates the logic for displaying and labeling a related chart, + including any database updates and UI feedback. + """ + + def __init__(self, engine, chosen_chart: data.Chart, sim_components: pd.DataFrame): + self.engine = engine + self.chosen_chart = chosen_chart + self.sim_components = sim_components + + def display( + self, + chart: data.Chart, + label: gm.RELATED_CHART_LABEL = "neutral", + ) -> None: + """ + Renders the chart block (info, scores, and label radio). + Also hooks up the callback for label changes. + """ + with st.container(): + col1, col2 = st.columns(2) + with col1: + st_chart_info(chart) + st.radio( + label="**Review Related Chart**", + key=f"label-{chart.chart_id}", + options=CHART_LABELS, + index=CHART_LABELS.index(label), + horizontal=True, + format_func=lambda x: f":{DISPLAY_STATE_OPTIONS[x]['color']}-background[{DISPLAY_STATE_OPTIONS[x]['label']}]", + on_change=self._push_status, + kwargs={"chart": chart}, + ) + with col2: + st_chart_scores(chart, self.sim_components) + + def _push_status(self, chart: data.Chart) -> None: + """ + Callback: triggered on label change. Saves to the DB and + shows an appropriate toast. + """ + label: gm.RELATED_CHART_LABEL = st.session_state[f"label-{chart.chart_id}"] + + with Session(self.engine) as session: + gm.RelatedChart( + chartId=self.chosen_chart.chart_id, + relatedChartId=chart.chart_id, + label=label, + reviewer=reviewer, + ).upsert(session) + session.commit() + + # Notify user + with st.spinner(): + match label: + case "good": + st.toast(":green[Recommendation labeled as **good**]", icon="βœ…") + case "bad": + st.toast(":red[Recommendation labeled as **bad**]", icon="❌") + case "neutral": + st.toast("**Resetting** recommendation to neutral", icon=":material/restart_alt:") + + +def st_related_charts_table( + related_charts: list[gm.RelatedChart], chart_map: dict[int, data.Chart], chosen_chart: data.Chart +) -> None: + """ + Shows a "matrix" of reviews in a pivoted table using st.dataframe: + - Row per related chart + - Columns for slug, title, views_365d, link, and one column per reviewer (icon) + - Hides chart_id + """ + if not related_charts: + st.info("No related charts have been selected yet.") + return + + # 1) Convert the list of RelatedChart objects to a DataFrame + rows = [] + for rc in related_charts: + c = chart_map.get(rc.relatedChartId) + if not c: + # Skip if missing + continue + + # rev = rc.reviewer + # if not rev.startswith("πŸ€–"): + # reviewer = "πŸ‘€" + " " + reviewer + + rows.append( + { + "chart_id": c.chart_id, + "slug": c.slug, + "title": c.title, + "views_365d": c.views_365d, + "coviews": c.coviews, + "score": c.similarity, + "reviewer": rc.reviewer, + "label": rc.label, + } + ) + df = pd.DataFrame(rows) + + # Exclude neutral reviews + df = df[df["label"] != "neutral"] + + # 2) Pivot so that each reviewer is a column, with the label as the cell value + pivot_df = df.pivot( + index=["chart_id", "slug", "title", "views_365d", "coviews", "score"], columns="reviewer", values="label" + ).fillna("neutral") + + reviewer_cols = list(pivot_df.columns) + + if reviewer in reviewer_cols: + print(pivot_df) + pivot_df["favorite"] = pivot_df[reviewer] == "good" + del pivot_df[reviewer] + reviewer_cols.remove(reviewer) + else: + pivot_df["favorite"] = False + + # 3) Map each label (good/bad/neutral) to an icon + def label_to_icon(label: str) -> str: + if label == "neutral": + return "" + else: + return DISPLAY_STATE_OPTIONS.get(label, DISPLAY_STATE_OPTIONS["neutral"])["icon"] + + pivot_df[reviewer_cols] = pivot_df[reviewer_cols].applymap(label_to_icon) + + # 4) Flatten the multi-index so 'chart_id', 'slug', etc. become columns + pivot_df.reset_index(inplace=True) + + # 6) Create a new column "link" + pivot_df["link"] = pivot_df["slug"].apply(lambda x: OWID_ENV.chart_site(x)) + # TODO: jump to anchor + # pivot_df["link"] = pivot_df["slug"].apply(lambda x: f"#{x}") + + # 7) Build the final column order + final_cols = ["link", "chart_id", "slug", "title", "views_365d", "coviews", "score"] + reviewer_cols + ["favorite"] + + pivot_df = pivot_df[final_cols].sort_values(["score"], ascending=False) + + # 8) Configure columns for st.dataframe + column_config = { + # The link column becomes a clickable link + "link": st.column_config.LinkColumn( + "Open", + # display_text="Jump to detail", + display_text="Open", + ), + "favorite": st.column_config.CheckboxColumn( + "Your favorite?", + help="Select your **favorite** widgets", + default=False, + ), + "chart_id": None, + } + # You could also configure text columns or numeric columns (like "views_365d"). + styled_df = pivot_df.style.format("{:.0%}", subset=["score"]) + + # Disable all columns except "favorite" + disabled_cols = [col for col in pivot_df.columns if col != "favorite"] + + old_favorites = set(pivot_df[pivot_df["favorite"]].chart_id) + + # 9) Show the result using st.data_editor + updated_df = st.data_editor( + styled_df, + use_container_width=True, + hide_index=True, + column_config=column_config, + disabled=disabled_cols, + ) + + new_favorites = set(updated_df[updated_df["favorite"]].chart_id) + + with Session(engine) as session: + for chart_id in new_favorites - old_favorites: + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=chart_id, + label="good", + reviewer=reviewer, + ).upsert(session) + + for chart_id in old_favorites - new_favorites: + # TODO: we can delete it as well + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=chart_id, + label="neutral", + reviewer=reviewer, + ).upsert(session) + + session.commit() + + +def add_coviews_to_charts(charts: List[data.Chart], chosen_chart: data.Chart, coviews: pd.Series) -> List[data.Chart]: + try: + chosen_chart_coviews = coviews.loc[chosen_chart.slug].to_dict() + except KeyError: + chosen_chart_coviews = {} + + for c in charts: + c.coviews = chosen_chart_coviews.get(c.slug, 0) + + return charts + + +######################################################################################################################## +# FETCH DATA & MODEL +######################################################################################################################## + charts = get_charts() -# Get scoring model. +coviews = get_coviews() + scoring_model = get_and_fit_model(charts) +# Re-set charts if the model comes from cache +scoring_model.charts = charts -######################################################################################################################## +# Build a chart map for quick lookups by chart_id +chart_map = {chart.chart_id: chart for chart in charts} + +# Pick top 100 charts by pageviews. +top_100_charts: list[data.Chart] = sorted(charts, key=lambda x: x.views_365d, reverse=True)[:100] # type: ignore ######################################################################################################################## # RENDER ######################################################################################################################## -# Streamlit app layout. st.title(":material/search: Similar charts") col1, col2 = st.columns(2) @@ -120,6 +388,7 @@ def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: st_multiselect_wider() with st_horizontal(): random_chart = st.button("Random chart", help="Get a random chart.") + random_100_chart = st.button("Random top 100 chart", help="Get a random chart from the top 100 charts.") # Filter indicators diversity_gpt = url_persist(st.checkbox)( @@ -130,28 +399,27 @@ def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: ) # Random chart was pressed or no search text - if random_chart or not st.query_params.get("chart_search_text"): - chart_slug = random.sample(charts, 1)[0].slug - st.session_state["chart_search_text"] = chart_slug - - # chart_search_text = url_persist(st.text_input)( - # key="chart_search_text", - # label="Chart slug or ID", - # placeholder="Type something...", - # ) - - chart_search_text = url_persist(st.selectbox)( + if random_chart or not st.query_params.get("slug"): + # weighted by views + chart = random.choices(charts, weights=[c.views_365d for c in charts], k=1)[0] # type: ignore + # non-weighted sample + # chart = random.sample(charts, 1)[0] + st.session_state["slug"] = chart.slug + elif random_100_chart: + chart_slug = random.sample(top_100_charts, 1)[0].slug + st.session_state["slug"] = chart_slug + + # Dropdown select for chart. + slug = url_persist(st.selectbox)( "Select a chart", - key="chart_search_text", + key="slug", options=[c.slug for c in charts], ) - # Advanced expander. + # Advanced options st.session_state.sim_charts_expander_advanced_options = st.session_state.get( "sim_charts_expander_advanced_options", False ) - - # Weights for each score with st.expander("Advanced options", expanded=st.session_state.sim_charts_expander_advanced_options): # Add text area for system prompt system_prompt = url_persist(st.text_area)( @@ -161,82 +429,140 @@ def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: height=150, ) - for score_name in ["title", "subtitle", "tags", "pageviews", "share_indicator"]: - # For some reason, if the slider minimum value is zero, streamlit raises an error when the slider is - # dragged to the minimum. Set it to a small, non-zero number. - key = f"w_{score_name}" - - # Set default values - if key not in st.session_state: - st.session_state[key] = scoring.DEFAULT_WEIGHTS[score_name] + # Regularization for coviews + url_persist(st.slider)( + "Coviews regularization", + key="coviews_regularization", + min_value=0.0, + max_value=0.001, + value=scoring.DEFAULT_COVIEWS_REGULARIZATION, + step=0.0001, + format="%.3f", + help="Penalize coviews score by subtracting this value times pageviews.", + ) + scoring_model.coviews_regularization = st.session_state["coviews_regularization"] + for score_name in ["title", "subtitle", "tags", "share_indicator", "pageviews_score", "coviews_score"]: + key = f"w_{score_name}" url_persist(st.slider)( f"Weight for {score_name} score", + key=key, min_value=1e-9, max_value=1.0, - # step=0.001, - key=key, value=scoring.DEFAULT_WEIGHTS[score_name], ) - scoring_model.weights[score_name] = st.session_state[key] - -# Find a chart based on inputs +# Find a chart chosen_chart = next( - (chart for chart in charts if chart.slug == chart_search_text or str(chart.chart_id) == chart_search_text), + (chart for chart in charts if chart.slug == slug or str(chart.chart_id) == slug), None, ) if not chosen_chart: - st.error(f"Chart with slug {chart_search_text} not found.") + st.error(f"Chart with slug {slug} not found.") + st.stop() - # # Find a chart by title - # chart_id = scoring_model.similar_chart_by_title(chart_search_text) - # chosen_chart = next((chart for chart in charts if chart.chart_id == chart_id), None) +# Add coviews +charts = add_coviews_to_charts(charts, chosen_chart, coviews) -assert chosen_chart - -# Display chosen chart -with col1: - st_chart_info(chosen_chart) - - -# Horizontal divider -st.markdown("---") +# Load "official" related charts from DB +with Session(engine) as session: + related_charts_db = gm.RelatedChart.load(session, chart_id=chosen_chart.chart_id) +# Compute similarity for all charts sim_dict = scoring_model.similarity(chosen_chart) sim_components = scoring_model.similarity_components(chosen_chart) -for chart in charts: - chart.similarity = sim_dict[chart.chart_id] +# Assign similarity +for c in charts: + c.similarity = sim_dict[c.chart_id] +# Sort by similarity sorted_charts = sorted(charts, key=lambda x: x.similarity, reverse=True) # type: ignore -# Postprocess charts with GPT and prioritize diversity +# Add reviews to top charts by similarity +for c in sorted_charts[:6]: + if c.chart_id == chosen_chart.chart_id: + continue + # Add to related charts + related_charts_db.append( + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=c.chart_id, + label="good", + reviewer="πŸ€– Score", + ) + ) + +# Possibly re-rank with GPT for diversity if diversity_gpt: with st.spinner("Diversifying chart results..."): slugs_to_reasons = scoring.gpt_diverse_charts(chosen_chart, sorted_charts, system_prompt=system_prompt) - for chart in sorted_charts: - if chart.slug in slugs_to_reasons: - chart.gpt_reason = slugs_to_reasons[chart.slug] + for c in sorted_charts: + if c.slug in slugs_to_reasons: + c.gpt_reason = slugs_to_reasons[c.slug] + + # Add to related charts + related_charts_db.append( + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=c.chart_id, + label="good", + reviewer="πŸ€– GPT", + ) + ) + + +# Add coviews reviewer +for chart_id in sim_components.sort_values("coviews_score", ascending=False).index[:5]: + c = chart_map[chart_id] + # Don't recommend zero coviews + if c.coviews == 0: + continue + related_charts_db.append( + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=c.chart_id, + label="good", + reviewer="πŸ€– Coviews", + ) + ) + +# Display chosen chart +with col1: + st_chart_info(chosen_chart, show_coviews=False) + +# Divider +st.markdown("---") +st.header("Reviewed Related Charts") +st_related_charts_table(related_charts_db, chart_map, chosen_chart) + +# Divider +st.markdown("---") +st.header("Recommended Related Charts") - # Put charts that are diverse at the top - # sorted_charts = sorted(sorted_charts, key=lambda x: (x.gpt_reason is not None, x.similarity), reverse=True) +# Create our new chart display component +displayer = RelatedChartDisplayer(engine, chosen_chart, sim_components) # Use pagination -items_per_page = 20 pagination = Pagination( - items=sorted_charts, - items_per_page=items_per_page, + items=sorted_charts[:100], + items_per_page=ITEMS_PER_PAGE, pagination_key=f"pagination-di-search-{chosen_chart.slug}", ) - -if len(charts) > items_per_page: +if len(sorted_charts) > ITEMS_PER_PAGE: pagination.show_controls(mode="bar") -# Show items (only current page) +# Display only the current page for item in pagination.get_page_items(): - # Don't show the chosen chart if item.slug == chosen_chart.slug: continue - st_display_chart(item, sim_components) + + # Check if we have a DB label for the related chart from us + labels = [r.label for r in related_charts_db if r.relatedChartId == item.chart_id and r.reviewer == reviewer] + label = labels[0] if labels else "neutral" + + # Use the new component to display + displayer.display(chart=item, label=label) # type: ignore + +PROFILER.stop() diff --git a/apps/wizard/app_pages/similar_charts/data.py b/apps/wizard/app_pages/similar_charts/data.py index 7294b99d711..4ef769eac9c 100644 --- a/apps/wizard/app_pages/similar_charts/data.py +++ b/apps/wizard/app_pages/similar_charts/data.py @@ -4,7 +4,9 @@ import pandas as pd +from apps.utils.google import read_gbq from apps.wizard.utils.embeddings import Doc +from etl.config import memory from etl.db import read_sql @@ -21,6 +23,7 @@ class Chart(Doc): views_14d: Optional[int] = None views_365d: Optional[int] = None gpt_reason: Optional[str] = None + coviews: Optional[int] = None def get_raw_charts() -> pd.DataFrame: @@ -66,3 +69,33 @@ def get_raw_charts() -> pd.DataFrame: assert df["chart_id"].nunique() == df.shape[0] return df + + +@memory.cache +def get_coviews_sessions(after_date: str, min_sessions: int = 5) -> pd.Series: + """ + Count of sessions in which a pair of URLs are both visited, aggregated daily + + note: this is a nondirectional network. url1 and url2 are string sorted and + do not indicate anything about whether url1 was visited before/after url2 in + the session. + """ + query = f""" + SELECT + REGEXP_EXTRACT(url1, r'grapher/([^/]+)') AS slug1, + REGEXP_EXTRACT(url2, r'grapher/([^/]+)') AS slug2, + SUM(sessions_coviewed) AS total_sessions + FROM prod_google_analytics4.coviews_by_day_page + WHERE day >= '{after_date}' + AND url1 LIKE 'https://ourworldindata.org/grapher%' + AND url2 LIKE 'https://ourworldindata.org/grapher%' + GROUP BY slug1, slug2 + HAVING total_sessions >= {min_sessions} + """ + df = read_gbq(query, project_id="owid-analytics") + + # concat with reversed slug1 and slug2 + df = pd.concat([df, df.rename(columns={"slug1": "slug2", "slug2": "slug1"})]) + + # set index for faster lookups + return df.set_index(["slug1", "slug2"]).sort_index()["total_sessions"] diff --git a/apps/wizard/app_pages/similar_charts/scoring.py b/apps/wizard/app_pages/similar_charts/scoring.py index 7c60851d2f1..0b2d2d2474d 100644 --- a/apps/wizard/app_pages/similar_charts/scoring.py +++ b/apps/wizard/app_pages/similar_charts/scoring.py @@ -21,13 +21,17 @@ # These are the default thresholds for the different scores. DEFAULT_WEIGHTS = { - "title": 0.4, + "title": 0.3, "subtitle": 0.1, "tags": 0.1, - "pageviews": 0.3, "share_indicator": 0.1, + "pageviews_score": 0.3, + "coviews_score": 0.1, } +# Default regularization term for coviews +DEFAULT_COVIEWS_REGULARIZATION = 0.0 + PREFIX_SYSTEM_PROMPT = """ You are an expert in recommending visual data insights. Your task: From a given chosen chart and a list of candidate charts, recommend up to 5 charts that are most relevant. @@ -58,11 +62,14 @@ class ScoringModel: # Weights for the different scores weights: dict[str, float] - def __init__(self, model: SentenceTransformer, weights: Optional[dict[str, float]] = None) -> None: + def __init__( + self, model: SentenceTransformer, weights: Optional[dict[str, float]] = None, coviews_regularization: float = 0 + ) -> None: self.model = model self.weights = weights or DEFAULT_WEIGHTS.copy() + self.coviews_regularization = coviews_regularization - def fit(self, charts: list[Chart]): + def fit(self, charts: list[Chart]) -> None: self.charts = charts # Get embeddings for title and subtitle @@ -121,8 +128,9 @@ def similarity_components(self, chart: Chart) -> pd.DataFrame: "subtitle": subtitle_scores[i], # score 1 if there is at least one tag in common, 0 otherwise "tags": float(bool(set(c.tags) & set(chart.tags))), - "pageviews": c.views_365d or 0, "share_indicator": float(c.chart_id in charts_sharing_indicator), + "pageviews": c.views_365d or 0, + "coviews": c.coviews or 0, } ) @@ -134,10 +142,9 @@ def similarity_components(self, chart: Chart) -> pd.DataFrame: if chart.subtitle == "": ret["subtitle"] = 0 - # Scale pageviews to [0, 1] - ret["pageviews"] = np.log(ret["pageviews"] + 1) - ret["pageviews"] = (ret["pageviews"] - ret["pageviews"].min()) / ( - ret["pageviews"].max() - ret["pageviews"].min() + ret["pageviews_score"] = score_pageviews(ret["pageviews"]) + ret["coviews_score"] = score_coviews( + ret["coviews"], ret["pageviews"], regularization=self.coviews_regularization ) # Get weights and normalize them @@ -148,14 +155,33 @@ def similarity_components(self, chart: Chart) -> pd.DataFrame: ret = (ret * w).fillna(0) # Reorder - ret = ret[["title", "subtitle", "tags", "share_indicator", "pageviews"]] + ret = ret[["title", "subtitle", "tags", "share_indicator", "pageviews_score", "coviews_score"]] log.info("similarity_components.end", t=time.time() - t) return ret -@st.cache_data(show_spinner=False, persist="disk") +def score_pageviews(pageviews: pd.Series) -> pd.Series: + """Log transform pageviews and scale them to [0, 1]. Chart with the most pageviews gets score 1 and + chart with the least pageviews gets score 0. + """ + pageviews = np.log(pageviews + 1) # type: ignore + return (pageviews - pageviews.min()) / (pageviews.max() - pageviews.min()) + + +def score_coviews(coviews: pd.Series, pageviews: pd.Series, regularization: float) -> float: + """Score coviews. First, get ratio of coviews to pageviews. Add regularization term to pageviews + to penalize charts with high pageviews that tend to show up, despite being not very relevant. + Then, normalize the score to [0, 1]. + """ + # p = coviews / (pageviews + lam) + # return (p - p.min()) / (p.max() - p.min()) + p = coviews - regularization * pageviews + return p / p.max() + + +@st.cache_data(show_spinner=False, persist="disk", hash_funcs={Chart: lambda chart: chart.chart_id}) def gpt_diverse_charts( chosen_chart: Chart, _charts: list[Chart], _n: int = 30, system_prompt=DEFAULT_SYSTEM_PROMPT ) -> dict[str, str]: diff --git a/etl/config.py b/etl/config.py index 6958fe64bbb..c5380dfe9a6 100644 --- a/etl/config.py +++ b/etl/config.py @@ -21,15 +21,18 @@ import pandas as pd import structlog from dotenv import dotenv_values, load_dotenv +from joblib import Memory from sqlalchemy.engine import Engine from sqlalchemy.orm import Session -from etl.paths import BASE_DIR +from etl.paths import BASE_DIR, CACHE_DIR log = structlog.get_logger() ENV_FILE = Path(env.get("ENV_FILE", BASE_DIR / ".env")) +memory = Memory(CACHE_DIR, verbose=0) + def get_username(): return pwd.getpwuid(os.getuid())[0] diff --git a/etl/grapher/model.py b/etl/grapher/model.py index 822f5bf2ec6..9f6881d0df5 100644 --- a/etl/grapher/model.py +++ b/etl/grapher/model.py @@ -54,6 +54,7 @@ ) from sqlalchemy import JSON as _JSON from sqlalchemy.dialects.mysql import ( + DOUBLE, ENUM, LONGBLOB, LONGTEXT, @@ -1815,6 +1816,68 @@ def get_conflict_batch( return conflicts +RELATED_CHART_LABEL = Literal["good", "bad", "neutral"] + + +class RelatedChart(Base): + __tablename__ = "related_charts" + __table_args__ = ( + ForeignKeyConstraint( + ["chartId"], ["charts.id"], ondelete="CASCADE", onupdate="CASCADE", name="related_charts_ibfk_1" + ), + ForeignKeyConstraint( + ["relatedChartId"], ["charts.id"], ondelete="CASCADE", onupdate="CASCADE", name="related_charts_ibfk_2" + ), + # Existing Index on chartId + Index("idx_related_charts_chartId", "chartId"), + # 1) Unique index on (chartId, relatedChartId, reviewer) + Index("uq_chartId_relatedChartId_reviewer", "chartId", "relatedChartId", "reviewer", unique=True), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False) + chartId: Mapped[int] = mapped_column(Integer, nullable=False) + relatedChartId: Mapped[int] = mapped_column(Integer, nullable=False) + label: Mapped[RELATED_CHART_LABEL] = mapped_column(VARCHAR(255), nullable=False) + reviewer: Mapped[Optional[str]] = mapped_column(VARCHAR(255)) + score: Mapped[Optional[float]] = mapped_column(DOUBLE, default=None) + reason: Mapped[Optional[str]] = mapped_column(TEXT, default=None) + updatedAt: Mapped[datetime] = mapped_column(DateTime, default=func.utc_timestamp()) + + @classmethod + def load(cls, session: Session, chart_id: Optional[int] = None) -> list["RelatedChart"]: + # Exclude "production" reviewer which is generated automatically + stm = select(cls).where(cls.reviewer != "production") + + if chart_id is None: + records = session.scalars(stm).all() + else: + records = session.scalars(stm.where(cls.chartId == chart_id)).all() + return list(records) + + def upsert( + self, + session: Session, + ) -> "RelatedChart": + cls = self.__class__ + + ds = session.scalars( + select(cls).where( + cls.chartId == self.chartId, cls.relatedChartId == self.relatedChartId, cls.reviewer == self.reviewer + ) + ).one_or_none() + + if not ds: + ds = self + else: + ds.label = self.label + ds.reason = self.reason + ds.updatedAt = func.utc_timestamp() + + session.add(ds) + session.flush() + return ds + + class MultiDimDataPage(Base): __tablename__ = "multi_dim_data_pages"