Skip to content

Commit

Permalink
chore: revamp upsert hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Fontanier committed Dec 9, 2024
1 parent 5c6aa4e commit 24e67bb
Show file tree
Hide file tree
Showing 27 changed files with 310 additions and 682 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import * as t from "io-ts";

import { callAction } from "@app/lib/actions/helpers";
import type { Authenticator } from "@app/lib/auth";
import { getTrackableDataSourceViews } from "@app/lib/documents_post_process_hooks/hooks/document_tracker/lib";
import { cloneBaseConfig, DustProdActionRegistry } from "@app/lib/registry";
import type { DataSourceViewResource } from "@app/lib/resources/data_source_view_resource";

// Part of the new doc tracker pipeline, performs the retrieval (semantic search) step
// it takes {input_text: string} as input
Expand All @@ -14,23 +14,25 @@ export async function callDocTrackerRetrievalAction(
inputText,
targetDocumentTokens,
topK,
}: { inputText: string; targetDocumentTokens: number; topK: number }
dataSourceViews,
}: {
inputText: string;
targetDocumentTokens: number;
topK: number;
dataSourceViews: DataSourceViewResource[];
}
): Promise<t.TypeOf<typeof DocTrackerRetrievalActionValueSchema>> {
const action = DustProdActionRegistry["doc-tracker-retrieval"];
const config = cloneBaseConfig(action.config);

const trackableDataSourceViews = await getTrackableDataSourceViews(auth);

if (!trackableDataSourceViews.length) {
if (!dataSourceViews.length) {
return [];
}

config.SEMANTIC_SEARCH.data_sources = trackableDataSourceViews.map(
(view) => ({
workspace_id: auth.getNonNullableWorkspace().sId,
data_source_id: view.sId,
})
);
const action = DustProdActionRegistry["doc-tracker-retrieval"];
const config = cloneBaseConfig(action.config);

config.SEMANTIC_SEARCH.data_sources = dataSourceViews.map((view) => ({
workspace_id: auth.getNonNullableWorkspace().sId,
data_source_id: view.sId,
}));

config.SEMANTIC_SEARCH.target_document_tokens = targetDocumentTokens;
config.SEMANTIC_SEARCH.top_k = topK;
Expand Down
23 changes: 23 additions & 0 deletions front/lib/document_upsert_hooks/hooks/document_tracker/consts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import type { ConnectorProvider } from "@dust-tt/types";
import { assertNever } from "@dust-tt/types";

export function isConnectorTypeTrackable(
connectorType: ConnectorProvider
): boolean {
switch (connectorType) {
case "google_drive":
case "github":
case "notion":
case "microsoft":
case "confluence":
case "intercom":
case "webcrawler":
case "snowflake":
case "zendesk":
return true;
case "slack":
return false;
default:
assertNever(connectorType);
}
}
35 changes: 35 additions & 0 deletions front/lib/document_upsert_hooks/hooks/document_tracker/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { getFeatureFlags } from "@app/lib/auth";
import type { DocumentUpsertHook } from "@app/lib/document_upsert_hooks/hooks";
import { launchRunDocumentTrackerWorkflow } from "@app/temporal/document_tracker/client";

// this hook is meant to suggest changes to tracked documents
// based on new information that has been added to other documents
// it should run on upserts if the workspace has tracked docs
export const documentTrackerUpsertHook: DocumentUpsertHook = {
type: "document_tracker",
fn: async ({
auth,
dataSourceId,
documentId,
documentHash,
dataSourceConnectorProvider,
}) => {
const owner = auth.workspace();
if (!owner) {
return;
}

const flags = await getFeatureFlags(owner);
if (!flags.includes("document_tracker")) {
return;
}

await launchRunDocumentTrackerWorkflow({
workspaceId: owner.sId,
dataSourceId,
documentId,
documentHash,
dataSourceConnectorProvider,
});
},
};
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import type { ConnectorProvider, UpsertContext } from "@dust-tt/types";
import { CoreAPI } from "@dust-tt/types";
import sgMail from "@sendgrid/mail";
import { Op } from "sequelize";
import showdown from "showdown";

import config from "@app/lib/api/config";
import type { Authenticator } from "@app/lib/auth";
import { getFeatureFlags } from "@app/lib/auth";
import type {
DocumentsPostProcessHookFilterParams,
DocumentsPostProcessHookOnUpsertParams,
} from "@app/lib/documents_post_process_hooks/hooks";
import {
getDatasource,
getDocumentDiff,
} from "@app/lib/documents_post_process_hooks/hooks/data_source_helpers";
} from "@app/lib/document_upsert_hooks/hooks/data_source_helpers";
import { callDocTrackerRetrievalAction } from "@app/lib/document_upsert_hooks/hooks/document_tracker/actions/doc_tracker_retrieval";
import { callDocTrackerSuggestChangesAction } from "@app/lib/document_upsert_hooks/hooks/document_tracker/actions/doc_tracker_suggest_changes";
import { isConnectorTypeTrackable } from "@app/lib/document_upsert_hooks/hooks/document_tracker/consts";
import {
DocumentTrackerChangeSuggestion,
TrackedDocument,
} from "@app/lib/models/doc_tracker";
import { DataSourceResource } from "@app/lib/resources/data_source_resource";
import { DataSourceViewResource } from "@app/lib/resources/data_source_view_resource";
import { SpaceResource } from "@app/lib/resources/space_resource";
import { UserModel } from "@app/lib/resources/storage/models/user";
import mainLogger from "@app/logger/logger";

import { callDocTrackerRetrievalAction } from "./actions/doc_tracker_retrieval";
import { callDocTrackerSuggestChangesAction } from "./actions/doc_tracker_suggest_changes";

const { SENDGRID_API_KEY } = process.env;

// If the sum of all INSERT diffs is less than this number, we skip the hook.
Expand All @@ -42,33 +42,30 @@ const TOTAL_TARGET_TOKENS = 8192;
const MAX_TRACKED_DOCUMENTS = 1;

const logger = mainLogger.child({
postProcessHook: "document_tracker_suggest_changes",
hookType: "document_tracker",
});

export async function shouldDocumentTrackerSuggestChangesRun(
params: DocumentsPostProcessHookFilterParams
auth: Authenticator,
{
dataSourceId,
documentId,
dataSourceConnectorProvider,
upsertContext,
}: {
dataSourceId: string;
documentId: string;
dataSourceConnectorProvider: ConnectorProvider;
upsertContext: UpsertContext;
}
): Promise<boolean> {
const auth = params.auth;
const owner = auth.getNonNullableWorkspace();
const flags = await getFeatureFlags(owner);

if (!flags.includes("document_tracker")) {
return false;
}

if (params.verb !== "upsert") {
logger.info(
"document_tracker_suggest_changes post process hook should only run for upsert."
);
return false;
}

const {
upsertContext,
dataSourceId,
documentId,
dataSourceConnectorProvider,
} = params;
const isBatchSync = upsertContext?.sync_type === "batch";

const localLogger = logger.child({
Expand Down Expand Up @@ -144,13 +141,19 @@ export async function shouldDocumentTrackerSuggestChangesRun(
return false;
}

export async function documentTrackerSuggestChangesOnUpsert({
export async function documentTrackerSuggestChanges({
auth,
dataSourceId,
documentId,
documentHash,
documentSourceUrl,
}: DocumentsPostProcessHookOnUpsertParams): Promise<void> {
}: {
auth: Authenticator;
dataSourceId: string;
documentId: string;
documentHash: string;
documentSourceUrl: string;
}): Promise<void> {
const owner = auth.workspace();
if (!owner) {
throw new Error("Workspace not found.");
Expand Down Expand Up @@ -261,10 +264,12 @@ export async function documentTrackerSuggestChangesOnUpsert({
},
"Calling doc tracker retrieval action."
);
const dataSourceViews = await getTrackableDataSourceViews(auth);
const retrievalResult = await callDocTrackerRetrievalAction(auth, {
inputText: diffString,
targetDocumentTokens: targetTrackedDocumentTokens,
topK: MAX_TRACKED_DOCUMENTS,
dataSourceViews,
});

if (!retrievalResult.length) {
Expand Down Expand Up @@ -510,3 +515,20 @@ function getDocumentTitle(tags: string[]): string | null {
}
return maybeTitleTag.split("title:")[1];
}

export async function getTrackableDataSourceViews(
auth: Authenticator
): Promise<DataSourceViewResource[]> {
const globalSpace = await SpaceResource.fetchWorkspaceGlobalSpace(auth);
// TODO(DOC_TRACKER):
const views = await DataSourceViewResource.listBySpace(auth, globalSpace);

// Filter data sources to only include trackable ones
const trackableViews = views.filter(
(view) =>
view.dataSource.connectorProvider &&
isConnectorTypeTrackable(view.dataSource.connectorProvider)
);

return trackableViews;
}
49 changes: 49 additions & 0 deletions front/lib/document_upsert_hooks/hooks/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import type { ConnectorProvider, UpsertContext } from "@dust-tt/types";

import type { Authenticator } from "@app/lib/auth";
import { documentTrackerUpsertHook } from "@app/lib/document_upsert_hooks/hooks/document_tracker";
import { wakeLock } from "@app/lib/wake_lock";
import logger from "@app/logger/logger";

const DUST_WORKSPACE = "0ec9852c2f";

const _hooks = {
document_tracker_suggest_changes: documentTrackerUpsertHook,
} as const;

export const DOCUMENT_UPSERT_HOOKS: Array<DocumentUpsertHook> =
Object.values(_hooks);

export type DocumentUpsertHook = {
type: string;
fn: (params: {
auth: Authenticator;
dataSourceId: string;
documentId: string;
documentHash: string;
dataSourceConnectorProvider: ConnectorProvider | null;
upsertContext?: UpsertContext;
}) => Promise<void>;
};

export function runDocumentUpsertHooks(
params: Parameters<DocumentUpsertHook["fn"]>[0]
): void {
// TODO(document-tracker): remove this once we have a way to enable/disable
if (params.auth.workspace()?.sId !== DUST_WORKSPACE) {
return;
}

for (const hook of DOCUMENT_UPSERT_HOOKS) {
void wakeLock(async () => {
try {
await hook.fn(params);
} catch (error) {
logger.error(
{ hookType: hook.type, error },
`Error running document upsert hook`
);
}
});
}
}

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit 24e67bb

Please sign in to comment.