From 8e5d242d1d937bb369d7ccc925a2f71a3594156b Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:23:34 -0700 Subject: [PATCH] Add custom fetch options (#538) --- js/src/client.ts | 34 +++++++++++++++++++++++++++++++++ js/src/tests/client.int.test.ts | 18 ++++++++--------- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/js/src/client.ts b/js/src/client.ts index 6928149a7..8bc650296 100644 --- a/js/src/client.ts +++ b/js/src/client.ts @@ -46,6 +46,7 @@ interface ClientConfig { hideOutputs?: boolean; autoBatchTracing?: boolean; pendingAutoBatchedRunLimit?: number; + fetchOptions?: RequestInit; } /** @@ -396,6 +397,8 @@ export class Client { private serverInfo: Record | undefined; + private fetchOptions: RequestInit; + constructor(config: ClientConfig = {}) { const defaultConfig = Client.getDefaultClientConfig(); @@ -414,6 +417,7 @@ export class Client { this.autoBatchTracing = config.autoBatchTracing ?? this.autoBatchTracing; this.pendingAutoBatchedRunLimit = config.pendingAutoBatchedRunLimit ?? this.pendingAutoBatchedRunLimit; + this.fetchOptions = config.fetchOptions || {}; } public static getDefaultClientConfig(): { @@ -522,6 +526,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { throw new Error( @@ -553,6 +558,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { throw new Error( @@ -584,6 +590,7 @@ export class Client { method: requestMethod, headers: { ...this.headers, "Content-Type": "application/json" }, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, body: JSON.stringify(bodyParams), }); const responseBody = await response.json(); @@ -693,6 +700,7 @@ export class Client { method: "GET", headers: { Accept: "application/json" }, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { // consume the response body to release the connection @@ -745,6 +753,7 @@ export class Client { headers, body: JSON.stringify(mergedRunCreateParams[0]), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); await raiseForStatus(response, "create run"); } @@ -872,6 +881,7 @@ export class Client { headers, body: body, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); await raiseForStatus(response, "batch create run"); @@ -917,6 +927,7 @@ export class Client { headers, body: JSON.stringify(run), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); await raiseForStatus(response, "update run"); @@ -1161,6 +1172,7 @@ export class Client { headers: this.headers, body: JSON.stringify(data), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const result = await response.json(); @@ -1179,6 +1191,7 @@ export class Client { method: "DELETE", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); await raiseForStatus(response, "unshare run"); @@ -1193,6 +1206,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const result = await response.json(); @@ -1226,6 +1240,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const runs = await response.json(); @@ -1251,6 +1266,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const shareSchema = await response.json(); @@ -1283,6 +1299,7 @@ export class Client { headers: this.headers, body: JSON.stringify(data), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const shareSchema = await response.json(); @@ -1301,6 +1318,7 @@ export class Client { method: "DELETE", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); await raiseForStatus(response, "unshare dataset"); @@ -1315,6 +1333,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const dataset = await response.json(); @@ -1355,6 +1374,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(body), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); const result = await response.json(); if (!response.ok) { @@ -1397,6 +1417,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(body), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); const result = await response.json(); if (!response.ok) { @@ -1434,6 +1455,7 @@ export class Client { method: "GET", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); // consume the response body to release the connection @@ -1583,6 +1605,7 @@ export class Client { method: "DELETE", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); await raiseForStatus( @@ -1625,6 +1648,7 @@ export class Client { headers: this.headers, body: formData, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { @@ -1660,6 +1684,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(body), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { @@ -1829,6 +1854,7 @@ export class Client { method: "DELETE", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { throw new Error( @@ -1867,6 +1893,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(data), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { @@ -1923,6 +1950,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(formattedExamples), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); @@ -2026,6 +2054,7 @@ export class Client { method: "DELETE", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { throw new Error( @@ -2048,6 +2077,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(update), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); if (!response.ok) { @@ -2163,6 +2193,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(feedback), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); await raiseForStatus(response, "create feedback"); return feedback as Feedback; @@ -2204,6 +2235,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(feedbackUpdate), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); await raiseForStatus(response, "update feedback"); @@ -2223,6 +2255,7 @@ export class Client { method: "DELETE", headers: this.headers, signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, }); if (!response.ok) { throw new Error( @@ -2314,6 +2347,7 @@ export class Client { headers: { ...this.headers, "Content-Type": "application/json" }, body: JSON.stringify(body), signal: AbortSignal.timeout(this.timeout_ms), + ...this.fetchOptions, } ); const result = await response.json(); diff --git a/js/src/tests/client.int.test.ts b/js/src/tests/client.int.test.ts index 77c9b3a2b..144528f33 100644 --- a/js/src/tests/client.int.test.ts +++ b/js/src/tests/client.int.test.ts @@ -279,17 +279,13 @@ test.concurrent( outputs: { generation: "hi there 2" }, end_time: new Date().getTime(), }); - await waitUntilRunFound(langchainClient, runId, true); + await waitUntilRunFound(langchainClient, runId, false); const run1 = await langchainClient.readRun(runId); - expect(run1.inputs).toBeDefined(); - expect(Object.keys(run1.inputs)).toHaveLength(0); - expect(run1.outputs).toBeDefined(); + expect(Object.keys(run1.inputs ?? {})).toHaveLength(0); expect(Object.keys(run1.outputs ?? {})).toHaveLength(0); - await waitUntilRunFound(langchainClient, runId2, true); + await waitUntilRunFound(langchainClient, runId2, false); const run2 = await langchainClient.readRun(runId2); - expect(run2.inputs).toBeDefined(); - expect(Object.keys(run2.inputs)).toHaveLength(0); - expect(run2.outputs).toBeDefined(); + expect(Object.keys(run2.inputs ?? {})).toHaveLength(0); expect(Object.keys(run2.outputs ?? {})).toHaveLength(0); }, 240_000 @@ -355,7 +351,11 @@ test.concurrent( describe("createChatExample", () => { it("should convert LangChainBaseMessage objects to examples", async () => { - const langchainClient = new Client({ autoBatchTracing: false }); + const langchainClient = new Client({ + autoBatchTracing: false, + // Test the fetch options option + fetchOptions: { cache: "no-store" }, + }); const datasetName = "__createChatExample-test-dataset JS"; await deleteDataset(langchainClient, datasetName);