Skip to content

Commit

Permalink
test: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hoangvvo committed Nov 8, 2024
1 parent 64fd33a commit 2133b92
Show file tree
Hide file tree
Showing 10 changed files with 693 additions and 390 deletions.
48 changes: 36 additions & 12 deletions javascript/examples/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,39 @@ export const openaiAudioModel = new OpenAIModel(
},
},
);
export const anthropicModel = new AnthropicModel({
modelId: "claude-3-5-sonnet-20241022",
apiKey: process.env["ANTHROPIC_API_KEY"] as string,
});
export const googleModel = new GoogleModel({
modelId: "gemini-1.5-pro",
apiKey: process.env["GOOGLE_API_KEY"] as string,
});
export const cohereModel = new CohereModel({
modelId: "command-r-08-2024",
apiKey: process.env["CO_API_KEY"] as string,
});
export const anthropicModel = new AnthropicModel(
{
modelId: "claude-3-5-sonnet-20241022",
apiKey: process.env["ANTHROPIC_API_KEY"] as string,
},
{
pricing: {
inputCostPerTextToken: 3.0 / 1_000_000,
outputCostPerTextToken: 15.0 / 1_000_000,
},
},
);
export const googleModel = new GoogleModel(
{
modelId: "gemini-1.5-pro",
apiKey: process.env["GOOGLE_API_KEY"] as string,
},
{
pricing: {
inputCostPerTextToken: 1.25 / 1_000_000,
outputCostPerTextToken: 5.0 / 1_000_000,
},
},
);
export const cohereModel = new CohereModel(
{
modelId: "command-r-08-2024",
apiKey: process.env["CO_API_KEY"] as string,
},
{
pricing: {
inputCostPerTextToken: 0.16 / 1_000_000,
outputCostPerTextToken: 0.6 / 1_000_000,
},
},
);
2 changes: 1 addition & 1 deletion javascript/src/anthropic/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ export function convertToAnthropicMessages(
options: AnthropicModelOptions,
): Anthropic.Messages.MessageParam[] {
if (options.convertAudioPartsToTextParts) {
messages = messages.map((message) => convertAudioPartsToTextParts(message));
messages = messages.map(convertAudioPartsToTextParts);
}

return messages.map((message): Anthropic.Messages.MessageParam => {
Expand Down
9 changes: 7 additions & 2 deletions javascript/src/cohere/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
Tool,
ToolCallPart,
} from "../schemas/types.gen.js";
import { convertAudioPartsToTextParts } from "../utils/message.utils.js";
import { ContentDeltaAccumulator } from "../utils/stream.utils.js";
import { calculateCost } from "../utils/usage.utils.js";
import { CohereModelOptions } from "./types.js";
Expand Down Expand Up @@ -114,7 +115,7 @@ export function convertToCohereParams(

return {
model: options.modelId,
messages: convertToCohereMessages(input),
messages: convertToCohereMessages(input, options),
...(input.tools && { tools: input.tools.map(convertToCohereTool) }),
...samplingParams,
...(response_format && {
Expand All @@ -126,10 +127,14 @@ export function convertToCohereParams(

export function convertToCohereMessages(
input: Pick<CohereLanguageModelInput, "messages" | "systemPrompt">,
options: CohereModelOptions,
): Cohere.ChatMessageV2[] {
const cohereMessages: Cohere.ChatMessageV2[] = [];

const messages = input.messages;
let messages = input.messages;
if (options.convertAudioPartsToTextParts) {
messages = messages.map(convertAudioPartsToTextParts);
}

if (input.systemPrompt) {
cohereMessages.push({
Expand Down
4 changes: 4 additions & 0 deletions javascript/src/cohere/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
export interface CohereModelOptions {
apiKey: string;
modelId: string;
/**
* If the AudioPart has a transcript, convert it to an TextPart
*/
convertAudioPartsToTextParts?: boolean;
}
4 changes: 2 additions & 2 deletions javascript/src/utils/stream.utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export class ContentDeltaAccumulator {
case "tool-call":
if (!delta.part.toolCallId || !delta.part.toolName) {
throw new Error(
`missing toolCallId or toolName at index ${String(delta.index)}. toolCallId: ${delta.part.toolCallId || "undefined"}, toolName: ${delta.part.toolName || "undefined"}`,
`missing toolCallId or toolName at index ${String(delta.index)}. toolCallId: ${String(delta.part.toolCallId)}, toolName: ${String(delta.part.toolName)}`,
);
}
return {
Expand All @@ -123,7 +123,7 @@ export class ContentDeltaAccumulator {
case "audio": {
if (delta.part.encoding !== "linear16") {
throw new Error(
`only linear16 encoding is supported for audio concatenation. encoding: ${delta.part.encoding || "undefined"}`,
`only linear16 encoding is supported for audio concatenation. encoding: ${String(delta.part.encoding)}`,
);
}
const concatenatedAudioData = mergeInt16Arrays(delta.part.audioData);
Expand Down
71 changes: 61 additions & 10 deletions javascript/test/anthropic.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,67 @@
/* eslint-disable @typescript-eslint/no-floating-promises */
import { describe, it } from "node:test";
import test, { suite } from "node:test";
import { AnthropicModel } from "../src/anthropic/anthropic.js";
import { getLanguageModelTests } from "./test-language-model.js";
import { log, testLanguageModel } from "./test-language-model.js";

const model = new AnthropicModel({
apiKey: process.env["ANTHROPIC_API_KEY"] as string,
modelId: "claude-3-5-sonnet-20241022",
});
const model = new AnthropicModel(
{
apiKey: process.env["ANTHROPIC_API_KEY"] as string,
modelId: "claude-3-5-sonnet-20241022",
},
{
pricing: {
inputCostPerTextToken: 3.0 / 1_000_000,
outputCostPerTextToken: 15.0 / 1_000_000,
},
},
);

suite("AnthropicModel", () => {
testLanguageModel(model);

test("convert audio part to text part if enabled", async () => {
const model = new AnthropicModel({
apiKey: process.env["ANTHROPIC_API_KEY"] as string,
// not an audio model
modelId: "claude-3-5-sonnet-20241022",
convertAudioPartsToTextParts: true,
});

const response = await model.generate({
messages: [
{
role: "user",
content: [
{
type: "text",
text: "Hello",
},
],
},
{
role: "assistant",
content: [
{
type: "audio",
audioData: "",
transcript: "Hi there, how can I help you?",
},
],
},
{
role: "user",
content: [
{
type: "text",
text: "Goodbye",
},
],
},
],
});

log(response);

describe("AnthropicModel", () => {
const tests = getLanguageModelTests(model);
tests.forEach(({ name, fn }) => {
it(name, fn);
// it should not throw a part unsupported error
});
});
71 changes: 61 additions & 10 deletions javascript/test/cohere.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,67 @@
/* eslint-disable @typescript-eslint/no-floating-promises */
import { describe, it } from "node:test";
import test, { suite } from "node:test";
import { CohereModel } from "../src/cohere/cohere.js";
import { getLanguageModelTests } from "./test-language-model.js";
import { log, testLanguageModel } from "./test-language-model.js";

const model = new CohereModel({
apiKey: process.env["CO_API_KEY"] as string,
modelId: "command-r-08-2024",
});
const model = new CohereModel(
{
apiKey: process.env["CO_API_KEY"] as string,
modelId: "command-r-08-2024",
},
{
pricing: {
inputCostPerTextToken: 0.16 / 1_000_000,
outputCostPerTextToken: 0.6 / 1_000_000,
},
},
);

suite("CohereModel", () => {
testLanguageModel(model);

test("convert audio part to text part if enabled", async () => {
const model = new CohereModel({
apiKey: process.env["CO_API_KEY"] as string,
// not an audio model
modelId: "command-r-08-2024",
convertAudioPartsToTextParts: true,
});

const response = await model.generate({
messages: [
{
role: "user",
content: [
{
type: "text",
text: "Hello",
},
],
},
{
role: "assistant",
content: [
{
type: "audio",
audioData: "",
transcript: "Hi there, how can I help you?",
},
],
},
{
role: "user",
content: [
{
type: "text",
text: "Goodbye",
},
],
},
],
});

log(response);

describe("CohereModel", () => {
const tests = getLanguageModelTests(model);
tests.forEach(({ name, fn }) => {
it(name, fn);
// it should not throw a part unsupported error
});
});
67 changes: 67 additions & 0 deletions javascript/test/google.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* eslint-disable @typescript-eslint/no-floating-promises */
import test, { suite } from "node:test";
import { GoogleModel } from "../src/google/google.js";
import { log, testLanguageModel } from "./test-language-model.js";

const model = new GoogleModel(
{
apiKey: process.env["GOOGLE_API_KEY"] as string,
modelId: "gemini-1.5-pro",
},
{
pricing: {
inputCostPerTextToken: 1.25 / 1_000_000,
outputCostPerTextToken: 5.0 / 1_000_000,
},
},
);

suite("GoogleModel", () => {
testLanguageModel(model);

test("convert audio part to text part if enabled", async () => {
const model = new GoogleModel({
apiKey: process.env["GOOGLE_API_KEY"] as string,
// not an audio model
modelId: "gemini-1.5-pro",
convertAudioPartsToTextParts: true,
});

const response = await model.generate({
messages: [
{
role: "user",
content: [
{
type: "text",
text: "Hello",
},
],
},
{
role: "assistant",
content: [
{
type: "audio",
audioData: "",
transcript: "Hi there, how can I help you?",
},
],
},
{
role: "user",
content: [
{
type: "text",
text: "Goodbye",
},
],
},
],
});

log(response);

// it should not throw a part unsupported error
});
});
Loading

0 comments on commit 2133b92

Please sign in to comment.