Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backend/chainlit/data/chainlit_data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ async def create_step(self, step_dict: StepDict):
query = """
INSERT INTO "Step" (
id, "threadId", "parentId", input, metadata, name, output,
type, "startTime", "endTime", "showInput", "isError"
type, "startTime", "endTime", "showInput", "isError", icon
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
)
ON CONFLICT (id) DO UPDATE SET
"parentId" = COALESCE(EXCLUDED."parentId", "Step"."parentId"),
Expand All @@ -359,7 +359,8 @@ async def create_step(self, step_dict: StepDict):
"endTime" = COALESCE(EXCLUDED."endTime", "Step"."endTime"),
"startTime" = LEAST(EXCLUDED."startTime", "Step"."startTime"),
"showInput" = COALESCE(EXCLUDED."showInput", "Step"."showInput"),
"isError" = COALESCE(EXCLUDED."isError", "Step"."isError")
"isError" = COALESCE(EXCLUDED."isError", "Step"."isError"),
icon = COALESCE(EXCLUDED.icon, "Step".icon)
"""

timestamp = await self.get_current_timestamp()
Expand All @@ -380,6 +381,7 @@ async def create_step(self, step_dict: StepDict):
"end_time": timestamp,
"show_input": str(step_dict.get("showInput", "json")),
"is_error": step_dict.get("isError", False),
"icon": step_dict.get("icon"),
}
await self.execute_query(query, params)

Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/data/literalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ async def create_step(self, step_dict: "StepDict"):
waitForAnswer=step_dict.get("waitForAnswer"),
language=step_dict.get("language"),
showInput=step_dict.get("showInput"),
icon=step_dict.get("icon"),
)

step: LiteralStepDict = {
Expand Down
8 changes: 8 additions & 0 deletions backend/chainlit/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class StepDict(TypedDict, total=False):
showInput: Optional[Union[bool, str]]
defaultOpen: Optional[bool]
language: Optional[str]
icon: Optional[str]
feedback: Optional[FeedbackDict]


Expand All @@ -83,6 +84,7 @@ def step(
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
language: Optional[str] = None,
icon: Optional[str] = None,
show_input: Union[bool, str] = "json",
default_open: bool = False,
):
Expand All @@ -106,6 +108,7 @@ async def async_wrapper(*args, **kwargs):
parent_id=parent_id,
tags=tags,
language=language,
icon=icon,
show_input=show_input,
default_open=default_open,
metadata=metadata,
Expand Down Expand Up @@ -135,6 +138,7 @@ def sync_wrapper(*args, **kwargs):
parent_id=parent_id,
tags=tags,
language=language,
icon=icon,
show_input=show_input,
default_open=default_open,
metadata=metadata,
Expand Down Expand Up @@ -182,6 +186,7 @@ class Step:
end: Union[str, None]
generation: Optional[BaseGeneration]
language: Optional[str]
icon: Optional[str]
default_open: Optional[bool]
elements: Optional[List[Element]]
fail_on_persist_error: bool
Expand All @@ -196,6 +201,7 @@ def __init__(
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
language: Optional[str] = None,
icon: Optional[str] = None,
default_open: Optional[bool] = False,
show_input: Union[bool, str] = "json",
thread_id: Optional[str] = None,
Expand All @@ -214,6 +220,7 @@ def __init__(
self.parent_id = parent_id

self.language = language
self.icon = icon
self.default_open = default_open
self.generation = None
self.elements = elements or []
Expand Down Expand Up @@ -303,6 +310,7 @@ def to_dict(self) -> StepDict:
"start": self.start,
"end": self.end,
"language": self.language,
"icon": self.icon,
"defaultOpen": self.default_open,
"showInput": self.show_input,
"generation": self.generation.to_dict() if self.generation else None,
Expand Down
3 changes: 3 additions & 0 deletions backend/tests/data/test_literalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_step_dict(test_thread) -> StepDict:
"waitForAnswer": True,
"showInput": True,
"language": "en",
"icon": "search",
}


Expand Down Expand Up @@ -179,6 +180,7 @@ async def test_create_step(
"waitForAnswer": True,
"language": "en",
"showInput": True,
"icon": "search",
},
"input": {"content": "test input"},
"output": {"content": "test output"},
Expand Down Expand Up @@ -768,6 +770,7 @@ async def test_update_step(
"waitForAnswer": True,
"language": "en",
"showInput": True,
"icon": "search",
},
"input": {"content": "test input"},
"output": {"content": "test output"},
Expand Down
1 change: 1 addition & 0 deletions backend/tests/data/test_sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"generation" JSONB,
"showInput" TEXT,
"language" TEXT,
"icon" TEXT,
"indent" INT
);
"""
Expand Down
46 changes: 46 additions & 0 deletions backend/tests/test_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ async def test_send_step(
mock_websocket_session.emit.assert_called_once_with("new_message", step_dict)


async def test_send_step_with_icon(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
step_dict: StepDict = {
"id": "test_step_with_icon",
"type": "tool",
"name": "Test Step with Icon",
"output": "This is a test step with an icon",
"icon": "search",
}

await emitter.send_step(step_dict)

mock_websocket_session.emit.assert_called_once_with("new_message", step_dict)


async def test_update_step(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
Expand All @@ -69,6 +85,22 @@ async def test_update_step(
mock_websocket_session.emit.assert_called_once_with("update_message", step_dict)


async def test_update_step_with_icon(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
step_dict: StepDict = {
"id": "test_step_with_icon",
"type": "tool",
"name": "Updated Test Step with Icon",
"output": "This is an updated test step with an icon",
"icon": "database",
}

await emitter.update_step(step_dict)

mock_websocket_session.emit.assert_called_once_with("update_message", step_dict)


async def test_delete_step(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
Expand Down Expand Up @@ -139,6 +171,20 @@ async def test_stream_start(
mock_websocket_session.emit.assert_called_once_with("stream_start", step_dict)


async def test_stream_start_with_icon(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
step_dict: StepDict = {
"id": "test_stream_with_icon",
"type": "tool",
"name": "Test Stream with Icon",
"output": "This is a test stream with an icon",
"icon": "cpu",
}
await emitter.stream_start(step_dict)
mock_websocket_session.emit.assert_called_once_with("stream_start", step_dict)


async def test_send_toast(
emitter: ChainlitEmitter, mock_websocket_session: MagicMock
) -> None:
Expand Down
33 changes: 33 additions & 0 deletions cypress/e2e/step_icon/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import chainlit as cl


@cl.step(name="search", type="tool", icon="search")
async def search():
await cl.sleep(1)
return "Response from search"


@cl.step(name="database", type="tool", icon="database")
async def database():
await cl.sleep(1)
return "Response from database"


@cl.step(name="regular", type="tool")
async def regular():
await cl.sleep(1)
return "Response from regular"


async def cpu():
async with cl.Step(name="cpu", type="tool", icon="cpu") as s:
await cl.sleep(1)
s.output = "Response from cpu"


@cl.on_message
async def main(message: cl.Message):
await search()
await database()
await regular()
await cpu()
42 changes: 42 additions & 0 deletions cypress/e2e/step_icon/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { submitMessage } from '../../support/testUtils';

describe('Step with Icon', () => {
it('should display icons for steps with icon property', () => {
submitMessage('Hello');

cy.get('.step').should('have.length', 5);

// Check that steps with icons have SVG icons (not avatar images)
// The avatar is a sibling of the step content in the .ai-message container
cy.get('#step-search')
.closest('.ai-message')
.within(() => {
// Should have an svg icon (Lucide icons are SVGs)
cy.get('svg').should('exist');
// Should NOT have an avatar image
cy.get('img').should('not.exist');
});

cy.get('#step-database')
.closest('.ai-message')
.within(() => {
cy.get('svg').should('exist');
cy.get('img').should('not.exist');
});

// Check that step without icon has avatar (image)
cy.get('#step-regular')
.closest('.ai-message')
.within(() => {
// Should have an avatar image
cy.get('img').should('exist');
});

cy.get('#step-cpu')
.closest('.ai-message')
.within(() => {
cy.get('svg').should('exist');
cy.get('img').should('not.exist');
});
});
});
35 changes: 22 additions & 13 deletions frontend/src/components/chat/Messages/Message/Avatar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
useConfig
} from '@chainlit/react-client';

import Icon from '@/components/Icon';
import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar';
import { Skeleton } from '@/components/ui/skeleton';
import {
Expand All @@ -21,9 +22,10 @@ interface Props {
author?: string;
hide?: boolean;
isError?: boolean;
iconName?: string;
}

const MessageAvatar = ({ author, hide, isError }: Props) => {
const MessageAvatar = ({ author, hide, isError, iconName }: Props) => {
const apiClient = useContext(ChainlitContext);
const { chatProfile } = useChatSession();
const { config } = useConfig();
Expand All @@ -48,22 +50,29 @@ const MessageAvatar = ({ author, hide, isError }: Props) => {
);
}

// Render icon or avatar based on iconName
const avatarContent = iconName ? (
<span className="inline-flex">
<Icon name={iconName} className="h-5 w-5 mt-[3px]" />
</span>
) : (
<Avatar className="h-5 w-5 mt-[3px]">
<AvatarImage
src={avatarUrl}
alt={`Avatar for ${author || 'default'}`}
className="bg-transparent"
/>
<AvatarFallback className="bg-transparent">
<Skeleton className="h-full w-full rounded-full" />
</AvatarFallback>
</Avatar>
);

return (
<span className={cn('inline-block', hide && 'invisible')}>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Avatar className="h-5 w-5 mt-[3px]">
<AvatarImage
src={avatarUrl}
alt={`Avatar for ${author || 'default'}`}
className="bg-transparent"
/>
<AvatarFallback className="bg-transparent">
<Skeleton className="h-full w-full rounded-full" />
</AvatarFallback>
</Avatar>
</TooltipTrigger>
<TooltipTrigger asChild>{avatarContent}</TooltipTrigger>
<TooltipContent>
<p>{author}</p>
</TooltipContent>
Expand Down
1 change: 1 addition & 0 deletions frontend/src/components/chat/Messages/Message/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ const Message = memo(
<MessageAvatar
author={message.metadata?.avatarName || message.name}
isError={message.isError}
iconName={message.icon}
/>
) : null}
{/* Display the step and its children */}
Expand Down
1 change: 1 addition & 0 deletions libs/react-client/src/types/step.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export interface IStep {
id: string;
name: string;
type: StepType;
icon?: string;
threadId?: string;
parentId?: string;
isError?: boolean;
Expand Down
Loading