diff --git a/apps/wizard/app_pages/indicator_upgrade/charts_update.py b/apps/wizard/app_pages/indicator_upgrade/charts_update.py index 0e0f631b0a4e..7a1ed4bdf5f7 100644 --- a/apps/wizard/app_pages/indicator_upgrade/charts_update.py +++ b/apps/wizard/app_pages/indicator_upgrade/charts_update.py @@ -11,7 +11,7 @@ import etl.grapher.model as gm from apps.chart_sync.admin_api import AdminAPI from apps.wizard.utils import set_states -from apps.wizard.utils.cached import get_grapher_user_id +from apps.wizard.utils.cached import get_grapher_user from apps.wizard.utils.components import st_toast_error, st_wizard_page_link from apps.wizard.utils.db import WizardDB from etl.config import OWID_ENV @@ -97,7 +97,8 @@ def push_new_charts(charts: List[gm.Chart]) -> None: """Updating charts in the database.""" # Use Tailscale user if it is available, otherwise use GRAPHER_USER_ID from env if "X-Forwarded-For" in st.context.headers: - grapher_user_id = get_grapher_user_id(st.context.headers["X-Forwarded-For"]) + grapher_user = get_grapher_user(st.context.headers["X-Forwarded-For"]) + grapher_user_id = grapher_user.id if grapher_user else None else: grapher_user_id = None diff --git a/apps/wizard/app_pages/insight_search/embeddings.py b/apps/wizard/app_pages/insight_search/embeddings.py index ddd31562743b..35f96dbacaad 100644 --- a/apps/wizard/app_pages/insight_search/embeddings.py +++ b/apps/wizard/app_pages/insight_search/embeddings.py @@ -5,13 +5,10 @@ import streamlit as st import torch -from joblib import Memory from sentence_transformers import SentenceTransformer, util from structlog import get_logger -from etl.paths import CACHE_DIR - -memory = Memory(CACHE_DIR, verbose=0) +from etl.config import memory # Initialize log. log = get_logger() diff --git a/apps/wizard/app_pages/producer_analytics.py b/apps/wizard/app_pages/producer_analytics.py new file mode 100644 index 000000000000..b6f2d17aaff6 --- /dev/null +++ b/apps/wizard/app_pages/producer_analytics.py @@ -0,0 +1,564 @@ +from datetime import datetime, timedelta +from typing import Optional, cast + +import owid.catalog.processing as pr +import pandas as pd +import plotly.express as px +import streamlit as st +from st_aggrid import AgGrid, GridUpdateMode, JsCode +from st_aggrid.grid_options_builder import GridOptionsBuilder +from structlog import get_logger + +from apps.wizard.utils.components import st_horizontal +from apps.wizard.utils.db import read_gbq +from etl.snapshot import Snapshot +from etl.version_tracker import VersionTracker + +# Initialize log. +log = get_logger() + +# Define constants. +TODAY = datetime.today() +# Date when the new views metric started to be recorded. +MIN_DATE = datetime.strptime("2024-11-01", "%Y-%m-%d") +GRAPHERS_BASE_URL = "https://ourworldindata.org/grapher/" +# List of auxiliary steps to be (optionally) excluded from the DAG. +# It may be convenient to ignore these steps because the analytics are heavily affected by a few producers (e.g. those that are involved in the population and income groups datasets). +AUXILIARY_STEPS = [ + "data://garden/demography/.*/population", + # Primary energy consumption is loaded by GCB. + "data://garden/energy/.*/primary_energy_consumption", + "data://garden/ggdc/.*/maddison_project_database", + "data://garden/wb/.*/income_groups", +] + +# PAGE CONFIG +st.set_page_config( + page_title="Wizard: Producer analytics", + layout="wide", + page_icon="🪄", +) + + +######################################################################################################################## +# FUNCTIONS & GLOBAL VARS +######################################################################################################################## +def columns_producer(min_date, max_date): + # Define columns to be shown. + cols_prod = { + "producer": { + "headerName": "Producer", + "headerTooltip": "Name of the producer. This is NOT the name of the dataset.", + }, + "n_charts": { + "headerName": "Charts", + "headerTooltip": "Number of charts using data from a producer.", + }, + "renders_custom": { + "headerName": "Views in custom range", + "headerTooltip": f"Number of renders between {min_date} and {max_date}.", + }, + "renders_365d": { + "headerName": "Views 365 days", + "headerTooltip": "Number of renders in the last 365 days.", + }, + "renders_30d": { + "headerName": "Views 30 days", + "headerTooltip": "Number of renders in the last 30 days.", + }, + } + return cols_prod + + +@st.cache_data(show_spinner=False) +def get_grapher_views( + date_start: str = MIN_DATE.strftime("%Y-%m-%d"), + date_end: str = TODAY.strftime("%Y-%m-%d"), + groupby: Optional[list[str]] = None, + grapher_urls: Optional[list[str]] = None, +) -> pd.DataFrame: + grapher_filter = "" + if grapher_urls: + # If a list of grapher URLs is given, consider only those. + grapher_urls_formatted = ", ".join(f"'{url}'" for url in grapher_urls) + grapher_filter = f"AND grapher IN ({grapher_urls_formatted})" + else: + # If no list is given, consider all grapher URLs. + grapher_filter = f"AND grapher LIKE '{GRAPHERS_BASE_URL}%'" + + if not groupby: + # If a groupby list is not given, assume the simplest case, which gives total views for each grapher. + groupby = ["grapher"] + + # Prepare the query. + groupby_clause = ", ".join(groupby) + select_clause = f"{groupby_clause}, SUM(events) AS renders" + query = f""" + SELECT + {select_clause} + FROM prod_google_analytics4.grapher_views_by_day_page_grapher_device_country_iframe + WHERE + day >= '{date_start}' + AND day <= '{date_end}' + {grapher_filter} + GROUP BY {groupby_clause} + ORDER BY {groupby_clause} + """ + + # Execute the query. + df_views = read_gbq(query, project_id="owid-analytics") + + return cast(pd.DataFrame, df_views) + + +@st.cache_data(show_spinner=False) +def get_chart_renders(min_date: str, max_date: str) -> pd.DataFrame: + # List ranges of dates to fetch views. + date_ranges = { + "renders_365d": ((TODAY - timedelta(days=365)).strftime("%Y-%m-%d"), TODAY.strftime("%Y-%m-%d")), + "renders_30d": ((TODAY - timedelta(days=30)).strftime("%Y-%m-%d"), TODAY.strftime("%Y-%m-%d")), + "renders_custom": (min_date, max_date), # Use user-defined date range. + } + + # Get analytics for those ranges, for all grapher URLs. + list_renders = [ + get_grapher_views(date_start=date_start, date_end=date_end, grapher_urls=None, groupby=["grapher"]).rename( + columns={"renders": column_name} + ) + for column_name, (date_start, date_end) in date_ranges.items() + ] + + # Merge all dataframes. + df_renders = pr.multi_merge(list_renders, on="grapher", how="outer") # type: ignore + + return df_renders + + +@st.cache_data(show_spinner=False) +def load_steps_df(excluded_steps) -> pd.DataFrame: + # Load steps dataframe. + steps_df = VersionTracker(exclude_steps=excluded_steps).steps_df + + return steps_df + + +@st.cache_data(show_spinner=False) +def load_steps_df_with_producer_data(excluded_steps) -> pd.DataFrame: + # Load steps dataframe. + # st.toast("⌛ Loading data from VersionTracker...") + steps_df = load_steps_df(excluded_steps=excluded_steps) + + # st.toast("⌛ Processing VersionTracker data...") + # Select only active snapshots. + df = steps_df[(steps_df["channel"] == "snapshot") & (steps_df["state"] == "active")].reset_index(drop=True) + + # Select only relevant columns. + df = df[["step", "all_chart_slugs"]] + + # Add a column of producer to steps df (where possible). + for i, row in df.iterrows(): + snap_uri = row["step"].split("snapshot://" if "snapshot://" in row["step"] else "snapshot-private://")[1] + snap = Snapshot(snap_uri) + origin = snap.metadata.origin + if (origin is not None) and (snap.metadata.namespace not in ["dummy"]): + producer = snap.metadata.origin.producer # type: ignore + df.loc[i, "producer"] = producer + + # Select only relevant columns. + df = df[["producer", "all_chart_slugs"]] + + # Remove rows with no producer. + df = df.dropna(subset=["producer"]).reset_index(drop=True) + + # Ignore the chart id, and keep only the slug. + df["all_chart_slugs"] = [sorted(set([slug for _, slug in id_slug])) for id_slug in df["all_chart_slugs"]] + + # Create a row for each producer-slug pair. Fill with "" (in cases where the producer has no charts). + df_expanded = df.explode("all_chart_slugs") + + # Remove duplicates. + # NOTE: This happens because df contains one row per snapshot. Some grapher datasets come from a combination of multiple snapshots (often from the same producer). We want to count producer-chart pairs only once. + df_expanded = df_expanded.drop_duplicates(subset=["producer", "all_chart_slugs"]).reset_index(drop=True) + + # Add a column for grapher URL. + df_expanded["grapher"] = GRAPHERS_BASE_URL + df_expanded["all_chart_slugs"] + + return df_expanded + + +@st.cache_data(show_spinner=False) +def get_producer_charts_analytics(min_date, max_date, excluded_steps): + # Get chart renders using user-defined date range for "renders_custom". + # st.toast("⌛ Getting analytics on chart renders...") + df_renders = get_chart_renders(min_date=min_date, max_date=max_date) + + # Load the steps dataframe with producer data. + df_expanded = load_steps_df_with_producer_data(excluded_steps=excluded_steps) + + # Add columns with the numbers of chart renders. + df_expanded = df_expanded.merge(df_renders, on="grapher", how="left").drop(columns=["all_chart_slugs"]) + + return df_expanded + + +@st.cache_data(show_spinner=False) +def get_producer_analytics_per_chart(min_date, max_date, excluded_steps): + # Load the steps dataframe with producer data and analytics. + df_expanded = get_producer_charts_analytics(min_date=min_date, max_date=max_date, excluded_steps=excluded_steps) + + # Create an expanded table with number of views per chart. + df_renders_per_chart = df_expanded.dropna(subset=["grapher"]).fillna(0).reset_index(drop=True) + df_renders_per_chart = df_renders_per_chart.sort_values("renders_custom", ascending=False).reset_index(drop=True) + + return df_renders_per_chart + + +@st.cache_data(show_spinner=False) +def get_producer_analytics_per_producer(min_date, max_date, excluded_steps): + # Load the steps dataframe with producer data and analytics. + df_expanded = get_producer_charts_analytics(min_date=min_date, max_date=max_date, excluded_steps=excluded_steps) + + # st.toast("⌛ Adapting the data for presentation...") + # Group by producer and get the full list of chart slugs for each producer. + df_grouped = df_expanded.groupby("producer", observed=True, as_index=False).agg( + { + "grapher": lambda x: [item for item in x if pd.notna(item)], # Filter out NaN values + "renders_365d": "sum", + "renders_30d": "sum", + "renders_custom": "sum", + } + ) + df_grouped["n_charts"] = df_grouped["grapher"].apply(len) + + # Check if lists are unique. If not, make them unique in the previous line. + error = "Duplicated chart slugs found for a given producer." + assert df_grouped["grapher"].apply(lambda x: len(x) == len(set(x))).all(), error + + # Drop unnecessary columns. + df_grouped = df_grouped.drop(columns=["grapher"]) + + # Sort conveniently. + df_grouped = df_grouped.sort_values(["renders_custom"], ascending=False).reset_index(drop=True) + + return df_grouped + + +def show_producers_grid(df_producers, min_date, max_date): + """Show table with producers analytics.""" + gb = GridOptionsBuilder.from_dataframe(df_producers) + gb.configure_grid_options(domLayout="autoHeight", enableCellTextSelection=True) + gb.configure_selection( + selection_mode="multiple", + use_checkbox=True, + rowMultiSelectWithClick=True, + suppressRowDeselection=False, + groupSelectsChildren=True, + groupSelectsFiltered=True, + ) + gb.configure_default_column(editable=False, groupable=True, sortable=True, filterable=True, resizable=True) + + # Enable column auto-sizing for the grid. + gb.configure_grid_options(suppressSizeToFit=False) # Allows dynamic resizing to fit. + gb.configure_default_column(autoSizeColumns=True) # Ensures all columns can auto-size. + + # Configure individual columns with specific settings. + COLUMNS_PRODUCERS = columns_producer(min_date, max_date) + for column in COLUMNS_PRODUCERS: + gb.configure_column(column, **COLUMNS_PRODUCERS[column]) + # Configure pagination with dynamic page size. + gb.configure_pagination(paginationAutoPageSize=False, paginationPageSize=20) + # Build the grid options. + grid_options = gb.build() + # Custom CSS to ensure the table stretches across the page. + custom_css = { + ".ag-theme-streamlit": { + "max-width": "100% !important", + "width": "100% !important", + "margin": "0 auto !important", # Centers the grid horizontally. + }, + } + # Display the grid table with the updated grid options. + grid_response = AgGrid( + data=df_producers, + gridOptions=grid_options, + height=1000, + width="100%", + update_mode=GridUpdateMode.MODEL_CHANGED, + fit_columns_on_grid_load=True, # Automatically adjust columns when the grid loads. + allow_unsafe_jscode=True, + theme="streamlit", + custom_css=custom_css, + # excel_export_mode=ExcelExportMode.MANUAL, # Doesn't work? + ) + + # Get the selected producers from the first table. + producers_selected = [row["producer"] for row in grid_response["selected_rows"]] + + return producers_selected + + +def plot_chart_analytics(df): + """Show chart with analytics on producer's charts.""" + # Get total daily views of selected producers. + grapher_urls_selected = df["grapher"].unique().tolist() # type: ignore + df_total_daily_views = get_grapher_views( + date_start=min_date, date_end=max_date, groupby=["day"], grapher_urls=grapher_urls_selected + ) + + # Get daily views of the top 10 charts. + grapher_urls_top_10 = ( + df.sort_values("renders_custom", ascending=False)["grapher"].unique().tolist()[0:10] # type: ignore + ) + df_top_10_daily_views = get_grapher_views( + date_start=min_date, date_end=max_date, groupby=["day", "grapher"], grapher_urls=grapher_urls_top_10 + ) + + # Get total number of views and average daily views. + total_views = df_total_daily_views["renders"].sum() + average_daily_views = df_total_daily_views["renders"].mean() + # Get total views of the top 10 charts in the selected date range. + df_top_10_total_views = df_top_10_daily_views.groupby("grapher", as_index=False).agg({"renders": "sum"}) + + # Create a line chart. + df_plot = pd.concat([df_total_daily_views.assign(**{"grapher": "Total"}), df_top_10_daily_views]).rename( + columns={"grapher": "Chart slug"} + ) + df_plot["Chart slug"] = df_plot["Chart slug"].apply(lambda x: x.split("/")[-1]) + df_plot["day"] = pd.to_datetime(df_plot["day"]).dt.strftime("%a. %Y-%m-%d") + fig = px.line( + df_plot, + x="day", + y="renders", + color="Chart slug", + title="Total daily views and views of top 10 charts", + ).update_layout(xaxis_title=None, yaxis_title=None) + + # Display the chart. + st.plotly_chart(fig, use_container_width=True) + + return total_views, average_daily_views, df_top_10_total_views + + +def show_producer_charts_grid(df): + """Show table with analytics on producer's charts.""" + # Configure and display the second table. + gb2 = GridOptionsBuilder.from_dataframe(df) + gb2.configure_grid_options(domLayout="autoHeight", enableCellTextSelection=True) + gb2.configure_default_column(editable=False, groupable=True, sortable=True, filterable=True, resizable=True) + + # Create a JavaScript renderer for clickable slugs. + grapher_slug_jscode = JsCode( + r""" + class UrlCellRenderer { + init(params) { + this.eGui = document.createElement('a'); + if (params.value) { + // Extract the slug from the full URL. + const url = new URL(params.value); + const slug = url.pathname.split('/').pop(); // Get the last part of the path as the slug. + this.eGui.innerText = slug; + this.eGui.setAttribute('href', params.value); + } else { + this.eGui.innerText = ''; + } + this.eGui.setAttribute('style', "text-decoration:none; color:blue"); + this.eGui.setAttribute('target', "_blank"); + } + getGui() { + return this.eGui; + } + } + """ + ) + + # Define columns to be shown, including the cell renderer for "grapher". + COLUMNS_PRODUCERS = columns_producer(min_date, max_date) + COLUMNS_PRODUCER_CHARTS = { + column: ( + { + "headerName": "Chart URL", + "headerTooltip": "URL of the chart in the grapher.", + "cellRenderer": grapher_slug_jscode, + } + if column == "grapher" + else COLUMNS_PRODUCERS[column] + ) + for column in ["producer", "renders_custom", "renders_365d", "renders_30d", "grapher"] + } + # Configure and display the second table. + gb2 = GridOptionsBuilder.from_dataframe(df) + gb2.configure_grid_options(domLayout="autoHeight", enableCellTextSelection=True) + gb2.configure_default_column(editable=False, groupable=True, sortable=True, filterable=True, resizable=True) + + # Apply column configurations directly from the dictionary. + for column, config in COLUMNS_PRODUCER_CHARTS.items(): + gb2.configure_column(column, **config) + + # Configure pagination with dynamic page size. + gb2.configure_pagination(paginationAutoPageSize=False, paginationPageSize=20) + grid_options2 = gb2.build() + + # Display the grid. + AgGrid( + data=df, + gridOptions=grid_options2, + height=500, + width="100%", + fit_columns_on_grid_load=True, + allow_unsafe_jscode=True, + theme="streamlit", + # excel_export_mode=ExcelExportMode.MANUAL, # Doesn't work? + ) + + +def prepare_summary( + df_top_10_total_views, producers_selected, total_views, average_daily_views, min_date, max_date +) -> str: + """Prepare summary at the end of the app.""" + # Prepare the total number of views. + total_views_str = f"{total_views:9,}" + # Prepare the average daily views. + average_views_str = f"{round(average_daily_views):9,}" + # Prepare a summary of the top 10 charts to be copy-pasted. + if len(producers_selected) == 0: + producers_selected_str = "all producers" + elif len(producers_selected) == 1: + producers_selected_str = producers_selected[0] + else: + producers_selected_str = ", ".join(producers_selected[:-1]) + " and " + producers_selected[-1] + # NOTE: I tried .to_string() and .to_markdown() and couldn't find a way to keep a meaningful format. + df_summary_str = "" + for _, row in df_top_10_total_views.sort_values("renders", ascending=False).iterrows(): + df_summary_str += f"{row['renders']:9,}" + " - " + row["grapher"] + "\n" + + # Define the content to copy. + summary = f"""\ +Analytics of charts using data by {producers_selected_str} between {min_date} and {max_date}: +- Total number of chart views: {total_views_str} +- Average daily chart views: {average_views_str} +- Views of top performing charts: +{df_summary_str} + + """ + return summary + + +######################################################################################################################## +# RENDER +######################################################################################################################## + +# Streamlit app layout. +st.title(":material/analytics: Producer analytics") +st.markdown("Explore analytics of data producers.") + +# SEARCH BOX +with st.container(border=True): + st.markdown( + f"Select a custom date range (note that this metric started to be recorded on {MIN_DATE.strftime('%Y-%m-%d')})." + ) + + with st_horizontal(vertical_alignment="center"): + # Create input fields for minimum and maximum dates. + min_date = st.date_input( + "Select minimum date", + value=MIN_DATE, + key="min_date", + format="YYYY-MM-DD", + ).strftime( # type: ignore + "%Y-%m-%d" + ) + max_date = st.date_input( + "Select maximum date", + value=TODAY, + key="max_date", + format="YYYY-MM-DD", + ).strftime( # type: ignore + "%Y-%m-%d" + ) + exclude_auxiliary_steps = st.checkbox( + "Exclude auxiliary steps (e.g. population)", + False, + help="Exclude steps that are commonly used as auxiliary data, so they do not skew the analytics in favor of a few producers. But note that this will exclude all uses of these steps, even when they are the main datasets (not auxiliary). Auxiliary steps are:\n- " + + "\n- ".join(sorted(f"`{s}`" for s in AUXILIARY_STEPS)), + ) + +if exclude_auxiliary_steps: + # If the user wants to exclude auxiliary steps, take the default list of excluded steps. + excluded_steps = AUXILIARY_STEPS +else: + # Otherwise, do not exclude any steps. + excluded_steps = [] + +######################################################################################################################## +# 1/ PRODUCER ANALYTICS: Display main table, with analytics per producer. +# Allow the user to select a subset of producers. +######################################################################################################################## +st.header("Analytics by producer") +st.markdown( + "Total number of charts and chart views for each producer. Producers selected in this table will be used to filter the producer-charts table below." +) + +# Load table content and select only columns to be shown. +with st.spinner("Loading producer data. We are accessing various databases. This can take few seconds..."): + df_producers = get_producer_analytics_per_producer( + min_date=min_date, max_date=max_date, excluded_steps=excluded_steps + ) + +# Prepare and display the grid table with producer analytics. +producers_selected = show_producers_grid( + df_producers=df_producers, + min_date=min_date, + max_date=max_date, +) + +######################################################################################################################## +# 2/ CHART ANALYTICS: Display a chart with the total number of daily views, and the daily views of the top performing charts. +######################################################################################################################## +st.header("Analytics by chart") +st.markdown("Number of views for each chart that uses data by the selected producers.") + +# Load detailed analytics per producer-chart. +with st.spinner("Loading chart data. This can take few seconds..."): + df_producer_charts = get_producer_analytics_per_chart( + min_date=min_date, max_date=max_date, excluded_steps=excluded_steps + ) + +# Get the selected producers from the first table. +if len(producers_selected) == 0: + # If no producers are selected, show all producer-charts. + df_producer_charts_filtered = df_producer_charts +else: + # Filter producer-charts by selected producers. + df_producer_charts_filtered = df_producer_charts[df_producer_charts["producer"].isin(producers_selected)] + +# Show chart with chart analytics, and get some summary data. +total_views, average_daily_views, df_top_10_total_views = plot_chart_analytics(df_producer_charts_filtered) + +# Show table +show_producer_charts_grid(df_producer_charts_filtered) + +######################################################################################################################## +# 3/ SUMMARY: Display a summary to be shared with the data producer. +######################################################################################################################## + +# Prepare the summary to be copy-pasted. +summary = prepare_summary( + df_top_10_total_views=df_top_10_total_views, + producers_selected=producers_selected, + total_views=total_views, + average_daily_views=average_daily_views, + min_date=min_date, + max_date=max_date, +) + +# Display the content. +st.markdown( + """## Summary for data producers + +For now, to share analytics with a data producer you can so any of the following: +- **Table export**: Right-click on a cell in the above's table and export as a CSV or Excel file. +- **Chart export**: Click on the camera icon on the top right of the chart to download the chart as a PNG. +- **Copy summary**: Click on the upper right corner of the box below to copy the summary to the clipboard. +""" +) +st.code(summary, language="text") diff --git a/apps/wizard/app_pages/similar_charts/app.py b/apps/wizard/app_pages/similar_charts/app.py index 279a8b2c6c1e..b5bad2b8c3af 100644 --- a/apps/wizard/app_pages/similar_charts/app.py +++ b/apps/wizard/app_pages/similar_charts/app.py @@ -1,17 +1,38 @@ +import datetime as dt import random +from enum import Enum +from typing import Dict, List, Optional, get_args import pandas as pd import streamlit as st +from sqlalchemy.orm import Session from structlog import get_logger from apps.wizard.app_pages.similar_charts import data, scoring from apps.wizard.utils import embeddings as emb +from apps.wizard.utils import start_profiler +from apps.wizard.utils.cached import get_grapher_user from apps.wizard.utils.components import Pagination, st_horizontal, st_multiselect_wider, url_persist from etl.config import OWID_ENV +from etl.db import get_engine +from etl.git_helpers import log_time +from etl.grapher import model as gm + +PROFILER = start_profiler() + +ITEMS_PER_PAGE = 20 # Initialize log. log = get_logger() +# Database engine. +engine = get_engine() + +# Get reviewer's name. +grapher_user = get_grapher_user(st.context.headers.get("X-Forwarded-For")) +assert grapher_user, "User not found" +reviewer = grapher_user.fullName + # PAGE CONFIG st.set_page_config( page_title="Wizard: Similar Charts", @@ -20,16 +41,35 @@ ) ######################################################################################################################## -# FUNCTIONS +# CONSTANTS & FUNCTIONS ######################################################################################################################## - -@st.cache_data(show_spinner=False, persist="disk") +DISPLAY_STATE_OPTIONS = { + "good": { + "label": "Good", + "color": "green", + "icon": "✅", + }, + "bad": { + "label": "Bad", + "color": "red", + "icon": "❌", + }, + "neutral": { + "label": "Neutral", + "color": "gray", + "icon": "⏳", + }, +} + +CHART_LABELS = get_args(gm.RELATED_CHART_LABEL) + + +@log_time +@st.cache_data(show_spinner=False) def get_charts() -> list[data.Chart]: with st.spinner("Loading charts..."): - # Get charts from the database.. df = data.get_raw_charts() - charts = df.to_dict(orient="records") ret = [] @@ -40,40 +80,40 @@ def get_charts() -> list[data.Chart]: return ret -def st_chart_info(chart: data.Chart) -> None: +@log_time +@st.cache_data(show_spinner=False) +def get_coviews() -> pd.Series: + # Load coviews for all charts for the past 365 days. + with st.spinner("Loading coviews..."): + return data.get_coviews_sessions(after_date=str(dt.date.today() - dt.timedelta(days=365)), min_sessions=3) + + +def st_chart_info(chart: data.Chart, show_coviews=True) -> None: + """Displays general info about a single chart.""" chart_url = OWID_ENV.chart_site(chart.slug) - title = f"#### [{chart.title}]({chart_url})" + # title = f"#### [{chart.title}]({chart_url})" + title = f"[{chart.title}]({chart_url})" if chart.gpt_reason: title += " 🤖" - st.markdown(title) + st.subheader(title, anchor=chart.slug) st.markdown(f"Slug: {chart.slug}") st.markdown(f"Subtitle: {chart.subtitle}") st.markdown(f"Tags: **{', '.join(chart.tags)}**") st.markdown(f"Pageviews: **{chart.views_365d}**") + if show_coviews: + st.markdown(f"Coviews: **{chart.coviews}**") def st_chart_scores(chart: data.Chart, sim_components: pd.DataFrame) -> None: - st.markdown(f"#### Similarity: {chart.similarity:.0%}") + """Displays scoring info (score, breakdown table) for a single chart.""" + st.markdown(f"#### Score: {chart.similarity:.0%}") st.table(sim_components.loc[chart.chart_id].to_frame("score").style.format("{:.0%}")) if chart.gpt_reason: st.markdown(f"**GPT Diversity Reason**:\n{chart.gpt_reason}") -def st_display_chart( - chart: data.Chart, - sim_components: pd.DataFrame = pd.DataFrame(), -) -> None: - with st.container(border=True): - col1, col2 = st.columns(2) - with col1: - st_chart_info(chart) - with col2: - st_chart_scores(chart, sim_components) - - def split_input_string(input_string: str) -> tuple[str, list[str], list[str]]: """Break input string into query, includes and excludes.""" - # Break input string into query, includes and excludes query = [] includes = [] excludes = [] @@ -88,28 +128,212 @@ def split_input_string(input_string: str) -> tuple[str, list[str], list[str]]: return " ".join(query), includes, excludes -@st.cache_data(show_spinner=False, max_entries=1) +@log_time +@st.cache_data( + show_spinner=False, + max_entries=1, + hash_funcs={list[data.Chart]: lambda charts: len(charts)}, +) def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: with st.spinner("Loading model..."): scoring_model = scoring.ScoringModel(emb.get_model()) - scoring_model.fit(charts) + with st.spinner("Fitting model..."): + scoring_model.fit(charts) return scoring_model ######################################################################################################################## -# Fetch all data indicators. +# NEW COMPONENTS +######################################################################################################################## + + +class RelatedChartDisplayer: + """ + Encapsulates the logic for displaying and labeling a related chart, + including any database updates and UI feedback. + """ + + def __init__(self, engine, chosen_chart: data.Chart, sim_components: pd.DataFrame): + self.engine = engine + self.chosen_chart = chosen_chart + self.sim_components = sim_components + + def display( + self, + chart: data.Chart, + label: gm.RELATED_CHART_LABEL = "neutral", + ) -> None: + """ + Renders the chart block (info, scores, and label radio). + Also hooks up the callback for label changes. + """ + with st.container(): + col1, col2 = st.columns(2) + with col1: + st_chart_info(chart) + st.radio( + label="**Review Related Chart**", + key=f"label-{chart.chart_id}", + options=CHART_LABELS, + index=CHART_LABELS.index(label), + horizontal=True, + format_func=lambda x: f":{DISPLAY_STATE_OPTIONS[x]['color']}-background[{DISPLAY_STATE_OPTIONS[x]['label']}]", + on_change=self._push_status, + kwargs={"chart": chart}, + ) + with col2: + st_chart_scores(chart, self.sim_components) + + def _push_status(self, chart: data.Chart) -> None: + """ + Callback: triggered on label change. Saves to the DB and + shows an appropriate toast. + """ + label: gm.RELATED_CHART_LABEL = st.session_state[f"label-{chart.chart_id}"] + + with Session(self.engine) as session: + gm.RelatedChart( + chartId=self.chosen_chart.chart_id, + relatedChartId=chart.chart_id, + label=label, + reviewer=reviewer, + ).upsert(session) + session.commit() + + # Notify user + with st.spinner(): + match label: + case "good": + st.toast(":green[Recommendation labeled as **good**]", icon="✅") + case "bad": + st.toast(":red[Recommendation labeled as **bad**]", icon="❌") + case "neutral": + st.toast("**Resetting** recommendation to neutral", icon=":material/restart_alt:") + + +def st_related_charts_table(related_charts: list[gm.RelatedChart], chart_map: dict[int, data.Chart]) -> None: + """ + Shows a "matrix" of reviews in a pivoted table using st.dataframe: + - Row per related chart + - Columns for slug, title, views_365d, link, and one column per reviewer (icon) + - Hides chart_id + """ + if not related_charts: + st.info("No related charts have been selected yet.") + return + + # 1) Convert the list of RelatedChart objects to a DataFrame + rows = [] + for rc in related_charts: + c = chart_map.get(rc.relatedChartId) + if not c: + # Skip if missing + continue + rows.append( + { + "chart_id": c.chart_id, + "slug": c.slug, + "title": c.title, + "views_365d": c.views_365d, + "coviews": c.coviews, + "score": c.similarity, + "reviewer": rc.reviewer, + "label": rc.label, + } + ) + df = pd.DataFrame(rows) + + # Exclude neutral reviews + df = df[df["label"] != "neutral"] + + # 2) Pivot so that each reviewer is a column, with the label as the cell value + pivot_df = df.pivot( + index=["chart_id", "slug", "title", "views_365d", "coviews", "score"], columns="reviewer", values="label" + ).fillna("neutral") + + reviewer_cols = list(pivot_df.columns) + + # 3) Map each label (good/bad/neutral) to an icon + def label_to_icon(label: str) -> str: + if label == "neutral": + return "" + else: + return DISPLAY_STATE_OPTIONS.get(label, DISPLAY_STATE_OPTIONS["neutral"])["icon"] + + pivot_df = pivot_df.applymap(label_to_icon) + + # 4) Flatten the multi-index so 'chart_id', 'slug', etc. become columns + pivot_df.reset_index(inplace=True) + + # 5) Remove chart_id (we're not interested in displaying it) + pivot_df.drop(columns=["chart_id"], inplace=True) + + # 6) Create a new column "link" + pivot_df["link"] = pivot_df["slug"].apply(lambda x: OWID_ENV.chart_site(x)) + # TODO: jump to anchor + # pivot_df["link"] = pivot_df["slug"].apply(lambda x: f"#{x}") + + # 7) Build the final column order + final_cols = ["slug", "title", "views_365d", "coviews", "score"] + reviewer_cols + ["link"] + + pivot_df = pivot_df[final_cols].sort_values(["score"], ascending=False) + + # 8) Configure columns for st.dataframe + column_config = { + # The link column becomes a clickable link + "link": st.column_config.LinkColumn( + "Open", + # display_text="Jump to detail", + display_text="Open", + ) + } + # You could also configure text columns or numeric columns (like "views_365d"). + + styled_df = pivot_df.style.format("{:.0%}", subset=["score"]) + + # 9) Show the result using st.dataframe + st.dataframe( + styled_df, + use_container_width=True, + hide_index=True, + column_config=column_config, + ) + + +def add_coviews_to_charts(charts: List[data.Chart], chosen_chart: data.Chart, coviews: pd.Series) -> List[data.Chart]: + try: + chosen_chart_coviews = coviews.loc[chosen_chart.slug].to_dict() + except KeyError: + chosen_chart_coviews = {} + + for c in charts: + c.coviews = chosen_chart_coviews.get(c.slug, 0) + + return charts + + +######################################################################################################################## +# FETCH DATA & MODEL +######################################################################################################################## + charts = get_charts() -# Get scoring model. +coviews = get_coviews() + scoring_model = get_and_fit_model(charts) +# Re-set charts if the model comes from cache +scoring_model.charts = charts -######################################################################################################################## +# Build a chart map for quick lookups by chart_id +chart_map = {chart.chart_id: chart for chart in charts} + +# Pick top 100 charts by pageviews. +top_100_charts: list[data.Chart] = sorted(charts, key=lambda x: x.views_365d, reverse=True)[:100] # type: ignore ######################################################################################################################## # RENDER ######################################################################################################################## -# Streamlit app layout. st.title(":material/search: Similar charts") col1, col2 = st.columns(2) @@ -117,6 +341,7 @@ def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: st_multiselect_wider() with st_horizontal(): random_chart = st.button("Random chart", help="Get a random chart.") + random_100_chart = st.button("Random top 100 chart", help="Get a random chart from the top 100 charts.") # Filter indicators diversity_gpt = url_persist(st.checkbox, default=True)( @@ -127,41 +352,35 @@ def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: # Random chart was pressed or no search text if random_chart or not st.query_params.get("chart_search_text"): - chart_slug = random.sample(charts, 1)[0].slug + # weighted by views + chart = random.choices(charts, weights=[c.views_365d for c in charts], k=1)[0] + # non-weighted sample + # chart = random.sample(charts, 1)[0] + st.session_state["chart_search_text"] = chart.slug + elif random_100_chart: + chart_slug = random.sample(top_100_charts, 1)[0].slug st.session_state["chart_search_text"] = chart_slug - # chart_search_text = url_persist(st.text_input)( - # key="chart_search_text", - # label="Chart slug or ID", - # placeholder="Type something...", - # ) - + # Dropdown select for chart. chart_search_text = url_persist(st.selectbox)( "Select a chart", key="chart_search_text", options=[c.slug for c in charts], ) - # Advanced expander. + # Advanced options st.session_state.sim_charts_expander_advanced_options = st.session_state.get( "sim_charts_expander_advanced_options", False ) - - # Weights for each score with st.expander("Advanced options", expanded=st.session_state.sim_charts_expander_advanced_options): - # Add text area for system prompt system_prompt = url_persist(st.text_area, default=scoring.DEFAULT_SYSTEM_PROMPT)( "GPT prompt for selecting diverse results", key="gpt_system_prompt", height=150, ) - for score_name in ["title", "subtitle", "tags", "pageviews", "share_indicator"]: - # For some reason, if the slider minimum value is zero, streamlit raises an error when the slider is - # dragged to the minimum. Set it to a small, non-zero number. + for score_name in ["title", "subtitle", "tags", "share_indicator", "pageviews_score", "coviews_score"]: key = f"w_{score_name}" - - # Set default values if key not in st.session_state: st.session_state[key] = scoring.DEFAULT_WEIGHTS[score_name] @@ -169,68 +388,120 @@ def get_and_fit_model(charts: list[data.Chart]) -> scoring.ScoringModel: f"Weight for {score_name} score", min_value=1e-9, max_value=1.0, - # step=0.001, key=key, ) - scoring_model.weights[score_name] = st.session_state[key] - -# Find a chart based on inputs +# Find a chart chosen_chart = next( (chart for chart in charts if chart.slug == chart_search_text or str(chart.chart_id) == chart_search_text), None, ) if not chosen_chart: st.error(f"Chart with slug {chart_search_text} not found.") + st.stop() - # # Find a chart by title - # chart_id = scoring_model.similar_chart_by_title(chart_search_text) - # chosen_chart = next((chart for chart in charts if chart.chart_id == chart_id), None) - -assert chosen_chart - -# Display chosen chart -with col1: - st_chart_info(chosen_chart) +# Add coviews +charts = add_coviews_to_charts(charts, chosen_chart, coviews) +# Load "official" related charts from DB +with Session(engine) as session: + related_charts_db = gm.RelatedChart.load(session, chart_id=chosen_chart.chart_id) -# Horizontal divider -st.markdown("---") - +# Compute similarity for all charts sim_dict = scoring_model.similarity(chosen_chart) sim_components = scoring_model.similarity_components(chosen_chart) -for chart in charts: - chart.similarity = sim_dict[chart.chart_id] +# Assign similarity +for c in charts: + c.similarity = sim_dict[c.chart_id] +# Sort by similarity sorted_charts = sorted(charts, key=lambda x: x.similarity, reverse=True) # type: ignore -# Postprocess charts with GPT and prioritize diversity +# Add reviews to top charts by similarity +for c in sorted_charts[:5]: + if c.chart_id == chosen_chart.chart_id: + continue + # Add to related charts + related_charts_db.append( + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=c.chart_id, + label="good", + reviewer="🔮", + ) + ) + +# Possibly re-rank with GPT for diversity if diversity_gpt: with st.spinner("Diversifying chart results..."): slugs_to_reasons = scoring.gpt_diverse_charts(chosen_chart, sorted_charts, system_prompt=system_prompt) - for chart in sorted_charts: - if chart.slug in slugs_to_reasons: - chart.gpt_reason = slugs_to_reasons[chart.slug] + for c in sorted_charts: + if c.slug in slugs_to_reasons: + c.gpt_reason = slugs_to_reasons[c.slug] + + # Add to related charts + related_charts_db.append( + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=c.chart_id, + label="good", + reviewer="🤖", + ) + ) + + +# Add coviews reviewer +for chart_id in sim_components.sort_values("coviews_score", ascending=False).index[:5]: + c = chart_map[chart_id] + # Don't recommend zero coviews + if c.coviews == 0: + continue + related_charts_db.append( + gm.RelatedChart( + chartId=chosen_chart.chart_id, + relatedChartId=c.chart_id, + label="good", + reviewer="👥", + ) + ) - # Put charts that are diverse at the top - # sorted_charts = sorted(sorted_charts, key=lambda x: (x.gpt_reason is not None, x.similarity), reverse=True) +# Display chosen chart +with col1: + st_chart_info(chosen_chart, show_coviews=False) + +# Divider +st.markdown("---") +st.header("Reviewed Related Charts") +st_related_charts_table(related_charts_db, chart_map) + +# Divider +st.markdown("---") +st.header("Recommended Related Charts") + +# Create our new chart display component +displayer = RelatedChartDisplayer(engine, chosen_chart, sim_components) # Use pagination -items_per_page = 20 pagination = Pagination( - items=sorted_charts, - items_per_page=items_per_page, + items=sorted_charts[:100], + items_per_page=ITEMS_PER_PAGE, pagination_key=f"pagination-di-search-{chosen_chart.slug}", ) - -if len(charts) > items_per_page: +if len(sorted_charts) > ITEMS_PER_PAGE: pagination.show_controls(mode="bar") -# Show items (only current page) +# Display only the current page for item in pagination.get_page_items(): - # Don't show the chosen chart if item.slug == chosen_chart.slug: continue - st_display_chart(item, sim_components) + + # Check if we have a DB label for the related chart from us + labels = [r.label for r in related_charts_db if r.relatedChartId == item.chart_id and r.reviewer == reviewer] + label = labels[0] if labels else "neutral" + + # Use the new component to display + displayer.display(chart=item, label=label) # type: ignore + +PROFILER.stop() diff --git a/apps/wizard/app_pages/similar_charts/data.py b/apps/wizard/app_pages/similar_charts/data.py index 3690aa6f9aa4..b0431f1a4f9d 100644 --- a/apps/wizard/app_pages/similar_charts/data.py +++ b/apps/wizard/app_pages/similar_charts/data.py @@ -1,11 +1,14 @@ from dataclasses import dataclass from datetime import datetime -from typing import Optional +from typing import Literal, Optional, get_args import pandas as pd +from apps.wizard.utils.db import read_gbq from apps.wizard.utils.embeddings import Doc +from etl.config import memory from etl.db import read_sql +from etl.grapher import model as gm @dataclass @@ -21,6 +24,7 @@ class Chart(Doc): views_14d: Optional[int] = None views_365d: Optional[int] = None gpt_reason: Optional[str] = None + coviews: Optional[int] = None def get_raw_charts() -> pd.DataFrame: @@ -66,3 +70,33 @@ def get_raw_charts() -> pd.DataFrame: assert df["chart_id"].nunique() == df.shape[0] return df + + +@memory.cache +def get_coviews_sessions(after_date: str, min_sessions: int = 5) -> pd.Series: + """ + Count of sessions in which a pair of URLs are both visited, aggregated daily + + note: this is a nondirectional network. url1 and url2 are string sorted and + do not indicate anything about whether url1 was visited before/after url2 in + the session. + """ + query = f""" + SELECT + REGEXP_EXTRACT(url1, r'grapher/([^/]+)') AS slug1, + REGEXP_EXTRACT(url2, r'grapher/([^/]+)') AS slug2, + SUM(sessions_coviewed) AS total_sessions + FROM prod_google_analytics4.coviews_by_day_page + WHERE day >= '{after_date}' + AND url1 LIKE 'https://ourworldindata.org/grapher%' + AND url2 LIKE 'https://ourworldindata.org/grapher%' + GROUP BY slug1, slug2 + HAVING total_sessions >= {min_sessions} + """ + df = read_gbq(query, project_id="owid-analytics") + + # concat with reversed slug1 and slug2 + df = pd.concat([df, df.rename(columns={"slug1": "slug2", "slug2": "slug1"})]) + + # set index for faster lookups + return df.set_index(["slug1", "slug2"]).sort_index()["total_sessions"] diff --git a/apps/wizard/app_pages/similar_charts/scoring.py b/apps/wizard/app_pages/similar_charts/scoring.py index 7c60851d2f10..99de8fa31c80 100644 --- a/apps/wizard/app_pages/similar_charts/scoring.py +++ b/apps/wizard/app_pages/similar_charts/scoring.py @@ -21,11 +21,12 @@ # These are the default thresholds for the different scores. DEFAULT_WEIGHTS = { - "title": 0.4, + "title": 0.3, "subtitle": 0.1, "tags": 0.1, - "pageviews": 0.3, "share_indicator": 0.1, + "pageviews_score": 0.3, + "coviews_score": 0.1, } PREFIX_SYSTEM_PROMPT = """ @@ -62,7 +63,7 @@ def __init__(self, model: SentenceTransformer, weights: Optional[dict[str, float self.model = model self.weights = weights or DEFAULT_WEIGHTS.copy() - def fit(self, charts: list[Chart]): + def fit(self, charts: list[Chart]) -> None: self.charts = charts # Get embeddings for title and subtitle @@ -121,8 +122,9 @@ def similarity_components(self, chart: Chart) -> pd.DataFrame: "subtitle": subtitle_scores[i], # score 1 if there is at least one tag in common, 0 otherwise "tags": float(bool(set(c.tags) & set(chart.tags))), - "pageviews": c.views_365d or 0, "share_indicator": float(c.chart_id in charts_sharing_indicator), + "pageviews": c.views_365d or 0, + "coviews": c.coviews or 0, } ) @@ -134,11 +136,8 @@ def similarity_components(self, chart: Chart) -> pd.DataFrame: if chart.subtitle == "": ret["subtitle"] = 0 - # Scale pageviews to [0, 1] - ret["pageviews"] = np.log(ret["pageviews"] + 1) - ret["pageviews"] = (ret["pageviews"] - ret["pageviews"].min()) / ( - ret["pageviews"].max() - ret["pageviews"].min() - ) + ret["pageviews_score"] = score_pageviews(ret["pageviews"]) + ret["coviews_score"] = score_coviews(ret["coviews"], ret["pageviews"]) # Get weights and normalize them w = pd.Series(self.weights) @@ -148,13 +147,30 @@ def similarity_components(self, chart: Chart) -> pd.DataFrame: ret = (ret * w).fillna(0) # Reorder - ret = ret[["title", "subtitle", "tags", "share_indicator", "pageviews"]] + ret = ret[["title", "subtitle", "tags", "share_indicator", "pageviews_score", "coviews_score"]] log.info("similarity_components.end", t=time.time() - t) return ret +def score_pageviews(pageviews: pd.Series) -> pd.Series: + """Log transform pageviews and scale them to [0, 1]. Chart with the most pageviews gets score 1 and + chart with the least pageviews gets score 0. + """ + pageviews = np.log(pageviews + 1) # type: ignore + return (pageviews - pageviews.min()) / (pageviews.max() - pageviews.min()) + + +def score_coviews(coviews: pd.Series, pageviews: pd.Series, lam: float = 200000) -> float: + """Score coviews. First, get ratio of coviews to pageviews. Add regularization term to pageviews + to penalize charts with high pageviews that tend to show up, despite being not very relevant. + Then, normalize the score to [0, 1]. + """ + p = coviews / (pageviews + lam) + return (p - p.min()) / (p.max() - p.min()) + + @st.cache_data(show_spinner=False, persist="disk") def gpt_diverse_charts( chosen_chart: Chart, _charts: list[Chart], _n: int = 30, system_prompt=DEFAULT_SYSTEM_PROMPT diff --git a/apps/wizard/utils/cached.py b/apps/wizard/utils/cached.py index 075f87b4722d..fce0fb1212c3 100644 --- a/apps/wizard/utils/cached.py +++ b/apps/wizard/utils/cached.py @@ -11,7 +11,7 @@ import etl.grapher.model as gm from apps.utils.map_datasets import get_grapher_changes -from etl.config import OWID_ENV, OWIDEnv +from etl.config import ENV_GRAPHER_USER_ID, OWID_ENV, OWIDEnv from etl.db import get_engine from etl.git_helpers import get_changed_files from etl.grapher import io as gio @@ -215,8 +215,14 @@ def get_tailscale_ip_to_user_map(): @st.cache_data -def get_grapher_user_id(user_ip: str) -> Optional[int]: - """Get the Grapher user ID associated with the given Tailscale IP address.""" +def get_grapher_user(user_ip: Optional[str]) -> Optional[gm.User]: + """Get the Grapher user associated with the given Tailscale IP address.""" + # Use local env variable if user_ip is not provided (when on localhost) + if user_ip is None: + with Session(get_engine()) as session: + assert ENV_GRAPHER_USER_ID, "ENV_GRAPHER_USER_ID is not set!" + return gm.User.load_user(session, id=int(ENV_GRAPHER_USER_ID)) + # Get Tailscale IP-to-User mapping ip_to_user_map = get_tailscale_ip_to_user_map() @@ -227,9 +233,4 @@ def get_grapher_user_id(user_ip: str) -> Optional[int]: return None with Session(get_engine()) as session: - grapher_user = gm.User.load_user(session, github_user_name) - - if grapher_user: - return grapher_user.id - else: - return None + return gm.User.load_user(session, github_user_name) diff --git a/apps/wizard/utils/components.py b/apps/wizard/utils/components.py index 481e274bc250..58ed77504397 100644 --- a/apps/wizard/utils/components.py +++ b/apps/wizard/utils/components.py @@ -408,7 +408,7 @@ def st_toast_success(message: str) -> None: def update_query_params(key): def _update_query_params(): value = st.session_state[key] - if value: + if value is not None: st.query_params.update({key: value}) else: st.query_params.pop(key, None) @@ -416,6 +416,10 @@ def _update_query_params(): return _update_query_params +def remove_query_params(key): + st.query_params.pop(key, None) + + def url_persist(component: Any, default: Any = None) -> Any: """Wrapper around streamlit components that persist values in the URL query string. @@ -451,7 +455,7 @@ def _persist(*args, **kwargs): if default is not None and key not in st.query_params: st.session_state[key] = default - if not st.session_state.get(key): + if st.session_state.get(key) is None: # Obtain params from query string params = _get_params(repeated, convert_to_bool, key, kwargs) @@ -464,8 +468,12 @@ def _persist(*args, **kwargs): else: # Set the value in query params, but only if it isn't default - if default is None or st.session_state[key] != default: + if default is None: + update_query_params(key)() + elif st.session_state.get(key) != default: update_query_params(key)() + elif st.session_state[key] == default: + remove_query_params(key) kwargs["on_change"] = update_query_params(key) diff --git a/apps/wizard/utils/db.py b/apps/wizard/utils/db.py index 1041cdc15d2e..d56f06cfe780 100644 --- a/apps/wizard/utils/db.py +++ b/apps/wizard/utils/db.py @@ -14,13 +14,15 @@ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple import pandas as pd +import pandas_gbq import streamlit as st import structlog +from google.oauth2 import service_account from sqlalchemy import text from sqlalchemy.orm import Session from apps.wizard.utils.paths import STREAMLIT_SECRETS, WIZARD_DB -from etl.config import OWID_ENV, OWIDEnv +from etl.config import GOOGLE_APPLICATION_CREDENTIALS, OWID_ENV, OWIDEnv from etl.db import get_engine, read_sql, to_sql from etl.grapher.model import Anomaly @@ -299,3 +301,13 @@ def simplify_varmap(df): mapping_no_identical = {k: v for k, v in mapping.items() if k != v} return mapping_no_identical + + +def read_gbq(*args, **kwargs) -> pd.DataFrame: + if GOOGLE_APPLICATION_CREDENTIALS: + # Use service account + credentials = service_account.Credentials.from_service_account_file(GOOGLE_APPLICATION_CREDENTIALS) + return pandas_gbq.read_gbq(*args, **kwargs, credentials=credentials) # type: ignore + else: + # Use browser authentication. + return pandas_gbq.read_gbq(*args, **kwargs) # type: ignore diff --git a/etl/config.py b/etl/config.py index 95aafcd7fc34..e63a957a8614 100644 --- a/etl/config.py +++ b/etl/config.py @@ -21,15 +21,18 @@ import pandas as pd import structlog from dotenv import dotenv_values, load_dotenv +from joblib import Memory from sqlalchemy.engine import Engine from sqlalchemy.orm import Session -from etl.paths import BASE_DIR +from etl.paths import BASE_DIR, CACHE_DIR log = structlog.get_logger() ENV_FILE = Path(env.get("ENV_FILE", BASE_DIR / ".env")) +memory = Memory(CACHE_DIR, verbose=0) + def get_username(): return pwd.getpwuid(os.getuid())[0] @@ -104,6 +107,10 @@ def get_container_name(branch_name): DB_USER = env.get("DB_USER", "root") DB_PASS = env.get("DB_PASS", "") +# save original GRAPHER_USER_ID from env for later use, because it'll be overwritten when +# we use staging servers +ENV_GRAPHER_USER_ID = GRAPHER_USER_ID + DB_IS_PRODUCTION = DB_NAME == "live_grapher" # Special ENV file with access to production DB (read-only), used by chart-diff diff --git a/etl/grapher/model.py b/etl/grapher/model.py index de1f1259ee0c..18cdda3b3088 100644 --- a/etl/grapher/model.py +++ b/etl/grapher/model.py @@ -310,8 +310,15 @@ class User(Base): lastSeen: Mapped[Optional[datetime]] = mapped_column(DateTime) @classmethod - def load_user(cls, session: Session, github_username: str) -> Optional["User"]: - return session.scalars(select(cls).where(cls.githubUsername == github_username)).one_or_none() + def load_user( + cls, session: Session, id: Optional[int] = None, github_username: Optional[str] = None + ) -> Optional["User"]: + if id: + return session.scalars(select(cls).where(cls.id == id)).one() + elif github_username: + return session.scalars(select(cls).where(cls.githubUsername == github_username)).one() + else: + raise ValueError("Either id or github_username must be provided") class ChartRevisions(Base): @@ -1810,6 +1817,64 @@ def get_conflict_batch( return conflicts +RELATED_CHART_LABEL = Literal["good", "bad", "neutral"] + + +class RelatedChart(Base): + __tablename__ = "related_charts" + __table_args__ = ( + ForeignKeyConstraint( + ["chartId"], ["charts.id"], ondelete="CASCADE", onupdate="CASCADE", name="related_charts_ibfk_1" + ), + ForeignKeyConstraint( + ["relatedChartId"], ["charts.id"], ondelete="CASCADE", onupdate="CASCADE", name="related_charts_ibfk_2" + ), + # Existing Index on chartId + Index("idx_related_charts_chartId", "chartId"), + # 1) Unique index on (chartId, relatedChartId, reviewer) + Index("uq_chartId_relatedChartId_reviewer", "chartId", "relatedChartId", "reviewer", unique=True), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False) + chartId: Mapped[int] = mapped_column(Integer, nullable=False) + relatedChartId: Mapped[int] = mapped_column(Integer, nullable=False) + label: Mapped[RELATED_CHART_LABEL] = mapped_column(VARCHAR(255), nullable=False) + reviewer: Mapped[Optional[str]] = mapped_column(VARCHAR(255)) + reason: Mapped[Optional[str]] = mapped_column(TEXT, default=None) + updatedAt: Mapped[datetime] = mapped_column(DateTime, default=func.utc_timestamp()) + + @classmethod + def load(cls, session: Session, chart_id: Optional[int] = None) -> list["RelatedChart"]: + if chart_id is None: + records = session.scalars(select(cls)).all() + else: + records = session.scalars(select(cls).where(cls.chartId == chart_id)).all() + return list(records) + + def upsert( + self, + session: Session, + ) -> "RelatedChart": + cls = self.__class__ + + ds = session.scalars( + select(cls).where( + cls.chartId == self.chartId, cls.relatedChartId == self.relatedChartId, cls.reviewer == self.reviewer + ) + ).one_or_none() + + if not ds: + ds = self + else: + ds.label = self.label + ds.reason = self.reason + ds.updatedAt = func.utc_timestamp() + + session.add(ds) + session.flush() + return ds + + class MultiDimDataPage(Base): __tablename__ = "multi_dim_data_pages" diff --git a/tests/apps/wizard/utils/test_components.py b/tests/apps/wizard/utils/test_components.py new file mode 100644 index 000000000000..8c45abe06bd8 --- /dev/null +++ b/tests/apps/wizard/utils/test_components.py @@ -0,0 +1,98 @@ +from unittest.mock import patch + +import pytest +import streamlit as st + +from apps.wizard.utils.components import url_persist + + +@pytest.fixture +def mock_component(): + """ + A simple mock component that returns a string indicating how it was called. + This helps verify it was invoked with the expected kwargs. + """ + + def _component(*args, **kwargs): + return f"mock_component_called_with_kwargs={kwargs}" + + return _component + + +@patch.object(st, "session_state", new={}) +@patch.object(st, "query_params", new={}) +def test_url_persist_sets_default_if_no_query_param_and_no_session_state(mock_component): + """ + If there's a default and no value in either query_params or session_state, + then session_state should be set to that default. + Also, because it matches the default, nothing should be added to query_params. + """ + wrapped = url_persist(mock_component, default="my_default") + result = wrapped(key="test_key") + + # Ensure the component was called, and session_state is set to default + assert "mock_component_called_with_kwargs" in result + assert st.session_state["test_key"] == "my_default" + # Because it's the default, it should NOT be in the query params + assert "test_key" not in st.query_params + + +@patch.object(st, "session_state", new={}) +@patch.object(st, "query_params", new={"test_key": "from_query"}) +def test_url_persist_uses_query_param_if_present(mock_component): + """ + If a query param exists and there's no session_state value, + session_state should adopt the query param. + """ + wrapped = url_persist(mock_component, default="my_default") + result = wrapped(key="test_key") + + # The session_state should be set from the query param + assert st.session_state["test_key"] == "from_query" + + # Component call check + assert "mock_component_called_with_kwargs" in result + + # Since 'from_query' != default, update_query_params would have been invoked, + # but the result is basically the same as the existing query_params + assert st.query_params == {"test_key": "from_query"} + + +@patch.object(st, "session_state", new={"test_key": "not_default"}) +@patch.object(st, "query_params", new={}) +def test_url_persist_when_session_state_is_non_default(mock_component): + """ + If session_state already has a non-default value, it should be used. + Then update_query_params should be called, adding that value to query_params. + """ + wrapped = url_persist(mock_component, default="my_default") + result = wrapped(key="test_key") + + # Check the component was called + assert "mock_component_called_with_kwargs" in result + # Session state remains whatever was set + assert st.session_state["test_key"] == "not_default" + # Because it's non-default, update_query_params is invoked -> query_params updated + assert st.query_params == {"test_key": "not_default"} + + +@patch.object(st, "session_state", new={}) +@patch.object(st, "query_params", new={"color": "red"}) +def test_url_persist_invalid_option_raises(mock_component): + """ + If the URL contains a value not in the allowed 'options', + a ValueError should be raised by _check_options_params. + """ + + # We'll simulate a component that expects an 'options' list + def _component(*args, **kwargs): + return f"mock_component_called_with_kwargs={kwargs}" + + # We'll wrap it with url_persist, providing a set of valid options + wrapped = url_persist(_component, default=None) + + # Because "red" is not in ["blue", "green"], we expect a ValueError + with pytest.raises(ValueError) as exc_info: + wrapped(key="color", options=["blue", "green"]) + + assert "not in options" in str(exc_info.value)