From 0074d9b359eefd76b2b726794a1124f28c4689a3 Mon Sep 17 00:00:00 2001
From: MirrorLimit <113953454+MirrorLimit@users.noreply.github.com>
Date: Sat, 30 Nov 2024 13:33:24 -0500
Subject: [PATCH 1/2] feat(community): Added Reddit integration to LangchainJS
Added Reddit integration to LangchainJS to bring it closer to parity with LangchainPY as per discussion #7043
---
docs/core_docs/.gitignore | 6 +
.../document_loaders/web_loaders/reddit.mdx | 13 ++
.../docs/integrations/tools/reddit.mdx | 12 ++
examples/.env.example | 3 +
examples/src/document_loaders/reddit.ts | 29 +++
examples/src/tools/reddit.ts | 34 ++++
langchain-core/src/utils/async_caller.js | 128 ++++++++++++
langchain/.env.example | 3 +
libs/langchain-community/langchain.config.js | 4 +
.../document_loaders/tests/reddit.int.test.ts | 47 +++++
.../src/document_loaders/web/reddit.ts | 133 +++++++++++++
.../src/load/import_constants.ts | 1 +
libs/langchain-community/src/tools/reddit.ts | 123 ++++++++++++
.../src/tools/tests/reddit.int.test.ts | 28 +++
libs/langchain-community/src/utils/reddit.ts | 145 ++++++++++++++
.../src/utils/tests/reddit.test.ts | 185 ++++++++++++++++++
16 files changed, 894 insertions(+)
create mode 100644 docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx
create mode 100644 docs/core_docs/docs/integrations/tools/reddit.mdx
create mode 100644 examples/src/document_loaders/reddit.ts
create mode 100644 examples/src/tools/reddit.ts
create mode 100644 langchain-core/src/utils/async_caller.js
create mode 100644 libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts
create mode 100644 libs/langchain-community/src/document_loaders/web/reddit.ts
create mode 100644 libs/langchain-community/src/tools/reddit.ts
create mode 100644 libs/langchain-community/src/tools/tests/reddit.int.test.ts
create mode 100644 libs/langchain-community/src/utils/reddit.ts
create mode 100644 libs/langchain-community/src/utils/tests/reddit.test.ts
diff --git a/docs/core_docs/.gitignore b/docs/core_docs/.gitignore
index 8f648f622a5e..f0d4ded7fee0 100644
--- a/docs/core_docs/.gitignore
+++ b/docs/core_docs/.gitignore
@@ -280,6 +280,12 @@ docs/integrations/retrievers/bm25.md
docs/integrations/retrievers/bm25.mdx
docs/integrations/retrievers/bedrock-knowledge-bases.md
docs/integrations/retrievers/bedrock-knowledge-bases.mdx
+docs/integrations/toolkits/vectorstore.md
+docs/integrations/toolkits/vectorstore.mdx
+docs/integrations/toolkits/sql.md
+docs/integrations/toolkits/sql.mdx
+docs/integrations/toolkits/openapi.md
+docs/integrations/toolkits/openapi.mdx
docs/integrations/text_embedding/togetherai.md
docs/integrations/text_embedding/togetherai.mdx
docs/integrations/text_embedding/pinecone.md
diff --git a/docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx b/docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx
new file mode 100644
index 000000000000..9ec8eb8b8346
--- /dev/null
+++ b/docs/core_docs/docs/integrations/document_loaders/web_loaders/reddit.mdx
@@ -0,0 +1,13 @@
+---
+hide_table_of_contents: true
+---
+
+# Reddit
+
+This example goes over how to load text from the posts of subreddits or Reddit users.
+You will need to make a [Reddit Application](https://www.reddit.com/prefs/apps/) and initialize the loader with with your Reddit API credentials.
+
+import CodeBlock from "@theme/CodeBlock";
+import Example from "@examples/document_loaders/reddit.ts";
+
+{Example}
diff --git a/docs/core_docs/docs/integrations/tools/reddit.mdx b/docs/core_docs/docs/integrations/tools/reddit.mdx
new file mode 100644
index 000000000000..7635430110f6
--- /dev/null
+++ b/docs/core_docs/docs/integrations/tools/reddit.mdx
@@ -0,0 +1,12 @@
+---
+hide_table_of_contents: true
+---
+
+# Reddit
+
+This example goes over how to retrieve post(s) from a subreddit or from a particular user.
+You will need to make a [Reddit Application](https://www.reddit.com/prefs/apps/) and initialize the tool with with your Reddit API credentials and user agent. Refer to https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki on user agent format.
+
+import CodeBlock from "@theme/CodeBlock";
+import Example from "@examples/document_loaders/reddit.ts";
+{Example}
\ No newline at end of file
diff --git a/examples/.env.example b/examples/.env.example
index 2abb8d8e6912..317f76d243f5 100644
--- a/examples/.env.example
+++ b/examples/.env.example
@@ -51,6 +51,9 @@ CLICKHOUSE_PORT=ADD_YOURS_HERE
CLICKHOUSE_USERNAME=ADD_YOURS_HERE
CLICKHOUSE_PASSWORD=ADD_YOURS_HERE
REDIS_URL=ADD_YOURS_HERE
+REDDIT_CLIENT_ID=ADD_YOURS_HERE #https://www.reddit.com/prefs/apps
+REDDIT_CLIENT_SECRET=ADD_YOURS_HERE #https://www.reddit.com/prefs/apps
+REDDIT_USER_AGENT=ADD_YOURS_HERE #https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki
SINGLESTORE_HOST=ADD_YOURS_HERE
SINGLESTORE_PORT=ADD_YOURS_HERE
SINGLESTORE_USERNAME=ADD_YOURS_HERE
diff --git a/examples/src/document_loaders/reddit.ts b/examples/src/document_loaders/reddit.ts
new file mode 100644
index 000000000000..25130d2968f4
--- /dev/null
+++ b/examples/src/document_loaders/reddit.ts
@@ -0,0 +1,29 @@
+import { RedditPostsLoader } from "@langchain/community/document_loaders/web/reddit";
+
+// load using 'subreddit' mode
+const loader = new RedditPostsLoader({
+ clientId: "REDDIT_CLIENT_ID", // or load it from process.env.REDDIT_CLIENT_ID
+ clientSecret: "REDDIT_CLIENT_SECRET", // or load it from process.env.REDDIT_CLIENT_SECRET
+ userAgent: "REDDIT_USER_AGENT", // or load it from process.env.REDDIT_USER_AGENT
+ searchQueries: ["LangChain", "Langchaindev"],
+ mode: "subreddit",
+ categories: ["hot", "new"],
+ numberPosts: 5
+});
+const docs = await loader.load();
+console.log({ docs });
+
+// // or load using 'username' mode
+// const loader = new RedditPostsLoader({
+// clientId: "REDDIT_CLIENT_ID", // or load it from process.env.REDDIT_CLIENT_ID
+// clientSecret: "REDDIT_CLIENT_SECRET", // or load it from process.env.REDDIT_CLIENT_SECRET
+// userAgent: "REDDIT_USER_AGENT", // or load it from process.env.REDDIT_USER_AGENT
+// searchQueries: ["AutoModerator"],
+// mode: "username",
+// categories: ["hot", "new"],
+// numberPosts: 2
+// });
+// const docs = await loader.load();
+// console.log({ docs });
+
+// Note: Categories can be only of following value - "controversial" "hot" "new" "rising" "top"
\ No newline at end of file
diff --git a/examples/src/tools/reddit.ts b/examples/src/tools/reddit.ts
new file mode 100644
index 000000000000..c610cc2b7278
--- /dev/null
+++ b/examples/src/tools/reddit.ts
@@ -0,0 +1,34 @@
+import RedditSearchRun from "@langchain/community/tools/reddit";
+
+// Retrieve a post from a subreddit
+
+// Refer to doc linked below for how to set the userAgent.
+// https://support.reddithelp.com/hc/en-us/articles/16160319875092-Reddit-Data-API-Wiki
+// clientId, clientSecret and userAgent can be set in the environment variables
+const search = new RedditSearchRun({
+ sortMethod: "relevance",
+ time: "all",
+ subreddit: "dankmemes",
+ limit: 1,
+ clientId: "REDDIT_CLIENT_ID", // or load from process.env.REDDIT_CLIENT_ID
+ clientSecret: "REDDIT_CLIENT_SECRET", // or load from process.env.REDDIT_CLIENT_SECRET
+ userAgent: "REDDIT_USER_AGENT" // or load from process.env.REDDIT_USER_AGENT
+});
+
+const post = await search.invoke("College");
+console.log(post);
+
+// Retrieve a post from a user
+
+// const search = new RedditSearchRun({
+// sortMethod: "relevance",
+// time: "all",
+// subreddit: "dankmemes",
+// limit: 1,
+// clientId: "REDDIT_CLIENT_ID", // or load from process.env.REDDIT_CLIENT_ID
+// clientSecret: "REDDIT_CLIENT_SECRET", // or load from process.env.REDDIT_CLIENT_SECRET
+// userAgent: "REDDIT_USER_AGENT" // or load from process.env.REDDIT_USER_AGENT
+// });
+
+// const post = await search.fetchUserPosts("REDDIT USER TO RETRIEVE POST FROM", 1, "all");
+// console.log(post);
\ No newline at end of file
diff --git a/langchain-core/src/utils/async_caller.js b/langchain-core/src/utils/async_caller.js
new file mode 100644
index 000000000000..e11619a2a985
--- /dev/null
+++ b/langchain-core/src/utils/async_caller.js
@@ -0,0 +1,128 @@
+"use strict";
+var __importDefault = (this && this.__importDefault) || function (mod) {
+ return (mod && mod.__esModule) ? mod : { "default": mod };
+};
+Object.defineProperty(exports, "__esModule", { value: true });
+exports.AsyncCaller = void 0;
+const p_retry_1 = __importDefault(require("p-retry"));
+const p_queue_1 = __importDefault(require("p-queue"));
+const STATUS_NO_RETRY = [
+ 400,
+ 401,
+ 402,
+ 403,
+ 404,
+ 405,
+ 406,
+ 407,
+ 409, // Conflict
+];
+// eslint-disable-next-line @typescript-eslint/no-explicit-any
+const defaultFailedAttemptHandler = (error) => {
+ if (error.message.startsWith("Cancel") ||
+ error.message.startsWith("AbortError") ||
+ error.name === "AbortError") {
+ throw error;
+ }
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ if (error?.code === "ECONNABORTED") {
+ throw error;
+ }
+ const status =
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ error?.response?.status ?? error?.status;
+ if (status && STATUS_NO_RETRY.includes(+status)) {
+ throw error;
+ }
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ if (error?.error?.code === "insufficient_quota") {
+ const err = new Error(error?.message);
+ err.name = "InsufficientQuotaError";
+ throw err;
+ }
+};
+/**
+ * A class that can be used to make async calls with concurrency and retry logic.
+ *
+ * This is useful for making calls to any kind of "expensive" external resource,
+ * be it because it's rate-limited, subject to network issues, etc.
+ *
+ * Concurrent calls are limited by the `maxConcurrency` parameter, which defaults
+ * to `Infinity`. This means that by default, all calls will be made in parallel.
+ *
+ * Retries are limited by the `maxRetries` parameter, which defaults to 6. This
+ * means that by default, each call will be retried up to 6 times, with an
+ * exponential backoff between each attempt.
+ */
+class AsyncCaller {
+ constructor(params) {
+ Object.defineProperty(this, "maxConcurrency", {
+ enumerable: true,
+ configurable: true,
+ writable: true,
+ value: void 0
+ });
+ Object.defineProperty(this, "maxRetries", {
+ enumerable: true,
+ configurable: true,
+ writable: true,
+ value: void 0
+ });
+ Object.defineProperty(this, "onFailedAttempt", {
+ enumerable: true,
+ configurable: true,
+ writable: true,
+ value: void 0
+ });
+ Object.defineProperty(this, "queue", {
+ enumerable: true,
+ configurable: true,
+ writable: true,
+ value: void 0
+ });
+ this.maxConcurrency = params.maxConcurrency ?? Infinity;
+ this.maxRetries = params.maxRetries ?? 6;
+ this.onFailedAttempt =
+ params.onFailedAttempt ?? defaultFailedAttemptHandler;
+ const PQueue = "default" in p_queue_1.default ? p_queue_1.default.default : p_queue_1.default;
+ this.queue = new PQueue({ concurrency: this.maxConcurrency });
+ }
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ call(callable, ...args) {
+ return this.queue.add(() => (0, p_retry_1.default)(() => callable(...args).catch((error) => {
+ // eslint-disable-next-line no-instanceof/no-instanceof
+ if (error instanceof Error) {
+ throw error;
+ }
+ else {
+ throw new Error(error);
+ }
+ }), {
+ onFailedAttempt: this.onFailedAttempt,
+ retries: this.maxRetries,
+ randomize: true,
+ // If needed we can change some of the defaults here,
+ // but they're quite sensible.
+ }), { throwOnTimeout: true });
+ }
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ callWithOptions(options, callable, ...args) {
+ // Note this doesn't cancel the underlying request,
+ // when available prefer to use the signal option of the underlying call
+ if (options.signal) {
+ return Promise.race([
+ this.call(callable, ...args),
+ new Promise((_, reject) => {
+ options.signal?.addEventListener("abort", () => {
+ reject(new Error("AbortError"));
+ });
+ }),
+ ]);
+ }
+ return this.call(callable, ...args);
+ }
+ fetch(...args) {
+ return this.call(() => fetch(...args).then((res) => (res.ok ? res : Promise.reject(res))));
+ }
+}
+exports.AsyncCaller = AsyncCaller;
diff --git a/langchain/.env.example b/langchain/.env.example
index 2eda74311a41..7c0819f756a8 100644
--- a/langchain/.env.example
+++ b/langchain/.env.example
@@ -51,6 +51,9 @@ CLICKHOUSE_USERNAME=ADD_YOURS_HERE
CLICKHOUSE_PASSWORD=ADD_YOURS_HERE
FIGMA_ACCESS_TOKEN=ADD_YOURS_HERE
REDIS_URL=ADD_YOURS_HERE
+REDDIT_CLIENT_ID=ADD_YOURS_HERE
+REDDIT_CLIENT_SECRET=ADD_YOURS_HERE
+REDDIT_USER_AGENT=ADD_YOURS_HERE
ROCKSET_API_KEY=ADD_YOURS_HERE
# defaults to "usw2a1" (oregon)
ROCKSET_REGION=ADD_YOURS_HERE
diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js
index 4a402c6941e8..9a86359d6973 100644
--- a/libs/langchain-community/langchain.config.js
+++ b/libs/langchain-community/langchain.config.js
@@ -53,6 +53,7 @@ export const config = {
"tools/google_places": "tools/google_places",
"tools/google_routes": "tools/google_routes",
"tools/ifttt": "tools/ifttt",
+ "tools/reddit": "tools/reddit",
"tools/searchapi": "tools/searchapi",
"tools/searxng_search": "tools/searxng_search",
"tools/serpapi": "tools/serpapi",
@@ -288,6 +289,7 @@ export const config = {
"document_loaders/web/notionapi": "document_loaders/web/notionapi",
"document_loaders/web/pdf": "document_loaders/web/pdf",
"document_loaders/web/recursive_url": "document_loaders/web/recursive_url",
+ "document_loaders/web/reddit": "document_loaders/web/reddit",
"document_loaders/web/s3": "document_loaders/web/s3",
"document_loaders/web/sitemap": "document_loaders/web/sitemap",
"document_loaders/web/sonix_audio": "document_loaders/web/sonix_audio",
@@ -338,6 +340,7 @@ export const config = {
"tools/discord",
"tools/gmail",
"tools/google_calendar",
+ "tools/reddit",
"agents/toolkits/aws_sfn",
"agents/toolkits/stagehand",
"callbacks/handlers/llmonitor",
@@ -506,6 +509,7 @@ export const config = {
"document_loaders/web/taskade",
"document_loaders/web/notionapi",
"document_loaders/web/recursive_url",
+ "document_loaders/web/reddit",
"document_loaders/web/s3",
"document_loaders/web/sitemap",
"document_loaders/web/sonix_audio",
diff --git a/libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts b/libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts
new file mode 100644
index 000000000000..07385b4f594b
--- /dev/null
+++ b/libs/langchain-community/src/document_loaders/tests/reddit.int.test.ts
@@ -0,0 +1,47 @@
+import { test } from "@jest/globals";
+import { Document } from "@langchain/core/documents";
+import { RedditPostsLoader } from "../web/reddit.js";
+
+test.skip("Test RedditPostsLoader in subreddit mode", async () => {
+ const loader = new RedditPostsLoader({
+ clientId: process.env.REDDIT_CLIENT_ID!,
+ clientSecret: process.env.REDDIT_CLIENT_SECRET!,
+ userAgent: process.env.REDDIT_USER_AGENT!,
+ searchQueries: ["LangChain"],
+ mode: "subreddit",
+ categories: ["new"],
+ numberPosts: 2,
+ });
+ const documents = await loader.load();
+ expect(documents).toHaveLength(2);
+ expect(documents[0]).toBeInstanceOf(Document);
+ expect(documents[0].metadata.post_subreddit).toMatch("LangChain");
+ expect(documents[0].metadata.post_category).toMatch("new");
+ expect(documents[0].metadata.post_title).toBeTruthy();
+ expect(documents[0].metadata.post_score).toBeGreaterThanOrEqual(0);
+ expect(documents[0].metadata.post_id).toBeTruthy();
+ expect(documents[0].metadata.post_author).toBeTruthy();
+ expect(documents[0].metadata.post_url).toMatch(/^http/);
+});
+
+test.skip("Test RedditPostsLoader in username mode", async () => {
+ const loader = new RedditPostsLoader({
+ clientId: process.env.REDDIT_CLIENT_ID!,
+ clientSecret: process.env.REDDIT_CLIENT_SECRET!,
+ userAgent: process.env.REDDIT_USER_AGENT!,
+ searchQueries: ["AutoModerator"],
+ mode: "username",
+ categories: ["hot", "new"],
+ numberPosts: 5,
+ });
+ const documents = await loader.load();
+ expect(documents).toHaveLength(10);
+ expect(documents[0]).toBeInstanceOf(Document);
+ expect(documents[0].metadata.post_author).toMatch("AutoModerator");
+ expect(documents[0].metadata.post_category).toMatch("hot");
+ expect(documents[0].metadata.post_title).toBeTruthy();
+ expect(documents[0].metadata.post_score).toBeGreaterThanOrEqual(0);
+ expect(documents[0].metadata.post_id).toBeTruthy();
+ expect(documents[0].metadata.post_subreddit).toBeTruthy();
+ expect(documents[0].metadata.post_url).toMatch(/^http/);
+});
diff --git a/libs/langchain-community/src/document_loaders/web/reddit.ts b/libs/langchain-community/src/document_loaders/web/reddit.ts
new file mode 100644
index 000000000000..937e1045998f
--- /dev/null
+++ b/libs/langchain-community/src/document_loaders/web/reddit.ts
@@ -0,0 +1,133 @@
+import { BaseDocumentLoader } from "@langchain/core/document_loaders/base";
+import { Document } from "@langchain/core/documents";
+import { getEnvironmentVariable } from "@langchain/core/utils/env";
+import {
+ RedditAPIWrapper,
+ RedditPost,
+ RedditAPIConfig,
+} from "../../utils/reddit.js";
+
+/**
+ * Class representing a document loader for loading Reddit posts. It extends
+ * the BaseDocumentLoader and implements the RedditAPIConfig interface.
+ * @example
+ * ```typescript
+ * const loader = new RedditPostsLoader({
+ * clientId: "REDDIT_CLIENT_ID",
+ * clientSecret: "REDDIT_CLIENT_SECRET",
+ * userAgent: "REDDIT_USER_AGENT",
+ * searchQueries: ["LangChain", "Langchaindev"],
+ * mode: "subreddit",
+ * categories: ["hot", "new"],
+ * numberPosts: 5
+ * });
+ * const docs = await loader.load();
+ * ```
+ */
+export class RedditPostsLoader
+ extends BaseDocumentLoader
+ implements RedditAPIConfig
+{
+ public clientId: string;
+
+ public clientSecret: string;
+
+ public userAgent: string;
+
+ private redditApiWrapper: RedditAPIWrapper;
+
+ private searchQueries: string[];
+
+ private mode: string;
+
+ private categories: string[];
+
+ private numberPosts: number;
+
+ constructor({
+ clientId = getEnvironmentVariable("REDDIT_CLIENT_ID") as string,
+ clientSecret = getEnvironmentVariable("REDDIT_CLIENT_SECRET") as string,
+ userAgent = getEnvironmentVariable("REDDIT_USER_AGENT") as string,
+ searchQueries,
+ mode,
+ categories = ["new"],
+ numberPosts = 10,
+ }: RedditAPIConfig & {
+ searchQueries: string[];
+ mode: string;
+ categories?: string[];
+ numberPosts?: number;
+ }) {
+ super();
+ this.clientId = clientId;
+ this.clientSecret = clientSecret;
+ this.userAgent = userAgent;
+ this.redditApiWrapper = new RedditAPIWrapper({
+ clientId: this.clientId,
+ clientSecret: this.clientSecret,
+ userAgent: this.userAgent,
+ });
+ this.searchQueries = searchQueries;
+ this.mode = mode;
+ this.categories = categories;
+ this.numberPosts = numberPosts;
+ }
+
+ /**
+ * Loads Reddit posts using the Reddit API, creates a Document instance
+ * with the JSON representation of the post as the page content and metadata,
+ * and returns it.
+ * @returns A Promise that resolves to an array of Document instances.
+ */
+ public async load(): Promise {
+ let results: Document[] = [];
+ for (const query of this.searchQueries) {
+ for (const category of this.categories) {
+ let posts: RedditPost[] = [];
+
+ if (this.mode === "subreddit") {
+ posts = await this.redditApiWrapper.searchSubreddit(
+ query,
+ "*",
+ category,
+ this.numberPosts
+ );
+ } else if (this.mode === "username") {
+ posts = await this.redditApiWrapper.fetchUserPosts(
+ query,
+ category,
+ this.numberPosts
+ );
+ } else {
+ throw new Error(
+ "Invalid mode: please choose 'subreddit' or 'username'"
+ );
+ }
+ results = results.concat(this._mapPostsToDocuments(posts, category));
+ }
+ }
+
+ return results;
+ }
+
+ private _mapPostsToDocuments(
+ posts: RedditPost[],
+ category: string
+ ): Document[] {
+ return posts.map(
+ (post) =>
+ new Document({
+ pageContent: post.selftext,
+ metadata: {
+ post_subreddit: post.subreddit_name_prefixed,
+ post_category: category,
+ post_title: post.title,
+ post_score: post.score,
+ post_id: post.id,
+ post_url: post.url,
+ post_author: post.author,
+ },
+ })
+ );
+ }
+}
diff --git a/libs/langchain-community/src/load/import_constants.ts b/libs/langchain-community/src/load/import_constants.ts
index 722dd82e678b..679384047248 100644
--- a/libs/langchain-community/src/load/import_constants.ts
+++ b/libs/langchain-community/src/load/import_constants.ts
@@ -160,6 +160,7 @@ export const optionalImportEntrypoints: string[] = [
"langchain_community/document_loaders/web/notionapi",
"langchain_community/document_loaders/web/pdf",
"langchain_community/document_loaders/web/recursive_url",
+ "langchain_community/document_loaders/web/reddit",
"langchain_community/document_loaders/web/s3",
"langchain_community/document_loaders/web/sitemap",
"langchain_community/document_loaders/web/sonix_audio",
diff --git a/libs/langchain-community/src/tools/reddit.ts b/libs/langchain-community/src/tools/reddit.ts
new file mode 100644
index 000000000000..f3e68cda6d3c
--- /dev/null
+++ b/libs/langchain-community/src/tools/reddit.ts
@@ -0,0 +1,123 @@
+import { getEnvironmentVariable } from "@langchain/core/utils/env"; //"../../../../langchain-core/src/utils/env.js";
+import { RedditAPIWrapper } from "../utils/reddit.js";
+import { Tool } from "@langchain/core/tools";
+
+/* Interface for the search parameters.
+ * sortMethod: The sorting method for the search results, can be one of "relevance", "hot", "top", "new", "comments"
+ * time: The time period for the search results, can be one of "hour", "day", "week", "month", "year", "all"
+ * subreddit: The subreddit to search in like "dankmemes" for "r/dankmemes"
+ * limit: The number of results to return
+ * clientId: The client ID for the Reddit API
+ * clientSecret: The client secret for the Reddit API
+ * userAgent: The user agent for the Reddit API
+ */
+export interface RedditSearchRunParams {
+ sortMethod?: string;
+ time?: string;
+ subreddit?: string;
+ limit?: number;
+ clientId?: string;
+ clientSecret?: string;
+ userAgent?: string;
+}
+
+/**
+ * Class representing a tool for searching reddit posts using the reddit API.
+ * It extends the Tool class.
+ *
+ * @example
+ * ```typescript
+ * const search = new RedditSearchRun({
+ * sortMethod: "relevance",
+ * time: "all",
+ * subreddit: "dankmemes",
+ * limit: 1,
+ * });
+ *
+ * const post = await search.invoke("College");
+ * ```
+ */
+export class RedditSearchRun extends Tool {
+ static lc_name() {
+ return "RedditSearchRun";
+ }
+
+ name = "Reddit_search";
+
+ description = "A tool for searching reddit posts using the reddit API";
+
+ // Default values for the search parameters
+ protected sortMethod = "relevance";
+ protected time = "all";
+ protected subreddit = "all";
+ protected limit = 2;
+ protected clientId = "";
+ protected clientSecret = "";
+ protected userAgent = "";
+
+ /**
+ * Constructor for the RedditSearchRun class
+ * @description Initializes the search parameters if given
+ * @param params The search parameters
+ */
+ constructor(params: RedditSearchRunParams = {}) {
+ super();
+
+ this.sortMethod = params.sortMethod ?? this.sortMethod;
+ this.time = params.time ?? this.time;
+ this.subreddit = params.subreddit ?? this.subreddit;
+ this.limit = params.limit ?? this.limit;
+ this.clientId =
+ params.clientId ?? (getEnvironmentVariable("REDDIT_CLIENT_ID") as string);
+ this.clientSecret =
+ params.clientSecret ??
+ (getEnvironmentVariable("REDDIT_CLIENT_SECRET") as string);
+ this.userAgent =
+ params.userAgent ??
+ (getEnvironmentVariable("REDDIT_USER_AGENT") as string);
+ }
+
+ /**
+ * @param {string} query The search query to be sent to reddit
+ * @description Function to retrieve posts based on a search query
+ * @returns the search results from using the API wrapper
+ */
+ async _call(query: string): Promise {
+ const apiWrapper = new RedditAPIWrapper({
+ clientId: this.clientId,
+ clientSecret: this.clientSecret,
+ userAgent: this.userAgent,
+ });
+
+ return apiWrapper.searchSubreddit(
+ this.subreddit,
+ query,
+ this.sortMethod,
+ this.limit,
+ this.time
+ );
+ }
+
+ /**
+ * @param {string} username The username whose posts are to be retrieved
+ * @param {string} sortMethod The sorting method for the posts to be retrieved
+ * @param {number} limit The number of posts to retrieve starting from the latest post
+ * @param {string} time The time period for the posts to be retrieved
+ * @description Function to retrieve posts from a certain user
+ * @returns The latest limit number of posts from the user
+ */
+ async fetchUserPosts(
+ username: string,
+ sortMethod: string = this.sortMethod,
+ limit: number = this.limit,
+ time: string = this.time
+ ): Promise {
+ const apiWrapper = new RedditAPIWrapper({
+ clientId: this.clientId,
+ clientSecret: this.clientSecret,
+ userAgent: this.userAgent,
+ });
+
+ return apiWrapper.fetchUserPosts(username, sortMethod, limit, time);
+ }
+}
diff --git a/libs/langchain-community/src/tools/tests/reddit.int.test.ts b/libs/langchain-community/src/tools/tests/reddit.int.test.ts
new file mode 100644
index 000000000000..a3dad748cc6a
--- /dev/null
+++ b/libs/langchain-community/src/tools/tests/reddit.int.test.ts
@@ -0,0 +1,28 @@
+import { test, expect } from "@jest/globals";
+//import { Document } from "@langchain/core/documents";
+//import { RedditPostsLoader } from "../web/reddit.js";
+import { RedditSearchRun } from "../reddit.js";
+
+test("Test fetching a post based on a query", async () => {
+ const search = new RedditSearchRun({
+ sortMethod: "relevance",
+ time: "all",
+ subreddit: "dankmemes",
+ limit: 1,
+ });
+
+ const post = await search.invoke("College");
+ expect(post).toHaveLength(1);
+});
+
+test("Test fetching a post from a user", async () => {
+ const search = new RedditSearchRun({
+ sortMethod: "relevance",
+ time: "all",
+ subreddit: "dankmemes",
+ limit: 1,
+ });
+
+ const post = await search.fetchUserPosts("BloodJunkie", 1, "all");
+ expect(post).toHaveLength(1);
+});
diff --git a/libs/langchain-community/src/utils/reddit.ts b/libs/langchain-community/src/utils/reddit.ts
new file mode 100644
index 000000000000..2e9fd1b68e04
--- /dev/null
+++ b/libs/langchain-community/src/utils/reddit.ts
@@ -0,0 +1,145 @@
+import dotenv from "dotenv";
+import { AsyncCaller } from "@langchain/core/utils/async_caller";
+
+dotenv.config();
+
+export interface RedditAPIConfig {
+ clientId: string;
+ clientSecret: string;
+ userAgent: string;
+}
+
+export interface RedditPost {
+ title: string;
+ selftext: string;
+ subreddit_name_prefixed: string;
+ score: number;
+ id: string;
+ url: string;
+ author: string;
+}
+
+export class RedditAPIWrapper {
+ private clientId: string;
+
+ private clientSecret: string;
+
+ private userAgent: string;
+
+ private token: string | null = null;
+
+ private baseUrl = "https://oauth.reddit.com";
+
+ private asyncCaller: AsyncCaller; // Using AsyncCaller for requests
+
+ constructor(config: RedditAPIConfig) {
+ this.clientId = config.clientId;
+ this.clientSecret = config.clientSecret;
+ this.userAgent = config.userAgent;
+ this.asyncCaller = new AsyncCaller({
+ maxConcurrency: 5,
+ maxRetries: 3,
+ onFailedAttempt: (error) => {
+ console.error("Attempt failed:", error.message);
+ },
+ });
+ }
+
+ private async authenticate() {
+ if (this.token) return;
+
+ const authString = btoa(`${this.clientId}:${this.clientSecret}`);
+
+ try {
+ const response = await fetch(
+ "https://www.reddit.com/api/v1/access_token",
+ {
+ method: "POST",
+ headers: {
+ Authorization: `Basic ${authString}`,
+ "User-Agent": this.userAgent,
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ body: "grant_type=client_credentials",
+ }
+ );
+
+ if (!response.ok) {
+ throw new Error(
+ `Error authenticating with Reddit: ${response.statusText}`
+ );
+ }
+
+ const data = await response.json();
+ this.token = data.access_token;
+ } catch (error) {
+ console.error("Error authenticating with Reddit:", error);
+ }
+ }
+
+ private async makeRequest(
+ endpoint: string,
+ params: Record = {}
+ ): Promise {
+ await this.authenticate();
+
+ const url = new URL(`${this.baseUrl}${endpoint}`);
+ Object.keys(params).forEach((key) =>
+ url.searchParams.append(key, params[key])
+ );
+
+ return this.asyncCaller.call(async () => {
+ const response = await fetch(url.toString(), {
+ headers: {
+ Authorization: `Bearer ${this.token}`,
+ "User-Agent": this.userAgent,
+ },
+ });
+
+ if (!response.ok) {
+ if (response.status === 429) {
+ console.warn("Rate limit exceeded, retrying...");
+ throw new Error("Rate limit exceeded");
+ }
+ throw new Error(
+ `Error making request to ${endpoint}: ${response.statusText}`
+ );
+ }
+
+ return await response.json();
+ });
+ }
+
+ async searchSubreddit(
+ subreddit: string,
+ query: string,
+ sort: "new",
+ limit: 10,
+ time: "all"
+ ): Promise {
+ const data = await this.makeRequest(`/r/${subreddit}/search`, {
+ q: query,
+ sort,
+ limit,
+ t: time,
+ restrict_sr: "on",
+ });
+
+ return data.data.children.map((item: { data: any; }) => item.data);
+ }
+
+ async fetchUserPosts(
+ username: string,
+ sort = "new",
+ limit = 10,
+ time = "all"
+ ): Promise {
+ const data = await this.makeRequest(`/user/${username}/submitted`, {
+ sort: sort,
+ limit: limit.toString(),
+ t: time,
+ });
+
+ return data.data.children.map((item: { data: any; }) => item.data);
+ }
+}
diff --git a/libs/langchain-community/src/utils/tests/reddit.test.ts b/libs/langchain-community/src/utils/tests/reddit.test.ts
new file mode 100644
index 000000000000..83065e406e47
--- /dev/null
+++ b/libs/langchain-community/src/utils/tests/reddit.test.ts
@@ -0,0 +1,185 @@
+import {
+ describe,
+ expect,
+ it,
+ jest,
+ beforeEach,
+ afterEach,
+} from "@jest/globals";
+import { RedditAPIWrapper, RedditAPIConfig } from "../reddit.js";
+
+// Mocking global fetch for HTTP requests
+global.fetch = jest.fn() as jest.MockedFunction;
+
+// Sample RedditAPIConfig for tests
+const fakeConfig: RedditAPIConfig = {
+ clientId: "fakeClientId",
+ clientSecret: "fakeClientSecret",
+ userAgent: "test-user-agent",
+};
+
+describe("RedditAPIWrapper", () => {
+ let redditAPIWrapper: RedditAPIWrapper;
+
+ beforeEach(() => {
+ redditAPIWrapper = new RedditAPIWrapper(fakeConfig);
+ });
+
+ afterEach(() => {
+ jest.clearAllMocks();
+ });
+
+ it("should authenticate successfully and set token", async () => {
+ const fakeAccessToken = "fakeAccessToken";
+ const fakeResponse = {
+ ok: true,
+ json: jest.fn().mockResolvedValue({ access_token: fakeAccessToken }),
+ } as unknown as Response;
+
+ // Mock fetch to return the fake response for authentication
+ (global.fetch as jest.Mock).mockResolvedValue(fakeResponse);
+
+ await redditAPIWrapper["authenticate"]();
+
+ // Validate that the fetch was called correctly with the correct headers and URL
+ expect(global.fetch).toHaveBeenCalledWith(
+ "https://www.reddit.com/api/v1/access_token",
+ expect.objectContaining({
+ method: "POST",
+ headers: expect.objectContaining({
+ Authorization: expect.stringContaining("Basic"), // Checks if Basic auth is used
+ }),
+ })
+ );
+ expect(redditAPIWrapper["token"]).toBe(fakeAccessToken);
+ });
+
+ it("should make a request successfully", async () => {
+ const fakeToken = "fakeAccessToken";
+ redditAPIWrapper["token"] = fakeToken;
+
+ const fakeJsonResponse = { data: { children: [] } };
+ const fakeResponse = {
+ ok: true,
+ json: jest.fn().mockResolvedValue(fakeJsonResponse),
+ } as unknown as Response;
+
+ // Mock fetch to return the fake response for making requests
+ (global.fetch as jest.Mock).mockResolvedValue(fakeResponse);
+
+ const response = await redditAPIWrapper["makeRequest"]("/r/test/search", {
+ q: "test",
+ });
+
+ // Validate that the fetch was called with the correct URL and authorization header
+ expect(global.fetch).toHaveBeenCalledWith(
+ "https://oauth.reddit.com/r/test/search?q=test",
+ expect.objectContaining({
+ headers: expect.objectContaining({
+ Authorization: `Bearer ${fakeToken}`,
+ }),
+ })
+ );
+ expect(response).toEqual(fakeJsonResponse);
+ });
+
+ it("should handle rate limit errors gracefully", async () => {
+ const fakeResponse = {
+ ok: false,
+ status: 429, // Rate limit exceeded
+ statusText: "Too Many Requests",
+ } as unknown as Response;
+
+ // Mock fetch to return a rate-limited response
+ (global.fetch as jest.Mock).mockResolvedValue(fakeResponse);
+
+ // Expect the method to throw an error when rate-limited
+ await expect(
+ redditAPIWrapper["makeRequest"]("/r/test/search", { q: "test" })
+ ).rejects.toThrow("Rate limit exceeded");
+ });
+
+ it("should search subreddit and map data correctly", async () => {
+ const fakeJsonResponse = {
+ data: {
+ children: [
+ {
+ data: {
+ title: "Test Post",
+ selftext: "Test Text",
+ subreddit_name_prefixed: "r/test",
+ score: 100,
+ id: "123",
+ url: "https://test.com",
+ author: "test_author",
+ },
+ },
+ ],
+ },
+ };
+
+ const fakeResponse = {
+ ok: true,
+ json: jest.fn().mockResolvedValue(fakeJsonResponse),
+ } as unknown as Response;
+
+ // Mock fetch to return the fake response for subreddit search
+ (global.fetch as jest.Mock).mockResolvedValue(fakeResponse);
+
+ const posts = await redditAPIWrapper.searchSubreddit("test", "test query");
+
+ // Validate the response format and data mapping
+ expect(posts).toHaveLength(1);
+ expect(posts[0]).toEqual({
+ title: "Test Post",
+ selftext: "Test Text",
+ subreddit_name_prefixed: "r/test",
+ score: 100,
+ id: "123",
+ url: "https://test.com",
+ author: "test_author",
+ });
+ });
+
+ it("should fetch user posts and map data correctly", async () => {
+ const fakeJsonResponse = {
+ data: {
+ children: [
+ {
+ data: {
+ title: "User Post",
+ selftext: "User Post Text",
+ subreddit_name_prefixed: "r/test",
+ score: 50,
+ id: "456",
+ url: "https://test.com",
+ author: "user_test",
+ },
+ },
+ ],
+ },
+ };
+
+ const fakeResponse = {
+ ok: true,
+ json: jest.fn().mockResolvedValue(fakeJsonResponse),
+ } as unknown as Response;
+
+ // Mock fetch to return the fake response for fetching user posts
+ (global.fetch as jest.Mock).mockResolvedValue(fakeResponse);
+
+ const posts = await redditAPIWrapper.fetchUserPosts("testuser", "new");
+
+ // Validate the response format and data mapping
+ expect(posts).toHaveLength(1);
+ expect(posts[0]).toEqual({
+ title: "User Post",
+ selftext: "User Post Text",
+ subreddit_name_prefixed: "r/test",
+ score: 50,
+ id: "456",
+ url: "https://test.com",
+ author: "user_test",
+ });
+ });
+});
From ed8f7a2869fd97d32993af679fd58cf4c463cc8e Mon Sep 17 00:00:00 2001
From: MirrorLimit <113953454+MirrorLimit@users.noreply.github.com>
Date: Sat, 30 Nov 2024 13:46:32 -0500
Subject: [PATCH 2/2] ci: Added reddit to the GitHub workflow unit test
---
.github/workflows/unit-tests-integrations.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/unit-tests-integrations.yml b/.github/workflows/unit-tests-integrations.yml
index a56c954cc535..7d8742cbb845 100644
--- a/.github/workflows/unit-tests-integrations.yml
+++ b/.github/workflows/unit-tests-integrations.yml
@@ -44,7 +44,7 @@ jobs:
needs: get-changed-files
runs-on: ubuntu-latest
env:
- PACKAGES: "anthropic,azure-openai,cloudflare,cohere,core,community,exa,google-common,google-gauth,google-genai,google-vertexai,google-vertexai-web,google-webauth,groq,mistralai,mongo,nomic,openai,pinecone,qdrant,redis,textsplitters,weaviate,yandex,baidu-qianfan"
+ PACKAGES: "anthropic,azure-openai,cloudflare,cohere,core,community,exa,google-common,google-gauth,google-genai,google-vertexai,google-vertexai-web,google-webauth,groq,mistralai,mongo,nomic,openai,pinecone,qdrant,redis,textsplitters,weaviate,yandex,baidu-qianfan,reddit"
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
matrix_length: ${{ steps.set-matrix.outputs.matrix_length }}