From 0074d9b359eefd76b2b726794a1124f28c4689a3 Mon Sep 17 00:00:00 2001 From: MirrorLimit <113953454+MirrorLimit@users.noreply.github.com> Date: Sat, 30 Nov 2024 13:33:24 -0500 Subject: [PATCH 1/2] feat(community): Added Reddit integration to LangchainJS Added Reddit integration to LangchainJS to bring it closer to parity with LangchainPY as per discussion #7043 --- docs/core_docs/.gitignore | 6 + .../document_loaders/web_loaders/reddit.mdx | 13 ++ .../docs/integrations/tools/reddit.mdx | 12 ++ examples/.env.example | 3 + examples/src/document_loaders/reddit.ts | 29 +++ examples/src/tools/reddit.ts | 34 ++++ langchain-core/src/utils/async_caller.js | 128 ++++++++++++ langchain/.env.example | 3 + libs/langchain-community/langchain.config.js | 4 + .../document_loaders/tests/reddit.int.test.ts | 47 +++++ .../src/document_loaders/web/reddit.ts | 133 +++++++++++++ .../src/load/import_constants.ts | 1 + libs/langchain-community/src/tools/reddit.ts | 123 ++++++++++++ .../src/tools/tests/reddit.int.test.ts | 28 +++ libs/langchain-community/src/utils/reddit.ts | 145 ++++++++++++++ .../src/utils/tests/reddit.test.ts | 185 ++++++++++++++++++ 16 files changed, 894 insertions(+) create mode 100644 docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx create mode 100644 docs/core_docs/docs/integrations/tools/reddit.mdx create mode 100644 examples/src/document_loaders/reddit.ts create mode 100644 examples/src/tools/reddit.ts create mode 100644 langchain-core/src/utils/async_caller.js create mode 100644 libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts create mode 100644 libs/langchain-community/src/document_loaders/web/reddit.ts create mode 100644 libs/langchain-community/src/tools/reddit.ts create mode 100644 libs/langchain-community/src/tools/tests/reddit.int.test.ts create mode 100644 libs/langchain-community/src/utils/reddit.ts create mode 100644 libs/langchain-community/src/utils/tests/reddit.test.ts diff --git a/docs/core_docs/.gitignore b/docs/core_docs/.gitignore index 8f648f622a5e..f0d4ded7fee0 100644 --- a/docs/core_docs/.gitignore +++ b/docs/core_docs/.gitignore @@ -280,6 +280,12 @@ docs/integrations/retrievers/bm25.md docs/integrations/retrievers/bm25.mdx docs/integrations/retrievers/bedrock-knowledge-bases.md docs/integrations/retrievers/bedrock-knowledge-bases.mdx +docs/integrations/toolkits/vectorstore.md +docs/integrations/toolkits/vectorstore.mdx +docs/integrations/toolkits/sql.md +docs/integrations/toolkits/sql.mdx +docs/integrations/toolkits/openapi.md +docs/integrations/toolkits/openapi.mdx docs/integrations/text_embedding/togetherai.md docs/integrations/text_embedding/togetherai.mdx docs/integrations/text_embedding/pinecone.md diff --git a/docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx b/docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx new file mode 100644 index 000000000000..9ec8eb8b8346 --- /dev/null +++ b/docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx @@ -0,0 +1,13 @@ +--- +hide_table_of_contents: true +--- + +# Reddit + +This example goes over how to load text from the posts of subreddits or Reddit users. +You will need to make a [Reddit Application](https://www.reddit.com/prefs/apps/) and initialize the loader with with your Reddit API credentials. + +import CodeBlock from "@theme/CodeBlock"; +import Example from "@examples/document_loaders/reddit.ts"; + +{Example} diff --git a/docs/core_docs/docs/integrations/tools/reddit.mdx b/docs/core_docs/docs/integrations/tools/reddit.mdx new file mode 100644 index 000000000000..7635430110f6 --- /dev/null +++ b/docs/core_docs/docs/integrations/tools/reddit.mdx @@ -0,0 +1,12 @@ +--- +hide_table_of_contents: true +--- + +# Reddit + +This example goes over how to retrieve post(s) from a subreddit or from a particular user. +You will need to make a [Reddit Application](https://www.reddit.com/prefs/apps/) and initialize the tool with with your Reddit API credentials and user agent. Refer to https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki on user agent format. + +import CodeBlock from "@theme/CodeBlock"; +import Example from "@examples/document_loaders/reddit.ts"; +{Example} \ No newline at end of file diff --git a/examples/.env.example b/examples/.env.example index 2abb8d8e6912..317f76d243f5 100644 --- a/examples/.env.example +++ b/examples/.env.example @@ -51,6 +51,9 @@ CLICKHOUSE_PORT=ADD_YOURS_HERE CLICKHOUSE_USERNAME=ADD_YOURS_HERE CLICKHOUSE_PASSWORD=ADD_YOURS_HERE REDIS_URL=ADD_YOURS_HERE +REDDIT_CLIENT_ID=ADD_YOURS_HERE #https://www.reddit.com/prefs/apps +REDDIT_CLIENT_SECRET=ADD_YOURS_HERE #https://www.reddit.com/prefs/apps +REDDIT_USER_AGENT=ADD_YOURS_HERE #https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki SINGLESTORE_HOST=ADD_YOURS_HERE SINGLESTORE_PORT=ADD_YOURS_HERE SINGLESTORE_USERNAME=ADD_YOURS_HERE diff --git a/examples/src/document_loaders/reddit.ts b/examples/src/document_loaders/reddit.ts new file mode 100644 index 000000000000..25130d2968f4 --- /dev/null +++ b/examples/src/document_loaders/reddit.ts @@ -0,0 +1,29 @@ +import { RedditPostsLoader } from "@langchain/community/document_loaders/web/reddit"; + +// load using 'subreddit' mode +const loader = new RedditPostsLoader({ + clientId: "REDDIT_CLIENT_ID", // or load it from process.env.REDDIT_CLIENT_ID + clientSecret: "REDDIT_CLIENT_SECRET", // or load it from process.env.REDDIT_CLIENT_SECRET + userAgent: "REDDIT_USER_AGENT", // or load it from process.env.REDDIT_USER_AGENT + searchQueries: ["LangChain", "Langchaindev"], + mode: "subreddit", + categories: ["hot", "new"], + numberPosts: 5 +}); +const docs = await loader.load(); +console.log({ docs }); + +// // or load using 'username' mode +// const loader = new RedditPostsLoader({ +// clientId: "REDDIT_CLIENT_ID", // or load it from process.env.REDDIT_CLIENT_ID +// clientSecret: "REDDIT_CLIENT_SECRET", // or load it from process.env.REDDIT_CLIENT_SECRET +// userAgent: "REDDIT_USER_AGENT", // or load it from process.env.REDDIT_USER_AGENT +// searchQueries: ["AutoModerator"], +// mode: "username", +// categories: ["hot", "new"], +// numberPosts: 2 +// }); +// const docs = await loader.load(); +// console.log({ docs }); + +// Note: Categories can be only of following value - "controversial" "hot" "new" "rising" "top" \ No newline at end of file diff --git a/examples/src/tools/reddit.ts b/examples/src/tools/reddit.ts new file mode 100644 index 000000000000..c610cc2b7278 --- /dev/null +++ b/examples/src/tools/reddit.ts @@ -0,0 +1,34 @@ +import RedditSearchRun from "@langchain/community/tools/reddit"; + +// Retrieve a post from a subreddit + +// Refer to doc linked below for how to set the userAgent. +// https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki +// clientId, clientSecret and userAgent can be set in the environment variables +const search = new RedditSearchRun({ + sortMethod: "relevance", + time: "all", + subreddit: "dankmemes", + limit: 1, + clientId: "REDDIT_CLIENT_ID", // or load from process.env.REDDIT_CLIENT_ID + clientSecret: "REDDIT_CLIENT_SECRET", // or load from process.env.REDDIT_CLIENT_SECRET + userAgent: "REDDIT_USER_AGENT" // or load from process.env.REDDIT_USER_AGENT +}); + +const post = await search.invoke("College"); +console.log(post); + +// Retrieve a post from a user + +// const search = new RedditSearchRun({ +// sortMethod: "relevance", +// time: "all", +// subreddit: "dankmemes", +// limit: 1, +// clientId: "REDDIT_CLIENT_ID", // or load from process.env.REDDIT_CLIENT_ID +// clientSecret: "REDDIT_CLIENT_SECRET", // or load from process.env.REDDIT_CLIENT_SECRET +// userAgent: "REDDIT_USER_AGENT" // or load from process.env.REDDIT_USER_AGENT +// }); + +// const post = await search.fetchUserPosts("REDDIT USER TO RETRIEVE POST FROM", 1, "all"); +// console.log(post); \ No newline at end of file diff --git a/langchain-core/src/utils/async_caller.js b/langchain-core/src/utils/async_caller.js new file mode 100644 index 000000000000..e11619a2a985 --- /dev/null +++ b/langchain-core/src/utils/async_caller.js @@ -0,0 +1,128 @@ +"use strict"; +var __importDefault = (this && this.__importDefault) || function (mod) { + return (mod && mod.__esModule) ? mod : { "default": mod }; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.AsyncCaller = void 0; +const p_retry_1 = __importDefault(require("p-retry")); +const p_queue_1 = __importDefault(require("p-queue")); +const STATUS_NO_RETRY = [ + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 409, // Conflict +]; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const defaultFailedAttemptHandler = (error) => { + if (error.message.startsWith("Cancel") || + error.message.startsWith("AbortError") || + error.name === "AbortError") { + throw error; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + if (error?.code === "ECONNABORTED") { + throw error; + } + const status = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + error?.response?.status ?? error?.status; + if (status && STATUS_NO_RETRY.includes(+status)) { + throw error; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + if (error?.error?.code === "insufficient_quota") { + const err = new Error(error?.message); + err.name = "InsufficientQuotaError"; + throw err; + } +}; +/** + * A class that can be used to make async calls with concurrency and retry logic. + * + * This is useful for making calls to any kind of "expensive" external resource, + * be it because it's rate-limited, subject to network issues, etc. + * + * Concurrent calls are limited by the `maxConcurrency` parameter, which defaults + * to `Infinity`. This means that by default, all calls will be made in parallel. + * + * Retries are limited by the `maxRetries` parameter, which defaults to 6. This + * means that by default, each call will be retried up to 6 times, with an + * exponential backoff between each attempt. + */ +class AsyncCaller { + constructor(params) { + Object.defineProperty(this, "maxConcurrency", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }); + Object.defineProperty(this, "maxRetries", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }); + Object.defineProperty(this, "onFailedAttempt", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }); + Object.defineProperty(this, "queue", { + enumerable: true, + configurable: true, + writable: true, + value: void 0 + }); + this.maxConcurrency = params.maxConcurrency ?? Infinity; + this.maxRetries = params.maxRetries ?? 6; + this.onFailedAttempt = + params.onFailedAttempt ?? defaultFailedAttemptHandler; + const PQueue = "default" in p_queue_1.default ? p_queue_1.default.default : p_queue_1.default; + this.queue = new PQueue({ concurrency: this.maxConcurrency }); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + call(callable, ...args) { + return this.queue.add(() => (0, p_retry_1.default)(() => callable(...args).catch((error) => { + // eslint-disable-next-line no-instanceof/no-instanceof + if (error instanceof Error) { + throw error; + } + else { + throw new Error(error); + } + }), { + onFailedAttempt: this.onFailedAttempt, + retries: this.maxRetries, + randomize: true, + // If needed we can change some of the defaults here, + // but they're quite sensible. + }), { throwOnTimeout: true }); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + callWithOptions(options, callable, ...args) { + // Note this doesn't cancel the underlying request, + // when available prefer to use the signal option of the underlying call + if (options.signal) { + return Promise.race([ + this.call(callable, ...args), + new Promise((_, reject) => { + options.signal?.addEventListener("abort", () => { + reject(new Error("AbortError")); + }); + }), + ]); + } + return this.call(callable, ...args); + } + fetch(...args) { + return this.call(() => fetch(...args).then((res) => (res.ok ? res : Promise.reject(res)))); + } +} +exports.AsyncCaller = AsyncCaller; diff --git a/langchain/.env.example b/langchain/.env.example index 2eda74311a41..7c0819f756a8 100644 --- a/langchain/.env.example +++ b/langchain/.env.example @@ -51,6 +51,9 @@ CLICKHOUSE_USERNAME=ADD_YOURS_HERE CLICKHOUSE_PASSWORD=ADD_YOURS_HERE FIGMA_ACCESS_TOKEN=ADD_YOURS_HERE REDIS_URL=ADD_YOURS_HERE +REDDIT_CLIENT_ID=ADD_YOURS_HERE +REDDIT_CLIENT_SECRET=ADD_YOURS_HERE +REDDIT_USER_AGENT=ADD_YOURS_HERE ROCKSET_API_KEY=ADD_YOURS_HERE # defaults to "usw2a1" (oregon) ROCKSET_REGION=ADD_YOURS_HERE diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js index 4a402c6941e8..9a86359d6973 100644 --- a/libs/langchain-community/langchain.config.js +++ b/libs/langchain-community/langchain.config.js @@ -53,6 +53,7 @@ export const config = { "tools/google_places": "tools/google_places", "tools/google_routes": "tools/google_routes", "tools/ifttt": "tools/ifttt", + "tools/reddit": "tools/reddit", "tools/searchapi": "tools/searchapi", "tools/searxng_search": "tools/searxng_search", "tools/serpapi": "tools/serpapi", @@ -288,6 +289,7 @@ export const config = { "document_loaders/web/notionapi": "document_loaders/web/notionapi", "document_loaders/web/pdf": "document_loaders/web/pdf", "document_loaders/web/recursive_url": "document_loaders/web/recursive_url", + "document_loaders/web/reddit": "document_loaders/web/reddit", "document_loaders/web/s3": "document_loaders/web/s3", "document_loaders/web/sitemap": "document_loaders/web/sitemap", "document_loaders/web/sonix_audio": "document_loaders/web/sonix_audio", @@ -338,6 +340,7 @@ export const config = { "tools/discord", "tools/gmail", "tools/google_calendar", + "tools/reddit", "agents/toolkits/aws_sfn", "agents/toolkits/stagehand", "callbacks/handlers/llmonitor", @@ -506,6 +509,7 @@ export const config = { "document_loaders/web/taskade", "document_loaders/web/notionapi", "document_loaders/web/recursive_url", + "document_loaders/web/reddit", "document_loaders/web/s3", "document_loaders/web/sitemap", "document_loaders/web/sonix_audio", diff --git a/libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts b/libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts new file mode 100644 index 000000000000..07385b4f594b --- /dev/null +++ b/libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts @@ -0,0 +1,47 @@ +import { test } from "@jest/globals"; +import { Document } from "@langchain/core/documents"; +import { RedditPostsLoader } from "../web/reddit.js"; + +test.skip("Test RedditPostsLoader in subreddit mode", async () => { + const loader = new RedditPostsLoader({ + clientId: process.env.REDDIT_CLIENT_ID!, + clientSecret: process.env.REDDIT_CLIENT_SECRET!, + userAgent: process.env.REDDIT_USER_AGENT!, + searchQueries: ["LangChain"], + mode: "subreddit", + categories: ["new"], + numberPosts: 2, + }); + const documents = await loader.load(); + expect(documents).toHaveLength(2); + expect(documents[0]).toBeInstanceOf(Document); + expect(documents[0].metadata.post_subreddit).toMatch("LangChain"); + expect(documents[0].metadata.post_category).toMatch("new"); + expect(documents[0].metadata.post_title).toBeTruthy(); + expect(documents[0].metadata.post_score).toBeGreaterThanOrEqual(0); + expect(documents[0].metadata.post_id).toBeTruthy(); + expect(documents[0].metadata.post_author).toBeTruthy(); + expect(documents[0].metadata.post_url).toMatch(/^http/); +}); + +test.skip("Test RedditPostsLoader in username mode", async () => { + const loader = new RedditPostsLoader({ + clientId: process.env.REDDIT_CLIENT_ID!, + clientSecret: process.env.REDDIT_CLIENT_SECRET!, + userAgent: process.env.REDDIT_USER_AGENT!, + searchQueries: ["AutoModerator"], + mode: "username", + categories: ["hot", "new"], + numberPosts: 5, + }); + const documents = await loader.load(); + expect(documents).toHaveLength(10); + expect(documents[0]).toBeInstanceOf(Document); + expect(documents[0].metadata.post_author).toMatch("AutoModerator"); + expect(documents[0].metadata.post_category).toMatch("hot"); + expect(documents[0].metadata.post_title).toBeTruthy(); + expect(documents[0].metadata.post_score).toBeGreaterThanOrEqual(0); + expect(documents[0].metadata.post_id).toBeTruthy(); + expect(documents[0].metadata.post_subreddit).toBeTruthy(); + expect(documents[0].metadata.post_url).toMatch(/^http/); +}); diff --git a/libs/langchain-community/src/document_loaders/web/reddit.ts b/libs/langchain-community/src/document_loaders/web/reddit.ts new file mode 100644 index 000000000000..937e1045998f --- /dev/null +++ b/libs/langchain-community/src/document_loaders/web/reddit.ts @@ -0,0 +1,133 @@ +import { BaseDocumentLoader } from "@langchain/core/document_loaders/base"; +import { Document } from "@langchain/core/documents"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { + RedditAPIWrapper, + RedditPost, + RedditAPIConfig, +} from "../../utils/reddit.js"; + +/** + * Class representing a document loader for loading Reddit posts. It extends + * the BaseDocumentLoader and implements the RedditAPIConfig interface. + * @example + * ```typescript + * const loader = new RedditPostsLoader({ + * clientId: "REDDIT_CLIENT_ID", + * clientSecret: "REDDIT_CLIENT_SECRET", + * userAgent: "REDDIT_USER_AGENT", + * searchQueries: ["LangChain", "Langchaindev"], + * mode: "subreddit", + * categories: ["hot", "new"], + * numberPosts: 5 + * }); + * const docs = await loader.load(); + * ``` + */ +export class RedditPostsLoader + extends BaseDocumentLoader + implements RedditAPIConfig +{ + public clientId: string; + + public clientSecret: string; + + public userAgent: string; + + private redditApiWrapper: RedditAPIWrapper; + + private searchQueries: string[]; + + private mode: string; + + private categories: string[]; + + private numberPosts: number; + + constructor({ + clientId = getEnvironmentVariable("REDDIT_CLIENT_ID") as string, + clientSecret = getEnvironmentVariable("REDDIT_CLIENT_SECRET") as string, + userAgent = getEnvironmentVariable("REDDIT_USER_AGENT") as string, + searchQueries, + mode, + categories = ["new"], + numberPosts = 10, + }: RedditAPIConfig & { + searchQueries: string[]; + mode: string; + categories?: string[]; + numberPosts?: number; + }) { + super(); + this.clientId = clientId; + this.clientSecret = clientSecret; + this.userAgent = userAgent; + this.redditApiWrapper = new RedditAPIWrapper({ + clientId: this.clientId, + clientSecret: this.clientSecret, + userAgent: this.userAgent, + }); + this.searchQueries = searchQueries; + this.mode = mode; + this.categories = categories; + this.numberPosts = numberPosts; + } + + /** + * Loads Reddit posts using the Reddit API, creates a Document instance + * with the JSON representation of the post as the page content and metadata, + * and returns it. + * @returns A Promise that resolves to an array of Document instances. + */ + public async load(): Promise { + let results: Document[] = []; + for (const query of this.searchQueries) { + for (const category of this.categories) { + let posts: RedditPost[] = []; + + if (this.mode === "subreddit") { + posts = await this.redditApiWrapper.searchSubreddit( + query, + "*", + category, + this.numberPosts + ); + } else if (this.mode === "username") { + posts = await this.redditApiWrapper.fetchUserPosts( + query, + category, + this.numberPosts + ); + } else { + throw new Error( + "Invalid mode: please choose 'subreddit' or 'username'" + ); + } + results = results.concat(this._mapPostsToDocuments(posts, category)); + } + } + + return results; + } + + private _mapPostsToDocuments( + posts: RedditPost[], + category: string + ): Document[] { + return posts.map( + (post) => + new Document({ + pageContent: post.selftext, + metadata: { + post_subreddit: post.subreddit_name_prefixed, + post_category: category, + post_title: post.title, + post_score: post.score, + post_id: post.id, + post_url: post.url, + post_author: post.author, + }, + }) + ); + } +} diff --git a/libs/langchain-community/src/load/import_constants.ts b/libs/langchain-community/src/load/import_constants.ts index 722dd82e678b..679384047248 100644 --- a/libs/langchain-community/src/load/import_constants.ts +++ b/libs/langchain-community/src/load/import_constants.ts @@ -160,6 +160,7 @@ export const optionalImportEntrypoints: string[] = [ "langchain_community/document_loaders/web/notionapi", "langchain_community/document_loaders/web/pdf", "langchain_community/document_loaders/web/recursive_url", + "langchain_community/document_loaders/web/reddit", "langchain_community/document_loaders/web/s3", "langchain_community/document_loaders/web/sitemap", "langchain_community/document_loaders/web/sonix_audio", diff --git a/libs/langchain-community/src/tools/reddit.ts b/libs/langchain-community/src/tools/reddit.ts new file mode 100644 index 000000000000..f3e68cda6d3c --- /dev/null +++ b/libs/langchain-community/src/tools/reddit.ts @@ -0,0 +1,123 @@ +import { getEnvironmentVariable } from "@langchain/core/utils/env"; //"../../../../langchain-core/src/utils/env.js"; +import { RedditAPIWrapper } from "../utils/reddit.js"; +import { Tool } from "@langchain/core/tools"; + +/* Interface for the search parameters. + * sortMethod: The sorting method for the search results, can be one of "relevance", "hot", "top", "new", "comments" + * time: The time period for the search results, can be one of "hour", "day", "week", "month", "year", "all" + * subreddit: The subreddit to search in like "dankmemes" for "r/dankmemes" + * limit: The number of results to return + * clientId: The client ID for the Reddit API + * clientSecret: The client secret for the Reddit API + * userAgent: The user agent for the Reddit API + */ +export interface RedditSearchRunParams { + sortMethod?: string; + time?: string; + subreddit?: string; + limit?: number; + clientId?: string; + clientSecret?: string; + userAgent?: string; +} + +/** + * Class representing a tool for searching reddit posts using the reddit API. + * It extends the Tool class. + * + * @example + * ```typescript + * const search = new RedditSearchRun({ + * sortMethod: "relevance", + * time: "all", + * subreddit: "dankmemes", + * limit: 1, + * }); + * + * const post = await search.invoke("College"); + * ``` + */ +export class RedditSearchRun extends Tool { + static lc_name() { + return "RedditSearchRun"; + } + + name = "Reddit_search"; + + description = "A tool for searching reddit posts using the reddit API"; + + // Default values for the search parameters + protected sortMethod = "relevance"; + protected time = "all"; + protected subreddit = "all"; + protected limit = 2; + protected clientId = ""; + protected clientSecret = ""; + protected userAgent = ""; + + /** + * Constructor for the RedditSearchRun class + * @description Initializes the search parameters if given + * @param params The search parameters + */ + constructor(params: RedditSearchRunParams = {}) { + super(); + + this.sortMethod = params.sortMethod ?? this.sortMethod; + this.time = params.time ?? this.time; + this.subreddit = params.subreddit ?? this.subreddit; + this.limit = params.limit ?? this.limit; + this.clientId = + params.clientId ?? (getEnvironmentVariable("REDDIT_CLIENT_ID") as string); + this.clientSecret = + params.clientSecret ?? + (getEnvironmentVariable("REDDIT_CLIENT_SECRET") as string); + this.userAgent = + params.userAgent ?? + (getEnvironmentVariable("REDDIT_USER_AGENT") as string); + } + + /** + * @param {string} query The search query to be sent to reddit + * @description Function to retrieve posts based on a search query + * @returns the search results from using the API wrapper + */ + async _call(query: string): Promise { + const apiWrapper = new RedditAPIWrapper({ + clientId: this.clientId, + clientSecret: this.clientSecret, + userAgent: this.userAgent, + }); + + return apiWrapper.searchSubreddit( + this.subreddit, + query, + this.sortMethod, + this.limit, + this.time + ); + } + + /** + * @param {string} username The username whose posts are to be retrieved + * @param {string} sortMethod The sorting method for the posts to be retrieved + * @param {number} limit The number of posts to retrieve starting from the latest post + * @param {string} time The time period for the posts to be retrieved + * @description Function to retrieve posts from a certain user + * @returns The latest limit number of posts from the user + */ + async fetchUserPosts( + username: string, + sortMethod: string = this.sortMethod, + limit: number = this.limit, + time: string = this.time + ): Promise { + const apiWrapper = new RedditAPIWrapper({ + clientId: this.clientId, + clientSecret: this.clientSecret, + userAgent: this.userAgent, + }); + + return apiWrapper.fetchUserPosts(username, sortMethod, limit, time); + } +} diff --git a/libs/langchain-community/src/tools/tests/reddit.int.test.ts b/libs/langchain-community/src/tools/tests/reddit.int.test.ts new file mode 100644 index 000000000000..a3dad748cc6a --- /dev/null +++ b/libs/langchain-community/src/tools/tests/reddit.int.test.ts @@ -0,0 +1,28 @@ +import { test, expect } from "@jest/globals"; +//import { Document } from "@langchain/core/documents"; +//import { RedditPostsLoader } from "../web/reddit.js"; +import { RedditSearchRun } from "../reddit.js"; + +test("Test fetching a post based on a query", async () => { + const search = new RedditSearchRun({ + sortMethod: "relevance", + time: "all", + subreddit: "dankmemes", + limit: 1, + }); + + const post = await search.invoke("College"); + expect(post).toHaveLength(1); +}); + +test("Test fetching a post from a user", async () => { + const search = new RedditSearchRun({ + sortMethod: "relevance", + time: "all", + subreddit: "dankmemes", + limit: 1, + }); + + const post = await search.fetchUserPosts("BloodJunkie", 1, "all"); + expect(post).toHaveLength(1); +}); diff --git a/libs/langchain-community/src/utils/reddit.ts b/libs/langchain-community/src/utils/reddit.ts new file mode 100644 index 000000000000..2e9fd1b68e04 --- /dev/null +++ b/libs/langchain-community/src/utils/reddit.ts @@ -0,0 +1,145 @@ +import dotenv from "dotenv"; +import { AsyncCaller } from "@langchain/core/utils/async_caller"; + +dotenv.config(); + +export interface RedditAPIConfig { + clientId: string; + clientSecret: string; + userAgent: string; +} + +export interface RedditPost { + title: string; + selftext: string; + subreddit_name_prefixed: string; + score: number; + id: string; + url: string; + author: string; +} + +export class RedditAPIWrapper { + private clientId: string; + + private clientSecret: string; + + private userAgent: string; + + private token: string | null = null; + + private baseUrl = "https://oauth.reddit.com"; + + private asyncCaller: AsyncCaller; // Using AsyncCaller for requests + + constructor(config: RedditAPIConfig) { + this.clientId = config.clientId; + this.clientSecret = config.clientSecret; + this.userAgent = config.userAgent; + this.asyncCaller = new AsyncCaller({ + maxConcurrency: 5, + maxRetries: 3, + onFailedAttempt: (error) => { + console.error("Attempt failed:", error.message); + }, + }); + } + + private async authenticate() { + if (this.token) return; + + const authString = btoa(`${this.clientId}:${this.clientSecret}`); + + try { + const response = await fetch( + "https://www.reddit.com/api/v1/access_token", + { + method: "POST", + headers: { + Authorization: `Basic ${authString}`, + "User-Agent": this.userAgent, + "Content-Type": "application/x-www-form-urlencoded", + }, + body: "grant_type=client_credentials", + } + ); + + if (!response.ok) { + throw new Error( + `Error authenticating with Reddit: ${response.statusText}` + ); + } + + const data = await response.json(); + this.token = data.access_token; + } catch (error) { + console.error("Error authenticating with Reddit:", error); + } + } + + private async makeRequest( + endpoint: string, + params: Record = {} + ): Promise { + await this.authenticate(); + + const url = new URL(`${this.baseUrl}${endpoint}`); + Object.keys(params).forEach((key) => + url.searchParams.append(key, params[key]) + ); + + return this.asyncCaller.call(async () => { + const response = await fetch(url.toString(), { + headers: { + Authorization: `Bearer ${this.token}`, + "User-Agent": this.userAgent, + }, + }); + + if (!response.ok) { + if (response.status === 429) { + console.warn("Rate limit exceeded, retrying..."); + throw new Error("Rate limit exceeded"); + } + throw new Error( + `Error making request to ${endpoint}: ${response.statusText}` + ); + } + + return await response.json(); + }); + } + + async searchSubreddit( + subreddit: string, + query: string, + sort: "new", + limit: 10, + time: "all" + ): Promise { + const data = await this.makeRequest(`/r/${subreddit}/search`, { + q: query, + sort, + limit, + t: time, + restrict_sr: "on", + }); + + return data.data.children.map((item: { data: any; }) => item.data); + } + + async fetchUserPosts( + username: string, + sort = "new", + limit = 10, + time = "all" + ): Promise { + const data = await this.makeRequest(`/user/${username}/submitted`, { + sort: sort, + limit: limit.toString(), + t: time, + }); + + return data.data.children.map((item: { data: any; }) => item.data); + } +} diff --git a/libs/langchain-community/src/utils/tests/reddit.test.ts b/libs/langchain-community/src/utils/tests/reddit.test.ts new file mode 100644 index 000000000000..83065e406e47 --- /dev/null +++ b/libs/langchain-community/src/utils/tests/reddit.test.ts @@ -0,0 +1,185 @@ +import { + describe, + expect, + it, + jest, + beforeEach, + afterEach, +} from "@jest/globals"; +import { RedditAPIWrapper, RedditAPIConfig } from "../reddit.js"; + +// Mocking global fetch for HTTP requests +global.fetch = jest.fn() as jest.MockedFunction; + +// Sample RedditAPIConfig for tests +const fakeConfig: RedditAPIConfig = { + clientId: "fakeClientId", + clientSecret: "fakeClientSecret", + userAgent: "test-user-agent", +}; + +describe("RedditAPIWrapper", () => { + let redditAPIWrapper: RedditAPIWrapper; + + beforeEach(() => { + redditAPIWrapper = new RedditAPIWrapper(fakeConfig); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it("should authenticate successfully and set token", async () => { + const fakeAccessToken = "fakeAccessToken"; + const fakeResponse = { + ok: true, + json: jest.fn().mockResolvedValue({ access_token: fakeAccessToken }), + } as unknown as Response; + + // Mock fetch to return the fake response for authentication + (global.fetch as jest.Mock).mockResolvedValue(fakeResponse); + + await redditAPIWrapper["authenticate"](); + + // Validate that the fetch was called correctly with the correct headers and URL + expect(global.fetch).toHaveBeenCalledWith( + "https://www.reddit.com/api/v1/access_token", + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + Authorization: expect.stringContaining("Basic"), // Checks if Basic auth is used + }), + }) + ); + expect(redditAPIWrapper["token"]).toBe(fakeAccessToken); + }); + + it("should make a request successfully", async () => { + const fakeToken = "fakeAccessToken"; + redditAPIWrapper["token"] = fakeToken; + + const fakeJsonResponse = { data: { children: [] } }; + const fakeResponse = { + ok: true, + json: jest.fn().mockResolvedValue(fakeJsonResponse), + } as unknown as Response; + + // Mock fetch to return the fake response for making requests + (global.fetch as jest.Mock).mockResolvedValue(fakeResponse); + + const response = await redditAPIWrapper["makeRequest"]("/r/test/search", { + q: "test", + }); + + // Validate that the fetch was called with the correct URL and authorization header + expect(global.fetch).toHaveBeenCalledWith( + "https://oauth.reddit.com/r/test/search?q=test", + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: `Bearer ${fakeToken}`, + }), + }) + ); + expect(response).toEqual(fakeJsonResponse); + }); + + it("should handle rate limit errors gracefully", async () => { + const fakeResponse = { + ok: false, + status: 429, // Rate limit exceeded + statusText: "Too Many Requests", + } as unknown as Response; + + // Mock fetch to return a rate-limited response + (global.fetch as jest.Mock).mockResolvedValue(fakeResponse); + + // Expect the method to throw an error when rate-limited + await expect( + redditAPIWrapper["makeRequest"]("/r/test/search", { q: "test" }) + ).rejects.toThrow("Rate limit exceeded"); + }); + + it("should search subreddit and map data correctly", async () => { + const fakeJsonResponse = { + data: { + children: [ + { + data: { + title: "Test Post", + selftext: "Test Text", + subreddit_name_prefixed: "r/test", + score: 100, + id: "123", + url: "https://test.com", + author: "test_author", + }, + }, + ], + }, + }; + + const fakeResponse = { + ok: true, + json: jest.fn().mockResolvedValue(fakeJsonResponse), + } as unknown as Response; + + // Mock fetch to return the fake response for subreddit search + (global.fetch as jest.Mock).mockResolvedValue(fakeResponse); + + const posts = await redditAPIWrapper.searchSubreddit("test", "test query"); + + // Validate the response format and data mapping + expect(posts).toHaveLength(1); + expect(posts[0]).toEqual({ + title: "Test Post", + selftext: "Test Text", + subreddit_name_prefixed: "r/test", + score: 100, + id: "123", + url: "https://test.com", + author: "test_author", + }); + }); + + it("should fetch user posts and map data correctly", async () => { + const fakeJsonResponse = { + data: { + children: [ + { + data: { + title: "User Post", + selftext: "User Post Text", + subreddit_name_prefixed: "r/test", + score: 50, + id: "456", + url: "https://test.com", + author: "user_test", + }, + }, + ], + }, + }; + + const fakeResponse = { + ok: true, + json: jest.fn().mockResolvedValue(fakeJsonResponse), + } as unknown as Response; + + // Mock fetch to return the fake response for fetching user posts + (global.fetch as jest.Mock).mockResolvedValue(fakeResponse); + + const posts = await redditAPIWrapper.fetchUserPosts("testuser", "new"); + + // Validate the response format and data mapping + expect(posts).toHaveLength(1); + expect(posts[0]).toEqual({ + title: "User Post", + selftext: "User Post Text", + subreddit_name_prefixed: "r/test", + score: 50, + id: "456", + url: "https://test.com", + author: "user_test", + }); + }); +}); From ed8f7a2869fd97d32993af679fd58cf4c463cc8e Mon Sep 17 00:00:00 2001 From: MirrorLimit <113953454+MirrorLimit@users.noreply.github.com> Date: Sat, 30 Nov 2024 13:46:32 -0500 Subject: [PATCH 2/2] ci: Added reddit to the GitHub workflow unit test --- .github/workflows/unit-tests-integrations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests-integrations.yml b/.github/workflows/unit-tests-integrations.yml index a56c954cc535..7d8742cbb845 100644 --- a/.github/workflows/unit-tests-integrations.yml +++ b/.github/workflows/unit-tests-integrations.yml @@ -44,7 +44,7 @@ jobs: needs: get-changed-files runs-on: ubuntu-latest env: - PACKAGES: "anthropic,azure-openai,cloudflare,cohere,core,community,exa,google-common,google-gauth,google-genai,google-vertexai,google-vertexai-web,google-webauth,groq,mistralai,mongo,nomic,openai,pinecone,qdrant,redis,textsplitters,weaviate,yandex,baidu-qianfan" + PACKAGES: "anthropic,azure-openai,cloudflare,cohere,core,community,exa,google-common,google-gauth,google-genai,google-vertexai,google-vertexai-web,google-webauth,groq,mistralai,mongo,nomic,openai,pinecone,qdrant,redis,textsplitters,weaviate,yandex,baidu-qianfan,reddit" outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} matrix_length: ${{ steps.set-matrix.outputs.matrix_length }}