diff --git a/src/lib/server/tools/index.ts b/src/lib/server/tools/index.ts index 2cb06903586..306b1d142a3 100644 --- a/src/lib/server/tools/index.ts +++ b/src/lib/server/tools/index.ts @@ -19,6 +19,7 @@ import calculator from "./calculator"; import directlyAnswer from "./directlyAnswer"; import fetchUrl from "./web/url"; import websearch from "./web/search"; +import weather from "./weather"; import { callSpace, getIpToken } from "./utils"; import { uploadFile } from "../files/uploadFile"; import type { MessageFile } from "$lib/types/Message"; @@ -127,7 +128,7 @@ export const configTools = z })) ) // add the extra hardcoded tools - .transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]); + .transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch, weather]); export function getCallMethod(tool: Omit): BackendCall { return async function* (params, ctx, uuid) { diff --git a/src/lib/server/tools/utils.ts b/src/lib/server/tools/utils.ts index 30cc9ae0ec4..74e209027b6 100644 --- a/src/lib/server/tools/utils.ts +++ b/src/lib/server/tools/utils.ts @@ -111,3 +111,30 @@ export async function extractJson(text: string): Promise { } return calls.flat(); } + +export async function fetchWeatherData(latitude: number, longitude: number): Promise { + const response = await fetch( + `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}&hourly=temperature_2m` + ); + if (!response.ok) { + throw new Error("Failed to fetch weather data"); + } + return response.json(); +} + +export async function fetchCoordinates( + location: string +): Promise<{ latitude: number; longitude: number }> { + const response = await fetch( + `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=1` + ); + if (!response.ok) { + throw new Error("Failed to fetch coordinates"); + } + const data = await response.json(); + if (data.results.length === 0) { + throw new Error("Location not found"); + } + const { latitude, longitude } = data.results[0]; + return { latitude, longitude }; +} diff --git a/src/lib/server/tools/weather.ts b/src/lib/server/tools/weather.ts new file mode 100644 index 00000000000..09819d8717b --- /dev/null +++ b/src/lib/server/tools/weather.ts @@ -0,0 +1,42 @@ +import type { ConfigTool } from "$lib/types/Tool"; +import { ObjectId } from "mongodb"; +import { fetchWeatherData, fetchCoordinates } from "./utils"; + +const weather: ConfigTool = { + _id: new ObjectId("00000000000000000000000D"), + type: "config", + description: "Fetch the weather for a specified location", + color: "blue", + icon: "cloud", + displayName: "Weather", + name: "weather", + endpoint: null, + inputs: [ + { + name: "location", + type: "str", + description: "The name of the location to fetch the weather for", + paramType: "required", + }, + ], + outputComponent: null, + outputComponentIdx: null, + showOutput: false, + async *call({ location }) { + try { + if (typeof location !== "string") { + throw new Error("Location must be a string"); + } + const coordinates = await fetchCoordinates(location); + const weatherData = await fetchWeatherData(coordinates.latitude, coordinates.longitude); + + return { + outputs: [{ weather: weatherData }], + }; + } catch (error) { + throw new Error("Failed to fetch weather data", { cause: error }); + } + }, +}; + +export default weather; diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index 39f203b1a38..9093e5be45a 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -20,7 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators"; import { MetricsServer } from "../metrics"; import { logger } from "$lib/server/logger"; -const MAX_N_PAGES_TO_SCRAPE = 8 as const; +const MAX_N_PAGES_TO_SCRAPE = 15 as const; const MAX_N_PAGES_TO_EMBED = 5 as const; export async function* runWebSearch( diff --git a/src/lib/server/websearch/search/generateQuery.ts b/src/lib/server/websearch/search/generateQuery.ts index c71841a8c17..b77b82c2e10 100644 --- a/src/lib/server/websearch/search/generateQuery.ts +++ b/src/lib/server/websearch/search/generateQuery.ts @@ -2,6 +2,9 @@ import type { Message } from "$lib/types/Message"; import { format } from "date-fns"; import type { EndpointMessage } from "../../endpoints/endpoints"; import { generateFromDefaultEndpoint } from "../../generateFromDefaultEndpoint"; +import { env } from "$env/dynamic/private"; + +const num_searches = env.NUM_SEARCHES ? parseInt(env.NUM_SEARCHES, 10) : 3; export async function generateQuery(messages: Message[]) { const currentDate = format(new Date(), "MMMM d, yyyy"); @@ -47,8 +50,26 @@ Current Question: Where is it being hosted?`, from: "assistant", content: `news ${format(new Date(Date.now() - 864e5), "MMMM d, yyyy")}`, }, - { from: "user", content: "What is the current weather in Paris?" }, - { from: "assistant", content: `weather in Paris ${currentDate}` }, + { + from: "user", + content: `Current Question: My dog has been bitten, what should the gums look like so that he is healthy and when does he need an infusion?`, + }, + { + from: "assistant", + content: `What healthy gums look like in dogs +What unhealthy gums look like in dogs +When dogs need an infusion, gum signals +`, + }, + { + from: "user", + content: `Current Question: Who is Elon Musk ?`, + }, + { + from: "assistant", + content: `Elon Musk +Elon Musk Biography`, + }, { from: "user", content: @@ -62,13 +83,19 @@ Current Question: Where is it being hosted?`, }, ]; + const preprompt = `You are tasked with generating precise and effective web search queries to answer the user's question. Provide a concise and specific query for Google search that will yield the most relevant and up-to-date results. Include key terms and related phrases, and avoid unnecessary words. Answer with only the queries split by linebreaks. Avoid duplicates, make the prompts as divers as you can. You are not allowed to repeat queries. Today is ${currentDate}`; + const webQuery = await generateFromDefaultEndpoint({ messages: convQuery, - preprompt: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}`, + preprompt, generateSettings: { - max_new_tokens: 30, + max_new_tokens: 128, }, }); - - return webQuery.trim(); + // transform to list, split by linebreaks + const webQueryList = webQuery.split("\n").map((query) => query.trim()); + // remove duplicates + const uniqueWebQueryList = Array.from(new Set(webQueryList)); + // return only the first num_searches queries + return uniqueWebQueryList.slice(0, num_searches); } diff --git a/src/lib/server/websearch/search/search.ts b/src/lib/server/websearch/search/search.ts index 9f232a0ea98..dba61807877 100644 --- a/src/lib/server/websearch/search/search.ts +++ b/src/lib/server/websearch/search/search.ts @@ -25,32 +25,91 @@ export async function* search( { searchQuery: string; pages: WebSearchSource[] }, undefined > { + const newLinks: string[] = []; + let requireQuery = false; + if (ragSettings && ragSettings?.allowedLinks.length > 0) { + for (const link of ragSettings.allowedLinks) { + if (link.includes("[query]")) { + requireQuery = true; + break; + } + } + if (!requireQuery) { + yield makeGeneralUpdate({ message: "Using links specified in Assistant" }); + return { + searchQuery: "", + pages: await directLinksToSource(ragSettings?.allowedLinks).then(filterByBlockList), + }; + } + } + + let searchQueries = await generateQuery(messages); + if (!searchQueries.length && query) { + searchQueries = [query]; + } + + for (const searchQuery of searchQueries) { + if (ragSettings && ragSettings?.allowedLinks.length > 0) { + for (const link of ragSettings.allowedLinks) { + const newLink = link.replace("[query]", encodeURIComponent(searchQuery)); + if (!newLinks.includes(newLink)) { + newLinks.push(newLink); + } + } + yield makeGeneralUpdate({ + message: `Querying provided Endpoints with`, + args: [searchQuery], + }); + } else { + yield makeGeneralUpdate({ + message: `Searching ${getWebSearchProvider()}`, + args: [searchQuery], + }); + } + } + + if (newLinks.length > 0) { yield makeGeneralUpdate({ message: "Using links specified in Assistant" }); return { searchQuery: "", - pages: await directLinksToSource(ragSettings.allowedLinks).then(filterByBlockList), + pages: await directLinksToSource(newLinks).then(filterByBlockList), }; } - const searchQuery = query ?? (await generateQuery(messages)); - yield makeGeneralUpdate({ message: `Searching ${getWebSearchProvider()}`, args: [searchQuery] }); + let combinedResults: WebSearchSource[] = []; + + for (const searchQuery of searchQueries) { + // handle the global and (optional) rag lists + if (ragSettings && ragSettings?.allowedDomains.length > 0) { + yield makeGeneralUpdate({ message: "Filtering on specified domains" }); + } + const filters = buildQueryFromSiteFilters( + [...(ragSettings?.allowedDomains ?? []), ...allowList], + blockList + ); - // handle the global and (optional) rag lists - if (ragSettings && ragSettings?.allowedDomains.length > 0) { - yield makeGeneralUpdate({ message: "Filtering on specified domains" }); + const searchQueryWithFilters = `${filters} ${searchQuery}`; + const searchResults = await searchWeb(searchQueryWithFilters).then(filterByBlockList); + combinedResults = [...combinedResults, ...searchResults]; } - const filters = buildQueryFromSiteFilters( - [...(ragSettings?.allowedDomains ?? []), ...allowList], - blockList - ); - const searchQueryWithFilters = `${filters} ${searchQuery}`; - const searchResults = await searchWeb(searchQueryWithFilters).then(filterByBlockList); + // re-sort the results by relevance + // all results are appended to the end of the list + // so the most relevant results are at the beginning + // using num_searches iterating over the list to get the most relevant results + // example input: [a1,a2,a3,a4,a5,b1,b2,b3,b4,b5,c1,c2,c3,c4,c5] + // example output: [a1,b1,c1,a2,b2,c2,a3,b3,c3,a4,b4,c4,a5,b5,c5] + const sortedResults = []; + for (let i = 0; i < searchQueries.length; i++) { + for (let j = i; j < combinedResults.length; j += searchQueries.length) { + sortedResults.push(combinedResults[j]); + } + } return { - searchQuery: searchQueryWithFilters, - pages: searchResults, + searchQuery: searchQueries.join(" | "), + pages: sortedResults, }; }