Skip to content

Commit

Permalink
fix tool nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard committed Jul 2, 2024
1 parent da5b3b1 commit 8ecd415
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 109 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Nothing unreleased!

## [1.1.306] - 2024-07-02
## [1.1.306] - 2024-07-03

### Added

Expand All @@ -19,6 +19,8 @@ Nothing unreleased!
### Fixed

- Message are now collapsible if too long
- Only first level tool calls are displayed
- OAuth redirection when mounting Chainlit on a FastAPI app should now work
- The Langchain callback handler should better capture chain runs
- The Llama Index callback handler should now work with other decorators

Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def get_user_facing_url(url: URL):
if config_url.path.endswith("/"):
config_url = config_url.replace(path=config_url.path[:-1])

return config_url.__str__() + url.path
return config_url.__str__()


@router.get("/auth/config")
Expand Down
4 changes: 2 additions & 2 deletions cypress/e2e/step/spec.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ describeSyncAsync('Step', () => {

cy.get('#tool-call-tool1').should('exist').click();

cy.get('#tool-call-tool2').should('exist').click();
cy.get('#tool-call-tool2').should('not.exist');

cy.get('#tool-call-tool3').should('exist').click();
cy.get('#tool-call-tool3').should('not.exist');

cy.get('.step').should('have.length', 2);
});
Expand Down
7 changes: 0 additions & 7 deletions cypress/e2e/streaming/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,3 @@ async def main():
await cl.sleep(0.2)

await step.send()

step = cl.Step(type="tool", name="tool2")
for seq in sequence_list:
await step.stream_token(token=seq, is_sequence=True)
await cl.sleep(0.2)

await step.send()
2 changes: 0 additions & 2 deletions cypress/e2e/streaming/spec.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ describe('Streaming', () => {

toolStream('tool1');

toolStream('tool2');

cy.get('.step').should('have.length', 3);
});
});
10 changes: 2 additions & 8 deletions frontend/src/components/molecules/messages/Message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { MessageContext } from 'contexts/MessageContext';
import { memo, useContext } from 'react';

import Box from '@mui/material/Box';
import Skeleton from '@mui/material/Skeleton';
import Stack from '@mui/material/Stack';

import { AskUploadButton } from './components/AskUploadButton';
Expand Down Expand Up @@ -40,7 +39,6 @@ const Message = memo(
onError
} = useContext(MessageContext);
const layoutMaxWidth = useLayoutMaxWidth();

const isAsk = message.waitForAnswer;
const isUserMessage = message.type === 'user_message';

Expand Down Expand Up @@ -94,14 +92,10 @@ const Message = memo(
direction="row"
gap="1rem"
alignItems="center"
my={0.5}
my={2}
width="100%"
>
<Skeleton
variant="circular"
width="1.6rem"
height="1.6rem"
/>
<MessageAvatar />
<BlinkingCursor />
</Stack>
)}
Expand Down
10 changes: 3 additions & 7 deletions frontend/src/components/molecules/messages/Messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ const Messages = memo(
({ messages, elements, actions, indent, isRunning }: Props) => {
const messageContext = useContext(MessageContext);

const isRoot = indent === 0;

const filtered = messages.filter((m, i) => {
const content = m.output;
const hasContent = !!content;
Expand All @@ -38,7 +36,7 @@ const Messages = memo(
(!hasContent && messageRunning)
);
});

console.log(messages);
return (
<>
{filtered.map((m, i) => {
Expand All @@ -51,11 +49,9 @@ const Messages = memo(
const showAvatar = typeIsDifferent || authorIsDifferent;

const isLast = filtered.length - 1 === i;
let messageRunning =
const messageRunning =
isRunning === undefined ? messageContext.loading : isRunning;
if (isRoot) {
messageRunning = messageRunning && isLast;
}

return (
<Message
message={m}
Expand Down
23 changes: 11 additions & 12 deletions frontend/src/components/molecules/messages/ToolCalls.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ interface Props {
isRunning?: boolean;
}

function groupToolSteps(step: IStep) {
function groupToolSteps(step: IStep): IStep[][] {
const groupedSteps: IStep[][] = [];

let currentGroup: IStep[] = [];

function traverseAndGroup(currentStep: IStep) {
function processStep(currentStep: IStep) {
if (currentStep.type === 'tool') {
if (
currentGroup.length === 0 ||
Expand All @@ -26,19 +25,19 @@ function groupToolSteps(step: IStep) {
currentGroup.push(currentStep);
} else {
groupedSteps.push(currentGroup);

currentGroup = [currentStep];
}
}

if (currentStep.steps) {
for (const childStep of currentStep.steps) {
traverseAndGroup(childStep);
} else if (currentStep.steps) {
// If we haven't found any tools yet, recurse into the steps
if (groupedSteps.length === 0 && currentGroup.length === 0) {
for (const childStep of currentStep.steps) {
processStep(childStep);
}
}
}
}

traverseAndGroup(step);
processStep(step);

// Push the last group if it exists
if (currentGroup.length > 0) {
Expand All @@ -59,10 +58,10 @@ export default function ToolCalls({ message, elements, isRunning }: Props) {

return (
<Stack width="100%" direction="column" gap={1}>
{toolCalls.map((toolCall, index) => (
{toolCalls.map((toolCalls, index) => (
<ToolCall
key={index}
steps={toolCall}
steps={toolCalls}
elements={elements}
isRunning={isRunning}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
} from '@chainlit/react-client';

interface Props {
author: string;
author?: string;
hide?: boolean;
}

Expand Down

This file was deleted.

67 changes: 30 additions & 37 deletions libs/react-client/src/utils/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ const isLastMessage = (messages: IStep[], index: number) => {
return true;
};

const addToolMessage = (messages: IStep[], toolMessage: IStep): IStep[] => {
const parentMessage = toolMessage.parentId
? findMessageById(messages, toolMessage.parentId)
: undefined;
if (parentMessage && parentMessage.type !== 'user_message') {
return addMessageToParent(messages, parentMessage.id, toolMessage);
}

return [
...messages,
{
...toolMessage,
name: '',
input: '',
output: '',
id: 'wrap_' + toolMessage.id,
type: 'assistant_message',
steps: [toolMessage]
}
];
};

// Nested messages utils

const addMessage = (messages: IStep[], message: IStep): IStep[] => {
Expand All @@ -41,43 +63,10 @@ const addMessage = (messages: IStep[], message: IStep): IStep[] => {
return messages;
}

const parentMessage = !isRoot
? findMessageById(messages, message.parentId!)
: undefined;

const shouldWrap =
(isRoot || parentMessage?.type !== 'assistant_message') &&
message.type === 'tool';

if (hasMessageById(messages, message.id)) {
return updateMessageById(messages, message.id, message);
} else if (shouldWrap) {
const lastMessage =
messages.length > 0 ? messages[messages.length - 1] : undefined;
const collapseTool =
lastMessage?.type === 'assistant_message' &&
lastMessage?.id.startsWith('wrap_');
if (lastMessage && collapseTool) {
return [
...messages.slice(0, messages.length - 1),
{
...lastMessage,
steps: [...(lastMessage.steps || []), message]
}
];
}
return [
...messages,
{
...message,
name: '',
input: '',
output: '',
id: 'wrap_' + message.id,
type: 'assistant_message',
steps: [message]
}
];
} else if (message.type === 'tool') {
return addToolMessage(messages, message);
} else if (!isMessageType && 'parentId' in message && message.parentId) {
return addMessageToParent(messages, message.parentId, message);
} else if ('indent' in message && message.indent && message.indent > 0) {
Expand Down Expand Up @@ -150,7 +139,11 @@ const findMessageById = (
for (const message of messages) {
if (isEqual(message.id, messageId)) {
return message;
} else if (message.steps && message.steps.length > 0) {
} else if (
message.steps &&
message.type !== 'user_message' &&
message.steps.length > 0
) {
const foundMessage = findMessageById(message.steps, messageId);
if (foundMessage) {
return foundMessage;
Expand Down Expand Up @@ -191,7 +184,7 @@ const deleteMessageById = (messages: IStep[], messageId: string) => {
for (let index = 0; index < nextMessages.length; index++) {
const msg = nextMessages[index];

if (msg.id === messageId || msg.id === 'wrap_' + messageId) {
if (msg.id === messageId) {
nextMessages = [
...nextMessages.slice(0, index),
...nextMessages.slice(index + 1)
Expand Down

0 comments on commit 8ecd415

Please sign in to comment.