Skip to content

Commit

Permalink
Add OpenAI as an additional AI provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lightningRalf committed May 17, 2024
1 parent 33d4879 commit 5a9bb84
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 11 deletions.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"@inquirer/prompts": "^4.3.0",
"cross-spawn": "^7.0.3",
"typescript": "^5.4.3",
"youtube-transcript": "^1.1.0"
"youtube-transcript": "^1.1.0",
"openai": "^4.47.1"
},
"devDependencies": {
"@biomejs/biome": "latest",
Expand Down
117 changes: 110 additions & 7 deletions src/ai.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import fs from "node:fs";
import path from "node:path";
import Anthropic from "@anthropic-ai/sdk";
import { Configuration, OpenAIApi } from 'openai';
import { MessageParam } from "@anthropic-ai/sdk/resources";
import { countTokens } from "@anthropic-ai/tokenizer";
import {
MESSAGES_FOLDER,
NUMBER_OF_CHARACTERS_TO_FLUSH_TO_FILE,
lumentisFolderPath
lumentisFolderPath,
ModelProvider,
config
} from "./constants";
import {
getOutlineInferenceMessages,
Expand Down Expand Up @@ -198,23 +201,123 @@ export async function runClaudeInference(
}
}

export function getClaudeCosts(
/**
* Runs an OpenAI inference with the given messages and parameters.
* @param messages The messages to send to the OpenAI API.
* @param model The OpenAI model to use for the inference.
* @param maxOutputTokens The maximum number of output tokens to generate.
* @returns The generated response from the OpenAI API.
* @throws An error if the OpenAI API call fails.
*/
async function runOpenAIInference(
messages: MessageParam[],
model: string,
maxOutputTokens: number
) {
try {
const configuration = new Configuration({
apiKey: config.openai.apiKey,
basePath: config.openai.apiEndpoint,
});
const openai = new OpenAIApi(configuration);

// Make the API call to OpenAI
const response = await openai.createChatCompletion({
model: model,
messages: messages.map(message => ({ role: message.role, content: message.content })),
max_tokens: maxOutputTokens,
});

// Process the response
return {
success: true,
response: response.data.choices[0].message.content
};
} catch (error) {
console.error('Error during OpenAI inference:', error);
// Handle the error appropriately (e.g., throw, return an error response)
throw error;
}
}

/**
* Calculates the cost of running a model inference based on the provider and model.
* @param messages The messages to send to the model API.
* @param outputTokensExpected The expected number of output tokens.
* @param model The model to use for the inference.
* @param provider The provider of the model (claude or openai).
* @returns The calculated cost of the inference.
* @throws An error if an unsupported model provider is specified.
*/
export function getModelCosts(
messages: MessageParam[],
outputTokensExpected: number,
model: string,
provider: ModelProvider
) {
if (provider === "claude") {
return getClaudeCosts(messages, outputTokensExpected, model);
} else if (provider === "openai") {
return getOpenAICosts(messages, outputTokensExpected, model);
} else {
throw new Error(`Unsupported model provider: ${provider}`);
}
}

/**
* Calculates the cost of running a model inference based on the provider, model, and input prompt.
* @param inputPrompt The input prompt to send to the model API.
* @param outputTokensExpected The expected number of output tokens.
* @param model The model to use for the inference.
* @param provider The provider of the model (claude or openai).
* @returns The calculated cost of the inference.
* @throws An error if an unsupported model provider is specified.
*/
export function getModelCostsFromText(
inputPrompt: string,
outputTokensExpected: number,
model: string,
provider: ModelProvider
) {
if (provider === "claude") {
return getClaudeCostsFromText(inputPrompt, outputTokensExpected, model);
} else if (provider === "openai") {
return getOpenAICostsFromText(inputPrompt, outputTokensExpected, model);
} else {
throw new Error(`Unsupported model provider: ${provider}`);
}
}

/**
* Calculates the cost of running an OpenAI inference based on the messages and model.
* @param messages The messages to send to the OpenAI API.
* @param outputTokensExpected The expected number of output tokens.
* @param model The OpenAI model to use for the inference.
* @returns The calculated cost of the OpenAI inference.
*/
function getOpenAICosts(
messages: MessageParam[],
outputTokensExpected: number,
model: string
) {
const inputText: string = messages.map((m) => m.content).join("\n");
return getClaudeCostsFromText(inputText, outputTokensExpected, model);
const inputText: string = messages.map((m) => m.content).join('\n');
return getOpenAICostsFromText(inputText, outputTokensExpected, model);
}

export function getClaudeCostsFromText(
/**
* Calculates the cost of running an OpenAI inference based on the input prompt and model.
* @param inputPrompt The input prompt to send to the OpenAI API.
* @param outputTokensExpected The expected number of output tokens.
* @param model The OpenAI model to use for the inference.
* @returns The calculated cost of the OpenAI inference.
*/
function getOpenAICostsFromText(
inputPrompt: string,
outputTokensExpected: number,
model: string
) {
const inputTokens = countTokens(inputPrompt);

return getClaudeCostsWithTokens(inputTokens, outputTokensExpected, model);
return getOpenAICostsWithTokens(inputTokens, outputTokensExpected, model);
}

function getClaudeCostsWithTokens(
Expand Down
14 changes: 11 additions & 3 deletions src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ import {
import {
CLAUDE_PRIMARYSOURCE_BUDGET,
getClaudeCosts,
runClaudeInference
runClaudeInference,
runOpenAIInference //, P578e, P9ebb: Importing the runOpenAIInference function
} from "./ai";
import {
CLAUDE_MODELS,
OPENAI_MODELS, //, P578e, P9ebb: Including OPENAI_MODELS in the import
EDITORS,
LUMENTIS_FOLDER,
RUNNERS,
WRITING_STYLE_SIZE_LIMIT,
lumentisFolderPath,
wizardStatePath
wizardStatePath,
ModelProvider //, P578e, P9ebb: Importing ModelProvider type
} from "./constants";
import { generatePages, idempotentlySetupNextraDocs } from "./page-generator";
import {
Expand Down Expand Up @@ -91,7 +94,12 @@ async function runWizard() {
choices: [
...CLAUDE_MODELS.map((model) => ({
name: model.name,
value: model.model,
value: { provider: "claude", model: model.model }, //, P578e, P9ebb: Adjusting value to include provider
description: model.smarterDescription
})),
...OPENAI_MODELS.map((model) => ({ //, P578e, P9ebb: Adding OpenAI models to the choices
name: model.name,
value: { provider: "openai", model: model.model }, //, P578e, P9ebb: Adjusting value to include provider
description: model.smarterDescription
})),
new Separator()
Expand Down
19 changes: 19 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import dotenv from 'dotenv';

dotenv.config();

/**
* Configuration object holding OpenAI and Claude specific settings.
*/
export const config = {
openai: {
apiKey: process.env.OPENAI_API_KEY || '',
apiEndpoint: process.env.OPENAI_API_ENDPOINT || 'https://api.openai.com',
defaultModel: 'gpt-3.5-turbo',
},
claude: {
apiKey: process.env.ANTHROPIC_API_KEY || '',
apiEndpoint: process.env.ANTHROPIC_API_ENDPOINT || 'https://api.anthropic.com',
defaultModel: 'claude-3-haiku-20240307',
},
};
23 changes: 23 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ export const CLAUDE_MODELS = [
}
] as const;

/**
* Supported OpenAI models for meta inference and page generation.
*/
export const OPENAI_MODELS = [
{
name: "GPT-4 omni",
model: "gpt-4o",
smarterDescription: "Most capable GPT-4o model for complex tasks.",
pageDescription: "Smartest - Ideal for nuanced writing and analysis"
},
{
name: "GPT-3.5 Turbo",
model: "gpt-3.5-turbo",
smarterDescription: "Versatile and capable GPT-3.5 model.",
pageDescription: "Balanced - Good for most use cases"
}
] as const;

/**
* Type representing the supported model providers.
*/
export type ModelProvider = "claude" | "openai";

export const EDITORS = [
{ name: "nano", command: "nano" },
{ name: "vim but know you can never leave", command: "vim" },
Expand Down

0 comments on commit 5a9bb84

Please sign in to comment.