Skip to content

Commit

Permalink
add note to context for local generation (#2604)
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune authored Jan 21, 2025
1 parent 4c8a60a commit cac3912
Showing 1 changed file with 47 additions and 41 deletions.
88 changes: 47 additions & 41 deletions packages/plugin-node/src/services/llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ export class LlamaService extends Service {
const modelName = "model.gguf";
this.modelPath = path.join(
process.env.LLAMALOCAL_PATH?.trim() ?? "./",
modelName
modelName,
);
this.ollamaModel = process.env.OLLAMA_MODEL;
}
Expand All @@ -202,7 +202,7 @@ export class LlamaService extends Service {
private async ensureInitialized() {
if (!this.modelInitialized) {
elizaLogger.info(
"Model not initialized, starting initialization..."
"Model not initialized, starting initialization...",
);
await this.initializeModel();
} else {
Expand All @@ -217,16 +217,16 @@ export class LlamaService extends Service {

const systemInfo = await si.graphics();
const hasCUDA = systemInfo.controllers.some((controller) =>
controller.vendor.toLowerCase().includes("nvidia")
controller.vendor.toLowerCase().includes("nvidia"),
);

if (hasCUDA) {
elizaLogger.info(
"LlamaService: CUDA detected, using GPU acceleration"
"LlamaService: CUDA detected, using GPU acceleration",
);
} else {
elizaLogger.warn(
"LlamaService: No CUDA detected - local response will be slow"
"LlamaService: No CUDA detected - local response will be slow",
);
}

Expand All @@ -238,7 +238,7 @@ export class LlamaService extends Service {
elizaLogger.info("Creating JSON schema grammar...");
const grammar = new LlamaJsonSchemaGrammar(
this.llama,
jsonSchemaGrammar as GbnfJsonSchema
jsonSchemaGrammar as GbnfJsonSchema,
);
this.grammar = grammar;

Expand All @@ -257,21 +257,21 @@ export class LlamaService extends Service {
} catch (error) {
elizaLogger.error(
"Model initialization failed. Deleting model and retrying:",
error
error,
);
try {
elizaLogger.info(
"Attempting to delete and re-download model..."
"Attempting to delete and re-download model...",
);
await this.deleteModel();
await this.initializeModel();
} catch (retryError) {
elizaLogger.error(
"Model re-initialization failed:",
retryError
retryError,
);
throw new Error(
`Model initialization failed after retry: ${retryError.message}`
`Model initialization failed after retry: ${retryError.message}`,
);
}
}
Expand All @@ -294,7 +294,7 @@ export class LlamaService extends Service {
response.headers.location
) {
elizaLogger.info(
`Following redirect to: ${response.headers.location}`
`Following redirect to: ${response.headers.location}`,
);
downloadModel(response.headers.location);
return;
Expand All @@ -303,24 +303,24 @@ export class LlamaService extends Service {
if (response.statusCode !== 200) {
reject(
new Error(
`Failed to download model: HTTP ${response.statusCode}`
)
`Failed to download model: HTTP ${response.statusCode}`,
),
);
return;
}

totalSize = Number.parseInt(
response.headers["content-length"] || "0",
10
10,
);
elizaLogger.info(
`Downloading model: Hermes-3-Llama-3.1-8B.Q8_0.gguf`
`Downloading model: Hermes-3-Llama-3.1-8B.Q8_0.gguf`,
);
elizaLogger.info(
`Download location: ${this.modelPath}`
`Download location: ${this.modelPath}`,
);
elizaLogger.info(
`Total size: ${(totalSize / 1024 / 1024).toFixed(2)} MB`
`Total size: ${(totalSize / 1024 / 1024).toFixed(2)} MB`,
);

response.pipe(file);
Expand All @@ -336,7 +336,7 @@ export class LlamaService extends Service {
).toFixed(1)
: "0.0";
const dots = ".".repeat(
Math.floor(Number(progress) / 5)
Math.floor(Number(progress) / 5),
);
progressString = `Downloading model: [${dots.padEnd(20, " ")}] ${progress}%`;
elizaLogger.progress(progressString);
Expand All @@ -353,17 +353,17 @@ export class LlamaService extends Service {
fs.unlink(this.modelPath, () => {});
reject(
new Error(
`Model download failed: ${error.message}`
)
`Model download failed: ${error.message}`,
),
);
});
})
.on("error", (error) => {
fs.unlink(this.modelPath, () => {});
reject(
new Error(
`Model download request failed: ${error.message}`
)
`Model download request failed: ${error.message}`,
),
);
});
};
Expand Down Expand Up @@ -393,7 +393,7 @@ export class LlamaService extends Service {
stop: string[],
frequency_penalty: number,
presence_penalty: number,
max_tokens: number
max_tokens: number,
): Promise<any> {

Check warning on line 397 in packages/plugin-node/src/services/llama.ts

View check run for this annotation

codefactor.io / CodeFactor

packages/plugin-node/src/services/llama.ts#L397

Unexpected any. Specify a different type. (@typescript-eslint/no-explicit-any)
await this.ensureInitialized();
return new Promise((resolve, reject) => {
Expand All @@ -418,7 +418,7 @@ export class LlamaService extends Service {
stop: string[],
frequency_penalty: number,
presence_penalty: number,
max_tokens: number
max_tokens: number,
): Promise<string> {
await this.ensureInitialized();

Expand Down Expand Up @@ -460,7 +460,7 @@ export class LlamaService extends Service {
message.frequency_penalty,
message.presence_penalty,
message.max_tokens,
message.useGrammar
message.useGrammar,
);
message.resolve(response);
} catch (error) {
Expand Down Expand Up @@ -509,14 +509,17 @@ export class LlamaService extends Service {
frequency_penalty: number,
presence_penalty: number,
max_tokens: number,
useGrammar: boolean
useGrammar: boolean,
): Promise<any | string> {
context = context +=
"\nIMPORTANT: Escape any quotes in any string fields with a backslash so the JSON is valid.";

const ollamaModel = process.env.OLLAMA_MODEL;
if (ollamaModel) {
const ollamaUrl =
process.env.OLLAMA_SERVER_URL || "http://localhost:11434";
elizaLogger.info(
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`,
);

const response = await fetch(`${ollamaUrl}/api/generate`, {
Expand All @@ -538,7 +541,7 @@ export class LlamaService extends Service {

if (!response.ok) {
throw new Error(
`Ollama request failed: ${response.statusText}`
`Ollama request failed: ${response.statusText}`,
);
}

Expand All @@ -552,11 +555,12 @@ export class LlamaService extends Service {
}

const session = new LlamaChatSession({
contextSequence: this.sequence
contextSequence: this.sequence,
});

const wordsToPunishTokens = wordsToPunish
.flatMap((word) => this.model!.tokenize(word));
const wordsToPunishTokens = wordsToPunish.flatMap((word) =>
this.model!.tokenize(word),
);

const repeatPenalty: LlamaChatSessionRepeatPenalty = {
punishTokensFilter: () => wordsToPunishTokens,
Expand All @@ -566,11 +570,12 @@ export class LlamaService extends Service {
};

const response = await session.prompt(context, {
onTextChunk(chunk) { // stream the response to the console as it's being generated
onTextChunk(chunk) {
// stream the response to the console as it's being generated
process.stdout.write(chunk);
},
temperature: Number(temperature),
repeatPenalty: repeatPenalty
repeatPenalty: repeatPenalty,
});

if (!response) {
Expand Down Expand Up @@ -612,7 +617,7 @@ export class LlamaService extends Service {
const embeddingModel =
process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large";
elizaLogger.info(
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`,
);

const response = await fetch(`${ollamaUrl}/api/embeddings`, {
Expand All @@ -626,7 +631,7 @@ export class LlamaService extends Service {

if (!response.ok) {
throw new Error(
`Ollama embeddings request failed: ${response.statusText}`
`Ollama embeddings request failed: ${response.statusText}`,
);
}

Expand All @@ -644,7 +649,7 @@ export class LlamaService extends Service {
const embeddingModel =
process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large";
elizaLogger.info(
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${this.ollamaModel})`
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${this.ollamaModel})`,
);

const response = await fetch(`${ollamaUrl}/api/embeddings`, {
Expand All @@ -671,7 +676,7 @@ export class LlamaService extends Service {
const ollamaUrl =
process.env.OLLAMA_SERVER_URL || "http://localhost:11434";
elizaLogger.info(
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`,
);

const response = await fetch(`${ollamaUrl}/api/generate`, {
Expand Down Expand Up @@ -706,7 +711,7 @@ export class LlamaService extends Service {
const embeddingModel =
process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large";
elizaLogger.info(
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`,
);

const response = await fetch(`${ollamaUrl}/api/embeddings`, {
Expand All @@ -720,7 +725,7 @@ export class LlamaService extends Service {

if (!response.ok) {
throw new Error(
`Ollama embeddings request failed: ${response.statusText}`
`Ollama embeddings request failed: ${response.statusText}`,
);
}

Expand All @@ -736,8 +741,9 @@ export class LlamaService extends Service {
const tokens = this.model!.tokenize(prompt);

// tokenize the words to punish
const wordsToPunishTokens = wordsToPunish
.flatMap((word) => this.model!.tokenize(word));
const wordsToPunishTokens = wordsToPunish.flatMap((word) =>
this.model!.tokenize(word),
);

const repeatPenalty: LlamaContextSequenceRepeatPenalty = {
punishTokens: () => wordsToPunishTokens,
Expand Down

0 comments on commit cac3912

Please sign in to comment.