Skip to content

Commit

Permalink
feat: Implement batch processing for Ollama embeddings
Browse files Browse the repository at this point in the history
This change introduces batch processing for the Ollama embeddings node. Previously, the
node processed each document separately, making multiple API requests. The updated
implementation now processes all documents in a single batch, reducing the number of
API requests and improving the overall performance.

The key changes are:

- Create a single array to store the embeddings for all documents
- Iterate over the input documents and fetch the embeddings for each one
- Populate the embeddings array with the results from the Ollama API
- Return the embeddings array as the output of the node

This change aims to optimize the performance of the Ollama embeddings node by
processing multiple documents in a single batch, reducing the number of API
requests and improving the overall efficiency of the node.
  • Loading branch information
cddigi authored and gitbutler-client committed Apr 21, 2024
1 parent bf94543 commit fda9a6d
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions src/nodes/OllamaEmbeddingsNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,56 +137,63 @@ export const ollamaEmbed = (rivet: typeof Rivet) => {
"string[]",
);

const prompt = docs[0];
const embeddings: number[][] = new Array(docs.length)
.fill(0)
.map(() => new Array(512).fill(0.0));

let apiResponse: Response;
for (let i = 0; docs.length - 1; i++) {
const prompt = docs[i];
let apiResponse: Response;

type RequestBodyType = {
model: string;
prompt: string;
};
type RequestBodyType = {
model: string;
prompt: string;
};

const requestBody: RequestBodyType = {
model,
prompt,
};
const requestBody: RequestBodyType = {
model,
prompt,
};

try {
apiResponse = await fetch(`${host}/api/embeddings`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
});
} catch (err) {
throw new Error(`Error from Ollama: ${rivet.getError(err).message}`);
}

if (!apiResponse.ok) {
try {
const error = await apiResponse.json();
throw new Error(`Error from Ollama: ${error.message}`);
apiResponse = await fetch(`${host}/api/embeddings`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
});
} catch (err) {
throw new Error(`Error from Ollama: ${apiResponse.statusText}`);
throw new Error(`Error from Ollama: ${rivet.getError(err).message}`);
}

if (!apiResponse.ok) {
try {
const error = await apiResponse.json();
throw new Error(`Error from Ollama: ${error.message}`);
} catch (err) {
throw new Error(`Error from Ollama: ${apiResponse.statusText}`);
}
}
}

const reader = apiResponse.body?.getReader();
const reader = apiResponse.body?.getReader();

if (!reader) {
throw new Error("No response body!");
}
if (!reader) {
throw new Error("No response body!");
}

let finalResponse: OllamaEmmbeddingsResponse | undefined;
let finalResponse: OllamaEmmbeddingsResponse | undefined;

if (!finalResponse) {
throw new Error("No final response from Ollama!");
if (!finalResponse) {
throw new Error("No final response from Ollama!");
}

embeddings[i] = finalResponse.embedding;
}

outputs["embeddings" as PortId] = {
type: "vector[]",
value: docsToEmbeddings(docs).embedding,
value: embeddings,
};

return outputs;
Expand All @@ -195,10 +202,3 @@ export const ollamaEmbed = (rivet: typeof Rivet) => {

return rivet.pluginNodeDefinition(impl, "Ollama Embeddings");
};

function docsToEmbeddings(document: string): OllamaEmmbeddingsResponse[] {
const result: OllamaEmmbeddingsResponse = {
embedding: [1, 2, 3],
};
return [result, result];
}

0 comments on commit fda9a6d

Please sign in to comment.