Skip to content

Commit 3c3515c

Browse files
authored
feat(playground): persist thread on storage and adjust tools (#874)
* chore(deps): upgrade ai package * feat(playground): persist thread on storage and adjust tools * refactor: clean up unused stuff and extra exports * refactor: remove generate message id function * test: adjust with id * refactor: remove extra export * feat(playground): add confirmation for clear the messages
1 parent 01b1719 commit 3c3515c

19 files changed

+1227
-136
lines changed

main/src/chat/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Export all chat-related functionality
22
export * from './types'
33
export * from './providers'
4-
export * from './storage'
4+
export * from './settings-storage'
5+
export * from './threads-storage'
6+
export * from './thread-integration'
57
export * from './mcp-tools'
68
export * from './streaming'
79
export * from './stream-utils'

main/src/chat/mcp-tools.ts

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@ import {
33
type experimental_MCPClient as MCPClient,
44
} from 'ai'
55
import type { ToolSet } from 'ai'
6-
76
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
87
import { createClient } from '@api/client'
98
import { getApiV1BetaWorkloads } from '@api/sdk.gen'
109
import { getHeaders } from '../headers'
1110
import { getToolhivePort, getToolhiveMcpPort } from '../toolhive-manager'
1211
import log from '../logger'
1312
import type { McpToolInfo } from './types'
14-
import { getEnabledMcpTools } from './storage'
13+
import { getEnabledMcpTools } from './settings-storage'
1514
import {
1615
type McpToolDefinition,
1716
createTransport,
@@ -121,7 +120,7 @@ export async function getMcpServerTools(serverName?: string): Promise<
121120
}
122121

123122
// Get enabled tools for this server
124-
const enabledTools = getEnabledMcpTools()
123+
const enabledTools = await getEnabledMcpTools()
125124
const enabledToolNames = enabledTools[serverName] || []
126125

127126
// If workload.tools is empty, try to discover tools by connecting to the server
@@ -162,10 +161,7 @@ export async function getMcpServerTools(serverName?: string): Promise<
162161

163162
// Otherwise return the original format for backward compatibility
164163
const mcpTools = (workloads || [])
165-
.filter(
166-
(workload) =>
167-
workload.name && workload.tools && workload.tool_type === 'mcp'
168-
)
164+
.filter((workload) => workload.name && workload.tools)
169165
.flatMap((workload) =>
170166
workload.tools!.map((toolName) => ({
171167
name: `mcp_${workload.name}_${toolName}`,
@@ -229,7 +225,7 @@ export async function createMcpTools(): Promise<{
229225
const workloads = data?.workloads
230226

231227
// Get enabled tools from storage
232-
const enabledTools = getEnabledMcpTools()
228+
const enabledTools = await getEnabledMcpTools()
233229

234230
// Continue with regular MCP servers even if no enabled tools (since we might have Toolhive MCP)
235231
if (Object.keys(enabledTools).length === 0) {
@@ -241,7 +237,12 @@ export async function createMcpTools(): Promise<{
241237
if (toolNames.length === 0) continue
242238

243239
const workload = workloads?.find((w) => w.name === serverName)
244-
if (!workload || workload.tool_type !== 'mcp') continue
240+
if (!workload) {
241+
log.debug(`Skipping ${serverName}: workload not found`)
242+
continue
243+
}
244+
245+
log.debug(`Found MCP workload for ${serverName}:`, workload.package)
245246

246247
try {
247248
const config = createTransport(workload)
@@ -254,13 +255,19 @@ export async function createMcpTools(): Promise<{
254255
const serverMcpTools = await mcpClient.tools()
255256

256257
// Add only the enabled tools from this server
258+
let addedToolsCount = 0
257259
for (const toolName of toolNames) {
258260
if (serverMcpTools[toolName]) {
259261
mcpTools[toolName] = serverMcpTools[toolName]
262+
addedToolsCount++
263+
} else {
264+
log.warn(`Tool ${toolName} not found in server ${serverName}`)
260265
}
261266
}
262267

263-
// MCP client created successfully
268+
log.debug(
269+
`Added ${addedToolsCount}/${toolNames.length} tools from ${serverName}`
270+
)
264271
} catch (error) {
265272
log.error(`Failed to create MCP client for ${serverName}:`, error)
266273
}

main/src/chat/storage.ts renamed to main/src/chat/settings-storage.ts

Lines changed: 103 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
import Store from 'electron-store'
22
import log from '../logger'
3+
import { getToolhivePort } from '../toolhive-manager'
4+
import { createClient } from '@api/client'
5+
import { getApiV1BetaWorkloads } from '@api/sdk.gen'
6+
import type { CoreWorkload } from '@api/types.gen'
7+
import { getHeaders } from '../headers'
8+
import { getTearingDownState } from '../app-state'
9+
10+
// Chat store types
11+
interface ChatSettingsProvider {
12+
apiKey: string
13+
enabledTools: string[]
14+
}
15+
16+
interface ChatSettingsSelectedModel {
17+
provider: string
18+
model: string
19+
}
20+
21+
interface ChatSettings {
22+
providers: Record<string, ChatSettingsProvider>
23+
selectedModel: ChatSettingsSelectedModel
24+
enabledMcpTools: Record<string, string[]> // serverName -> [toolName1, toolName2]
25+
}
326

427
// Type guard functions
528
function isRecord(value: unknown): value is Record<string, unknown> {
@@ -10,9 +33,7 @@ function isStringArray(value: unknown): value is string[] {
1033
return Array.isArray(value) && value.every((item) => typeof item === 'string')
1134
}
1235

13-
function isProvidersRecord(
14-
value: unknown
15-
): value is Record<string, ChatSettings> {
36+
function isProvidersRecord(value: unknown): value is ChatSettings['providers'] {
1637
if (!isRecord(value)) return false
1738
return Object.values(value).every(
1839
(item) =>
@@ -22,14 +43,14 @@ function isProvidersRecord(
2243
)
2344
}
2445

25-
function isToolsRecord(value: unknown): value is Record<string, string[]> {
46+
function isToolsRecord(
47+
value: unknown
48+
): value is ChatSettings['enabledMcpTools'] {
2649
if (!isRecord(value)) return false
2750
return Object.values(value).every((item) => isStringArray(item))
2851
}
2952

30-
function isSelectedModel(
31-
value: unknown
32-
): value is { provider: string; model: string } {
53+
function isSelectedModel(value: unknown): value is ChatSettingsSelectedModel {
3354
return (
3455
isRecord(value) &&
3556
typeof value.provider === 'string' &&
@@ -38,34 +59,21 @@ function isSelectedModel(
3859
}
3960

4061
// Create a secure store for chat settings (API keys and model selection)
41-
const chatStore = new Store({
62+
const chatStore = new Store<ChatSettings>({
4263
name: 'chat-settings',
4364
encryptionKey: 'toolhive-chat-encryption-key', // Basic encryption for API keys
4465
defaults: {
45-
providers: {} as Record<
46-
string,
47-
{
48-
apiKey: string
49-
enabledTools: string[]
50-
}
51-
>,
66+
providers: {},
5267
selectedModel: {
5368
provider: '',
5469
model: '',
5570
},
56-
// Individual tool enablement per server (single source of truth)
57-
enabledMcpTools: {} as Record<string, string[]>, // serverName -> [toolName1, toolName2]
71+
enabledMcpTools: {},
5872
},
5973
})
6074

61-
// Chat settings interface
62-
interface ChatSettings {
63-
apiKey: string
64-
enabledTools: string[]
65-
}
66-
6775
// Get chat settings for a provider
68-
export function getChatSettings(providerId: string): ChatSettings {
76+
export function getChatSettings(providerId: string): ChatSettingsProvider {
6977
try {
7078
const providers = chatStore.get('providers')
7179
if (isProvidersRecord(providers)) {
@@ -81,7 +89,7 @@ export function getChatSettings(providerId: string): ChatSettings {
8189
// Save chat settings for a provider
8290
export function saveChatSettings(
8391
providerId: string,
84-
settings: ChatSettings
92+
settings: ChatSettingsProvider
8593
): { success: boolean; error?: string } {
8694
try {
8795
const providers = chatStore.get('providers')
@@ -126,7 +134,7 @@ export function clearChatSettings(providerId?: string): {
126134
}
127135

128136
// Get selected model
129-
export function getSelectedModel(): { provider: string; model: string } {
137+
export function getSelectedModel(): ChatSettingsSelectedModel {
130138
try {
131139
const selectedModel = chatStore.get('selectedModel')
132140
if (
@@ -159,20 +167,6 @@ export function saveSelectedModel(
159167
}
160168
}
161169

162-
// Get enabled MCP tools for a specific server
163-
// function getEnabledMcpToolsForServer(serverName: string): string[] {
164-
// try {
165-
// const enabledMcpTools = chatStore.get('enabledMcpTools')
166-
// if (isToolsRecord(enabledMcpTools)) {
167-
// return enabledMcpTools[serverName] || []
168-
// }
169-
// return []
170-
// } catch (error) {
171-
// log.error('Failed to get enabled MCP tools:', error)
172-
// return []
173-
// }
174-
// }
175-
176170
// Save enabled MCP tools for a server
177171
export function saveEnabledMcpTools(
178172
serverName: string,
@@ -193,24 +187,87 @@ export function saveEnabledMcpTools(
193187
}
194188
}
195189

196-
// Get all enabled MCP tools (global)
197-
export function getEnabledMcpTools(): Record<string, string[]> {
190+
// Get all enabled MCP tools (global) - filters out tools from stopped servers
191+
export async function getEnabledMcpTools(): Promise<
192+
ChatSettings['enabledMcpTools']
193+
> {
198194
try {
199195
const enabledMcpTools = chatStore.get('enabledMcpTools')
200-
if (isToolsRecord(enabledMcpTools)) {
196+
if (!isToolsRecord(enabledMcpTools)) {
197+
return {}
198+
}
199+
200+
// Skip validation during shutdown to prevent interrupting teardown
201+
if (getTearingDownState()) {
202+
log.debug('Skipping MCP tools validation during teardown')
203+
return enabledMcpTools
204+
}
205+
206+
// Get running servers to filter out tools from stopped servers
207+
const port = getToolhivePort()
208+
209+
if (!port) {
210+
// If ToolHive is not running, return stored tools without validation
211+
return enabledMcpTools
212+
}
213+
214+
try {
215+
const client = createClient({
216+
baseUrl: `http://localhost:${port}`,
217+
headers: getHeaders(),
218+
})
219+
220+
const { data } = await getApiV1BetaWorkloads({
221+
client,
222+
query: { all: true },
223+
})
224+
225+
const runningServerNames = (data?.workloads || [])
226+
.filter((w: CoreWorkload) => w.status === 'running')
227+
.map((w: CoreWorkload) => w.name)
228+
229+
// Filter enabled tools to only include tools from running servers
230+
const filteredTools: ChatSettings['enabledMcpTools'] = {}
231+
const serversToRemove: string[] = []
232+
233+
for (const [serverName, tools] of Object.entries(enabledMcpTools)) {
234+
if (runningServerNames.includes(serverName)) {
235+
filteredTools[serverName] = tools
236+
} else if (tools.length > 0) {
237+
// Only log if server actually had tools to clean up
238+
log.info(`Cleaning up tools for stopped server: ${serverName}`)
239+
serversToRemove.push(serverName)
240+
}
241+
}
242+
243+
// Remove stopped servers from storage in one operation
244+
if (serversToRemove.length > 0) {
245+
const updatedTools = { ...enabledMcpTools }
246+
for (const serverName of serversToRemove) {
247+
delete updatedTools[serverName]
248+
}
249+
chatStore.set('enabledMcpTools', updatedTools)
250+
}
251+
252+
return filteredTools
253+
} catch (apiError) {
254+
log.warn(
255+
'Failed to check running servers during shutdown, returning stored tools:',
256+
apiError
257+
)
258+
// During shutdown, just return stored tools without validation
201259
return enabledMcpTools
202260
}
203-
return {}
204261
} catch (error) {
205262
log.error('Failed to get all enabled MCP tools:', error)
206263
return {}
207264
}
208265
}
209266

210267
// Get enabled MCP servers from tools (get servers that have enabled tools)
211-
export function getEnabledMcpServersFromTools(): string[] {
268+
export async function getEnabledMcpServersFromTools(): Promise<string[]> {
212269
try {
213-
const allEnabledTools = getEnabledMcpTools()
270+
const allEnabledTools = await getEnabledMcpTools()
214271
const enabledServerNames = Object.keys(allEnabledTools).filter(
215272
(serverName) => {
216273
const tools = allEnabledTools[serverName]

0 commit comments

Comments
 (0)