Skip to content

Commit

Permalink
Upg: allow more flexibility when parsing enum in public api client wh…
Browse files Browse the repository at this point in the history
…ile keeping strong typing. (#8446)
  • Loading branch information
Fraggle authored Nov 5, 2024
1 parent 7effe40 commit 1329c3e
Showing 1 changed file with 77 additions and 63 deletions.
140 changes: 77 additions & 63 deletions sdks/js/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
import moment from "moment-timezone";
import { z } from "zod";

const ModelProviderIdSchema = z.enum([
const FlexibleEnumSchema = <U extends string>(values: readonly [U, ...U[]]) =>
z.enum(values).transform((val) => val); // Transform bypass for the enum validation when parsing but doesn't affect the inferred type

const ModelProviderIdSchema = FlexibleEnumSchema([
"openai",
"anthropic",
"mistral",
"google_ai_studio",
]);

const EmbeddingProviderIdSchema = z.enum(["openai", "mistral"]);
const ModelLLMIdSchema = FlexibleEnumSchema([
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o-2024-08-06",
"gpt-4o",
"gpt-4o-mini",
"o1-preview",
"o1-mini",
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-instant-1.2",
"mistral-large-latest",
"mistral-medium",
"mistral-small-latest",
"codestral-latest",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
]);

const EmbeddingProviderIdSchema = FlexibleEnumSchema(["openai", "mistral"]);

const ConnectorsAPIErrorTypeSchema = z.enum([
const ConnectorsAPIErrorTypeSchema = FlexibleEnumSchema([
"authorization_error",
"not_found",
"internal_server_error",
Expand Down Expand Up @@ -42,7 +68,7 @@ export type ConnectorsAPIErrorType = z.infer<
typeof ConnectorsAPIErrorTypeSchema
>;

const SupportedContentFragmentTypeSchema = z.enum([
const SupportedContentFragmentTypeSchema = FlexibleEnumSchema([
...([
// Text content types.
"application/msword",
Expand All @@ -65,21 +91,21 @@ const SupportedContentFragmentTypeSchema = z.enum([
] as const),
]);

const UserMessageOriginSchema = z.union([
z.enum([
"slack",
"web",
"api",
"gsheet",
"zapier",
"make",
"zendesk",
"raycast",
"extension",
]),
z.null(),
z.undefined(),
]);
const UserMessageOriginSchema = FlexibleEnumSchema([
"slack",
"web",
"api",
"gsheet",
"zapier",
"make",
"zendesk",
"raycast",
"extension",
])
.or(z.null())
.or(z.undefined());

const VisibilitySchema = FlexibleEnumSchema(["visible", "deleted"]);

const RankSchema = z.object({
rank: z.number(),
Expand Down Expand Up @@ -116,7 +142,7 @@ const Timezone = z.string().refine((s) => moment.tz.names().includes(s), {
message: "Invalid timezone",
});

const ConnectorProvidersSchema = z.enum([
const ConnectorProvidersSchema = FlexibleEnumSchema([
"confluence",
"github",
"google_drive",
Expand Down Expand Up @@ -226,14 +252,14 @@ export interface LoggerInterface {
warn: (args: Record<string, unknown>, message: string) => void;
}

const DataSourceViewCategoriesSchema = z.enum([
const DataSourceViewCategoriesSchema = FlexibleEnumSchema([
"managed",
"folder",
"website",
"apps",
]);

const BlockTypeSchema = z.enum([
const BlockTypeSchema = FlexibleEnumSchema([
"input",
"data",
"data_source",
Expand Down Expand Up @@ -318,7 +344,10 @@ const FunctionMessageTypeModelSchema = z.object({
content: z.string(),
});

const TokensClassificationSchema = z.enum(["tokens", "chain_of_thought"]);
const TokensClassificationSchema = FlexibleEnumSchema([
"tokens",
"chain_of_thought",
]);

export const GenerationTokensEventSchema = z.object({
type: z.literal("generation_tokens"),
Expand All @@ -334,7 +363,7 @@ export const GenerationTokensEventSchema = z.object({
});
export type GenerationTokensEvent = z.infer<typeof GenerationTokensEventSchema>;

const BaseActionTypeSchema = z.enum([
const BaseActionTypeSchema = FlexibleEnumSchema([
"dust_app_run_action",
"tables_query_action",
"retrieval_action",
Expand Down Expand Up @@ -402,7 +431,7 @@ const DustAppRunActionTypeSchema = BaseActionSchema.extend({
output: o.output,
}));

const DataSourceViewKindSchema = z.enum(["default", "custom"]);
const DataSourceViewKindSchema = FlexibleEnumSchema(["default", "custom"]);

const DataSourceViewSchema = z.object({
category: DataSourceViewCategoriesSchema,
Expand Down Expand Up @@ -520,7 +549,7 @@ const TablesQueryActionTypeSchema = BaseActionSchema.extend({
type: z.literal("tables_query_action"),
});

const WhitelistableFeaturesSchema = z.enum([
const WhitelistableFeaturesSchema = FlexibleEnumSchema([
"usage_data_api",
"okta_enterprise_connection",
"labs_transcripts",
Expand All @@ -532,9 +561,12 @@ const WhitelistableFeaturesSchema = z.enum([
"snowflake_connector_feature",
"zendesk_connector_feature",
]);

export type WhitelistableFeature = z.infer<typeof WhitelistableFeaturesSchema>;

const WorkspaceSegmentationSchema = z.enum(["interesting"]).nullable();
const WorkspaceSegmentationSchema = FlexibleEnumSchema([
"interesting",
]).nullable();

const RoleSchema = z.enum(["admin", "builder", "user", "none"]);

Expand All @@ -552,7 +584,7 @@ const WorkspaceSchema = LightWorkspaceSchema.extend({
ssoEnforced: z.boolean().optional(),
});

const UserProviderSchema = z.enum(["github", "google"]).nullable();
const UserProviderSchema = FlexibleEnumSchema(["github", "google"]).nullable();

const UserSchema = z.object({
sId: z.string(),
Expand Down Expand Up @@ -601,21 +633,21 @@ const WebsearchActionTypeSchema = BaseActionSchema.extend({
type: z.literal("websearch_action"),
});

const GlobalAgentStatusSchema = z.enum([
const GlobalAgentStatusSchema = FlexibleEnumSchema([
"active",
"disabled_by_admin",
"disabled_missing_datasource",
"disabled_free_workspace",
]);

const AgentStatusSchema = z.enum(["active", "archived", "draft"]);
const AgentStatusSchema = FlexibleEnumSchema(["active", "archived", "draft"]);

const AgentConfigurationStatusSchema = z.union([
AgentStatusSchema,
GlobalAgentStatusSchema,
]);

const AgentConfigurationScopeSchema = z.enum([
const AgentConfigurationScopeSchema = FlexibleEnumSchema([
"global",
"workspace",
"published",
Expand All @@ -631,28 +663,7 @@ const AgentRecentAuthorsSchema = z.array(z.string()).readonly();

const AgentModelConfigurationSchema = z.object({
providerId: ModelProviderIdSchema,
modelId: z.enum([
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o-2024-08-06",
"gpt-4o",
"gpt-4o-mini",
"o1-preview",
"o1-mini",
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-instant-1.2",
"mistral-large-latest",
"mistral-medium",
"mistral-small-latest",
"codestral-latest",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
]),
modelId: ModelLLMIdSchema,
temperature: z.number(),
});

Expand Down Expand Up @@ -692,7 +703,7 @@ const ContentFragmentSchema = z.object({
fileId: z.string().nullable(),
created: z.number(),
type: z.literal("content_fragment"),
visibility: z.enum(["visible", "deleted"]),
visibility: VisibilitySchema,
version: z.number(),
sourceUrl: z.string().nullable(),
textUrl: z.string(),
Expand All @@ -708,8 +719,6 @@ const AgentMentionSchema = z.object({

const MentionTypeSchema = AgentMentionSchema;

const MessageVisibilitySchema = z.enum(["visible", "deleted"]);

const UserMessageContextSchema = z.object({
username: z.string(),
timezone: Timezone,
Expand All @@ -724,7 +733,7 @@ const UserMessageSchema = z.object({
created: z.number(),
type: z.literal("user_message"),
sId: z.string(),
visibility: MessageVisibilitySchema,
visibility: VisibilitySchema,
version: z.number(),
user: UserSchema.nullable(),
mentions: z.array(MentionTypeSchema),
Expand All @@ -742,7 +751,7 @@ const AgentActionTypeSchema = z.union([
BrowseActionTypeSchema,
]);

const AgentMessageStatusSchema = z.enum([
const AgentMessageStatusSchema = FlexibleEnumSchema([
"created",
"succeeded",
"failed",
Expand All @@ -755,7 +764,7 @@ const AgentMessageTypeSchema = z.object({
created: z.number(),
type: z.literal("agent_message"),
sId: z.string(),
visibility: MessageVisibilitySchema,
visibility: VisibilitySchema,
version: z.number(),
parentMessageId: z.string().nullable(),
configuration: LightAgentConfigurationSchema,
Expand All @@ -778,7 +787,7 @@ const AgentMessageTypeSchema = z.object({
});
export type AgentMessageType = z.infer<typeof AgentMessageTypeSchema>;

const ConversationVisibilitySchema = z.enum([
const ConversationVisibilitySchema = FlexibleEnumSchema([
"unlisted",
"workspace",
"deleted",
Expand Down Expand Up @@ -1016,7 +1025,7 @@ export const CoreAPIErrorSchema = z.object({
export const CoreAPITokenTypeSchema = z.tuple([z.number(), z.string()]);
export type CoreAPITokenType = z.infer<typeof CoreAPITokenTypeSchema>;

const APIErrorTypeSchema = z.enum([
const APIErrorTypeSchema = FlexibleEnumSchema([
"action_api_error",
"action_failed",
"action_unknown_error",
Expand Down Expand Up @@ -1253,7 +1262,12 @@ export type DustAPICredentials = {
userEmail?: string;
};

const SpaceKindSchema = z.enum(["regular", "global", "system", "public"]);
const SpaceKindSchema = FlexibleEnumSchema([
"regular",
"global",
"system",
"public",
]);

const SpaceTypeSchema = z.object({
name: z.string(),
Expand Down Expand Up @@ -1797,7 +1811,7 @@ const usageTables = [
"all",
] as const;

const SupportedUsageTablesSchema = z.enum(usageTables);
const SupportedUsageTablesSchema = FlexibleEnumSchema(usageTables);

export type UsageTableType = z.infer<typeof SupportedUsageTablesSchema>;

Expand Down

0 comments on commit 1329c3e

Please sign in to comment.