From 49009ad70a46aa47148dae83aba33fdc73bdaca7 Mon Sep 17 00:00:00 2001 From: BalanaguYashwanth Date: Tue, 31 Dec 2024 19:22:03 +0530 Subject: [PATCH] Add grok model --- scripts/tweets2character.js | 41 +++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/scripts/tweets2character.js b/scripts/tweets2character.js index 9fa87ca..d5bb0d9 100755 --- a/scripts/tweets2character.js +++ b/scripts/tweets2character.js @@ -124,6 +124,40 @@ const runChatCompletion = async (messages, useGrammar = false, model) => { const content = data.content[0].text; const parsed = parseJsonFromMarkdown(content) || JSON.parse(content); return parsed; + } else if(model === 'grok'){ + const modelName = 'grok-beta'; + const response = await fetch('https://api.x.ai/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${process.env.GROK_API_KEY}`, + }, + body: JSON.stringify({ + model: modelName, + temperature: 0, + stream: false, + messages: [ + { + role: "user", + content: messages[0].content + } + ], + }), + }); + + if (response.status === 429) { + await new Promise(resolve => setTimeout(resolve, 30000)); + return runChatCompletion(messages, useGrammar, model); + } + + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + const data = await response.json(); + const content = data.choices[0].message.content.trim(); + const parsed = parseJsonFromMarkdown(content) || JSON.parse(content); + return parsed; } }; @@ -407,6 +441,7 @@ process.on('unhandledRejection', (reason, promise) => { program .option('--openai ', 'OpenAI API key') .option('--claude ', 'Claude API key') + .option('--grok ', 'Grok API key') .parse(process.argv); const limitConcurrency = async (tasks, concurrencyLimit) => { @@ -472,6 +507,8 @@ const validateApiKey = (apiKey, model) => { return apiKey.trim().startsWith('sk-'); } else if (model === 'claude') { return apiKey.trim().length > 0; + } else if (model === 'grok') { + return apiKey.trim().startsWith('xai-'); } return false; }; @@ -494,7 +531,7 @@ const resumeOrStartNewSession = async (projectCache, archivePath) => { } if (!projectCache.unfinishedSession) { - projectCache.model = await promptUser('Select model (openai/claude): '); + projectCache.model = await promptUser('Select model (openai/claude/grok): '); projectCache.basicUserInfo = await promptUser('Enter additional user info that might help the summarizer (real name, nicknames and handles, age, past employment vs current, etc): '); projectCache.unfinishedSession = { currentChunk: 0, @@ -531,7 +568,7 @@ const main = async () => { let projectCache = loadProjectCache(archivePath) || {}; projectCache = await resumeOrStartNewSession(projectCache, archivePath); - + const apiKey = await getApiKey(projectCache.model); if (!apiKey) { throw new Error(`Failed to get a valid API key for ${projectCache.model}`);