From d7f7580a3082ab54f48cd5ee5d5fba22c121acb2 Mon Sep 17 00:00:00 2001 From: "koalazf.99" Date: Mon, 18 Aug 2025 22:34:08 +0800 Subject: [PATCH 01/16] support: qwen md selection --- packages/core/src/tools/memoryTool.test.ts | 62 ++++++++--- packages/core/src/tools/memoryTool.ts | 113 +++++++++++++++------ 2 files changed, 131 insertions(+), 44 deletions(-) diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts index 75a2c08a0..edae8c42d 100644 --- a/packages/core/src/tools/memoryTool.test.ts +++ b/packages/core/src/tools/memoryTool.test.ts @@ -203,7 +203,7 @@ describe('MemoryTool', () => { }); it('should call performAddMemoryEntry with correct parameters and return success', async () => { - const params = { fact: 'The sky is blue' }; + const params = { fact: 'The sky is blue', scope: 'global' as const }; const result = await memoryTool.execute(params, mockAbortSignal); // Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test const expectedFilePath = path.join( @@ -224,7 +224,7 @@ describe('MemoryTool', () => { expectedFilePath, expectedFsArgument, ); - const successMessage = `Okay, I've remembered that: "${params.fact}"`; + const successMessage = `Okay, I've remembered that in global memory: "${params.fact}"`; expect(result.llmContent).toBe( JSON.stringify({ success: true, message: successMessage }), ); @@ -244,7 +244,7 @@ describe('MemoryTool', () => { }); it('should handle errors from performAddMemoryEntry', async () => { - const params = { fact: 'This will fail' }; + const params = { fact: 'This will fail', scope: 'global' as const }; const underlyingError = new Error( '[MemoryTool] Failed to add memory entry: Disk full', ); @@ -276,7 +276,7 @@ describe('MemoryTool', () => { }); it('should return confirmation details when memory file is not allowlisted', async () => { - const params = { fact: 'Test fact' }; + const params = { fact: 'Test fact', scope: 'global' as const }; const result = await memoryTool.shouldConfirmExecute( params, mockAbortSignal, @@ -287,7 +287,9 @@ describe('MemoryTool', () => { if (result && result.type === 'edit') { const expectedPath = path.join('~', '.qwen', 'QWEN.md'); - expect(result.title).toBe(`Confirm Memory Save: ${expectedPath}`); + expect(result.title).toBe( + `Confirm Memory Save: ${expectedPath} (global)`, + ); expect(result.fileName).toContain(path.join('mock', 'home', '.qwen')); expect(result.fileName).toContain('QWEN.md'); expect(result.fileDiff).toContain('Index: QWEN.md'); @@ -300,16 +302,16 @@ describe('MemoryTool', () => { }); it('should return false when memory file is already allowlisted', async () => { - const params = { fact: 'Test fact' }; + const params = { fact: 'Test fact', scope: 'global' as const }; const memoryFilePath = path.join( os.homedir(), '.qwen', getCurrentGeminiMdFilename(), ); - // Add the memory file to the allowlist + // Add the memory file to the allowlist with the new key format (MemoryTool as unknown as { allowlist: Set }).allowlist.add( - memoryFilePath, + `${memoryFilePath}_global`, ); const result = await memoryTool.shouldConfirmExecute( @@ -321,7 +323,7 @@ describe('MemoryTool', () => { }); it('should add memory file to allowlist when ProceedAlways is confirmed', async () => { - const params = { fact: 'Test fact' }; + const params = { fact: 'Test fact', scope: 'global' as const }; const memoryFilePath = path.join( os.homedir(), '.qwen', @@ -340,10 +342,10 @@ describe('MemoryTool', () => { // Simulate the onConfirm callback await result.onConfirm(ToolConfirmationOutcome.ProceedAlways); - // Check that the memory file was added to the allowlist + // Check that the memory file was added to the allowlist with the new key format expect( (MemoryTool as unknown as { allowlist: Set }).allowlist.has( - memoryFilePath, + `${memoryFilePath}_global`, ), ).toBe(true); } @@ -384,7 +386,7 @@ describe('MemoryTool', () => { }); it('should handle existing memory file with content', async () => { - const params = { fact: 'New fact' }; + const params = { fact: 'New fact', scope: 'global' as const }; const existingContent = 'Some existing content.\n\n## Qwen Added Memories\n- Old fact\n'; @@ -401,7 +403,9 @@ describe('MemoryTool', () => { if (result && result.type === 'edit') { const expectedPath = path.join('~', '.qwen', 'QWEN.md'); - expect(result.title).toBe(`Confirm Memory Save: ${expectedPath}`); + expect(result.title).toBe( + `Confirm Memory Save: ${expectedPath} (global)`, + ); expect(result.fileDiff).toContain('Index: QWEN.md'); expect(result.fileDiff).toContain('+- New fact'); expect(result.originalContent).toBe(existingContent); @@ -409,5 +413,37 @@ describe('MemoryTool', () => { expect(result.newContent).toContain('- New fact'); } }); + + it('should prompt for scope selection when scope is not specified', async () => { + const params = { fact: 'Test fact' }; + const result = await memoryTool.shouldConfirmExecute( + params, + mockAbortSignal, + ); + + expect(result).toBeDefined(); + expect(result).not.toBe(false); + + if (result && result.type === 'edit') { + expect(result.title).toBe('Choose Memory Storage Location'); + expect(result.fileName).toBe('Memory Storage Options'); + expect(result.fileDiff).toContain('Choose where to save this memory'); + expect(result.fileDiff).toContain('Test fact'); + expect(result.fileDiff).toContain('Global:'); + expect(result.fileDiff).toContain('Project:'); + expect(result.originalContent).toBe(''); + } + }); + + it('should return error when executing without scope parameter', async () => { + const params = { fact: 'Test fact' }; + const result = await memoryTool.execute(params, mockAbortSignal); + + expect(result.llmContent).toContain( + 'Please specify where to save this memory', + ); + expect(result.returnDisplay).toContain('Global:'); + expect(result.returnDisplay).toContain('Project:'); + }); }); }); diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 2b20735db..b0f4bb5e4 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -32,6 +32,12 @@ const memoryToolSchemaData: FunctionDeclaration = { description: 'The specific fact or piece of information to remember. Should be a clear, self-contained statement.', }, + scope: { + type: Type.STRING, + description: + 'Where to save the memory: "global" saves to user-level ~/.qwen/QWEN.md (shared across all projects), "project" saves to current project\'s QWEN.md (project-specific). If not specified, will prompt user to choose.', + enum: ['global', 'project'], + }, }, required: ['fact'], }, @@ -54,6 +60,10 @@ Do NOT use this tool: ## Parameters - \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue". +- \`scope\` (string, optional): Where to save the memory: + - "global": Saves to user-level ~/.qwen/QWEN.md (shared across all projects) + - "project": Saves to current project's QWEN.md (project-specific) + - If not specified, the tool will ask the user where they want to save the memory. `; export const GEMINI_CONFIG_DIR = '.qwen'; @@ -92,12 +102,23 @@ interface SaveMemoryParams { fact: string; modified_by_user?: boolean; modified_content?: string; + scope?: 'global' | 'project'; } function getGlobalMemoryFilePath(): string { return path.join(homedir(), GEMINI_CONFIG_DIR, getCurrentGeminiMdFilename()); } +function getProjectMemoryFilePath(): string { + return path.join(process.cwd(), getCurrentGeminiMdFilename()); +} + +function getMemoryFilePath(scope: 'global' | 'project' = 'global'): string { + return scope === 'project' + ? getProjectMemoryFilePath() + : getGlobalMemoryFilePath(); +} + /** * Ensures proper newline separation before appending content. */ @@ -127,17 +148,20 @@ export class MemoryTool ); } - getDescription(_params: SaveMemoryParams): string { - const memoryFilePath = getGlobalMemoryFilePath(); - return `in ${tildeifyPath(memoryFilePath)}`; + getDescription(params: SaveMemoryParams): string { + const scope = params.scope || 'global'; + const memoryFilePath = getMemoryFilePath(scope); + return `in ${tildeifyPath(memoryFilePath)} (${scope})`; } /** * Reads the current content of the memory file */ - private async readMemoryFileContent(): Promise { + private async readMemoryFileContent( + scope: 'global' | 'project' = 'global', + ): Promise { try { - return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); + return await fs.readFile(getMemoryFilePath(scope), 'utf-8'); } catch (err) { const error = err as Error & { code?: string }; if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; @@ -193,15 +217,35 @@ export class MemoryTool params: SaveMemoryParams, _abortSignal: AbortSignal, ): Promise { - const memoryFilePath = getGlobalMemoryFilePath(); - const allowlistKey = memoryFilePath; + // If scope is not specified, prompt the user to choose + if (!params.scope) { + const globalPath = tildeifyPath(getMemoryFilePath('global')); + const projectPath = tildeifyPath(getMemoryFilePath('project')); + + const confirmationDetails: ToolEditConfirmationDetails = { + type: 'edit', + title: `Choose Memory Storage Location`, + fileName: 'Memory Storage Options', + fileDiff: `Choose where to save this memory:\n\n"${params.fact}"\n\nOptions:\n- Global: ${globalPath} (shared across all projects)\n- Project: ${projectPath} (current project only)\n\nPlease specify the scope parameter: "global" or "project"`, + originalContent: '', + newContent: `Memory to save: ${params.fact}\n\nScope options:\n- global: ${globalPath}\n- project: ${projectPath}`, + onConfirm: async (_outcome: ToolConfirmationOutcome) => { + // This will be handled by the execution flow + }, + }; + return confirmationDetails; + } + + const scope = params.scope; + const memoryFilePath = getMemoryFilePath(scope); + const allowlistKey = `${memoryFilePath}_${scope}`; if (MemoryTool.allowlist.has(allowlistKey)) { return false; } // Read current content of the memory file - const currentContent = await this.readMemoryFileContent(); + const currentContent = await this.readMemoryFileContent(scope); // Calculate the new content that will be written to the memory file const newContent = this.computeNewContent(currentContent, params.fact); @@ -218,7 +262,7 @@ export class MemoryTool const confirmationDetails: ToolEditConfirmationDetails = { type: 'edit', - title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`, + title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)} (${scope})`, fileName: memoryFilePath, fileDiff, originalContent: currentContent, @@ -316,18 +360,27 @@ export class MemoryTool }; } + // If scope is not specified, prompt the user to choose + if (!params.scope) { + const errorMessage = + 'Please specify where to save this memory. Use scope parameter: "global" for user-level (~/.qwen/QWEN.md) or "project" for current project (./QWEN.md).'; + return { + llmContent: JSON.stringify({ success: false, error: errorMessage }), + returnDisplay: `${errorMessage}\n\nGlobal: ${tildeifyPath(getMemoryFilePath('global'))}\nProject: ${tildeifyPath(getMemoryFilePath('project'))}`, + }; + } + + const scope = params.scope; + const memoryFilePath = getMemoryFilePath(scope); + try { if (modified_by_user && modified_content !== undefined) { // User modified the content in external editor, write it directly - await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + await fs.mkdir(path.dirname(memoryFilePath), { recursive: true, }); - await fs.writeFile( - getGlobalMemoryFilePath(), - modified_content, - 'utf-8', - ); - const successMessage = `Okay, I've updated the memory file with your modifications.`; + await fs.writeFile(memoryFilePath, modified_content, 'utf-8'); + const successMessage = `Okay, I've updated the ${scope} memory file with your modifications.`; return { llmContent: JSON.stringify({ success: true, @@ -337,16 +390,12 @@ export class MemoryTool }; } else { // Use the normal memory entry logic - await MemoryTool.performAddMemoryEntry( - fact, - getGlobalMemoryFilePath(), - { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }, - ); - const successMessage = `Okay, I've remembered that: "${fact}"`; + await MemoryTool.performAddMemoryEntry(fact, memoryFilePath, { + readFile: fs.readFile, + writeFile: fs.writeFile, + mkdir: fs.mkdir, + }); + const successMessage = `Okay, I've remembered that in ${scope} memory: "${fact}"`; return { llmContent: JSON.stringify({ success: true, @@ -359,7 +408,7 @@ export class MemoryTool const errorMessage = error instanceof Error ? error.message : String(error); console.error( - `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`, + `[MemoryTool] Error executing save_memory for fact "${fact}" in ${scope}: ${errorMessage}`, ); return { llmContent: JSON.stringify({ @@ -373,11 +422,13 @@ export class MemoryTool getModifyContext(_abortSignal: AbortSignal): ModifyContext { return { - getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), - getCurrentContent: async (_params: SaveMemoryParams): Promise => - this.readMemoryFileContent(), + getFilePath: (params: SaveMemoryParams) => + getMemoryFilePath(params.scope || 'global'), + getCurrentContent: async (params: SaveMemoryParams): Promise => + this.readMemoryFileContent(params.scope || 'global'), getProposedContent: async (params: SaveMemoryParams): Promise => { - const currentContent = await this.readMemoryFileContent(); + const scope = params.scope || 'global'; + const currentContent = await this.readMemoryFileContent(scope); return this.computeNewContent(currentContent, params.fact); }, createUpdatedParams: ( From 300881405a04a3f7ba1380a156f08d0c961b6e5a Mon Sep 17 00:00:00 2001 From: koalazf99 Date: Sun, 24 Aug 2025 00:31:10 +0800 Subject: [PATCH 02/16] tmp --- package.json | 4 +- .../cli/src/ui/commands/memoryCommand.test.ts | 57 ++++++- packages/cli/src/ui/commands/memoryCommand.ts | 45 +++++- packages/core/src/tools/memoryTool.test.ts | 15 +- packages/core/src/tools/memoryTool.ts | 149 ++++++++++++++---- 5 files changed, 229 insertions(+), 41 deletions(-) diff --git a/package.json b/package.json index 56116b529..4b1cd151c 100644 --- a/package.json +++ b/package.json @@ -80,13 +80,13 @@ "json": "^11.0.0", "lodash": "^4.17.21", "memfs": "^4.17.2", + "mnemonist": "^0.40.3", "mock-fs": "^5.5.0", "prettier": "^3.5.3", "react-devtools-core": "^4.28.5", "tsx": "^4.20.3", "typescript-eslint": "^8.30.1", "vitest": "^3.2.4", - "yargs": "^17.7.2", - "mnemonist": "^0.40.3" + "yargs": "^17.7.2" } } diff --git a/packages/cli/src/ui/commands/memoryCommand.test.ts b/packages/cli/src/ui/commands/memoryCommand.test.ts index 9ee33e69d..77281110f 100644 --- a/packages/cli/src/ui/commands/memoryCommand.test.ts +++ b/packages/cli/src/ui/commands/memoryCommand.test.ts @@ -117,7 +117,7 @@ describe('memoryCommand', () => { expect(result).toEqual({ type: 'message', messageType: 'error', - content: 'Usage: /memory add ', + content: 'Usage: /memory add [--global|--project] ', }); expect(mockContext.ui.addItem).not.toHaveBeenCalled(); @@ -143,6 +143,61 @@ describe('memoryCommand', () => { toolArgs: { fact }, }); }); + + it('should handle --global flag and add scope to tool args', () => { + if (!addCommand.action) throw new Error('Command has no action'); + + const fact = 'remember this globally'; + const result = addCommand.action(mockContext, `--global ${fact}`); + + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: `Attempting to save to memory (global): "${fact}"`, + }, + expect.any(Number), + ); + + expect(result).toEqual({ + type: 'tool', + toolName: 'save_memory', + toolArgs: { fact, scope: 'global' }, + }); + }); + + it('should handle --project flag and add scope to tool args', () => { + if (!addCommand.action) throw new Error('Command has no action'); + + const fact = 'remember this for project'; + const result = addCommand.action(mockContext, `--project ${fact}`); + + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: `Attempting to save to memory (project): "${fact}"`, + }, + expect.any(Number), + ); + + expect(result).toEqual({ + type: 'tool', + toolName: 'save_memory', + toolArgs: { fact, scope: 'project' }, + }); + }); + + it('should return error if flag is provided but no fact follows', () => { + if (!addCommand.action) throw new Error('Command has no action'); + + const result = addCommand.action(mockContext, '--global '); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Usage: /memory add [--global|--project] ', + }); + + expect(mockContext.ui.addItem).not.toHaveBeenCalled(); + }); }); describe('/memory refresh', () => { diff --git a/packages/cli/src/ui/commands/memoryCommand.ts b/packages/cli/src/ui/commands/memoryCommand.ts index dd34d92cf..8b742ef31 100644 --- a/packages/cli/src/ui/commands/memoryCommand.ts +++ b/packages/cli/src/ui/commands/memoryCommand.ts @@ -44,29 +44,66 @@ export const memoryCommand: SlashCommand = { }, { name: 'add', - description: 'Add content to the memory.', + description: + 'Add content to the memory. Use --global for global memory or --project for project memory.', kind: CommandKind.BUILT_IN, action: (context, args): SlashCommandActionReturn | void => { if (!args || args.trim() === '') { return { type: 'message', messageType: 'error', - content: 'Usage: /memory add ', + content: + 'Usage: /memory add [--global|--project] ', }; } + const trimmedArgs = args.trim(); + let scope: 'global' | 'project' | undefined; + let fact: string; + + // Check for scope flags + if (trimmedArgs.startsWith('--global ')) { + scope = 'global'; + fact = trimmedArgs.substring('--global '.length).trim(); + } else if (trimmedArgs.startsWith('--project ')) { + scope = 'project'; + fact = trimmedArgs.substring('--project '.length).trim(); + } else if (trimmedArgs === '--global' || trimmedArgs === '--project') { + // Flag provided but no text after it + return { + type: 'message', + messageType: 'error', + content: + 'Usage: /memory add [--global|--project] ', + }; + } else { + // No scope specified, will be handled by the tool + fact = trimmedArgs; + } + + if (!fact || fact.trim() === '') { + return { + type: 'message', + messageType: 'error', + content: + 'Usage: /memory add [--global|--project] ', + }; + } + + const scopeText = scope ? ` (${scope})` : ''; context.ui.addItem( { type: MessageType.INFO, - text: `Attempting to save to memory: "${args.trim()}"`, + text: `Attempting to save to memory${scopeText}: "${fact}"`, }, Date.now(), ); + const toolArgs = scope ? { fact, scope } : { fact }; return { type: 'tool', toolName: 'save_memory', - toolArgs: { fact: args.trim() }, + toolArgs, }; }, }, diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts index b01471f71..b78ff10bb 100644 --- a/packages/core/src/tools/memoryTool.test.ts +++ b/packages/core/src/tools/memoryTool.test.ts @@ -425,13 +425,16 @@ describe('MemoryTool', () => { expect(result).not.toBe(false); if (result && result.type === 'edit') { - expect(result.title).toBe('Choose Memory Storage Location'); - expect(result.fileName).toBe('Memory Storage Options'); - expect(result.fileDiff).toContain('Choose where to save this memory'); + expect(result.title).toContain('Choose Memory Location'); + expect(result.title).toContain('GLOBAL'); + expect(result.title).toContain('PROJECT'); + expect(result.fileName).toBe('QWEN.md'); expect(result.fileDiff).toContain('Test fact'); - expect(result.fileDiff).toContain('Global:'); - expect(result.fileDiff).toContain('Project:'); - expect(result.originalContent).toBe(''); + expect(result.fileDiff).toContain('--- QWEN.md'); + expect(result.fileDiff).toContain('+++ QWEN.md'); + expect(result.fileDiff).toContain('+- Test fact'); + expect(result.originalContent).toContain('scope: global'); + expect(result.originalContent).toContain('INSTRUCTIONS:'); } }); diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 4b8fe0658..fb1abf334 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -149,7 +149,12 @@ export class MemoryTool } getDescription(params: SaveMemoryParams): string { - const scope = params.scope || 'global'; + if (!params.scope) { + const globalPath = tildeifyPath(getMemoryFilePath('global')); + const projectPath = tildeifyPath(getMemoryFilePath('project')); + return `CHOOSE: ${globalPath} (global) OR ${projectPath} (project)`; + } + const scope = params.scope; const memoryFilePath = getMemoryFilePath(scope); return `in ${tildeifyPath(memoryFilePath)} (${scope})`; } @@ -217,27 +222,43 @@ export class MemoryTool params: SaveMemoryParams, _abortSignal: AbortSignal, ): Promise { - // If scope is not specified, prompt the user to choose + // When scope is not specified, show a choice dialog defaulting to global if (!params.scope) { + // Show preview of what would be added to global by default + const defaultScope = 'global'; + const currentContent = await this.readMemoryFileContent(defaultScope); + const newContent = this.computeNewContent(currentContent, params.fact); + + const fileName = path.basename(getMemoryFilePath(defaultScope)); + const fileDiff = Diff.createPatch( + fileName, + currentContent, + newContent, + 'Current', + 'Proposed (Global)', + DEFAULT_DIFF_OPTIONS, + ); + const globalPath = tildeifyPath(getMemoryFilePath('global')); const projectPath = tildeifyPath(getMemoryFilePath('project')); const confirmationDetails: ToolEditConfirmationDetails = { type: 'edit', - title: `Choose Memory Storage Location`, - fileName: 'Memory Storage Options', - filePath: '', - fileDiff: `Choose where to save this memory:\n\n"${params.fact}"\n\nOptions:\n- Global: ${globalPath} (shared across all projects)\n- Project: ${projectPath} (current project only)\n\nPlease specify the scope parameter: "global" or "project"`, - originalContent: '', - newContent: `Memory to save: ${params.fact}\n\nScope options:\n- global: ${globalPath}\n- project: ${projectPath}`, + title: `Choose Memory Location: GLOBAL (${globalPath}) or PROJECT (${projectPath})`, + fileName, + filePath: getMemoryFilePath(defaultScope), + fileDiff, + originalContent: `scope: global\n\n# INSTRUCTIONS:\n# - Click "Yes" to save to GLOBAL memory: ${globalPath}\n# - Click "Modify with external editor" and change "global" to "project" to save to PROJECT memory: ${projectPath}\n\n${currentContent}`, + newContent: `scope: global\n\n# INSTRUCTIONS:\n# - Click "Yes" to save to GLOBAL memory: ${globalPath}\n# - Click "Modify with external editor" and change "global" to "project" to save to PROJECT memory: ${projectPath}\n\n${newContent}`, onConfirm: async (_outcome: ToolConfirmationOutcome) => { - // This will be handled by the execution flow + // Will be handled in createUpdatedParams }, }; return confirmationDetails; } - const scope = params.scope; + // Only check allowlist when scope is specified + const scope = params.scope!; // We know scope is specified at this point const memoryFilePath = getMemoryFilePath(scope); const allowlistKey = `${memoryFilePath}_${scope}`; @@ -362,17 +383,25 @@ export class MemoryTool }; } - // If scope is not specified, prompt the user to choose - if (!params.scope) { - const errorMessage = - 'Please specify where to save this memory. Use scope parameter: "global" for user-level (~/.qwen/QWEN.md) or "project" for current project (./QWEN.md).'; + // If scope is not specified and user didn't modify content, return error prompting for choice + if (!params.scope && !params.modified_by_user) { + const globalPath = tildeifyPath(getMemoryFilePath('global')); + const projectPath = tildeifyPath(getMemoryFilePath('project')); + const errorMessage = `Please specify where to save this memory: + +Global: ${globalPath} (shared across all projects) +Project: ${projectPath} (current project only)`; + return { - llmContent: JSON.stringify({ success: false, error: errorMessage }), - returnDisplay: `${errorMessage}\n\nGlobal: ${tildeifyPath(getMemoryFilePath('global'))}\nProject: ${tildeifyPath(getMemoryFilePath('project'))}`, + llmContent: JSON.stringify({ + success: false, + error: 'Please specify where to save this memory', + }), + returnDisplay: errorMessage, }; } - const scope = params.scope; + const scope = params.scope || 'global'; const memoryFilePath = getMemoryFilePath(scope); try { @@ -424,24 +453,88 @@ export class MemoryTool getModifyContext(_abortSignal: AbortSignal): ModifyContext { return { - getFilePath: (params: SaveMemoryParams) => - getMemoryFilePath(params.scope || 'global'), - getCurrentContent: async (params: SaveMemoryParams): Promise => - this.readMemoryFileContent(params.scope || 'global'), - getProposedContent: async (params: SaveMemoryParams): Promise => { + getFilePath: (params: SaveMemoryParams) => { + // Determine scope from modified content or default + let scope = params.scope || 'global'; + if (params.modified_content) { + const scopeMatch = params.modified_content.match( + /^scope:\s*(global|project)\s*\n/i, + ); + if (scopeMatch) { + scope = scopeMatch[1].toLowerCase() as 'global' | 'project'; + } + } + return getMemoryFilePath(scope); + }, + getCurrentContent: async (params: SaveMemoryParams): Promise => { + // Check if content starts with scope directive + if (params.modified_content) { + const scopeMatch = params.modified_content.match( + /^scope:\s*(global|project)\s*\n/i, + ); + if (scopeMatch) { + const scope = scopeMatch[1].toLowerCase() as 'global' | 'project'; + const content = await this.readMemoryFileContent(scope); + const globalPath = tildeifyPath(getMemoryFilePath('global')); + const projectPath = tildeifyPath(getMemoryFilePath('project')); + return `scope: ${scope}\n\n# INSTRUCTIONS:\n# - Save as "global" for GLOBAL memory: ${globalPath}\n# - Save as "project" for PROJECT memory: ${projectPath}\n\n${content}`; + } + } const scope = params.scope || 'global'; + const content = await this.readMemoryFileContent(scope); + const globalPath = tildeifyPath(getMemoryFilePath('global')); + const projectPath = tildeifyPath(getMemoryFilePath('project')); + return `scope: ${scope}\n\n# INSTRUCTIONS:\n# - Save as "global" for GLOBAL memory: ${globalPath}\n# - Save as "project" for PROJECT memory: ${projectPath}\n\n${content}`; + }, + getProposedContent: async (params: SaveMemoryParams): Promise => { + let scope = params.scope || 'global'; + + // Check if modified content has scope directive + if (params.modified_content) { + const scopeMatch = params.modified_content.match( + /^scope:\s*(global|project)\s*\n/i, + ); + if (scopeMatch) { + scope = scopeMatch[1].toLowerCase() as 'global' | 'project'; + } + } + const currentContent = await this.readMemoryFileContent(scope); - return this.computeNewContent(currentContent, params.fact); + const newContent = this.computeNewContent(currentContent, params.fact); + const globalPath = tildeifyPath(getMemoryFilePath('global')); + const projectPath = tildeifyPath(getMemoryFilePath('project')); + return `scope: ${scope}\n\n# INSTRUCTIONS:\n# - Save as "global" for GLOBAL memory: ${globalPath}\n# - Save as "project" for PROJECT memory: ${projectPath}\n\n${newContent}`; }, createUpdatedParams: ( _oldContent: string, modifiedProposedContent: string, originalParams: SaveMemoryParams, - ): SaveMemoryParams => ({ - ...originalParams, - modified_by_user: true, - modified_content: modifiedProposedContent, - }), + ): SaveMemoryParams => { + // Parse user's scope choice from modified content + const scopeMatch = modifiedProposedContent.match( + /^scope:\s*(global|project)/i, + ); + const scope = scopeMatch + ? (scopeMatch[1].toLowerCase() as 'global' | 'project') + : 'global'; + + // Strip out the scope directive and instruction lines, keep only the actual memory content + const contentWithoutScope = modifiedProposedContent.replace( + /^scope:\s*(global|project)\s*\n/, + '', + ); + const actualContent = contentWithoutScope + .replace(/^#[^\n]*\n/gm, '') + .replace(/^\s*\n/gm, '') + .trim(); + + return { + ...originalParams, + scope, + modified_by_user: true, + modified_content: actualContent, + }; + }, }; } } From d2d1e748c309b78a1b64b994bf19849fbc4a85f8 Mon Sep 17 00:00:00 2001 From: ajiwo Date: Mon, 25 Aug 2025 13:32:57 +0000 Subject: [PATCH 03/16] feat(tools): Include the new content after edits Introduces a new configuration setting, `readAfterEdit`, which is enabled by default. When this setting is active, the `edit` tool will automatically append the full content of a file to its response message (`llmContent`) after a successful modification or creation. This provides the AI with immediate context of the changes, improving its awareness of the file's current state and reducing the need for a subsequent `read_file` call. Co-authored-by: Qwen-Coder --- docs/cli/configuration.md | 4 + docs/tools/file-system.md | 1 + packages/cli/src/config/config.ts | 1 + packages/cli/src/config/settingsSchema.ts | 10 + packages/core/src/config/config.ts | 7 + packages/core/src/tools/edit.test.ts | 287 ++++++++++++++++++++++ packages/core/src/tools/edit.ts | 7 +- 7 files changed, 316 insertions(+), 1 deletion(-) diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index 5e8c90029..e1760df9a 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -272,6 +272,10 @@ In addition to a project settings file, a project's `.qwen` directory can contai - **Description:** API key for Tavily web search service. Required to enable the `web_search` tool functionality. If not configured, the web search tool will be disabled and skipped. - **Default:** `undefined` (web search disabled) - **Example:** `"tavilyApiKey": "tvly-your-api-key-here"` +- **`readAfterEdit`** (boolean): + - **Description:** Automatically read file content after editing to provide context to the AI. When enabled, the content of a file is included in the LLM response after successful edit operations, enhancing the AI's awareness of the changes made. + - **Default:** `true` + - **Example:** `"readAfterEdit": false` - **`chatCompression`** (object): - **Description:** Controls the settings for chat history compression, both automatic and when manually invoked through the /compress command. diff --git a/docs/tools/file-system.md b/docs/tools/file-system.md index 45c1eaa7b..0181614c8 100644 --- a/docs/tools/file-system.md +++ b/docs/tools/file-system.md @@ -167,6 +167,7 @@ search_file_content(pattern="function", include="*.js", maxResults=10) - `old_string` is found multiple times, and the self-correction mechanism cannot resolve it to a single, unambiguous match. - **Output (`llmContent`):** - On success: `Successfully modified file: /path/to/file.txt (1 replacements).` or `Created new file: /path/to/new_file.txt with provided content.` + - When the `readAfterEdit` configuration is enabled (default), the updated file content is also included in the response to provide context to the AI. - On failure: An error message explaining the reason (e.g., `Failed to edit, 0 occurrences found...`, `Failed to edit, expected 1 occurrences but found 2...`). - **Confirmation:** Yes. Shows a diff of the proposed changes and asks for user approval before writing to the file. diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index d929747ed..1c2252a90 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -584,6 +584,7 @@ export async function loadCliConfig( chatCompression: settings.chatCompression, folderTrustFeature, folderTrust, + readAfterEdit: settings.readAfterEdit ?? true, interactive, trustedFolder, }); diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 4a21ebe5b..bf3a31975 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -540,6 +540,16 @@ export const SETTINGS_SCHEMA = { description: 'The API key for the Tavily API.', showInDialog: false, }, + readAfterEdit: { + type: 'boolean', + label: 'Read After Edit', + category: 'Tools', + requiresRestart: false, + default: true, + description: + 'Automatically read file content after editing to provide context to the AI.', + showInDialog: true, + }, } as const; type InferSettings = { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index d09c24e67..52693e169 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -222,6 +222,7 @@ export interface ConfigParameters { chatCompression?: ChatCompressionSettings; interactive?: boolean; trustedFolder?: boolean; + readAfterEdit?: boolean; } export class Config { @@ -301,6 +302,7 @@ export class Config { private readonly chatCompression: ChatCompressionSettings | undefined; private readonly interactive: boolean; private readonly trustedFolder: boolean | undefined; + private readonly readAfterEdit: boolean; private initialized: boolean = false; constructor(params: ConfigParameters) { @@ -377,6 +379,7 @@ export class Config { this.chatCompression = params.chatCompression; this.interactive = params.interactive ?? false; this.trustedFolder = params.trustedFolder; + this.readAfterEdit = params.readAfterEdit ?? true; // Web search this.tavilyApiKey = params.tavilyApiKey; @@ -803,6 +806,10 @@ export class Config { return this.interactive; } + getReadAfterEdit(): boolean { + return this.readAfterEdit; + } + async getGitService(): Promise { if (!this.gitService) { this.gitService = new GitService(this.targetDir); diff --git a/packages/core/src/tools/edit.test.ts b/packages/core/src/tools/edit.test.ts index b2e31fdda..4c5058c84 100644 --- a/packages/core/src/tools/edit.test.ts +++ b/packages/core/src/tools/edit.test.ts @@ -81,6 +81,7 @@ describe('EditTool', () => { getGeminiMdFileCount: () => 0, setGeminiMdFileCount: vi.fn(), getToolRegistry: () => ({}) as any, // Minimal mock for ToolRegistry + getReadAfterEdit: () => vi.fn().mockReturnValue(true), } as unknown as Config; // Reset mocks before each test @@ -847,3 +848,289 @@ describe('EditTool', () => { }); }); }); + +describe('EditTool - readAfterEdit', () => { + let tool: EditTool; + let tempDir: string; + let rootDir: string; + let mockConfig: Config; + let geminiClient: any; + + beforeEach(() => { + vi.restoreAllMocks(); + tempDir = fs.mkdtempSync( + path.join(os.tmpdir(), 'edit-tool-readafteredit-test-'), + ); + rootDir = path.join(tempDir, 'root'); + fs.mkdirSync(rootDir); + + geminiClient = { + generateJson: mockGenerateJson, + }; + + mockConfig = { + getGeminiClient: vi.fn().mockReturnValue(geminiClient), + getTargetDir: () => rootDir, + getApprovalMode: vi.fn(), + getWorkspaceContext: () => createMockWorkspaceContext(rootDir), + getReadAfterEdit: vi.fn().mockReturnValue(true), // Default to true for these tests + } as unknown as Config; + + (mockConfig.getApprovalMode as Mock).mockClear(); + (mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.DEFAULT); + + mockEnsureCorrectEdit.mockReset(); + mockEnsureCorrectEdit.mockImplementation( + async (_, currentContent, params) => { + let occurrences = 0; + if (params.old_string && currentContent) { + let index = currentContent.indexOf(params.old_string); + while (index !== -1) { + occurrences++; + index = currentContent.indexOf(params.old_string, index + 1); + } + } else if (params.old_string === '') { + occurrences = 0; + } + return Promise.resolve({ params, occurrences }); + }, + ); + + mockGenerateJson.mockReset(); + mockGenerateJson.mockImplementation(async () => Promise.resolve({})); + + tool = new EditTool(mockConfig); + }); + + afterEach(() => { + fs.rmSync(tempDir, { recursive: true, force: true }); + }); + + describe('readAfterEdit enabled', () => { + beforeEach(() => { + (mockConfig.getReadAfterEdit as Mock).mockReturnValue(true); + }); + + it('should include file content in llmContent after successful edit', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'This is the original content.'; + const newContent = 'This is the modified content.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'original', + new_string: 'modified', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).toContain(newContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(newContent); + }); + + it('should include file content in llmContent when creating a new file', async () => { + const newFileName = 'new_file.txt'; + const newFilePath = path.join(rootDir, newFileName); + const fileContent = 'Content for the new file.'; + + const params: EditToolParams = { + file_path: newFilePath, + old_string: '', + new_string: fileContent, + }; + + (mockConfig.getApprovalMode as Mock).mockReturnValueOnce( + ApprovalMode.AUTO_EDIT, + ); + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Created new file/); + expect(result.llmContent).toContain(fileContent); + expect(fs.existsSync(newFilePath)).toBe(true); + expect(fs.readFileSync(newFilePath, 'utf8')).toBe(fileContent); + }); + + it('should include file content in llmContent when replacing multiple occurrences', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'old text old text old text'; + const expectedContent = 'new text new text new text'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + expected_replacements: 3, + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).toContain(expectedContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(expectedContent); + }); + + it('should include file content even when user modified the new_string', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'This is some old text.'; + const newContent = 'This is some new text.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + modified_by_user: true, + }; + + (mockConfig.getApprovalMode as Mock).mockReturnValueOnce( + ApprovalMode.AUTO_EDIT, + ); + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch( + /User modified the `new_string` content/, + ); + expect(result.llmContent).toContain(newContent); + }); + }); + + describe('readAfterEdit disabled', () => { + beforeEach(() => { + (mockConfig.getReadAfterEdit as Mock).mockReturnValue(false); + }); + + it('should NOT include file content in llmContent after successful edit when disabled', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'This is the original content.'; + const newContent = 'This is the modified content.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'original', + new_string: 'modified', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).not.toContain(newContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(newContent); + }); + + it('should NOT include file content when creating a new file and feature is disabled', async () => { + const newFileName = 'new_file.txt'; + const newFilePath = path.join(rootDir, newFileName); + const fileContent = 'Content for the new file.'; + + const params: EditToolParams = { + file_path: newFilePath, + old_string: '', + new_string: fileContent, + }; + + (mockConfig.getApprovalMode as Mock).mockReturnValueOnce( + ApprovalMode.AUTO_EDIT, + ); + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Created new file/); + expect(result.llmContent).not.toContain(fileContent); + expect(fs.existsSync(newFilePath)).toBe(true); + expect(fs.readFileSync(newFilePath, 'utf8')).toBe(fileContent); + }); + + it('should NOT include file content when replacing multiple occurrences and feature is disabled', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'old text old text old text'; + const expectedContent = 'new text new text new text'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + expected_replacements: 3, + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).not.toContain(expectedContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(expectedContent); + }); + }); + + describe('Error cases with readAfterEdit', () => { + beforeEach(() => { + (mockConfig.getReadAfterEdit as Mock).mockReturnValue(true); + }); + + it('should not include file content in llmContent when edit fails', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'Some content.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'nonexistent', + new_string: 'replacement', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch( + /0 occurrences found for old_string in/, + ); + expect(result.llmContent).not.toContain(initialContent); // Should not include file content on error + expect(fs.readFileSync(filePath, 'utf8')).toBe(initialContent); // File should be unchanged + }); + + it('should not include file content in llmContent when file already exists during creation', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const existingContent = 'Existing content'; + + fs.writeFileSync(filePath, existingContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: '', + new_string: 'new content', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/File already exists, cannot create/); + expect(result.llmContent).not.toContain(existingContent); // Should not include file content on error + expect(fs.readFileSync(filePath, 'utf8')).toBe(existingContent); // File should be unchanged + }); + }); +}); diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 8d90dfe45..bbbe34de2 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -384,8 +384,13 @@ class EditToolInvocation implements ToolInvocation { ); } + let llmContent = llmSuccessMessageParts.join(' '); + if (this.config.getReadAfterEdit()) { + llmContent += `\n${editData.newContent}`; + } + return { - llmContent: llmSuccessMessageParts.join(' '), + llmContent, returnDisplay: displayResult, }; } catch (error) { From 380afc53cbf41b136614d0c1af7356c6a055a7ce Mon Sep 17 00:00:00 2001 From: "koalazf.99" Date: Tue, 26 Aug 2025 13:18:11 +0800 Subject: [PATCH 04/16] update: use sub-command to switch between project and global memory ops --- .../cli/src/ui/commands/memoryCommand.test.ts | 4 +- packages/cli/src/ui/commands/memoryCommand.ts | 136 +++++++++++++++++- 2 files changed, 133 insertions(+), 7 deletions(-) diff --git a/packages/cli/src/ui/commands/memoryCommand.test.ts b/packages/cli/src/ui/commands/memoryCommand.test.ts index 77281110f..684f61e9e 100644 --- a/packages/cli/src/ui/commands/memoryCommand.test.ts +++ b/packages/cli/src/ui/commands/memoryCommand.test.ts @@ -132,7 +132,7 @@ describe('memoryCommand', () => { expect(mockContext.ui.addItem).toHaveBeenCalledWith( { type: MessageType.INFO, - text: `Attempting to save to memory: "${fact}"`, + text: `Attempting to save to memory : "${fact}"`, }, expect.any(Number), ); @@ -228,7 +228,7 @@ describe('memoryCommand', () => { mockContext = createMockCommandContext({ services: { - config: Promise.resolve(mockConfig), + config: mockConfig, settings: { merged: { memoryDiscoveryMaxDirs: 1000, diff --git a/packages/cli/src/ui/commands/memoryCommand.ts b/packages/cli/src/ui/commands/memoryCommand.ts index 8b742ef31..2c1bda37e 100644 --- a/packages/cli/src/ui/commands/memoryCommand.ts +++ b/packages/cli/src/ui/commands/memoryCommand.ts @@ -7,7 +7,11 @@ import { getErrorMessage, loadServerHierarchicalMemory, + GEMINI_DIR, } from '@qwen-code/qwen-code-core'; +import path from 'node:path'; +import os from 'os'; +import fs from 'fs/promises'; import { MessageType } from '../types.js'; import { CommandKind, @@ -41,6 +45,71 @@ export const memoryCommand: SlashCommand = { Date.now(), ); }, + subCommands: [ + { + name: '--project', + description: 'Show project-level memory contents.', + kind: CommandKind.BUILT_IN, + action: async (context) => { + const memoryContent = + context.services.config?.getUserMemory() || ''; + const fileCount = + context.services.config?.getGeminiMdFileCount() || 0; + + const messageContent = + memoryContent.length > 0 + ? `Project memory content from ${fileCount} file(s):\n\n---\n${memoryContent}\n---` + : 'Project memory is currently empty.'; + + context.ui.addItem( + { + type: MessageType.INFO, + text: messageContent, + }, + Date.now(), + ); + }, + }, + { + name: '--global', + description: 'Show global memory contents.', + kind: CommandKind.BUILT_IN, + action: async (context) => { + try { + const globalMemoryPath = path.join( + os.homedir(), + GEMINI_DIR, + 'QWEN.md', + ); + const globalMemoryContent = await fs.readFile( + globalMemoryPath, + 'utf-8', + ); + + const messageContent = + globalMemoryContent.trim().length > 0 + ? `Global memory content:\n\n---\n${globalMemoryContent}\n---` + : 'Global memory is currently empty.'; + + context.ui.addItem( + { + type: MessageType.INFO, + text: messageContent, + }, + Date.now(), + ); + } catch (_error) { + context.ui.addItem( + { + type: MessageType.INFO, + text: 'Global memory file not found or is currently empty.', + }, + Date.now(), + ); + } + }, + }, + ], }, { name: 'add', @@ -90,22 +159,79 @@ export const memoryCommand: SlashCommand = { }; } - const scopeText = scope ? ` (${scope})` : ''; + const scopeText = scope ? `(${scope})` : ''; context.ui.addItem( { type: MessageType.INFO, - text: `Attempting to save to memory${scopeText}: "${fact}"`, + text: `Attempting to save to memory ${scopeText}: "${fact}"`, }, Date.now(), ); - const toolArgs = scope ? { fact, scope } : { fact }; return { type: 'tool', toolName: 'save_memory', - toolArgs, + toolArgs: scope ? { fact, scope } : { fact }, }; }, + subCommands: [ + { + name: '--project', + description: 'Add content to project-level memory.', + kind: CommandKind.BUILT_IN, + action: (context, args): SlashCommandActionReturn | void => { + if (!args || args.trim() === '') { + return { + type: 'message', + messageType: 'error', + content: 'Usage: /memory add --project ', + }; + } + + context.ui.addItem( + { + type: MessageType.INFO, + text: `Attempting to save to project memory: "${args.trim()}"`, + }, + Date.now(), + ); + + return { + type: 'tool', + toolName: 'save_memory', + toolArgs: { fact: args.trim(), scope: 'project' }, + }; + }, + }, + { + name: '--global', + description: 'Add content to global memory.', + kind: CommandKind.BUILT_IN, + action: (context, args): SlashCommandActionReturn | void => { + if (!args || args.trim() === '') { + return { + type: 'message', + messageType: 'error', + content: 'Usage: /memory add --global ', + }; + } + + context.ui.addItem( + { + type: MessageType.INFO, + text: `Attempting to save to global memory: "${args.trim()}"`, + }, + Date.now(), + ); + + return { + type: 'tool', + toolName: 'save_memory', + toolArgs: { fact: args.trim(), scope: 'global' }, + }; + }, + }, + ], }, { name: 'refresh', @@ -121,7 +247,7 @@ export const memoryCommand: SlashCommand = { ); try { - const config = await context.services.config; + const config = context.services.config; if (config) { const { memoryContent, fileCount } = await loadServerHierarchicalMemory( From 98fd0f6a89e8e0b03d536325afbe80bc2c48fd8c Mon Sep 17 00:00:00 2001 From: pomelo Date: Tue, 26 Aug 2025 15:49:29 +0800 Subject: [PATCH 05/16] feat: update /docs link (#438) --- packages/cli/src/ui/commands/docsCommand.test.ts | 6 +++--- packages/cli/src/ui/commands/docsCommand.ts | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/cli/src/ui/commands/docsCommand.test.ts b/packages/cli/src/ui/commands/docsCommand.test.ts index 083ef5b26..fa66d2c7b 100644 --- a/packages/cli/src/ui/commands/docsCommand.test.ts +++ b/packages/cli/src/ui/commands/docsCommand.test.ts @@ -35,7 +35,7 @@ describe('docsCommand', () => { throw new Error('docsCommand must have an action.'); } - const docsUrl = 'https://github.com/QwenLM/qwen-code'; + const docsUrl = 'https://qwenlm.github.io/qwen-code-docs/en'; await docsCommand.action(mockContext, ''); @@ -57,7 +57,7 @@ describe('docsCommand', () => { // Simulate a sandbox environment process.env.SANDBOX = 'gemini-sandbox'; - const docsUrl = 'https://github.com/QwenLM/qwen-code'; + const docsUrl = 'https://qwenlm.github.io/qwen-code-docs/en'; await docsCommand.action(mockContext, ''); @@ -80,7 +80,7 @@ describe('docsCommand', () => { // Simulate the specific 'sandbox-exec' environment process.env.SANDBOX = 'sandbox-exec'; - const docsUrl = 'https://github.com/QwenLM/qwen-code'; + const docsUrl = 'https://qwenlm.github.io/qwen-code-docs/en'; await docsCommand.action(mockContext, ''); diff --git a/packages/cli/src/ui/commands/docsCommand.ts b/packages/cli/src/ui/commands/docsCommand.ts index 72d988d51..1f409a353 100644 --- a/packages/cli/src/ui/commands/docsCommand.ts +++ b/packages/cli/src/ui/commands/docsCommand.ts @@ -18,7 +18,7 @@ export const docsCommand: SlashCommand = { description: 'open full Qwen Code documentation in your browser', kind: CommandKind.BUILT_IN, action: async (context: CommandContext): Promise => { - const docsUrl = 'https://github.com/QwenLM/qwen-code'; + const docsUrl = 'https://qwenlm.github.io/qwen-code-docs/en'; if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') { context.ui.addItem( From 1baf5d795f9aa316d14cd76b127d93a2e410efc7 Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Tue, 26 Aug 2025 16:54:52 +0800 Subject: [PATCH 06/16] Fix GitHub Workflows Configuration Issues (#451) --- .github/workflows/build-and-publish-image.yml | 2 +- .github/workflows/ci.yml | 28 +++++++++++++++---- .github/workflows/e2e.yml | 22 +++++++++++++-- .../gemini-automated-issue-triage.yml | 4 +-- .../gemini-scheduled-issue-triage.yml | 6 ++-- .github/workflows/qwen-code-pr-review.yml | 19 ++++--------- 6 files changed, 55 insertions(+), 26 deletions(-) diff --git a/.github/workflows/build-and-publish-image.yml b/.github/workflows/build-and-publish-image.yml index e25051a04..ab8b85fd5 100644 --- a/.github/workflows/build-and-publish-image.yml +++ b/.github/workflows/build-and-publish-image.yml @@ -46,7 +46,7 @@ jobs: - name: 'Log in to the Container registry' if: |- - ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v')) }} + ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v') || github.event.inputs.publish == 'true') }} uses: 'docker/login-action@v3' # ratchet:exclude with: registry: '${{ env.REGISTRY }}' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c9e5e188f..5bf026028 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -95,10 +95,19 @@ jobs: with: node-version-file: '.nvmrc' cache: 'npm' + cache-dependency-path: 'package-lock.json' + registry-url: 'https://registry.npmjs.org/' + + - name: 'Configure npm for rate limiting' + run: |- + npm config set fetch-retry-mintimeout 20000 + npm config set fetch-retry-maxtimeout 120000 + npm config set fetch-retries 5 + npm config set fetch-timeout 300000 - name: 'Install dependencies' run: |- - npm ci + npm ci --prefer-offline --no-audit --progress=false - name: 'Run formatter check' run: |- @@ -273,14 +282,23 @@ jobs: with: node-version: '${{ matrix.node-version }}' cache: 'npm' + cache-dependency-path: 'package-lock.json' + registry-url: 'https://registry.npmjs.org/' - - name: 'Build project' + - name: 'Configure npm for rate limiting' run: |- - npm run build + npm config set fetch-retry-mintimeout 20000 + npm config set fetch-retry-maxtimeout 120000 + npm config set fetch-retries 5 + npm config set fetch-timeout 300000 - - name: 'Install dependencies for testing' + - name: 'Install dependencies' run: |- - npm ci + npm ci --prefer-offline --no-audit --progress=false + + - name: 'Build project' + run: |- + npm run build - name: 'Run tests and generate reports' env: diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 3c778f1a0..077ba92e3 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -28,10 +28,19 @@ jobs: with: node-version: '${{ matrix.node-version }}' cache: 'npm' + cache-dependency-path: 'package-lock.json' + registry-url: 'https://registry.npmjs.org/' + + - name: 'Configure npm for rate limiting' + run: |- + npm config set fetch-retry-mintimeout 20000 + npm config set fetch-retry-maxtimeout 120000 + npm config set fetch-retries 5 + npm config set fetch-timeout 300000 - name: 'Install dependencies' run: |- - npm ci + npm ci --prefer-offline --no-audit --progress=false - name: 'Build project' run: |- @@ -74,10 +83,19 @@ jobs: with: node-version-file: '.nvmrc' cache: 'npm' + cache-dependency-path: 'package-lock.json' + registry-url: 'https://registry.npmjs.org/' + + - name: 'Configure npm for rate limiting' + run: |- + npm config set fetch-retry-mintimeout 20000 + npm config set fetch-retry-maxtimeout 120000 + npm config set fetch-retries 5 + npm config set fetch-timeout 300000 - name: 'Install dependencies' run: |- - npm ci + npm ci --prefer-offline --no-audit --progress=false - name: 'Build project' run: |- diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index a4c5e2518..96d71b7b1 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -48,7 +48,7 @@ jobs: OPENAI_API_KEY: '${{ secrets.OPENAI_API_KEY }}' OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}' OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}' - settings_json: | + settings_json: |- { "maxSessionTurns": 25, "coreTools": [ @@ -68,7 +68,7 @@ jobs: ## Steps 1. Run: `gh label list --repo ${{ github.repository }} --limit 100` to get all available labels. - 2. Use right tool to review the issue title and body provided in the environment variables: "${ISSUE_TITLE}" and "${ISSUE_BODY}". + 2. Use shell command `echo` to check the issue title and body provided in the environment variables: "${ISSUE_TITLE}" and "${ISSUE_BODY}". 3. Ignore any existing priorities or tags on the issue. Just report your findings. 4. Select the most relevant labels from the existing labels, focusing on kind/*, area/*, sub-area/* and priority/*. For area/* and kind/* limit yourself to only the single most applicable label in each case. 6. Apply the selected labels to this issue using: `gh issue edit ${{ github.event.issue.number }} --repo ${{ github.repository }} --add-label "label1,label2"`. diff --git a/.github/workflows/gemini-scheduled-issue-triage.yml b/.github/workflows/gemini-scheduled-issue-triage.yml index 7711a329d..69d7eb71b 100644 --- a/.github/workflows/gemini-scheduled-issue-triage.yml +++ b/.github/workflows/gemini-scheduled-issue-triage.yml @@ -36,7 +36,7 @@ jobs: env: GITHUB_TOKEN: '${{ secrets.GITHUB_TOKEN }}' GITHUB_REPOSITORY: '${{ github.repository }}' - run: | + run: |- echo "🔍 Finding issues without labels..." NO_LABEL_ISSUES=$(gh issue list --repo ${{ github.repository }} --search "is:open is:issue no:label" --json number,title,body) @@ -66,7 +66,7 @@ jobs: OPENAI_API_KEY: '${{ secrets.OPENAI_API_KEY }}' OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}' OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}' - settings_json: | + settings_json: |- { "maxSessionTurns": 25, "coreTools": [ @@ -88,7 +88,7 @@ jobs: ## Steps 1. Run: `gh label list --repo ${{ github.repository }} --limit 100` to get all available labels. - 2. Use right tool to check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues) + 2. Use shell command `echo` to check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues) 3. Review the issue title, body and any comments provided in the environment variables. 4. Ignore any existing priorities or tags on the issue. 5. Select the most relevant labels from the existing labels, focusing on kind/*, area/*, sub-area/* and priority/*. diff --git a/.github/workflows/qwen-code-pr-review.yml b/.github/workflows/qwen-code-pr-review.yml index 1efae14b1..6d7f0934f 100644 --- a/.github/workflows/qwen-code-pr-review.yml +++ b/.github/workflows/qwen-code-pr-review.yml @@ -16,7 +16,7 @@ on: jobs: review-pr: - if: > + if: |- github.event_name == 'workflow_dispatch' || (github.event_name == 'pull_request_target' && github.event.action == 'opened' && @@ -59,7 +59,7 @@ jobs: ${{ github.event_name == 'pull_request_target' || github.event_name == 'workflow_dispatch' }} env: GITHUB_TOKEN: '${{ secrets.GITHUB_TOKEN }}' - run: | + run: |- if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then PR_NUMBER=${{ github.event.inputs.pr_number }} else @@ -82,7 +82,7 @@ jobs: env: GITHUB_TOKEN: '${{ secrets.GITHUB_TOKEN }}' COMMENT_BODY: '${{ github.event.comment.body }}' - run: | + run: |- PR_NUMBER=${{ github.event.issue.number }} echo "pr_number=$PR_NUMBER" >> "$GITHUB_OUTPUT" # Extract additional instructions from comment @@ -110,22 +110,15 @@ jobs: OPENAI_API_KEY: '${{ secrets.OPENAI_API_KEY }}' OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}' OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}' - settings_json: | + settings_json: |- { "coreTools": [ - "run_shell_command(echo)", - "run_shell_command(gh pr view)", - "run_shell_command(gh pr diff)", - "run_shell_command(gh pr comment)", - "run_shell_command(cat)", - "run_shell_command(head)", - "run_shell_command(tail)", - "run_shell_command(grep)", + "run_shell_command", "write_file" ], "sandbox": false } - prompt: | + prompt: |- You are an expert code reviewer. You have access to shell commands to gather PR information and perform the review. IMPORTANT: Use the available shell commands to gather information. Do not ask for information to be provided. From 472df045d366217e23d7cf6eddd0b4690930f353 Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 26 Aug 2025 17:01:09 +0800 Subject: [PATCH 07/16] Fix parallel tool use (#400) --- .../core/src/core/openaiContentGenerator.ts | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index 8f05f2e00..be681b0a5 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -1351,7 +1351,9 @@ export class OpenAIContentGenerator implements ContentGenerator { // Handle text content if (choice.delta?.content) { - parts.push({ text: choice.delta.content }); + if (typeof choice.delta.content === 'string') { + parts.push({ text: choice.delta.content }); + } } // Handle tool calls - only accumulate during streaming, emit when complete @@ -1371,10 +1373,36 @@ export class OpenAIContentGenerator implements ContentGenerator { accumulatedCall.id = toolCall.id; } if (toolCall.function?.name) { + // If this is a new function name, reset the arguments + if (accumulatedCall.name !== toolCall.function.name) { + accumulatedCall.arguments = ''; + } accumulatedCall.name = toolCall.function.name; } if (toolCall.function?.arguments) { - accumulatedCall.arguments += toolCall.function.arguments; + // Check if we already have a complete JSON object + const currentArgs = accumulatedCall.arguments; + const newArgs = toolCall.function.arguments; + + // If current arguments already form a complete JSON and new arguments start a new object, + // this indicates a new tool call with the same name + let shouldReset = false; + if (currentArgs && newArgs.trim().startsWith('{')) { + try { + JSON.parse(currentArgs); + // If we can parse current arguments as complete JSON and new args start with {, + // this is likely a new tool call + shouldReset = true; + } catch { + // Current arguments are not complete JSON, continue accumulating + } + } + + if (shouldReset) { + accumulatedCall.arguments = newArgs; + } else { + accumulatedCall.arguments += newArgs; + } } } } @@ -1562,7 +1590,7 @@ export class OpenAIContentGenerator implements ContentGenerator { } } - messageContent = textParts.join(''); + messageContent = textParts.join('').trimEnd(); } const choice: OpenAIChoice = { From f73d662260c1b94f0405d3ef76d5238a47676412 Mon Sep 17 00:00:00 2001 From: "koalazf.99" Date: Tue, 26 Aug 2025 19:44:02 +0800 Subject: [PATCH 08/16] update: remove context.services.config?.getUserMemory() logic from project level memory show --- packages/cli/src/ui/commands/memoryCommand.ts | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/packages/cli/src/ui/commands/memoryCommand.ts b/packages/cli/src/ui/commands/memoryCommand.ts index 2c1bda37e..ec778f7a4 100644 --- a/packages/cli/src/ui/commands/memoryCommand.ts +++ b/packages/cli/src/ui/commands/memoryCommand.ts @@ -51,23 +51,34 @@ export const memoryCommand: SlashCommand = { description: 'Show project-level memory contents.', kind: CommandKind.BUILT_IN, action: async (context) => { - const memoryContent = - context.services.config?.getUserMemory() || ''; - const fileCount = - context.services.config?.getGeminiMdFileCount() || 0; + try { + const projectMemoryPath = path.join(process.cwd(), 'QWEN.md'); + const memoryContent = await fs.readFile( + projectMemoryPath, + 'utf-8', + ); - const messageContent = - memoryContent.length > 0 - ? `Project memory content from ${fileCount} file(s):\n\n---\n${memoryContent}\n---` - : 'Project memory is currently empty.'; + const messageContent = + memoryContent.trim().length > 0 + ? `Project memory content from ${projectMemoryPath}:\n\n---\n${memoryContent}\n---` + : 'Project memory is currently empty.'; - context.ui.addItem( - { - type: MessageType.INFO, - text: messageContent, - }, - Date.now(), - ); + context.ui.addItem( + { + type: MessageType.INFO, + text: messageContent, + }, + Date.now(), + ); + } catch (_error) { + context.ui.addItem( + { + type: MessageType.INFO, + text: 'Project memory file not found or is currently empty.', + }, + Date.now(), + ); + } }, }, { From 5cd33497738a0d8187af8f1503fab91f6be18638 Mon Sep 17 00:00:00 2001 From: "koalazf.99" Date: Tue, 26 Aug 2025 20:14:10 +0800 Subject: [PATCH 09/16] rename GEMINI_DIR to QWEN_DIR --- packages/cli/src/ui/commands/memoryCommand.ts | 4 ++-- packages/core/src/services/gitService.test.ts | 4 ++-- packages/core/src/services/gitService.ts | 4 ++-- packages/core/src/utils/paths.ts | 8 ++++---- packages/core/src/utils/user_account.ts | 4 ++-- packages/core/src/utils/user_id.ts | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/packages/cli/src/ui/commands/memoryCommand.ts b/packages/cli/src/ui/commands/memoryCommand.ts index ec778f7a4..7e3732523 100644 --- a/packages/cli/src/ui/commands/memoryCommand.ts +++ b/packages/cli/src/ui/commands/memoryCommand.ts @@ -7,7 +7,7 @@ import { getErrorMessage, loadServerHierarchicalMemory, - GEMINI_DIR, + QWEN_DIR, } from '@qwen-code/qwen-code-core'; import path from 'node:path'; import os from 'os'; @@ -89,7 +89,7 @@ export const memoryCommand: SlashCommand = { try { const globalMemoryPath = path.join( os.homedir(), - GEMINI_DIR, + QWEN_DIR, 'QWEN.md', ); const globalMemoryContent = await fs.readFile( diff --git a/packages/core/src/services/gitService.test.ts b/packages/core/src/services/gitService.test.ts index 9820ba5fc..3e1e63f76 100644 --- a/packages/core/src/services/gitService.test.ts +++ b/packages/core/src/services/gitService.test.ts @@ -10,7 +10,7 @@ import * as path from 'path'; import * as fs from 'fs/promises'; import * as os from 'os'; import type { ChildProcess } from 'node:child_process'; -import { getProjectHash, GEMINI_DIR } from '../utils/paths.js'; +import { getProjectHash, QWEN_DIR } from '../utils/paths.js'; const hoistedMockExec = vi.hoisted(() => vi.fn()); vi.mock('node:child_process', () => ({ @@ -157,7 +157,7 @@ describe('GitService', () => { let gitConfigPath: string; beforeEach(() => { - repoDir = path.join(homedir, GEMINI_DIR, 'history', hash); + repoDir = path.join(homedir, QWEN_DIR, 'history', hash); gitConfigPath = path.join(repoDir, '.gitconfig'); }); diff --git a/packages/core/src/services/gitService.ts b/packages/core/src/services/gitService.ts index 8b3fe46ff..30f67cf7c 100644 --- a/packages/core/src/services/gitService.ts +++ b/packages/core/src/services/gitService.ts @@ -10,7 +10,7 @@ import * as os from 'os'; import { isNodeError } from '../utils/errors.js'; import { exec } from 'node:child_process'; import { simpleGit, SimpleGit, CheckRepoActions } from 'simple-git'; -import { getProjectHash, GEMINI_DIR } from '../utils/paths.js'; +import { getProjectHash, QWEN_DIR } from '../utils/paths.js'; export class GitService { private projectRoot: string; @@ -21,7 +21,7 @@ export class GitService { private getHistoryDir(): string { const hash = getProjectHash(this.projectRoot); - return path.join(os.homedir(), GEMINI_DIR, 'history', hash); + return path.join(os.homedir(), QWEN_DIR, 'history', hash); } async initialize(): Promise { diff --git a/packages/core/src/utils/paths.ts b/packages/core/src/utils/paths.ts index 52c578cd0..2b512c47d 100644 --- a/packages/core/src/utils/paths.ts +++ b/packages/core/src/utils/paths.ts @@ -8,7 +8,7 @@ import path from 'node:path'; import os from 'os'; import * as crypto from 'crypto'; -export const GEMINI_DIR = '.qwen'; +export const QWEN_DIR = '.qwen'; export const GOOGLE_ACCOUNTS_FILENAME = 'google_accounts.json'; const TMP_DIR_NAME = 'tmp'; const COMMANDS_DIR_NAME = 'commands'; @@ -181,7 +181,7 @@ export function getProjectHash(projectRoot: string): string { */ export function getProjectTempDir(projectRoot: string): string { const hash = getProjectHash(projectRoot); - return path.join(os.homedir(), GEMINI_DIR, TMP_DIR_NAME, hash); + return path.join(os.homedir(), QWEN_DIR, TMP_DIR_NAME, hash); } /** @@ -189,7 +189,7 @@ export function getProjectTempDir(projectRoot: string): string { * @returns The path to the user's commands directory. */ export function getUserCommandsDir(): string { - return path.join(os.homedir(), GEMINI_DIR, COMMANDS_DIR_NAME); + return path.join(os.homedir(), QWEN_DIR, COMMANDS_DIR_NAME); } /** @@ -198,5 +198,5 @@ export function getUserCommandsDir(): string { * @returns The path to the project's commands directory. */ export function getProjectCommandsDir(projectRoot: string): string { - return path.join(projectRoot, GEMINI_DIR, COMMANDS_DIR_NAME); + return path.join(projectRoot, QWEN_DIR, COMMANDS_DIR_NAME); } diff --git a/packages/core/src/utils/user_account.ts b/packages/core/src/utils/user_account.ts index 6701dfe3d..4f788fcec 100644 --- a/packages/core/src/utils/user_account.ts +++ b/packages/core/src/utils/user_account.ts @@ -7,7 +7,7 @@ import path from 'node:path'; import { promises as fsp, existsSync, readFileSync } from 'node:fs'; import * as os from 'os'; -import { GEMINI_DIR, GOOGLE_ACCOUNTS_FILENAME } from './paths.js'; +import { QWEN_DIR, GOOGLE_ACCOUNTS_FILENAME } from './paths.js'; interface UserAccounts { active: string | null; @@ -15,7 +15,7 @@ interface UserAccounts { } function getGoogleAccountsCachePath(): string { - return path.join(os.homedir(), GEMINI_DIR, GOOGLE_ACCOUNTS_FILENAME); + return path.join(os.homedir(), QWEN_DIR, GOOGLE_ACCOUNTS_FILENAME); } async function readAccounts(filePath: string): Promise { diff --git a/packages/core/src/utils/user_id.ts b/packages/core/src/utils/user_id.ts index 6f16806f6..50e73efe7 100644 --- a/packages/core/src/utils/user_id.ts +++ b/packages/core/src/utils/user_id.ts @@ -8,10 +8,10 @@ import * as os from 'os'; import * as fs from 'fs'; import * as path from 'path'; import { randomUUID } from 'crypto'; -import { GEMINI_DIR } from './paths.js'; +import { QWEN_DIR } from './paths.js'; const homeDir = os.homedir() ?? ''; -const geminiDir = path.join(homeDir, GEMINI_DIR); +const geminiDir = path.join(homeDir, QWEN_DIR); const installationIdFile = path.join(geminiDir, 'installation_id'); function ensureGeminiDirExists() { From 1b38f96eaa18d8b56dd4a61775f64f5dca299bc4 Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Tue, 26 Aug 2025 20:22:44 +0800 Subject: [PATCH 10/16] fix: early stop on invalid tool call (#458) --- packages/cli/src/ui/hooks/useGeminiStream.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2207b9c9c..fcd6ba500 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -644,8 +644,9 @@ export const useGeminiStream = ( options?: { isContinuation: boolean }, prompt_id?: string, ) => { - // Prevent concurrent executions of submitQuery - if (isSubmittingQueryRef.current) { + // Prevent concurrent executions of submitQuery, but allow continuations + // which are part of the same logical flow (tool responses) + if (isSubmittingQueryRef.current && !options?.isContinuation) { return; } From de279b56f357b6197320c52924421c23a158bc98 Mon Sep 17 00:00:00 2001 From: Mingholy Date: Wed, 27 Aug 2025 11:32:48 +0800 Subject: [PATCH 11/16] fix: add explicit is_background param for shell tool (#445) * fix: add explicit background param for shell tool * fix: explicit param schema * docs(shelltool): update `is_background` description --- docs/tools/shell.md | 69 ++++++++-- packages/core/src/tools/shell.test.ts | 180 ++++++++++++++++++++++---- packages/core/src/tools/shell.ts | 45 ++++++- 3 files changed, 254 insertions(+), 40 deletions(-) diff --git a/docs/tools/shell.md b/docs/tools/shell.md index 77f6b026f..9bd82a397 100644 --- a/docs/tools/shell.md +++ b/docs/tools/shell.md @@ -13,10 +13,39 @@ Use `run_shell_command` to interact with the underlying system, run scripts, or - `command` (string, required): The exact shell command to execute. - `description` (string, optional): A brief description of the command's purpose, which will be shown to the user. - `directory` (string, optional): The directory (relative to the project root) in which to execute the command. If not provided, the command runs in the project root. +- `is_background` (boolean, required): Whether to run the command in background. This parameter is required to ensure explicit decision-making about command execution mode. Set to true for long-running processes like development servers, watchers, or daemons that should continue running without blocking further commands. Set to false for one-time commands that should complete before proceeding. ## How to use `run_shell_command` with Qwen Code -When using `run_shell_command`, the command is executed as a subprocess. `run_shell_command` can start background processes using `&`. The tool returns detailed information about the execution, including: +When using `run_shell_command`, the command is executed as a subprocess. You can control whether commands run in background or foreground using the `is_background` parameter, or by explicitly adding `&` to commands. The tool returns detailed information about the execution, including: + +### Required Background Parameter + +The `is_background` parameter is **required** for all command executions. This design ensures that the LLM (and users) must explicitly decide whether each command should run in the background or foreground, promoting intentional and predictable command execution behavior. By making this parameter mandatory, we avoid unintended fallback to foreground execution, which could block subsequent operations when dealing with long-running processes. + +### Background vs Foreground Execution + +The tool intelligently handles background and foreground execution based on your explicit choice: + +**Use background execution (`is_background: true`) for:** + +- Long-running development servers: `npm run start`, `npm run dev`, `yarn dev` +- Build watchers: `npm run watch`, `webpack --watch` +- Database servers: `mongod`, `mysql`, `redis-server` +- Web servers: `python -m http.server`, `php -S localhost:8000` +- Any command expected to run indefinitely until manually stopped + +**Use foreground execution (`is_background: false`) for:** + +- One-time commands: `ls`, `cat`, `grep` +- Build commands: `npm run build`, `make` +- Installation commands: `npm install`, `pip install` +- Git operations: `git commit`, `git push` +- Test runs: `npm test`, `pytest` + +### Execution Information + +The tool returns detailed information about the execution, including: - `Command`: The command that was executed. - `Directory`: The directory where the command was run. @@ -29,28 +58,48 @@ When using `run_shell_command`, the command is executed as a subprocess. `run_sh Usage: +```bash +run_shell_command(command="Your commands.", description="Your description of the command.", directory="Your execution directory.", is_background=false) ``` -run_shell_command(command="Your commands.", description="Your description of the command.", directory="Your execution directory.") -``` + +**Note:** The `is_background` parameter is required and must be explicitly specified for every command execution. ## `run_shell_command` examples List files in the current directory: -``` -run_shell_command(command="ls -la") +```bash +run_shell_command(command="ls -la", is_background=false) ``` Run a script in a specific directory: +```bash +run_shell_command(command="./my_script.sh", directory="scripts", description="Run my custom script", is_background=false) +``` + +Start a background development server (recommended approach): + +```bash +run_shell_command(command="npm run dev", description="Start development server in background", is_background=true) ``` -run_shell_command(command="./my_script.sh", directory="scripts", description="Run my custom script") + +Start a background server (alternative with explicit &): + +```bash +run_shell_command(command="npm run dev &", description="Start development server in background", is_background=false) ``` -Start a background server: +Run a build command in foreground: +```bash +run_shell_command(command="npm run build", description="Build the project", is_background=false) ``` -run_shell_command(command="npm run dev &", description="Start development server in background") + +Start multiple background services: + +```bash +run_shell_command(command="docker-compose up", description="Start all services", is_background=true) ``` ## Important notes @@ -58,7 +107,9 @@ run_shell_command(command="npm run dev &", description="Start development server - **Security:** Be cautious when executing commands, especially those constructed from user input, to prevent security vulnerabilities. - **Interactive commands:** Avoid commands that require interactive user input, as this can cause the tool to hang. Use non-interactive flags if available (e.g., `npm init -y`). - **Error handling:** Check the `Stderr`, `Error`, and `Exit Code` fields to determine if a command executed successfully. -- **Background processes:** When a command is run in the background with `&`, the tool will return immediately and the process will continue to run in the background. The `Background PIDs` field will contain the process ID of the background process. +- **Background processes:** When `is_background=true` or when a command contains `&`, the tool will return immediately and the process will continue to run in the background. The `Background PIDs` field will contain the process ID of the background process. +- **Background execution choices:** The `is_background` parameter is required and provides explicit control over execution mode. You can also add `&` to the command for manual background execution, but the `is_background` parameter must still be specified. The parameter provides clearer intent and automatically handles the background execution setup. +- **Command descriptions:** When using `is_background=true`, the command description will include a `[background]` indicator to clearly show the execution mode. ## Environment Variables diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts index 2bd31dadb..3023b21a1 100644 --- a/packages/core/src/tools/shell.test.ts +++ b/packages/core/src/tools/shell.test.ts @@ -99,24 +99,47 @@ describe('ShellTool', () => { describe('build', () => { it('should return an invocation for a valid command', () => { - const invocation = shellTool.build({ command: 'ls -l' }); + const invocation = shellTool.build({ + command: 'ls -l', + is_background: false, + }); expect(invocation).toBeDefined(); }); it('should throw an error for an empty command', () => { - expect(() => shellTool.build({ command: ' ' })).toThrow( - 'Command cannot be empty.', - ); + expect(() => + shellTool.build({ command: ' ', is_background: false }), + ).toThrow('Command cannot be empty.'); }); it('should throw an error for a non-existent directory', () => { vi.mocked(fs.existsSync).mockReturnValue(false); expect(() => - shellTool.build({ command: 'ls', directory: 'rel/path' }), + shellTool.build({ + command: 'ls', + directory: 'rel/path', + is_background: false, + }), ).toThrow( "Directory 'rel/path' is not a registered workspace directory.", ); }); + + it('should include background indicator in description when is_background is true', () => { + const invocation = shellTool.build({ + command: 'npm start', + is_background: true, + }); + expect(invocation.getDescription()).toContain('[background]'); + }); + + it('should not include background indicator in description when is_background is false', () => { + const invocation = shellTool.build({ + command: 'npm test', + is_background: false, + }); + expect(invocation.getDescription()).not.toContain('[background]'); + }); }); describe('execute', () => { @@ -141,7 +164,10 @@ describe('ShellTool', () => { }; it('should wrap command on linux and parse pgrep output', async () => { - const invocation = shellTool.build({ command: 'my-command &' }); + const invocation = shellTool.build({ + command: 'my-command &', + is_background: false, + }); const promise = invocation.execute(mockAbortSignal); resolveShellExecution({ pid: 54321 }); @@ -162,9 +188,81 @@ describe('ShellTool', () => { expect(vi.mocked(fs.unlinkSync)).toHaveBeenCalledWith(tmpFile); }); + it('should add ampersand to command when is_background is true and command does not end with &', async () => { + const invocation = shellTool.build({ + command: 'npm start', + is_background: true, + }); + const promise = invocation.execute(mockAbortSignal); + resolveShellExecution({ pid: 54321 }); + + vi.mocked(fs.existsSync).mockReturnValue(true); + vi.mocked(fs.readFileSync).mockReturnValue('54321\n54322\n'); + + await promise; + + const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp'); + const wrappedCommand = `{ npm start & }; __code=$?; pgrep -g 0 >${tmpFile} 2>&1; exit $__code;`; + expect(mockShellExecutionService).toHaveBeenCalledWith( + wrappedCommand, + expect.any(String), + expect.any(Function), + mockAbortSignal, + ); + }); + + it('should not add extra ampersand when is_background is true and command already ends with &', async () => { + const invocation = shellTool.build({ + command: 'npm start &', + is_background: true, + }); + const promise = invocation.execute(mockAbortSignal); + resolveShellExecution({ pid: 54321 }); + + vi.mocked(fs.existsSync).mockReturnValue(true); + vi.mocked(fs.readFileSync).mockReturnValue('54321\n54322\n'); + + await promise; + + const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp'); + const wrappedCommand = `{ npm start & }; __code=$?; pgrep -g 0 >${tmpFile} 2>&1; exit $__code;`; + expect(mockShellExecutionService).toHaveBeenCalledWith( + wrappedCommand, + expect.any(String), + expect.any(Function), + mockAbortSignal, + ); + }); + + it('should not add ampersand when is_background is false', async () => { + const invocation = shellTool.build({ + command: 'npm test', + is_background: false, + }); + const promise = invocation.execute(mockAbortSignal); + resolveShellExecution({ pid: 54321 }); + + vi.mocked(fs.existsSync).mockReturnValue(true); + vi.mocked(fs.readFileSync).mockReturnValue('54321\n54322\n'); + + await promise; + + const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp'); + const wrappedCommand = `{ npm test; }; __code=$?; pgrep -g 0 >${tmpFile} 2>&1; exit $__code;`; + expect(mockShellExecutionService).toHaveBeenCalledWith( + wrappedCommand, + expect.any(String), + expect.any(Function), + mockAbortSignal, + ); + }); + it('should not wrap command on windows', async () => { vi.mocked(os.platform).mockReturnValue('win32'); - const invocation = shellTool.build({ command: 'dir' }); + const invocation = shellTool.build({ + command: 'dir', + is_background: false, + }); const promise = invocation.execute(mockAbortSignal); resolveShellExecution({ rawOutput: Buffer.from(''), @@ -188,7 +286,10 @@ describe('ShellTool', () => { it('should format error messages correctly', async () => { const error = new Error('wrapped command failed'); - const invocation = shellTool.build({ command: 'user-command' }); + const invocation = shellTool.build({ + command: 'user-command', + is_background: false, + }); const promise = invocation.execute(mockAbortSignal); resolveShellExecution({ error, @@ -209,15 +310,19 @@ describe('ShellTool', () => { }); it('should throw an error for invalid parameters', () => { - expect(() => shellTool.build({ command: '' })).toThrow( - 'Command cannot be empty.', - ); + expect(() => + shellTool.build({ command: '', is_background: false }), + ).toThrow('Command cannot be empty.'); }); it('should throw an error for invalid directory', () => { vi.mocked(fs.existsSync).mockReturnValue(false); expect(() => - shellTool.build({ command: 'ls', directory: 'nonexistent' }), + shellTool.build({ + command: 'ls', + directory: 'nonexistent', + is_background: false, + }), ).toThrow( `Directory 'nonexistent' is not a registered workspace directory.`, ); @@ -231,7 +336,10 @@ describe('ShellTool', () => { 'summarized output', ); - const invocation = shellTool.build({ command: 'ls' }); + const invocation = shellTool.build({ + command: 'ls', + is_background: false, + }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ output: 'long output', @@ -264,7 +372,10 @@ describe('ShellTool', () => { }); vi.mocked(fs.existsSync).mockReturnValue(true); // Pretend the file exists - const invocation = shellTool.build({ command: 'a-command' }); + const invocation = shellTool.build({ + command: 'a-command', + is_background: false, + }); await expect(invocation.execute(mockAbortSignal)).rejects.toThrow(error); const tmpFile = path.join(os.tmpdir(), 'shell_pgrep_abcdef.tmp'); @@ -282,7 +393,10 @@ describe('ShellTool', () => { }); it('should throttle text output updates', async () => { - const invocation = shellTool.build({ command: 'stream' }); + const invocation = shellTool.build({ + command: 'stream', + is_background: false, + }); const promise = invocation.execute(mockAbortSignal, updateOutputMock); // First chunk, should be throttled. @@ -322,7 +436,10 @@ describe('ShellTool', () => { }); it('should immediately show binary detection message and throttle progress', async () => { - const invocation = shellTool.build({ command: 'cat img' }); + const invocation = shellTool.build({ + command: 'cat img', + is_background: false, + }); const promise = invocation.execute(mockAbortSignal, updateOutputMock); mockShellOutputCallback({ type: 'binary_detected' }); @@ -370,7 +487,7 @@ describe('ShellTool', () => { describe('addCoAuthorToGitCommit', () => { it('should add co-author to git commit with double quotes', async () => { const command = 'git commit -m "Initial commit"'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); // Mock the shell execution to return success @@ -401,7 +518,7 @@ describe('ShellTool', () => { it('should add co-author to git commit with single quotes', async () => { const command = "git commit -m 'Fix bug'"; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -430,7 +547,7 @@ describe('ShellTool', () => { it('should handle git commit with additional flags', async () => { const command = 'git commit -a -m "Add feature"'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -459,7 +576,7 @@ describe('ShellTool', () => { it('should not modify non-git commands', async () => { const command = 'npm install'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -487,7 +604,7 @@ describe('ShellTool', () => { it('should not modify git commands without -m flag', async () => { const command = 'git commit'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -515,7 +632,7 @@ describe('ShellTool', () => { it('should handle git commit with escaped quotes in message', async () => { const command = 'git commit -m "Fix \\"quoted\\" text"'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -551,7 +668,7 @@ describe('ShellTool', () => { }); const command = 'git commit -m "Initial commit"'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -586,7 +703,7 @@ describe('ShellTool', () => { }); const command = 'git commit -m "Test commit"'; - const invocation = shellTool.build({ command }); + const invocation = shellTool.build({ command, is_background: false }); const promise = invocation.execute(mockAbortSignal); resolveExecutionPromise({ @@ -617,7 +734,7 @@ describe('ShellTool', () => { describe('shouldConfirmExecute', () => { it('should request confirmation for a new command and whitelist it on "Always"', async () => { - const params = { command: 'npm install' }; + const params = { command: 'npm install', is_background: false }; const invocation = shellTool.build(params); const confirmation = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -632,7 +749,10 @@ describe('ShellTool', () => { ); // Should now be whitelisted - const secondInvocation = shellTool.build({ command: 'npm test' }); + const secondInvocation = shellTool.build({ + command: 'npm test', + is_background: false, + }); const secondConfirmation = await secondInvocation.shouldConfirmExecute( new AbortController().signal, ); @@ -640,7 +760,9 @@ describe('ShellTool', () => { }); it('should throw an error if validation fails', () => { - expect(() => shellTool.build({ command: '' })).toThrow(); + expect(() => + shellTool.build({ command: '', is_background: false }), + ).toThrow(); }); }); }); @@ -658,6 +780,7 @@ describe('validateToolParams', () => { const result = shellTool.validateToolParams({ command: 'ls', directory: 'test', + is_background: false, }); expect(result).toBeNull(); }); @@ -674,6 +797,7 @@ describe('validateToolParams', () => { const result = shellTool.validateToolParams({ command: 'ls', directory: 'test2', + is_background: false, }); expect(result).toContain('is not a registered workspace directory'); }); @@ -692,6 +816,7 @@ describe('build', () => { const invocation = shellTool.build({ command: 'ls', directory: 'test', + is_background: false, }); expect(invocation).toBeDefined(); }); @@ -709,6 +834,7 @@ describe('build', () => { shellTool.build({ command: 'ls', directory: 'test2', + is_background: false, }), ).toThrow('is not a registered workspace directory'); }); diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index fc91d3f3b..77a8bdfc4 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -37,6 +37,7 @@ export const OUTPUT_UPDATE_INTERVAL_MS = 1000; export interface ShellToolParams { command: string; + is_background: boolean; description?: string; directory?: string; } @@ -60,6 +61,10 @@ class ShellToolInvocation extends BaseToolInvocation< if (this.params.directory) { description += ` [in ${this.params.directory}]`; } + // append background indicator + if (this.params.is_background) { + description += ` [background]`; + } // append optional (description), replacing any line breaks with spaces if (this.params.description) { description += ` (${this.params.description.replace(/\n/g, ' ')})`; @@ -117,12 +122,20 @@ class ShellToolInvocation extends BaseToolInvocation< // Add co-author to git commit commands const processedCommand = this.addCoAuthorToGitCommit(strippedCommand); + const shouldRunInBackground = this.params.is_background; + let finalCommand = processedCommand; + + // If explicitly marked as background and doesn't already end with &, add it + if (shouldRunInBackground && !finalCommand.trim().endsWith('&')) { + finalCommand = finalCommand.trim() + ' &'; + } + // pgrep is not available on Windows, so we can't get background PIDs const commandToExecute = isWindows - ? processedCommand + ? finalCommand : (() => { // wrap command to append subprocess pids (via pgrep) to temporary file - let command = processedCommand.trim(); + let command = finalCommand.trim(); if (!command.endsWith('&')) command += ';'; return `{ ${command} }; __code=$?; pgrep -g 0 >${tempFilePath} 2>&1; exit $__code;`; })(); @@ -343,7 +356,26 @@ export class ShellTool extends BaseDeclarativeTool< super( ShellTool.Name, 'Shell', - `This tool executes a given shell command as \`bash -c \`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`. + `This tool executes a given shell command as \`bash -c \`. + + **Background vs Foreground Execution:** + You should decide whether commands should run in background or foreground based on their nature: + + **Use background execution (is_background: true) for:** + - Long-running development servers: \`npm run start\`, \`npm run dev\`, \`yarn dev\`, \`bun run start\` + - Build watchers: \`npm run watch\`, \`webpack --watch\` + - Database servers: \`mongod\`, \`mysql\`, \`redis-server\` + - Web servers: \`python -m http.server\`, \`php -S localhost:8000\` + - Any command expected to run indefinitely until manually stopped + + **Use foreground execution (is_background: false) for:** + - One-time commands: \`ls\`, \`cat\`, \`grep\` + - Build commands: \`npm run build\`, \`make\` + - Installation commands: \`npm install\`, \`pip install\` + - Git operations: \`git commit\`, \`git push\` + - Test runs: \`npm test\`, \`pytest\` + + Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`. The following information is returned: @@ -364,6 +396,11 @@ export class ShellTool extends BaseDeclarativeTool< type: 'string', description: 'Exact bash command to execute as `bash -c `', }, + is_background: { + type: 'boolean', + description: + 'Whether to run the command in background. Default is false. Set to true for long-running processes like development servers, watchers, or daemons that should continue running without blocking further commands.', + }, description: { type: 'string', description: @@ -375,7 +412,7 @@ export class ShellTool extends BaseDeclarativeTool< '(OPTIONAL) Directory to run the command in, if not the project root directory. Must be relative to the project root directory and must already exist.', }, }, - required: ['command'], + required: ['command', 'is_background'], }, false, // output is not markdown true, // output can be updated From 009e083b73b55e56f1d64e9259d6a8d4d08c25f8 Mon Sep 17 00:00:00 2001 From: Mingholy Date: Wed, 27 Aug 2025 13:17:28 +0800 Subject: [PATCH 12/16] fix: sync token among multiple qwen sessions (#443) * fix: sync token among multiple qwen sessions * fix: adjust cleanup function --- .vscode/launch.json | 19 + packages/cli/src/ui/components/AuthDialog.tsx | 4 - .../cli/src/utils/installationInfo.test.ts | 4 +- packages/cli/src/utils/installationInfo.ts | 7 +- .../cli/src/validateNonInterActiveAuth.ts | 3 - .../src/qwen/qwenContentGenerator.test.ts | 840 +++++++++++- .../core/src/qwen/qwenContentGenerator.ts | 328 ++--- packages/core/src/qwen/qwenOAuth2.test.ts | 1188 +++++++++++++++-- packages/core/src/qwen/qwenOAuth2.ts | 181 ++- .../core/src/qwen/sharedTokenManager.test.ts | 758 +++++++++++ packages/core/src/qwen/sharedTokenManager.ts | 662 +++++++++ 11 files changed, 3562 insertions(+), 432 deletions(-) create mode 100644 packages/core/src/qwen/sharedTokenManager.test.ts create mode 100644 packages/core/src/qwen/sharedTokenManager.ts diff --git a/.vscode/launch.json b/.vscode/launch.json index 72d16ce1c..880de0bbf 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -67,6 +67,19 @@ "console": "integratedTerminal", "internalConsoleOptions": "neverOpen", "skipFiles": ["/**"] + }, + { + "type": "node", + "request": "launch", + "name": "Launch CLI Non-Interactive", + "runtimeExecutable": "npm", + "runtimeArgs": ["run", "start", "--", "-p", "${input:prompt}", "-y"], + "skipFiles": ["/**"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "env": { + "GEMINI_SANDBOX": "false" + } } ], "inputs": [ @@ -75,6 +88,12 @@ "type": "promptString", "description": "Enter the path to the test file (e.g., ${workspaceFolder}/packages/cli/src/ui/components/LoadingIndicator.test.tsx)", "default": "${workspaceFolder}/packages/cli/src/ui/components/LoadingIndicator.test.tsx" + }, + { + "id": "prompt", + "type": "promptString", + "description": "Enter your prompt for non-interactive mode", + "default": "Explain this code" } ] } diff --git a/packages/cli/src/ui/components/AuthDialog.tsx b/packages/cli/src/ui/components/AuthDialog.tsx index f0ff73c5d..ea25351ab 100644 --- a/packages/cli/src/ui/components/AuthDialog.tsx +++ b/packages/cli/src/ui/components/AuthDialog.tsx @@ -69,10 +69,6 @@ export function AuthDialog({ return item.value === AuthType.USE_GEMINI; } - if (process.env.QWEN_OAUTH_TOKEN) { - return item.value === AuthType.QWEN_OAUTH; - } - return item.value === AuthType.LOGIN_WITH_GOOGLE; }), ); diff --git a/packages/cli/src/utils/installationInfo.test.ts b/packages/cli/src/utils/installationInfo.test.ts index 39cae3229..4529e5892 100644 --- a/packages/cli/src/utils/installationInfo.test.ts +++ b/packages/cli/src/utils/installationInfo.test.ts @@ -140,7 +140,7 @@ describe('getInstallationInfo', () => { const info = getInstallationInfo(projectRoot, false); expect(mockedExecSync).toHaveBeenCalledWith( - 'brew list -1 | grep -q "^gemini-cli$"', + 'brew list -1 | grep -q "^qwen-code$"', { stdio: 'ignore' }, ); expect(info.packageManager).toBe(PackageManager.HOMEBREW); @@ -162,7 +162,7 @@ describe('getInstallationInfo', () => { const info = getInstallationInfo(projectRoot, false); expect(mockedExecSync).toHaveBeenCalledWith( - 'brew list -1 | grep -q "^gemini-cli$"', + 'brew list -1 | grep -q "^qwen-code$"', { stdio: 'ignore' }, ); // Should fall back to default global npm diff --git a/packages/cli/src/utils/installationInfo.ts b/packages/cli/src/utils/installationInfo.ts index 8097f56af..61239a781 100644 --- a/packages/cli/src/utils/installationInfo.ts +++ b/packages/cli/src/utils/installationInfo.ts @@ -77,8 +77,8 @@ export function getInstallationInfo( // Check for Homebrew if (process.platform === 'darwin') { try { - // The package name in homebrew is gemini-cli - childProcess.execSync('brew list -1 | grep -q "^gemini-cli$"', { + // We do not support homebrew for now, keep forward compatibility for future use + childProcess.execSync('brew list -1 | grep -q "^qwen-code$"', { stdio: 'ignore', }); return { @@ -88,8 +88,7 @@ export function getInstallationInfo( 'Installed via Homebrew. Please update with "brew upgrade".', }; } catch (_error) { - // Brew is not installed or gemini-cli is not installed via brew. - // Continue to the next check. + // continue to the next check } } diff --git a/packages/cli/src/validateNonInterActiveAuth.ts b/packages/cli/src/validateNonInterActiveAuth.ts index c1e7c586b..63a6166cf 100644 --- a/packages/cli/src/validateNonInterActiveAuth.ts +++ b/packages/cli/src/validateNonInterActiveAuth.ts @@ -21,9 +21,6 @@ function getAuthTypeFromEnv(): AuthType | undefined { if (process.env.OPENAI_API_KEY) { return AuthType.USE_OPENAI; } - if (process.env.QWEN_OAUTH_TOKEN) { - return AuthType.QWEN_OAUTH; - } return undefined; } diff --git a/packages/core/src/qwen/qwenContentGenerator.test.ts b/packages/core/src/qwen/qwenContentGenerator.test.ts index a56aed816..e8dfd3c3a 100644 --- a/packages/core/src/qwen/qwenContentGenerator.test.ts +++ b/packages/core/src/qwen/qwenContentGenerator.test.ts @@ -20,9 +20,117 @@ import { FinishReason, } from '@google/genai'; import { QwenContentGenerator } from './qwenContentGenerator.js'; +import { SharedTokenManager } from './sharedTokenManager.js'; import { Config } from '../config/config.js'; import { AuthType, ContentGeneratorConfig } from '../core/contentGenerator.js'; +// Mock SharedTokenManager +vi.mock('./sharedTokenManager.js', () => ({ + SharedTokenManager: class { + private static instance: unknown = null; + private mockCredentials: QwenCredentials | null = null; + private shouldThrowError: boolean = false; + private errorToThrow: Error | null = null; + + static getInstance() { + if (!this.instance) { + this.instance = new this(); + } + return this.instance; + } + + async getValidCredentials( + qwenClient: IQwenOAuth2Client, + ): Promise { + // If we're configured to throw an error, do so + if (this.shouldThrowError && this.errorToThrow) { + throw this.errorToThrow; + } + + // Try to get credentials from the mock client first to trigger auth errors + try { + const { token } = await qwenClient.getAccessToken(); + if (token) { + const credentials = qwenClient.getCredentials(); + return credentials; + } + } catch (error) { + // If it's an auth error and we need to simulate refresh behavior + const errorMessage = + error instanceof Error + ? error.message.toLowerCase() + : String(error).toLowerCase(); + const errorCode = + (error as { status?: number; code?: number })?.status || + (error as { status?: number; code?: number })?.code; + + const isAuthError = + errorCode === 401 || + errorCode === 403 || + errorMessage.includes('unauthorized') || + errorMessage.includes('forbidden') || + errorMessage.includes('token expired'); + + if (isAuthError) { + // Try to refresh the token through the client + try { + const refreshResult = await qwenClient.refreshAccessToken(); + if (refreshResult && !('error' in refreshResult)) { + // Refresh succeeded, update client credentials and return them + const updatedCredentials = qwenClient.getCredentials(); + return updatedCredentials; + } else { + // Refresh failed, throw appropriate error + throw new Error( + 'Failed to obtain valid Qwen access token. Please re-authenticate.', + ); + } + } catch { + throw new Error( + 'Failed to obtain valid Qwen access token. Please re-authenticate.', + ); + } + } else { + // Re-throw non-auth errors + throw error; + } + } + + // Return mock credentials only if they're set + if (this.mockCredentials && this.mockCredentials.access_token) { + return this.mockCredentials; + } + + // Default fallback for tests that need credentials + return { + access_token: 'valid-token', + refresh_token: 'valid-refresh-token', + resource_url: 'https://test-endpoint.com/v1', + expiry_date: Date.now() + 3600000, + }; + } + + getCurrentCredentials(): QwenCredentials | null { + return this.mockCredentials; + } + + clearCache(): void { + this.mockCredentials = null; + } + + // Helper method for tests to set credentials + setMockCredentials(credentials: QwenCredentials | null): void { + this.mockCredentials = credentials; + } + + // Helper method for tests to simulate errors + setMockError(error: Error | null): void { + this.shouldThrowError = !!error; + this.errorToThrow = error; + } + }, +})); + // Mock the OpenAIContentGenerator parent class vi.mock('../core/openaiContentGenerator.js', () => ({ OpenAIContentGenerator: class { @@ -236,8 +344,10 @@ describe('QwenContentGenerator', () => { it('should refresh token on auth error and retry', async () => { const authError = { status: 401, message: 'Unauthorized' }; - // First call fails with auth error - vi.mocked(mockQwenClient.getAccessToken).mockRejectedValueOnce(authError); + // First call fails with auth error, second call succeeds + vi.mocked(mockQwenClient.getAccessToken) + .mockRejectedValueOnce(authError) + .mockResolvedValueOnce({ token: 'refreshed-token' }); // Refresh succeeds vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({ @@ -247,6 +357,15 @@ describe('QwenContentGenerator', () => { resource_url: 'https://refreshed-endpoint.com', }); + // Set credentials for second call + vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ + access_token: 'refreshed-token', + token_type: 'Bearer', + refresh_token: 'refresh-token', + resource_url: 'https://refreshed-endpoint.com', + expiry_date: Date.now() + 3600000, + }); + const request: GenerateContentParameters = { model: 'qwen-turbo', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], @@ -261,12 +380,62 @@ describe('QwenContentGenerator', () => { expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled(); }); - it('should handle token refresh failure', async () => { - vi.mocked(mockQwenClient.getAccessToken).mockRejectedValue( - new Error('Token expired'), + it('should refresh token on auth error and retry for content stream', async () => { + const authError = { status: 401, message: 'Unauthorized' }; + + // Reset mocks for this test + vi.clearAllMocks(); + + // First call fails with auth error, second call succeeds + vi.mocked(mockQwenClient.getAccessToken) + .mockRejectedValueOnce(authError) + .mockResolvedValueOnce({ token: 'refreshed-stream-token' }); + + // Refresh succeeds + vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({ + access_token: 'refreshed-stream-token', + token_type: 'Bearer', + expires_in: 3600, + resource_url: 'https://refreshed-stream-endpoint.com', + }); + + // Set credentials for second call + vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ + access_token: 'refreshed-stream-token', + token_type: 'Bearer', + refresh_token: 'refresh-token', + resource_url: 'https://refreshed-stream-endpoint.com', + expiry_date: Date.now() + 3600000, + }); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Hello stream' }] }], + }; + + const stream = await qwenContentGenerator.generateContentStream( + request, + 'test-prompt-id', ); - vi.mocked(mockQwenClient.refreshAccessToken).mockRejectedValue( - new Error('Refresh failed'), + const chunks: string[] = []; + + for await (const chunk of stream) { + chunks.push(chunk.text || ''); + } + + expect(chunks).toEqual(['Stream chunk 1', 'Stream chunk 2']); + expect(mockQwenClient.refreshAccessToken).toHaveBeenCalled(); + }); + + it('should handle token refresh failure', async () => { + // Mock the SharedTokenManager to throw an error + const mockTokenManager = SharedTokenManager.getInstance() as unknown as { + setMockError: (error: Error | null) => void; + }; + mockTokenManager.setMockError( + new Error( + 'Failed to obtain valid Qwen access token. Please re-authenticate.', + ), ); const request: GenerateContentParameters = { @@ -279,6 +448,9 @@ describe('QwenContentGenerator', () => { ).rejects.toThrow( 'Failed to obtain valid Qwen access token. Please re-authenticate.', ); + + // Clean up + mockTokenManager.setMockError(null); }); it('should update endpoint when token is refreshed', async () => { @@ -547,10 +719,24 @@ describe('QwenContentGenerator', () => { const originalGenerateContent = parentPrototype.generateContent; parentPrototype.generateContent = mockGenerateContent; - vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ - token: 'initial-token', + // Mock getAccessToken to fail initially, then succeed + let getAccessTokenCallCount = 0; + vi.mocked(mockQwenClient.getAccessToken).mockImplementation(async () => { + getAccessTokenCallCount++; + if (getAccessTokenCallCount <= 2) { + throw authError; // Fail on first two calls (initial + retry) + } + return { token: 'refreshed-token' }; // Succeed after refresh }); - vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials); + + vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ + access_token: 'refreshed-token', + token_type: 'Bearer', + refresh_token: 'refresh-token', + resource_url: 'https://test-endpoint.com', + expiry_date: Date.now() + 3600000, + }); + vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({ access_token: 'refreshed-token', token_type: 'Bearer', @@ -637,31 +823,16 @@ describe('QwenContentGenerator', () => { expect(qwenContentGenerator.getCurrentToken()).toBe('cached-token'); }); - it('should clear token and endpoint on clearToken()', () => { - // Simulate having cached values + it('should clear token on clearToken()', () => { + // Simulate having cached token value const qwenInstance = qwenContentGenerator as unknown as { currentToken: string; - currentEndpoint: string; - refreshPromise: Promise; }; qwenInstance.currentToken = 'cached-token'; - qwenInstance.currentEndpoint = 'https://cached-endpoint.com'; - qwenInstance.refreshPromise = Promise.resolve('token'); qwenContentGenerator.clearToken(); expect(qwenContentGenerator.getCurrentToken()).toBeNull(); - expect( - (qwenContentGenerator as unknown as { currentEndpoint: string | null }) - .currentEndpoint, - ).toBeNull(); - expect( - ( - qwenContentGenerator as unknown as { - refreshPromise: Promise | null; - } - ).refreshPromise, - ).toBeNull(); }); it('should handle concurrent token refresh requests', async () => { @@ -674,9 +845,7 @@ describe('QwenContentGenerator', () => { const authError = { status: 401, message: 'Unauthorized' }; let parentCallCount = 0; - vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ - token: 'initial-token', - }); + vi.mocked(mockQwenClient.getAccessToken).mockRejectedValue(authError); vi.mocked(mockQwenClient.getCredentials).mockReturnValue(mockCredentials); vi.mocked(mockQwenClient.refreshAccessToken).mockImplementation( @@ -725,6 +894,7 @@ describe('QwenContentGenerator', () => { // The main test is that all requests succeed without crashing expect(results).toHaveLength(3); + // With our new implementation through SharedTokenManager, refresh should still be called expect(refreshCallCount).toBeGreaterThanOrEqual(1); // Restore original method @@ -796,13 +966,24 @@ describe('QwenContentGenerator', () => { ); parentPrototype.generateContent = mockGenerateContent; - vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ - token: 'initial-token', + // Mock getAccessToken to fail initially, then succeed + let getAccessTokenCallCount = 0; + vi.mocked(mockQwenClient.getAccessToken).mockImplementation(async () => { + getAccessTokenCallCount++; + if (getAccessTokenCallCount <= 2) { + throw authError; // Fail on first two calls (initial + retry) + } + return { token: 'new-token' }; // Succeed after refresh }); + vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ - ...mockCredentials, - resource_url: 'custom-endpoint.com', + access_token: 'new-token', + token_type: 'Bearer', + refresh_token: 'refresh-token', + resource_url: 'https://new-endpoint.com', + expiry_date: Date.now() + 7200000, }); + vi.mocked(mockQwenClient.refreshAccessToken).mockResolvedValue({ access_token: 'new-token', token_type: 'Bearer', @@ -826,4 +1007,595 @@ describe('QwenContentGenerator', () => { expect(callCount).toBe(2); // Initial call + retry }); }); + + describe('SharedTokenManager Integration', () => { + it('should use SharedTokenManager to get valid credentials', async () => { + const mockTokenManager = { + getValidCredentials: vi.fn().mockResolvedValue({ + access_token: 'manager-token', + resource_url: 'https://manager-endpoint.com', + }), + getCurrentCredentials: vi.fn(), + clearCache: vi.fn(), + }; + + // Mock the SharedTokenManager.getInstance() + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + // Create new instance to pick up the mock + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + + await newGenerator.generateContent(request, 'test-prompt-id'); + + expect(mockTokenManager.getValidCredentials).toHaveBeenCalledWith( + mockQwenClient, + ); + + // Restore original + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should handle SharedTokenManager errors gracefully', async () => { + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Token manager error')), + getCurrentCredentials: vi.fn(), + clearCache: vi.fn(), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + + await expect( + newGenerator.generateContent(request, 'test-prompt-id'), + ).rejects.toThrow('Failed to obtain valid Qwen access token'); + + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should handle missing access token from credentials', async () => { + const mockTokenManager = { + getValidCredentials: vi.fn().mockResolvedValue({ + access_token: undefined, + resource_url: 'https://test-endpoint.com', + }), + getCurrentCredentials: vi.fn(), + clearCache: vi.fn(), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + + await expect( + newGenerator.generateContent(request, 'test-prompt-id'), + ).rejects.toThrow('Failed to obtain valid Qwen access token'); + + SharedTokenManager.getInstance = originalGetInstance; + }); + }); + + describe('getCurrentEndpoint Method', () => { + it('should handle URLs with custom ports', () => { + const endpoints = [ + { input: 'localhost:8080', expected: 'https://localhost:8080/v1' }, + { + input: 'http://localhost:8080', + expected: 'http://localhost:8080/v1', + }, + { + input: 'https://api.example.com:443', + expected: 'https://api.example.com:443/v1', + }, + { + input: 'api.example.com:9000/api', + expected: 'https://api.example.com:9000/api/v1', + }, + ]; + + endpoints.forEach(({ input, expected }) => { + vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ + token: 'test-token', + }); + vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ + ...mockCredentials, + resource_url: input, + }); + + const generator = qwenContentGenerator as unknown as { + getCurrentEndpoint: (resourceUrl?: string) => string; + }; + + expect(generator.getCurrentEndpoint(input)).toBe(expected); + }); + }); + + it('should handle URLs with existing paths', () => { + const endpoints = [ + { + input: 'https://api.example.com/api', + expected: 'https://api.example.com/api/v1', + }, + { + input: 'api.example.com/api/v2', + expected: 'https://api.example.com/api/v2/v1', + }, + { + input: 'https://api.example.com/api/v1', + expected: 'https://api.example.com/api/v1', + }, + ]; + + endpoints.forEach(({ input, expected }) => { + const generator = qwenContentGenerator as unknown as { + getCurrentEndpoint: (resourceUrl?: string) => string; + }; + + expect(generator.getCurrentEndpoint(input)).toBe(expected); + }); + }); + + it('should handle undefined resource URL', () => { + const generator = qwenContentGenerator as unknown as { + getCurrentEndpoint: (resourceUrl?: string) => string; + }; + + expect(generator.getCurrentEndpoint(undefined)).toBe( + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + }); + + it('should handle empty resource URL', () => { + const generator = qwenContentGenerator as unknown as { + getCurrentEndpoint: (resourceUrl?: string) => string; + }; + + // Empty string should fall back to default endpoint + expect(generator.getCurrentEndpoint('')).toBe( + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + }); + }); + + describe('isAuthError Method Enhanced', () => { + it('should identify auth errors by numeric status codes', () => { + const authErrors = [ + { code: 401 }, + { status: 403 }, + { code: '401' }, // String status codes + { status: '403' }, + ]; + + authErrors.forEach((error) => { + const generator = qwenContentGenerator as unknown as { + isAuthError: (error: unknown) => boolean; + }; + expect(generator.isAuthError(error)).toBe(true); + }); + + // 400 is not typically an auth error, it's bad request + const nonAuthError = { status: 400 }; + const generator = qwenContentGenerator as unknown as { + isAuthError: (error: unknown) => boolean; + }; + expect(generator.isAuthError(nonAuthError)).toBe(false); + }); + + it('should identify auth errors by message content variations', () => { + const authMessages = [ + 'UNAUTHORIZED access', + 'Access is FORBIDDEN', + 'Invalid API Key provided', + 'Invalid Access Token', + 'Token has Expired', + 'Authentication Required', + 'Access Denied by server', + 'The token has expired and needs refresh', + 'Bearer token expired', + ]; + + authMessages.forEach((message) => { + const error = new Error(message); + const generator = qwenContentGenerator as unknown as { + isAuthError: (error: unknown) => boolean; + }; + expect(generator.isAuthError(error)).toBe(true); + }); + }); + + it('should not identify non-auth errors', () => { + const nonAuthErrors = [ + new Error('Network timeout'), + new Error('Rate limit exceeded'), + { status: 500 }, + { code: 429 }, + 'Internal server error', + null, + undefined, + '', + { status: 200 }, + new Error('Model not found'), + ]; + + nonAuthErrors.forEach((error) => { + const generator = qwenContentGenerator as unknown as { + isAuthError: (error: unknown) => boolean; + }; + expect(generator.isAuthError(error)).toBe(false); + }); + }); + + it('should handle complex error objects', () => { + const complexErrors = [ + { error: { status: 401, message: 'Unauthorized' } }, + { response: { status: 403 } }, + { details: { code: 401 } }, + ]; + + // These should not be identified as auth errors because the method only looks at top-level properties + complexErrors.forEach((error) => { + const generator = qwenContentGenerator as unknown as { + isAuthError: (error: unknown) => boolean; + }; + expect(generator.isAuthError(error)).toBe(false); + }); + }); + }); + + describe('Stream Error Handling', () => { + it('should restore credentials when stream generation fails', async () => { + const client = ( + qwenContentGenerator as unknown as { + client: { apiKey: string; baseURL: string }; + } + ).client; + const originalApiKey = client.apiKey; + const originalBaseURL = client.baseURL; + + vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ + token: 'stream-token', + }); + vi.mocked(mockQwenClient.getCredentials).mockReturnValue({ + ...mockCredentials, + resource_url: 'https://stream-endpoint.com', + }); + + // Mock parent method to throw error + const parentPrototype = Object.getPrototypeOf( + Object.getPrototypeOf(qwenContentGenerator), + ); + const originalGenerateContentStream = + parentPrototype.generateContentStream; + parentPrototype.generateContentStream = vi + .fn() + .mockRejectedValue(new Error('Stream error')); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Stream test' }] }], + }; + + try { + await qwenContentGenerator.generateContentStream( + request, + 'test-prompt-id', + ); + } catch (error) { + expect(error).toBeInstanceOf(Error); + } + + // Credentials should be restored even on error + expect(client.apiKey).toBe(originalApiKey); + expect(client.baseURL).toBe(originalBaseURL); + + // Restore original method + parentPrototype.generateContentStream = originalGenerateContentStream; + }); + + it('should not restore credentials in finally block for successful streams', async () => { + const client = ( + qwenContentGenerator as unknown as { + client: { apiKey: string; baseURL: string }; + } + ).client; + + // Set up the mock to return stream credentials + const streamCredentials = { + access_token: 'stream-token', + refresh_token: 'stream-refresh-token', + resource_url: 'https://stream-endpoint.com', + expiry_date: Date.now() + 3600000, + }; + + vi.mocked(mockQwenClient.getAccessToken).mockResolvedValue({ + token: 'stream-token', + }); + vi.mocked(mockQwenClient.getCredentials).mockReturnValue( + streamCredentials, + ); + + // Set the SharedTokenManager mock to return stream credentials + const mockTokenManager = SharedTokenManager.getInstance() as unknown as { + setMockCredentials: (credentials: QwenCredentials | null) => void; + }; + mockTokenManager.setMockCredentials(streamCredentials); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Stream test' }] }], + }; + + const stream = await qwenContentGenerator.generateContentStream( + request, + 'test-prompt-id', + ); + + // After successful stream creation, credentials should still be set for the stream + expect(client.apiKey).toBe('stream-token'); + expect(client.baseURL).toBe('https://stream-endpoint.com/v1'); + + // Consume the stream + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + + expect(chunks).toHaveLength(2); + + // Clean up + mockTokenManager.setMockCredentials(null); + }); + }); + + describe('Token and Endpoint Management', () => { + it('should get current token from SharedTokenManager', () => { + const mockTokenManager = { + getCurrentCredentials: vi.fn().mockReturnValue({ + access_token: 'current-token', + }), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + expect(newGenerator.getCurrentToken()).toBe('current-token'); + + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should return null when no credentials available', () => { + const mockTokenManager = { + getCurrentCredentials: vi.fn().mockReturnValue(null), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + expect(newGenerator.getCurrentToken()).toBeNull(); + + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should return null when credentials have no access token', () => { + const mockTokenManager = { + getCurrentCredentials: vi.fn().mockReturnValue({ + access_token: undefined, + }), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + expect(newGenerator.getCurrentToken()).toBeNull(); + + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should clear token through SharedTokenManager', () => { + const mockTokenManager = { + clearCache: vi.fn(), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + newGenerator.clearToken(); + + expect(mockTokenManager.clearCache).toHaveBeenCalled(); + + SharedTokenManager.getInstance = originalGetInstance; + }); + }); + + describe('Constructor and Initialization', () => { + it('should initialize with default base URL', () => { + const generator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const client = (generator as unknown as { client: { baseURL: string } }) + .client; + expect(client.baseURL).toBe( + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + }); + + it('should get SharedTokenManager instance', () => { + const generator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const sharedManager = ( + generator as unknown as { sharedManager: SharedTokenManager } + ).sharedManager; + expect(sharedManager).toBeDefined(); + }); + }); + + describe('Edge Cases and Error Conditions', () => { + it('should handle token retrieval with warning when SharedTokenManager fails', async () => { + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Internal token manager error')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const request: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + + await expect( + newGenerator.generateContent(request, 'test-prompt-id'), + ).rejects.toThrow('Failed to obtain valid Qwen access token'); + + expect(consoleSpy).toHaveBeenCalledWith( + 'Failed to get token from shared manager:', + expect.any(Error), + ); + + consoleSpy.mockRestore(); + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should handle all method types with token failure', async () => { + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Token error')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const newGenerator = new QwenContentGenerator( + mockQwenClient, + { model: 'qwen-turbo', authType: AuthType.QWEN_OAUTH }, + mockConfig, + ); + + const generateRequest: GenerateContentParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + + const countRequest: CountTokensParameters = { + model: 'qwen-turbo', + contents: [{ role: 'user', parts: [{ text: 'Count' }] }], + }; + + const embedRequest: EmbedContentParameters = { + model: 'qwen-turbo', + contents: [{ parts: [{ text: 'Embed' }] }], + }; + + // All methods should fail with the same error + await expect( + newGenerator.generateContent(generateRequest, 'test-id'), + ).rejects.toThrow('Failed to obtain valid Qwen access token'); + + await expect( + newGenerator.generateContentStream(generateRequest, 'test-id'), + ).rejects.toThrow('Failed to obtain valid Qwen access token'); + + await expect(newGenerator.countTokens(countRequest)).rejects.toThrow( + 'Failed to obtain valid Qwen access token', + ); + + await expect(newGenerator.embedContent(embedRequest)).rejects.toThrow( + 'Failed to obtain valid Qwen access token', + ); + + SharedTokenManager.getInstance = originalGetInstance; + }); + }); }); diff --git a/packages/core/src/qwen/qwenContentGenerator.ts b/packages/core/src/qwen/qwenContentGenerator.ts index 4180efa2d..9ef894972 100644 --- a/packages/core/src/qwen/qwenContentGenerator.ts +++ b/packages/core/src/qwen/qwenContentGenerator.ts @@ -5,12 +5,8 @@ */ import { OpenAIContentGenerator } from '../core/openaiContentGenerator.js'; -import { - IQwenOAuth2Client, - type TokenRefreshData, - type ErrorData, - isErrorResponse, -} from './qwenOAuth2.js'; +import { IQwenOAuth2Client } from './qwenOAuth2.js'; +import { SharedTokenManager } from './sharedTokenManager.js'; import { Config } from '../config/config.js'; import { GenerateContentParameters, @@ -31,11 +27,8 @@ const DEFAULT_QWEN_BASE_URL = */ export class QwenContentGenerator extends OpenAIContentGenerator { private qwenClient: IQwenOAuth2Client; - - // Token management (integrated from QwenTokenManager) - private currentToken: string | null = null; - private currentEndpoint: string | null = null; - private refreshPromise: Promise | null = null; + private sharedManager: SharedTokenManager; + private currentToken?: string; constructor( qwenClient: IQwenOAuth2Client, @@ -45,6 +38,7 @@ export class QwenContentGenerator extends OpenAIContentGenerator { // Initialize with empty API key, we'll override it dynamically super(contentGeneratorConfig, config); this.qwenClient = qwenClient; + this.sharedManager = SharedTokenManager.getInstance(); // Set default base URL, will be updated dynamically this.client.baseURL = DEFAULT_QWEN_BASE_URL; @@ -53,8 +47,8 @@ export class QwenContentGenerator extends OpenAIContentGenerator { /** * Get the current endpoint URL with proper protocol and /v1 suffix */ - private getCurrentEndpoint(): string { - const baseEndpoint = this.currentEndpoint || DEFAULT_QWEN_BASE_URL; + private getCurrentEndpoint(resourceUrl?: string): string { + const baseEndpoint = resourceUrl || DEFAULT_QWEN_BASE_URL; const suffix = '/v1'; // Normalize the URL: add protocol if missing, ensure /v1 suffix @@ -79,237 +73,149 @@ export class QwenContentGenerator extends OpenAIContentGenerator { } /** - * Override to use dynamic token and endpoint + * Get valid token and endpoint using the shared token manager */ - override async generateContent( - request: GenerateContentParameters, - userPromptId: string, - ): Promise { - return this.withValidToken(async (token) => { - // Temporarily update the API key and base URL - const originalApiKey = this.client.apiKey; - const originalBaseURL = this.client.baseURL; - this.client.apiKey = token; - this.client.baseURL = this.getCurrentEndpoint(); + private async getValidToken(): Promise<{ token: string; endpoint: string }> { + try { + // Use SharedTokenManager for consistent token/endpoint pairing and automatic refresh + const credentials = await this.sharedManager.getValidCredentials( + this.qwenClient, + ); - try { - return await super.generateContent(request, userPromptId); - } finally { - // Restore original values - this.client.apiKey = originalApiKey; - this.client.baseURL = originalBaseURL; + if (!credentials.access_token) { + throw new Error('No access token available'); } - }); - } - /** - * Override to use dynamic token and endpoint - */ - override async generateContentStream( - request: GenerateContentParameters, - userPromptId: string, - ): Promise> { - return this.withValidTokenForStream(async (token) => { - // Update the API key and base URL before streaming - const originalApiKey = this.client.apiKey; - const originalBaseURL = this.client.baseURL; - this.client.apiKey = token; - this.client.baseURL = this.getCurrentEndpoint(); - - try { - return await super.generateContentStream(request, userPromptId); - } catch (error) { - // Restore original values on error - this.client.apiKey = originalApiKey; - this.client.baseURL = originalBaseURL; + return { + token: credentials.access_token, + endpoint: this.getCurrentEndpoint(credentials.resource_url), + }; + } catch (error) { + // Propagate auth errors as-is for retry logic + if (this.isAuthError(error)) { throw error; } - // Note: We don't restore the values in finally for streaming because - // the generator may continue to be used after this method returns - }); + console.warn('Failed to get token from shared manager:', error); + throw new Error( + 'Failed to obtain valid Qwen access token. Please re-authenticate.', + ); + } } /** - * Override to use dynamic token and endpoint + * Execute an operation with automatic credential management and retry logic. + * This method handles: + * - Dynamic token and endpoint retrieval + * - Temporary client configuration updates + * - Automatic restoration of original configuration + * - Retry logic on authentication errors with token refresh + * + * @param operation - The operation to execute with updated client configuration + * @param restoreOnCompletion - Whether to restore original config after operation completes + * @returns The result of the operation */ - override async countTokens( - request: CountTokensParameters, - ): Promise { - return this.withValidToken(async (token) => { + private async executeWithCredentialManagement( + operation: () => Promise, + restoreOnCompletion: boolean = true, + ): Promise { + // Attempt the operation with credential management and retry logic + const attemptOperation = async (): Promise => { + const { token, endpoint } = await this.getValidToken(); + + // Store original configuration const originalApiKey = this.client.apiKey; const originalBaseURL = this.client.baseURL; + + // Apply dynamic configuration this.client.apiKey = token; - this.client.baseURL = this.getCurrentEndpoint(); + this.client.baseURL = endpoint; try { - return await super.countTokens(request); - } finally { - this.client.apiKey = originalApiKey; - this.client.baseURL = originalBaseURL; - } - }); - } + const result = await operation(); - /** - * Override to use dynamic token and endpoint - */ - override async embedContent( - request: EmbedContentParameters, - ): Promise { - return this.withValidToken(async (token) => { - const originalApiKey = this.client.apiKey; - const originalBaseURL = this.client.baseURL; - this.client.apiKey = token; - this.client.baseURL = this.getCurrentEndpoint(); + // For streaming operations, we may need to keep the configuration active + if (restoreOnCompletion) { + this.client.apiKey = originalApiKey; + this.client.baseURL = originalBaseURL; + } - try { - return await super.embedContent(request); - } finally { + return result; + } catch (error) { + // Always restore on error this.client.apiKey = originalApiKey; this.client.baseURL = originalBaseURL; + throw error; } - }); - } - - /** - * Execute operation with a valid token, with retry on auth failure - */ - private async withValidToken( - operation: (token: string) => Promise, - ): Promise { - const token = await this.getTokenWithRetry(); + }; + // Execute with retry logic for auth errors try { - return await operation(token); + return await attemptOperation(); } catch (error) { - // Check if this is an authentication error if (this.isAuthError(error)) { - // Refresh token and retry once silently - const newToken = await this.refreshToken(); - return await operation(newToken); + try { + // Use SharedTokenManager to properly refresh and persist the token + // This ensures the refreshed token is saved to oauth_creds.json + await this.sharedManager.getValidCredentials(this.qwenClient, true); + // Retry the operation once with fresh credentials + return await attemptOperation(); + } catch (_refreshError) { + throw new Error( + 'Failed to obtain valid Qwen access token. Please re-authenticate.', + ); + } } - throw error; } } /** - * Execute operation with a valid token for streaming, with retry on auth failure + * Override to use dynamic token and endpoint with automatic retry */ - private async withValidTokenForStream( - operation: (token: string) => Promise, - ): Promise { - const token = await this.getTokenWithRetry(); - - try { - return await operation(token); - } catch (error) { - // Check if this is an authentication error - if (this.isAuthError(error)) { - // Refresh token and retry once silently - const newToken = await this.refreshToken(); - return await operation(newToken); - } - - throw error; - } + override async generateContent( + request: GenerateContentParameters, + userPromptId: string, + ): Promise { + return this.executeWithCredentialManagement(() => + super.generateContent(request, userPromptId), + ); } /** - * Get token with retry logic + * Override to use dynamic token and endpoint with automatic retry. + * Note: For streaming, the client configuration is not restored immediately + * since the generator may continue to be used after this method returns. */ - private async getTokenWithRetry(): Promise { - try { - return await this.getValidToken(); - } catch (error) { - console.error('Failed to get valid token:', error); - throw new Error( - 'Failed to obtain valid Qwen access token. Please re-authenticate.', - ); - } + override async generateContentStream( + request: GenerateContentParameters, + userPromptId: string, + ): Promise> { + return this.executeWithCredentialManagement( + () => super.generateContentStream(request, userPromptId), + false, // Don't restore immediately for streaming + ); } - // Token management methods (integrated from QwenTokenManager) - /** - * Get a valid access token, refreshing if necessary + * Override to use dynamic token and endpoint with automatic retry */ - private async getValidToken(): Promise { - // If there's already a refresh in progress, wait for it - if (this.refreshPromise) { - return this.refreshPromise; - } - - try { - const { token } = await this.qwenClient.getAccessToken(); - if (token) { - this.currentToken = token; - // Also update endpoint from current credentials - const credentials = this.qwenClient.getCredentials(); - if (credentials.resource_url) { - this.currentEndpoint = credentials.resource_url; - } - return token; - } - } catch (error) { - console.warn('Failed to get access token, attempting refresh:', error); - } - - // Start a new refresh operation - this.refreshPromise = this.performTokenRefresh(); - - try { - const newToken = await this.refreshPromise; - return newToken; - } finally { - this.refreshPromise = null; - } + override async countTokens( + request: CountTokensParameters, + ): Promise { + return this.executeWithCredentialManagement(() => + super.countTokens(request), + ); } /** - * Force refresh the access token + * Override to use dynamic token and endpoint with automatic retry */ - private async refreshToken(): Promise { - this.refreshPromise = this.performTokenRefresh(); - - try { - const newToken = await this.refreshPromise; - return newToken; - } finally { - this.refreshPromise = null; - } - } - - private async performTokenRefresh(): Promise { - try { - const response = await this.qwenClient.refreshAccessToken(); - - if (isErrorResponse(response)) { - const errorData = response as ErrorData; - throw new Error( - `${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`, - ); - } - - const tokenData = response as TokenRefreshData; - - if (!tokenData.access_token) { - throw new Error('Failed to refresh access token: no token returned'); - } - - this.currentToken = tokenData.access_token; - - // Update endpoint if provided - if (tokenData.resource_url) { - this.currentEndpoint = tokenData.resource_url; - } - - return tokenData.access_token; - } catch (error) { - throw new Error( - `${error instanceof Error ? error.message : String(error)}`, - ); - } + override async embedContent( + request: EmbedContentParameters, + ): Promise { + return this.executeWithCredentialManagement(() => + super.embedContent(request), + ); } /** @@ -331,9 +237,10 @@ export class QwenContentGenerator extends OpenAIContentGenerator { const errorCode = errorWithCode?.status || errorWithCode?.code; return ( - errorCode === 400 || errorCode === 401 || errorCode === 403 || + errorCode === '401' || + errorCode === '403' || errorMessage.includes('unauthorized') || errorMessage.includes('forbidden') || errorMessage.includes('invalid api key') || @@ -349,15 +256,22 @@ export class QwenContentGenerator extends OpenAIContentGenerator { * Get the current cached token (may be expired) */ getCurrentToken(): string | null { - return this.currentToken; + // First check internal state for backwards compatibility with tests + if (this.currentToken) { + return this.currentToken; + } + // Fall back to SharedTokenManager + const credentials = this.sharedManager.getCurrentCredentials(); + return credentials?.access_token || null; } /** - * Clear the cached token and endpoint + * Clear the cached token */ clearToken(): void { - this.currentToken = null; - this.currentEndpoint = null; - this.refreshPromise = null; + // Clear internal state for backwards compatibility with tests + this.currentToken = undefined; + // Also clear SharedTokenManager + this.sharedManager.clearCache(); } } diff --git a/packages/core/src/qwen/qwenOAuth2.test.ts b/packages/core/src/qwen/qwenOAuth2.test.ts index 73a3a567c..ffeee83ec 100644 --- a/packages/core/src/qwen/qwenOAuth2.test.ts +++ b/packages/core/src/qwen/qwenOAuth2.test.ts @@ -20,7 +20,74 @@ import { type DeviceAuthorizationResponse, type DeviceTokenResponse, type ErrorData, + type QwenCredentials, } from './qwenOAuth2.js'; +import { + SharedTokenManager, + TokenManagerError, + TokenError, +} from './sharedTokenManager.js'; + +interface MockSharedTokenManager { + getValidCredentials(qwenClient: QwenOAuth2Client): Promise; + getCurrentCredentials(): QwenCredentials | null; + clearCache(): void; +} + +// Mock SharedTokenManager +vi.mock('./sharedTokenManager.js', () => ({ + SharedTokenManager: class { + private static instance: MockSharedTokenManager | null = null; + + static getInstance() { + if (!this.instance) { + this.instance = new this(); + } + return this.instance; + } + + async getValidCredentials( + qwenClient: QwenOAuth2Client, + ): Promise { + // Try to get credentials from the client first + const clientCredentials = qwenClient.getCredentials(); + if (clientCredentials && clientCredentials.access_token) { + return clientCredentials; + } + + // Fall back to default mock credentials if client has none + return { + access_token: 'new-access-token', + refresh_token: 'valid-refresh-token', + resource_url: undefined, + token_type: 'Bearer', + expiry_date: Date.now() + 3600000, + }; + } + + getCurrentCredentials(): QwenCredentials | null { + // Return null to let the client manage its own credentials + return null; + } + + clearCache(): void { + // Do nothing in mock + } + }, + TokenManagerError: class extends Error { + constructor(message: string) { + super(message); + this.name = 'TokenManagerError'; + } + }, + TokenError: { + REFRESH_FAILED: 'REFRESH_FAILED', + NO_REFRESH_TOKEN: 'NO_REFRESH_TOKEN', + LOCK_TIMEOUT: 'LOCK_TIMEOUT', + FILE_ACCESS_ERROR: 'FILE_ACCESS_ERROR', + NETWORK_ERROR: 'NETWORK_ERROR', + }, +})); // Mock qrcode-terminal vi.mock('qrcode-terminal', () => ({ @@ -227,7 +294,7 @@ describe('QwenOAuth2Client', () => { beforeEach(() => { // Create client instance - client = new QwenOAuth2Client({ proxy: undefined }); + client = new QwenOAuth2Client(); // Mock fetch originalFetch = global.fetch; @@ -345,10 +412,9 @@ describe('QwenOAuth2Client', () => { ); }); - it('should cache credentials after successful refresh', async () => { - const { promises: fs } = await import('node:fs'); - const mockWriteFile = vi.mocked(fs.writeFile); - const mockMkdir = vi.mocked(fs.mkdir); + it('should successfully refresh access token and update credentials', async () => { + // Clear any previous calls + vi.clearAllMocks(); const mockResponse = { ok: true, @@ -362,28 +428,30 @@ describe('QwenOAuth2Client', () => { vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); - await client.refreshAccessToken(); - - // Verify that cacheQwenCredentials was called by checking if writeFile was called - expect(mockMkdir).toHaveBeenCalled(); - expect(mockWriteFile).toHaveBeenCalled(); + const result = await client.refreshAccessToken(); - // Verify the cached credentials contain the new token data - const writeCall = mockWriteFile.mock.calls[0]; - const cachedCredentials = JSON.parse(writeCall[1] as string); + // Verify the response + expect(result).toMatchObject({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + resource_url: 'https://new-endpoint.com', + }); - expect(cachedCredentials).toMatchObject({ + // Verify credentials were updated + const credentials = client.getCredentials(); + expect(credentials).toMatchObject({ access_token: 'new-access-token', token_type: 'Bearer', refresh_token: 'test-refresh-token', // Should preserve existing refresh token resource_url: 'https://new-endpoint.com', }); - expect(cachedCredentials.expiry_date).toBeDefined(); + expect(credentials.expiry_date).toBeDefined(); }); it('should use new refresh token if provided in response', async () => { - const { promises: fs } = await import('node:fs'); - const mockWriteFile = vi.mocked(fs.writeFile); + // Clear any previous calls + vi.clearAllMocks(); const mockResponse = { ok: true, @@ -400,11 +468,9 @@ describe('QwenOAuth2Client', () => { await client.refreshAccessToken(); - // Verify the cached credentials contain the new refresh token - const writeCall = mockWriteFile.mock.calls[0]; - const cachedCredentials = JSON.parse(writeCall[1] as string); - - expect(cachedCredentials.refresh_token).toBe('new-refresh-token'); + // Verify the credentials contain the new refresh token + const credentials = client.getCredentials(); + expect(credentials.refresh_token).toBe('new-refresh-token'); }); }); @@ -428,19 +494,22 @@ describe('QwenOAuth2Client', () => { expiry_date: Date.now() - 1000, // 1 second ago }); - const mockRefreshResponse = { - ok: true, - json: async () => ({ + // Override the client's SharedTokenManager instance directly + ( + client as unknown as { + sharedManager: { + getValidCredentials: () => Promise; + }; + } + ).sharedManager = { + getValidCredentials: vi.fn().mockResolvedValue({ access_token: 'new-access-token', + refresh_token: 'valid-refresh-token', token_type: 'Bearer', - expires_in: 3600, + expiry_date: Date.now() + 3600000, }), }; - vi.mocked(global.fetch).mockResolvedValue( - mockRefreshResponse as Response, - ); - const result = await client.getAccessToken(); expect(result.token).toBe('new-access-token'); }); @@ -448,6 +517,19 @@ describe('QwenOAuth2Client', () => { it('should return undefined if no access token and no refresh token', async () => { client.setCredentials({}); + // Override the client's SharedTokenManager instance directly + ( + client as unknown as { + sharedManager: { + getValidCredentials: () => Promise; + }; + } + ).sharedManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials available')), + }; + const result = await client.getAccessToken(); expect(result.token).toBeUndefined(); }); @@ -662,7 +744,6 @@ describe('getQwenOAuthClient', () => { beforeEach(() => { mockConfig = { - getProxy: vi.fn().mockReturnValue(undefined), isBrowserLaunchSuppressed: vi.fn().mockReturnValue(false), } as unknown as Config; @@ -675,38 +756,8 @@ describe('getQwenOAuthClient', () => { vi.clearAllMocks(); }); - it('should create client with proxy configuration', async () => { - const proxyUrl = 'http://proxy.example.com:8080'; - mockConfig.getProxy = vi.fn().mockReturnValue(proxyUrl); - - const { promises: fs } = await import('node:fs'); - vi.mocked(fs.readFile).mockRejectedValue( - new Error('No cached credentials'), - ); - - // Mock device authorization flow to fail quickly for this test - const mockAuthResponse = { - ok: true, - json: async () => ({ - error: 'test_error', - error_description: 'Test error for quick failure', - }), - }; - vi.mocked(global.fetch).mockResolvedValue(mockAuthResponse as Response); - - try { - await import('./qwenOAuth2.js').then((module) => - module.getQwenOAuthClient(mockConfig), - ); - } catch { - // Expected to fail due to mocked error - } - - expect(mockConfig.getProxy).toHaveBeenCalled(); - }); - it('should load cached credentials if available', async () => { - const { promises: fs } = await import('node:fs'); + const fs = await import('node:fs'); const mockCredentials = { access_token: 'cached-token', refresh_token: 'cached-refresh', @@ -714,29 +765,30 @@ describe('getQwenOAuthClient', () => { expiry_date: Date.now() + 3600000, }; - vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockCredentials)); + vi.mocked(fs.promises.readFile).mockResolvedValue( + JSON.stringify(mockCredentials), + ); - // Mock successful refresh - const mockRefreshResponse = { - ok: true, - json: async () => ({ - access_token: 'refreshed-token', - token_type: 'Bearer', - expires_in: 3600, - }), + // Mock SharedTokenManager to use cached credentials + const mockTokenManager = { + getValidCredentials: vi.fn().mockResolvedValue(mockCredentials), }; - vi.mocked(global.fetch).mockResolvedValue(mockRefreshResponse as Response); + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); const client = await import('./qwenOAuth2.js').then((module) => module.getQwenOAuthClient(mockConfig), ); expect(client).toBeInstanceOf(Object); - expect(fs.readFile).toHaveBeenCalled(); + expect(mockTokenManager.getValidCredentials).toHaveBeenCalled(); + + SharedTokenManager.getInstance = originalGetInstance; }); it('should handle cached credentials refresh failure', async () => { - const { promises: fs } = await import('node:fs'); + const fs = await import('node:fs'); const mockCredentials = { access_token: 'cached-token', refresh_token: 'expired-refresh', @@ -744,23 +796,38 @@ describe('getQwenOAuthClient', () => { expiry_date: Date.now() + 3600000, // Valid expiry time so loadCachedQwenCredentials returns true }; - vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockCredentials)); + vi.mocked(fs.promises.readFile).mockResolvedValue( + JSON.stringify(mockCredentials), + ); - // Mock refresh failure with 400 status to trigger credential clearing - const mockRefreshResponse = { - ok: false, - status: 400, - statusText: 'Bad Request', - text: async () => 'Refresh token expired or invalid', + // Mock SharedTokenManager to fail with a specific error + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Token refresh failed')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + + // Mock device flow to also fail + const mockAuthResponse = { + ok: true, + json: async () => ({ + error: 'invalid_request', + error_description: 'Invalid request parameters', + }), }; - vi.mocked(global.fetch).mockResolvedValue(mockRefreshResponse as Response); + vi.mocked(global.fetch).mockResolvedValue(mockAuthResponse as Response); // The function should handle the invalid cached credentials and throw the expected error await expect( import('./qwenOAuth2.js').then((module) => module.getQwenOAuthClient(mockConfig), ), - ).rejects.toThrow('Cached Qwen credentials are invalid'); + ).rejects.toThrow('Qwen OAuth authentication failed'); + + SharedTokenManager.getInstance = originalGetInstance; }); }); @@ -803,7 +870,7 @@ describe('QwenOAuth2Client - Additional Error Scenarios', () => { let originalFetch: typeof global.fetch; beforeEach(() => { - client = new QwenOAuth2Client({ proxy: undefined }); + client = new QwenOAuth2Client(); originalFetch = global.fetch; global.fetch = vi.fn(); }); @@ -858,7 +925,6 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { beforeEach(() => { mockConfig = { - getProxy: vi.fn().mockReturnValue(undefined), isBrowserLaunchSuppressed: vi.fn().mockReturnValue(false), } as unknown as Config; @@ -882,22 +948,33 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockCredentials)); - // Mock generic refresh failure (not 400 status) - const mockRefreshResponse = { - ok: false, - status: 500, - statusText: 'Internal Server Error', - text: async () => 'Internal server error', + // Mock SharedTokenManager to fail + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Refresh failed')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + + // Mock device flow to also fail + const mockAuthResponse = { + ok: true, + json: async () => ({ + error: 'invalid_request', + error_description: 'Invalid request parameters', + }), }; - vi.mocked(global.fetch).mockResolvedValue(mockRefreshResponse as Response); + vi.mocked(global.fetch).mockResolvedValue(mockAuthResponse as Response); await expect( import('./qwenOAuth2.js').then((module) => module.getQwenOAuthClient(mockConfig), ), - ).rejects.toThrow( - 'Qwen token refresh failed: Token refresh failed: 500 Internal Server Error', - ); + ).rejects.toThrow('Qwen OAuth authentication failed'); + + SharedTokenManager.getInstance = originalGetInstance; }); it('should handle different authentication failure reasons - timeout', async () => { @@ -906,6 +983,16 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { new Error('No cached credentials'), ); + // Mock SharedTokenManager to fail + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + // Mock device authorization to succeed but polling to timeout const mockAuthResponse = { ok: true, @@ -925,7 +1012,8 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { }), }; - vi.mocked(global.fetch) + global.fetch = vi + .fn() .mockResolvedValueOnce(mockAuthResponse as Response) .mockResolvedValue(mockPendingResponse as Response); @@ -934,6 +1022,8 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { module.getQwenOAuthClient(mockConfig), ), ).rejects.toThrow('Qwen OAuth authentication timed out'); + + SharedTokenManager.getInstance = originalGetInstance; }); it('should handle authentication failure reason - rate limit', async () => { @@ -942,6 +1032,16 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { new Error('No cached credentials'), ); + // Mock SharedTokenManager to fail + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + // Mock device authorization to succeed but polling to get rate limited const mockAuthResponse = { ok: true, @@ -961,7 +1061,8 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { text: async () => 'Rate limited', }; - vi.mocked(global.fetch) + global.fetch = vi + .fn() .mockResolvedValueOnce(mockAuthResponse as Response) .mockResolvedValue(mockRateLimitResponse as Response); @@ -972,6 +1073,8 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { ).rejects.toThrow( 'Too many request for Qwen OAuth authentication, please try again later.', ); + + SharedTokenManager.getInstance = originalGetInstance; }); it('should handle authentication failure reason - error', async () => { @@ -980,6 +1083,16 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { new Error('No cached credentials'), ); + // Mock SharedTokenManager to fail + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + // Mock device authorization to fail const mockAuthResponse = { ok: true, @@ -989,13 +1102,15 @@ describe('getQwenOAuthClient - Enhanced Error Scenarios', () => { }), }; - vi.mocked(global.fetch).mockResolvedValue(mockAuthResponse as Response); + global.fetch = vi.fn().mockResolvedValue(mockAuthResponse as Response); await expect( import('./qwenOAuth2.js').then((module) => module.getQwenOAuthClient(mockConfig), ), ).rejects.toThrow('Qwen OAuth authentication failed'); + + SharedTokenManager.getInstance = originalGetInstance; }); }); @@ -1005,11 +1120,9 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { beforeEach(() => { mockConfig = { - getProxy: vi.fn().mockReturnValue(undefined), isBrowserLaunchSuppressed: vi.fn().mockReturnValue(false), } as unknown as Config; - new QwenOAuth2Client({ proxy: undefined }); originalFetch = global.fetch; global.fetch = vi.fn(); @@ -1029,6 +1142,16 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { new Error('No cached credentials'), ); + // Mock SharedTokenManager to fail + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + const mockAuthResponse = { ok: true, json: async () => ({ @@ -1037,13 +1160,15 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { }), }; - vi.mocked(global.fetch).mockResolvedValue(mockAuthResponse as Response); + global.fetch = vi.fn().mockResolvedValue(mockAuthResponse as Response); await expect( import('./qwenOAuth2.js').then((module) => module.getQwenOAuthClient(mockConfig), ), ).rejects.toThrow('Qwen OAuth authentication failed'); + + SharedTokenManager.getInstance = originalGetInstance; }); it('should handle successful authentication flow', async () => { @@ -1091,6 +1216,16 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { new Error('No cached credentials'), ); + // Mock SharedTokenManager to fail + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + const mockAuthResponse = { ok: true, json: async () => ({ @@ -1109,7 +1244,8 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { text: async () => 'Device code expired', }; - vi.mocked(global.fetch) + global.fetch = vi + .fn() .mockResolvedValueOnce(mockAuthResponse as Response) .mockResolvedValue(mock401Response as Response); @@ -1118,6 +1254,8 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { module.getQwenOAuthClient(mockConfig), ), ).rejects.toThrow('Qwen OAuth authentication failed'); + + SharedTokenManager.getInstance = originalGetInstance; }); it('should handle token polling with browser launch suppressed', async () => { @@ -1126,6 +1264,16 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { new Error('No cached credentials'), ); + // Mock SharedTokenManager to fail initially so device flow is used + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi.fn().mockReturnValue(mockTokenManager); + // Mock browser launch as suppressed mockConfig.isBrowserLaunchSuppressed = vi.fn().mockReturnValue(true); @@ -1151,7 +1299,8 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { }), }; - vi.mocked(global.fetch) + global.fetch = vi + .fn() .mockResolvedValueOnce(mockAuthResponse as Response) .mockResolvedValue(mockTokenResponse as Response); @@ -1161,6 +1310,8 @@ describe('authWithQwenDeviceFlow - Comprehensive Testing', () => { expect(client).toBeInstanceOf(Object); expect(mockConfig.isBrowserLaunchSuppressed).toHaveBeenCalled(); + + SharedTokenManager.getInstance = originalGetInstance; }); }); @@ -1170,7 +1321,6 @@ describe('Browser Launch and Error Handling', () => { beforeEach(() => { mockConfig = { - getProxy: vi.fn().mockReturnValue(undefined), isBrowserLaunchSuppressed: vi.fn().mockReturnValue(false), } as unknown as Config; @@ -1295,3 +1445,833 @@ describe('Event Emitter Integration', () => { expect(QwenOAuth2Event.AuthCancel).toBe('auth-cancel'); }); }); + +describe('Utility Functions', () => { + describe('objectToUrlEncoded', () => { + it('should encode object properties to URL-encoded format', async () => { + // Since objectToUrlEncoded is private, we test it indirectly through the client + const objectToUrlEncoded = (data: Record): string => + Object.keys(data) + .map( + (key) => + `${encodeURIComponent(key)}=${encodeURIComponent(data[key])}`, + ) + .join('&'); + + const testData = { + client_id: 'test-client', + scope: 'openid profile', + redirect_uri: 'https://example.com/callback', + }; + + const result = objectToUrlEncoded(testData); + + expect(result).toContain('client_id=test-client'); + expect(result).toContain('scope=openid%20profile'); + expect(result).toContain( + 'redirect_uri=https%3A%2F%2Fexample.com%2Fcallback', + ); + }); + + it('should handle special characters', async () => { + const objectToUrlEncoded = (data: Record): string => + Object.keys(data) + .map( + (key) => + `${encodeURIComponent(key)}=${encodeURIComponent(data[key])}`, + ) + .join('&'); + + const testData = { + 'param with spaces': 'value with spaces', + 'param&with&s': 'value&with&s', + 'param=with=equals': 'value=with=equals', + }; + + const result = objectToUrlEncoded(testData); + + expect(result).toContain('param%20with%20spaces=value%20with%20spaces'); + expect(result).toContain('param%26with%26amps=value%26with%26amps'); + expect(result).toContain('param%3Dwith%3Dequals=value%3Dwith%3Dequals'); + }); + + it('should handle empty object', async () => { + const objectToUrlEncoded = (data: Record): string => + Object.keys(data) + .map( + (key) => + `${encodeURIComponent(key)}=${encodeURIComponent(data[key])}`, + ) + .join('&'); + + const result = objectToUrlEncoded({}); + expect(result).toBe(''); + }); + }); + + describe('getQwenCachedCredentialPath', () => { + it('should return correct path to cached credentials', async () => { + const os = await import('os'); + const path = await import('path'); + + const expectedPath = path.join(os.homedir(), '.qwen', 'oauth_creds.json'); + + // Since this is a private function, we test it indirectly through clearQwenCredentials + const { promises: fs } = await import('node:fs'); + const { clearQwenCredentials } = await import('./qwenOAuth2.js'); + + vi.mocked(fs.unlink).mockResolvedValue(undefined); + + await clearQwenCredentials(); + + expect(fs.unlink).toHaveBeenCalledWith(expectedPath); + }); + }); +}); + +describe('Credential Caching Functions', () => { + describe('cacheQwenCredentials', () => { + it('should create directory and write credentials to file', async () => { + // Mock the internal cacheQwenCredentials function by creating client and calling refresh + const client = new QwenOAuth2Client(); + client.setCredentials({ + refresh_token: 'test-refresh', + }); + + const mockResponse = { + ok: true, + json: async () => ({ + access_token: 'new-token', + token_type: 'Bearer', + expires_in: 3600, + }), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse as Response); + + await client.refreshAccessToken(); + + // Note: File caching is now handled by SharedTokenManager, so these calls won't happen + // This test verifies that refreshAccessToken works correctly + const updatedCredentials = client.getCredentials(); + expect(updatedCredentials.access_token).toBe('new-token'); + }); + }); + + describe('loadCachedQwenCredentials', () => { + it('should load and validate cached credentials successfully', async () => { + const { promises: fs } = await import('node:fs'); + const mockCredentials = { + access_token: 'cached-token', + refresh_token: 'cached-refresh', + token_type: 'Bearer', + expiry_date: Date.now() + 3600000, + }; + + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(mockCredentials)); + + // Test through getQwenOAuthClient which calls loadCachedQwenCredentials + const mockConfig = { + isBrowserLaunchSuppressed: vi.fn().mockReturnValue(true), + } as unknown as Config; + + // Make SharedTokenManager fail to test the fallback + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No cached creds')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + // Mock successful auth flow after cache load fails + const mockAuthResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + const mockTokenResponse = { + ok: true, + json: async () => ({ + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + token_type: 'Bearer', + expires_in: 3600, + scope: 'openid profile email model.completion', + }), + }; + + global.fetch = vi + .fn() + .mockResolvedValueOnce(mockAuthResponse as Response) + .mockResolvedValue(mockTokenResponse as Response); + + try { + await import('./qwenOAuth2.js').then((module) => + module.getQwenOAuthClient(mockConfig), + ); + } catch { + // Expected to fail in test environment + } + + expect(fs.readFile).toHaveBeenCalled(); + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should handle invalid cached credentials gracefully', async () => { + const { promises: fs } = await import('node:fs'); + + // Mock file read to return invalid JSON + vi.mocked(fs.readFile).mockResolvedValue('invalid-json'); + + const mockConfig = { + isBrowserLaunchSuppressed: vi.fn().mockReturnValue(true), + } as unknown as Config; + + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No cached creds')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + // Mock auth flow + const mockAuthResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + const mockTokenResponse = { + ok: true, + json: async () => ({ + access_token: 'new-token', + refresh_token: 'new-refresh', + token_type: 'Bearer', + expires_in: 3600, + }), + }; + + global.fetch = vi + .fn() + .mockResolvedValueOnce(mockAuthResponse as Response) + .mockResolvedValue(mockTokenResponse as Response); + + try { + await import('./qwenOAuth2.js').then((module) => + module.getQwenOAuthClient(mockConfig), + ); + } catch { + // Expected to fail in test environment + } + + SharedTokenManager.getInstance = originalGetInstance; + }); + + it('should handle file access errors', async () => { + const { promises: fs } = await import('node:fs'); + + vi.mocked(fs.readFile).mockRejectedValue(new Error('File not found')); + + const mockConfig = { + isBrowserLaunchSuppressed: vi.fn().mockReturnValue(true), + } as unknown as Config; + + const mockTokenManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No cached creds')), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + // Mock device flow to fail quickly + const mockAuthResponse = { + ok: true, + json: async () => ({ + error: 'invalid_request', + error_description: 'Invalid request parameters', + }), + }; + + global.fetch = vi.fn().mockResolvedValue(mockAuthResponse as Response); + + // Should proceed to device flow when cache loading fails + try { + await import('./qwenOAuth2.js').then((module) => + module.getQwenOAuthClient(mockConfig), + ); + } catch { + // Expected to fail in test environment + } + + SharedTokenManager.getInstance = originalGetInstance; + }); + }); +}); + +describe('Enhanced Error Handling and Edge Cases', () => { + let client: QwenOAuth2Client; + let originalFetch: typeof global.fetch; + + beforeEach(() => { + client = new QwenOAuth2Client(); + originalFetch = global.fetch; + global.fetch = vi.fn(); + }); + + afterEach(() => { + global.fetch = originalFetch; + vi.clearAllMocks(); + }); + + describe('QwenOAuth2Client getAccessToken enhanced scenarios', () => { + it('should handle SharedTokenManager failure and fall back to cached token', async () => { + // Set up client with valid credentials + client.setCredentials({ + access_token: 'fallback-token', + expiry_date: Date.now() + 3600000, // Valid for 1 hour + }); + + // Override the client's SharedTokenManager instance directly to ensure it fails + ( + client as unknown as { + sharedManager: { + getValidCredentials: () => Promise; + }; + } + ).sharedManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Manager failed')), + }; + + // Mock console.warn to avoid test noise + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + const result = await client.getAccessToken(); + + expect(result.token).toBe('fallback-token'); + expect(consoleSpy).toHaveBeenCalledWith( + 'Failed to get access token from shared manager:', + expect.any(Error), + ); + + consoleSpy.mockRestore(); + }); + + it('should return undefined when both manager and cache fail', async () => { + // Set up client with expired credentials + client.setCredentials({ + access_token: 'expired-token', + expiry_date: Date.now() - 1000, // Expired + }); + + // Override the client's SharedTokenManager instance directly to ensure it fails + ( + client as unknown as { + sharedManager: { + getValidCredentials: () => Promise; + }; + } + ).sharedManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('Manager failed')), + }; + + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + const result = await client.getAccessToken(); + + expect(result.token).toBeUndefined(); + + consoleSpy.mockRestore(); + }); + + it('should handle missing credentials gracefully', async () => { + // No credentials set + client.setCredentials({}); + + // Override the client's SharedTokenManager instance directly to ensure it fails + ( + client as unknown as { + sharedManager: { + getValidCredentials: () => Promise; + }; + } + ).sharedManager = { + getValidCredentials: vi + .fn() + .mockRejectedValue(new Error('No credentials')), + }; + + const consoleSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + const result = await client.getAccessToken(); + + expect(result.token).toBeUndefined(); + + consoleSpy.mockRestore(); + }); + }); + + describe('Enhanced requestDeviceAuthorization scenarios', () => { + it('should include x-request-id header', async () => { + const mockResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); + + await client.requestDeviceAuthorization({ + scope: 'openid profile email model.completion', + code_challenge: 'test-challenge', + code_challenge_method: 'S256', + }); + + expect(global.fetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + headers: expect.objectContaining({ + 'x-request-id': expect.any(String), + }), + }), + ); + }); + + it('should include correct Content-Type and Accept headers', async () => { + const mockResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); + + await client.requestDeviceAuthorization({ + scope: 'openid profile email model.completion', + code_challenge: 'test-challenge', + code_challenge_method: 'S256', + }); + + expect(global.fetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + headers: expect.objectContaining({ + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }), + }), + ); + }); + + it('should send correct form data', async () => { + const mockResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); + + await client.requestDeviceAuthorization({ + scope: 'test-scope', + code_challenge: 'test-challenge', + code_challenge_method: 'S256', + }); + + const [, options] = vi.mocked(global.fetch).mock.calls[0]; + expect(options?.body).toContain( + 'client_id=f0304373b74a44d2b584a3fb70ca9e56', + ); + expect(options?.body).toContain('scope=test-scope'); + expect(options?.body).toContain('code_challenge=test-challenge'); + expect(options?.body).toContain('code_challenge_method=S256'); + }); + }); + + describe('Enhanced pollDeviceToken scenarios', () => { + it('should handle JSON parsing error during error response', async () => { + const mockResponse = { + ok: false, + status: 400, + statusText: 'Bad Request', + json: vi.fn().mockRejectedValue(new Error('Invalid JSON')), + text: vi.fn().mockResolvedValue('Invalid request format'), + }; + + vi.mocked(global.fetch).mockResolvedValue( + mockResponse as unknown as Response, + ); + + await expect( + client.pollDeviceToken({ + device_code: 'test-device-code', + code_verifier: 'test-verifier', + }), + ).rejects.toThrow('Device token poll failed: 400 Bad Request'); + }); + + it('should include status code in thrown errors', async () => { + const mockResponse = { + ok: false, + status: 500, + statusText: 'Internal Server Error', + json: vi.fn().mockRejectedValue(new Error('Invalid JSON')), + text: vi.fn().mockResolvedValue('Internal server error'), + }; + + global.fetch = vi + .fn() + .mockResolvedValue(mockResponse as unknown as Response); + + await expect( + client.pollDeviceToken({ + device_code: 'test-device-code', + code_verifier: 'test-verifier', + }), + ).rejects.toMatchObject({ + message: expect.stringContaining( + 'Device token poll failed: 500 Internal Server Error', + ), + status: 500, + }); + }); + + it('should handle authorization_pending with correct status', async () => { + const mockResponse = { + ok: false, + status: 400, + statusText: 'Bad Request', + json: vi.fn().mockResolvedValue({ + error: 'authorization_pending', + error_description: 'Authorization request is pending', + }), + }; + + vi.mocked(global.fetch).mockResolvedValue( + mockResponse as unknown as Response, + ); + + const result = await client.pollDeviceToken({ + device_code: 'test-device-code', + code_verifier: 'test-verifier', + }); + + expect(result).toEqual({ status: 'pending' }); + }); + }); + + describe('Enhanced refreshAccessToken scenarios', () => { + it('should call clearQwenCredentials on 400 error', async () => { + client.setCredentials({ + refresh_token: 'expired-refresh', + }); + + const { promises: fs } = await import('node:fs'); + vi.mocked(fs.unlink).mockResolvedValue(undefined); + + const mockResponse = { + ok: false, + status: 400, + text: async () => 'Bad Request', + }; + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); + + await expect(client.refreshAccessToken()).rejects.toThrow( + "Refresh token expired or invalid. Please use '/auth' to re-authenticate.", + ); + + expect(fs.unlink).toHaveBeenCalled(); + }); + + it('should preserve existing refresh token when new one not provided', async () => { + const originalRefreshToken = 'original-refresh-token'; + client.setCredentials({ + refresh_token: originalRefreshToken, + }); + + const mockResponse = { + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + // No refresh_token in response + }), + }; + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); + + await client.refreshAccessToken(); + + const credentials = client.getCredentials(); + expect(credentials.refresh_token).toBe(originalRefreshToken); + }); + + it('should include resource_url when provided in response', async () => { + client.setCredentials({ + refresh_token: 'test-refresh', + }); + + const mockResponse = { + ok: true, + json: async () => ({ + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + resource_url: 'https://new-resource-url.com', + }), + }; + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as Response); + + await client.refreshAccessToken(); + + const credentials = client.getCredentials(); + expect(credentials.resource_url).toBe('https://new-resource-url.com'); + }); + }); + + describe('isTokenValid edge cases', () => { + it('should return false for tokens expiring within buffer time', () => { + const nearExpiryTime = Date.now() + 15000; // 15 seconds from now (within 30s buffer) + + client.setCredentials({ + access_token: 'test-token', + expiry_date: nearExpiryTime, + }); + + const isValid = ( + client as unknown as { isTokenValid(): boolean } + ).isTokenValid(); + expect(isValid).toBe(false); + }); + + it('should return true for tokens expiring well beyond buffer time', () => { + const futureExpiryTime = Date.now() + 120000; // 2 minutes from now (beyond 30s buffer) + + client.setCredentials({ + access_token: 'test-token', + expiry_date: futureExpiryTime, + }); + + const isValid = ( + client as unknown as { isTokenValid(): boolean } + ).isTokenValid(); + expect(isValid).toBe(true); + }); + }); +}); + +describe('SharedTokenManager Integration in QwenOAuth2Client', () => { + let client: QwenOAuth2Client; + + beforeEach(() => { + client = new QwenOAuth2Client(); + }); + + it('should use SharedTokenManager instance in constructor', () => { + const sharedManager = ( + client as unknown as { sharedManager: MockSharedTokenManager } + ).sharedManager; + expect(sharedManager).toBeDefined(); + }); + + it('should handle TokenManagerError types correctly in getQwenOAuthClient', async () => { + const mockConfig = { + isBrowserLaunchSuppressed: vi.fn().mockReturnValue(true), + } as unknown as Config; + + // Test different TokenManagerError types + const tokenErrors = [ + { type: TokenError.NO_REFRESH_TOKEN, message: 'No refresh token' }, + { type: TokenError.REFRESH_FAILED, message: 'Token refresh failed' }, + { type: TokenError.NETWORK_ERROR, message: 'Network error' }, + { type: TokenError.REFRESH_FAILED, message: 'Refresh failed' }, + ]; + + for (const errorInfo of tokenErrors) { + const tokenError = new TokenManagerError( + errorInfo.type, + errorInfo.message, + ); + + const mockTokenManager = { + getValidCredentials: vi.fn().mockRejectedValue(tokenError), + }; + + const originalGetInstance = SharedTokenManager.getInstance; + SharedTokenManager.getInstance = vi + .fn() + .mockReturnValue(mockTokenManager); + + const { promises: fs } = await import('node:fs'); + vi.mocked(fs.readFile).mockRejectedValue(new Error('No cached file')); + + // Mock device flow to succeed + const mockAuthResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + const mockTokenResponse = { + ok: true, + json: async () => ({ + access_token: 'new-token', + refresh_token: 'new-refresh', + token_type: 'Bearer', + expires_in: 3600, + }), + }; + + global.fetch = vi + .fn() + .mockResolvedValueOnce(mockAuthResponse as Response) + .mockResolvedValue(mockTokenResponse as Response); + + try { + await import('./qwenOAuth2.js').then((module) => + module.getQwenOAuthClient(mockConfig), + ); + } catch { + // Expected to fail in test environment + } + + SharedTokenManager.getInstance = originalGetInstance; + vi.clearAllMocks(); + } + }); +}); + +describe('Constants and Configuration', () => { + it('should have correct OAuth endpoints', async () => { + // Test that the constants are properly defined by checking they're used in requests + const client = new QwenOAuth2Client(); + + const mockResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse as Response); + + await client.requestDeviceAuthorization({ + scope: 'test-scope', + code_challenge: 'test-challenge', + code_challenge_method: 'S256', + }); + + const [url] = vi.mocked(global.fetch).mock.calls[0]; + expect(url).toBe('https://chat.qwen.ai/api/v1/oauth2/device/code'); + }); + + it('should use correct client ID in requests', async () => { + const client = new QwenOAuth2Client(); + + const mockResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse as Response); + + await client.requestDeviceAuthorization({ + scope: 'test-scope', + code_challenge: 'test-challenge', + code_challenge_method: 'S256', + }); + + const [, options] = vi.mocked(global.fetch).mock.calls[0]; + expect(options?.body).toContain( + 'client_id=f0304373b74a44d2b584a3fb70ca9e56', + ); + }); + + it('should use correct default scope', async () => { + // Test the default scope constant by checking it's used in device flow + const client = new QwenOAuth2Client(); + + const mockResponse = { + ok: true, + json: async () => ({ + device_code: 'test-device-code', + user_code: 'TEST123', + verification_uri: 'https://chat.qwen.ai/device', + verification_uri_complete: 'https://chat.qwen.ai/device?code=TEST123', + expires_in: 1800, + }), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse as Response); + + await client.requestDeviceAuthorization({ + scope: 'openid profile email model.completion', + code_challenge: 'test-challenge', + code_challenge_method: 'S256', + }); + + const [, options] = vi.mocked(global.fetch).mock.calls[0]; + expect(options?.body).toContain( + 'scope=openid%20profile%20email%20model.completion', + ); + }); +}); diff --git a/packages/core/src/qwen/qwenOAuth2.ts b/packages/core/src/qwen/qwenOAuth2.ts index 58692117e..592a6dd74 100644 --- a/packages/core/src/qwen/qwenOAuth2.ts +++ b/packages/core/src/qwen/qwenOAuth2.ts @@ -13,6 +13,11 @@ import open from 'open'; import { EventEmitter } from 'events'; import { Config } from '../config/config.js'; import { randomUUID } from 'node:crypto'; +import { + SharedTokenManager, + TokenManagerError, + TokenError, +} from './sharedTokenManager.js'; // OAuth Endpoints const QWEN_OAUTH_BASE_URL = 'https://chat.qwen.ai'; @@ -234,8 +239,11 @@ export interface IQwenOAuth2Client { */ export class QwenOAuth2Client implements IQwenOAuth2Client { private credentials: QwenCredentials = {}; + private sharedManager: SharedTokenManager; - constructor(_options?: { proxy?: string }) {} + constructor() { + this.sharedManager = SharedTokenManager.getInstance(); + } setCredentials(credentials: QwenCredentials): void { this.credentials = credentials; @@ -246,17 +254,23 @@ export class QwenOAuth2Client implements IQwenOAuth2Client { } async getAccessToken(): Promise<{ token?: string }> { - if (this.credentials.access_token && this.isTokenValid()) { - return { token: this.credentials.access_token }; - } + try { + // Use shared manager to get valid credentials with cross-session synchronization + const credentials = await this.sharedManager.getValidCredentials(this); + return { token: credentials.access_token }; + } catch (error) { + console.warn('Failed to get access token from shared manager:', error); + + // Only return cached token if it's still valid, don't refresh uncoordinated + // This prevents the cross-session token invalidation issue + if (this.credentials.access_token && this.isTokenValid()) { + return { token: this.credentials.access_token }; + } - if (this.credentials.refresh_token) { - const refreshResponse = await this.refreshAccessToken(); - const tokenData = refreshResponse as TokenRefreshData; - return { token: tokenData.access_token }; + // If we can't get valid credentials through shared manager, fail gracefully + // All token refresh operations should go through the SharedTokenManager + return { token: undefined }; } - - return { token: undefined }; } async requestDeviceAuthorization(options: { @@ -289,7 +303,7 @@ export class QwenOAuth2Client implements IQwenOAuth2Client { } const result = (await response.json()) as DeviceAuthorizationResponse; - console.log('Device authorization result:', result); + console.debug('Device authorization result:', result); // Check if the response indicates success if (!isDeviceAuthorizationSuccess(result)) { @@ -423,8 +437,8 @@ export class QwenOAuth2Client implements IQwenOAuth2Client { this.setCredentials(tokens); - // Cache the updated credentials to file - await cacheQwenCredentials(tokens); + // Note: File caching is now handled by SharedTokenManager + // to prevent cross-session token invalidation issues return responseData; } @@ -462,68 +476,85 @@ export const qwenOAuth2Events = new EventEmitter(); export async function getQwenOAuthClient( config: Config, ): Promise { - const client = new QwenOAuth2Client({ - proxy: config.getProxy(), - }); + const client = new QwenOAuth2Client(); - // If there are cached creds on disk, they always take precedence - if (await loadCachedQwenCredentials(client)) { - console.log('Loaded cached Qwen credentials.'); + // Use shared token manager to get valid credentials with cross-session synchronization + const sharedManager = SharedTokenManager.getInstance(); - try { - await client.refreshAccessToken(); - return client; - } catch (error: unknown) { - // Handle refresh token errors - const errorMessage = - error instanceof Error ? error.message : String(error); + try { + // Try to get valid credentials from shared cache first + const credentials = await sharedManager.getValidCredentials(client); + client.setCredentials(credentials); + return client; + } catch (error: unknown) { + console.debug( + 'Shared token manager failed, attempting device flow:', + error, + ); - const isInvalidToken = errorMessage.includes( - 'Refresh token expired or invalid', - ); - const userMessage = isInvalidToken - ? 'Cached credentials are invalid. Please re-authenticate.' - : `Token refresh failed: ${errorMessage}`; - const throwMessage = isInvalidToken - ? 'Cached Qwen credentials are invalid. Please re-authenticate.' - : `Qwen token refresh failed: ${errorMessage}`; - - // Emit token refresh error event - qwenOAuth2Events.emit(QwenOAuth2Event.AuthProgress, 'error', userMessage); - throw new Error(throwMessage); + // Handle specific token manager errors + if (error instanceof TokenManagerError) { + switch (error.type) { + case TokenError.NO_REFRESH_TOKEN: + console.debug( + 'No refresh token available, proceeding with device flow', + ); + break; + case TokenError.REFRESH_FAILED: + console.debug('Token refresh failed, proceeding with device flow'); + break; + case TokenError.NETWORK_ERROR: + console.warn( + 'Network error during token refresh, trying device flow', + ); + break; + default: + console.warn('Token manager error:', (error as Error).message); + } } - } - // Use device authorization flow for authentication (single attempt) - const result = await authWithQwenDeviceFlow(client, config); - if (!result.success) { - // Only emit timeout event if the failure reason is actually timeout - // Other error types (401, 429, etc.) have already emitted their specific events - if (result.reason === 'timeout') { - qwenOAuth2Events.emit( - QwenOAuth2Event.AuthProgress, - 'timeout', - 'Authentication timed out. Please try again or select a different authentication method.', - ); + // If shared manager fails, check if we have cached credentials for device flow + if (await loadCachedQwenCredentials(client)) { + // We have cached credentials but they might be expired + // Try device flow instead of forcing refresh + const result = await authWithQwenDeviceFlow(client, config); + if (!result.success) { + throw new Error('Qwen OAuth authentication failed'); + } + return client; } - // Throw error with appropriate message based on failure reason - switch (result.reason) { - case 'timeout': - throw new Error('Qwen OAuth authentication timed out'); - case 'cancelled': - throw new Error('Qwen OAuth authentication was cancelled by user'); - case 'rate_limit': - throw new Error( - 'Too many request for Qwen OAuth authentication, please try again later.', + // No cached credentials, use device authorization flow for authentication + const result = await authWithQwenDeviceFlow(client, config); + if (!result.success) { + // Only emit timeout event if the failure reason is actually timeout + // Other error types (401, 429, etc.) have already emitted their specific events + if (result.reason === 'timeout') { + qwenOAuth2Events.emit( + QwenOAuth2Event.AuthProgress, + 'timeout', + 'Authentication timed out. Please try again or select a different authentication method.', ); - case 'error': - default: - throw new Error('Qwen OAuth authentication failed'); + } + + // Throw error with appropriate message based on failure reason + switch (result.reason) { + case 'timeout': + throw new Error('Qwen OAuth authentication timed out'); + case 'cancelled': + throw new Error('Qwen OAuth authentication was cancelled by user'); + case 'rate_limit': + throw new Error( + 'Too many request for Qwen OAuth authentication, please try again later.', + ); + case 'error': + default: + throw new Error('Qwen OAuth authentication failed'); + } } - } - return client; + return client; + } } async function authWithQwenDeviceFlow( @@ -580,7 +611,9 @@ async function authWithQwenDeviceFlow( // causing the entire Node.js process to crash. if (childProcess) { childProcess.on('error', () => { - console.log('Failed to open browser. Visit this URL to authorize:'); + console.debug( + 'Failed to open browser. Visit this URL to authorize:', + ); showFallbackMessage(); }); } @@ -599,7 +632,7 @@ async function authWithQwenDeviceFlow( 'Waiting for authorization...', ); - console.log('Waiting for authorization...\n'); + console.debug('Waiting for authorization...\n'); // Poll for the token let pollInterval = 2000; // 2 seconds, can be increased if slow_down is received @@ -610,7 +643,7 @@ async function authWithQwenDeviceFlow( for (let attempt = 0; attempt < maxAttempts; attempt++) { // Check if authentication was cancelled if (isCancelled) { - console.log('\nAuthentication cancelled by user.'); + console.debug('\nAuthentication cancelled by user.'); qwenOAuth2Events.emit( QwenOAuth2Event.AuthProgress, 'error', @@ -620,7 +653,7 @@ async function authWithQwenDeviceFlow( } try { - console.log('polling for token...'); + console.debug('polling for token...'); const tokenResponse = await client.pollDeviceToken({ device_code: deviceAuth.device_code, code_verifier, @@ -653,7 +686,7 @@ async function authWithQwenDeviceFlow( 'Authentication successful! Access token obtained.', ); - console.log('Authentication successful! Access token obtained.'); + console.debug('Authentication successful! Access token obtained.'); return { success: true }; } @@ -664,8 +697,8 @@ async function authWithQwenDeviceFlow( // Handle slow_down error by increasing poll interval if (pendingData.slowDown) { pollInterval = Math.min(pollInterval * 1.5, 10000); // Increase by 50%, max 10 seconds - console.log( - `\nServer requested to slow down, increasing poll interval to ${pollInterval}ms`, + console.debug( + `\nServer requested to slow down, increasing poll interval to ${pollInterval}ms'`, ); } else { pollInterval = 2000; // Reset to default interval @@ -706,7 +739,7 @@ async function authWithQwenDeviceFlow( // Check for cancellation after waiting if (isCancelled) { - console.log('\nAuthentication cancelled by user.'); + console.debug('\nAuthentication cancelled by user.'); qwenOAuth2Events.emit( QwenOAuth2Event.AuthProgress, 'error', @@ -834,7 +867,7 @@ export async function clearQwenCredentials(): Promise { try { const filePath = getQwenCachedCredentialPath(); await fs.unlink(filePath); - console.log('Cached Qwen credentials cleared successfully.'); + console.debug('Cached Qwen credentials cleared successfully.'); } catch (error: unknown) { // If file doesn't exist or can't be deleted, we consider it cleared if (error instanceof Error && 'code' in error && error.code === 'ENOENT') { diff --git a/packages/core/src/qwen/sharedTokenManager.test.ts b/packages/core/src/qwen/sharedTokenManager.test.ts new file mode 100644 index 000000000..fbd068fa7 --- /dev/null +++ b/packages/core/src/qwen/sharedTokenManager.test.ts @@ -0,0 +1,758 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + * + */ + +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { promises as fs, unlinkSync, type Stats } from 'node:fs'; +import * as os from 'os'; +import path from 'node:path'; + +import { + SharedTokenManager, + TokenManagerError, + TokenError, +} from './sharedTokenManager.js'; +import type { + IQwenOAuth2Client, + QwenCredentials, + TokenRefreshData, + ErrorData, +} from './qwenOAuth2.js'; + +// Mock external dependencies +vi.mock('node:fs', () => ({ + promises: { + stat: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + mkdir: vi.fn(), + unlink: vi.fn(), + }, + unlinkSync: vi.fn(), +})); + +vi.mock('node:os', () => ({ + homedir: vi.fn(), +})); + +vi.mock('node:path', () => ({ + default: { + join: vi.fn(), + dirname: vi.fn(), + }, +})); + +/** + * Helper to access private properties for testing + */ +function getPrivateProperty(obj: unknown, property: string): T { + return (obj as Record)[property]; +} + +/** + * Helper to set private properties for testing + */ +function setPrivateProperty(obj: unknown, property: string, value: T): void { + (obj as Record)[property] = value; +} + +/** + * Creates a mock QwenOAuth2Client for testing + */ +function createMockQwenClient( + initialCredentials: Partial = {}, +): IQwenOAuth2Client { + let credentials: QwenCredentials = { + access_token: 'mock_access_token', + refresh_token: 'mock_refresh_token', + token_type: 'Bearer', + expiry_date: Date.now() + 3600000, // 1 hour from now + resource_url: 'https://api.example.com', + ...initialCredentials, + }; + + return { + setCredentials: vi.fn((creds: QwenCredentials) => { + credentials = { ...credentials, ...creds }; + }), + getCredentials: vi.fn(() => credentials), + getAccessToken: vi.fn(), + requestDeviceAuthorization: vi.fn(), + pollDeviceToken: vi.fn(), + refreshAccessToken: vi.fn(), + }; +} + +/** + * Creates valid mock credentials + */ +function createValidCredentials( + overrides: Partial = {}, +): QwenCredentials { + return { + access_token: 'valid_access_token', + refresh_token: 'valid_refresh_token', + token_type: 'Bearer', + expiry_date: Date.now() + 3600000, // 1 hour from now + resource_url: 'https://api.example.com', + ...overrides, + }; +} + +/** + * Creates expired mock credentials + */ +function createExpiredCredentials( + overrides: Partial = {}, +): QwenCredentials { + return { + access_token: 'expired_access_token', + refresh_token: 'expired_refresh_token', + token_type: 'Bearer', + expiry_date: Date.now() - 3600000, // 1 hour ago + resource_url: 'https://api.example.com', + ...overrides, + }; +} + +/** + * Creates a successful token refresh response + */ +function createSuccessfulRefreshResponse( + overrides: Partial = {}, +): TokenRefreshData { + return { + access_token: 'fresh_access_token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new_refresh_token', + resource_url: 'https://api.example.com', + ...overrides, + }; +} + +/** + * Creates an error response + */ +function createErrorResponse( + error = 'invalid_grant', + description = 'Token expired', +): ErrorData { + return { + error, + error_description: description, + }; +} + +describe('SharedTokenManager', () => { + let tokenManager: SharedTokenManager; + + // Get mocked modules + const mockFs = vi.mocked(fs); + const mockOs = vi.mocked(os); + const mockPath = vi.mocked(path); + const mockUnlinkSync = vi.mocked(unlinkSync); + + beforeEach(() => { + // Clean up any existing instance's listeners first + const existingInstance = getPrivateProperty( + SharedTokenManager, + 'instance', + ) as SharedTokenManager; + if (existingInstance) { + existingInstance.cleanup(); + } + + // Reset all mocks + vi.clearAllMocks(); + + // Setup default mock implementations + mockOs.homedir.mockReturnValue('/home/user'); + mockPath.join.mockImplementation((...args) => args.join('/')); + mockPath.dirname.mockImplementation((filePath) => { + // Handle undefined/null input gracefully + if (!filePath || typeof filePath !== 'string') { + return '/home/user/.qwen'; // Return the expected directory path + } + const parts = filePath.split('/'); + const result = parts.slice(0, -1).join('/'); + return result || '/'; + }); + + // Reset singleton instance for each test + setPrivateProperty(SharedTokenManager, 'instance', null); + tokenManager = SharedTokenManager.getInstance(); + }); + + afterEach(() => { + // Clean up listeners after each test + if (tokenManager) { + tokenManager.cleanup(); + } + }); + + describe('Singleton Pattern', () => { + it('should return the same instance when called multiple times', () => { + const instance1 = SharedTokenManager.getInstance(); + const instance2 = SharedTokenManager.getInstance(); + + expect(instance1).toBe(instance2); + expect(instance1).toBe(tokenManager); + }); + + it('should create a new instance after reset', () => { + const instance1 = SharedTokenManager.getInstance(); + + // Reset singleton for testing + setPrivateProperty(SharedTokenManager, 'instance', null); + const instance2 = SharedTokenManager.getInstance(); + + expect(instance1).not.toBe(instance2); + }); + }); + + describe('getValidCredentials', () => { + it('should return valid cached credentials without refresh', async () => { + const mockClient = createMockQwenClient(); + const validCredentials = createValidCredentials(); + + // Mock file operations to indicate no file changes + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + + // Manually set cached credentials + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ + credentials: QwenCredentials | null; + fileModTime: number; + lastCheck: number; + }>(tokenManager, 'memoryCache'); + memoryCache.credentials = validCredentials; + memoryCache.fileModTime = 1000; + memoryCache.lastCheck = Date.now(); + + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result).toEqual(validCredentials); + expect(mockClient.refreshAccessToken).not.toHaveBeenCalled(); + }); + + it('should refresh expired credentials', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const refreshResponse = createSuccessfulRefreshResponse(); + + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(refreshResponse); + + // Mock file operations + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result.access_token).toBe(refreshResponse.access_token); + expect(mockClient.refreshAccessToken).toHaveBeenCalled(); + expect(mockClient.setCredentials).toHaveBeenCalled(); + }); + + it('should force refresh when forceRefresh is true', async () => { + const mockClient = createMockQwenClient(createValidCredentials()); + const refreshResponse = createSuccessfulRefreshResponse(); + + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(refreshResponse); + + // Mock file operations + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + const result = await tokenManager.getValidCredentials(mockClient, true); + + expect(result.access_token).toBe(refreshResponse.access_token); + expect(mockClient.refreshAccessToken).toHaveBeenCalled(); + }); + + it('should throw TokenManagerError when refresh token is missing', async () => { + const mockClient = createMockQwenClient({ + access_token: 'expired_token', + refresh_token: undefined, // No refresh token + expiry_date: Date.now() - 3600000, + }); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow(TokenManagerError); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow('No refresh token available'); + }); + + it('should throw TokenManagerError when refresh fails', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const errorResponse = createErrorResponse(); + + mockClient.refreshAccessToken = vi.fn().mockResolvedValue(errorResponse); + + // Mock file operations + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow(TokenManagerError); + }); + + it('should handle network errors during refresh', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const networkError = new Error('Network request failed'); + + mockClient.refreshAccessToken = vi.fn().mockRejectedValue(networkError); + + // Mock file operations + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow(TokenManagerError); + }); + + it('should wait for ongoing refresh and return same result', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const refreshResponse = createSuccessfulRefreshResponse(); + + // Create a delayed refresh response + let resolveRefresh: (value: TokenRefreshData) => void; + const refreshPromise = new Promise((resolve) => { + resolveRefresh = resolve; + }); + + mockClient.refreshAccessToken = vi.fn().mockReturnValue(refreshPromise); + + // Mock file operations + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + // Start two concurrent refresh operations + const promise1 = tokenManager.getValidCredentials(mockClient); + const promise2 = tokenManager.getValidCredentials(mockClient); + + // Resolve the refresh + resolveRefresh!(refreshResponse); + + const [result1, result2] = await Promise.all([promise1, promise2]); + + expect(result1).toEqual(result2); + expect(mockClient.refreshAccessToken).toHaveBeenCalledTimes(1); + }); + + it('should reload credentials from file when file is modified', async () => { + const mockClient = createMockQwenClient(); + const fileCredentials = createValidCredentials({ + access_token: 'file_access_token', + }); + + // Mock file operations to simulate file modification + mockFs.stat.mockResolvedValue({ mtimeMs: 2000 } as Stats); + mockFs.readFile.mockResolvedValue(JSON.stringify(fileCredentials)); + + // Set initial cache state + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ fileModTime: number }>( + tokenManager, + 'memoryCache', + ); + memoryCache.fileModTime = 1000; // Older than file + + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result.access_token).toBe('file_access_token'); + expect(mockFs.readFile).toHaveBeenCalled(); + }); + }); + + describe('Cache Management', () => { + it('should clear cache', () => { + // Set some cache data + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ + credentials: QwenCredentials | null; + }>(tokenManager, 'memoryCache'); + memoryCache.credentials = createValidCredentials(); + + tokenManager.clearCache(); + + expect(tokenManager.getCurrentCredentials()).toBeNull(); + }); + + it('should return current credentials from cache', () => { + const credentials = createValidCredentials(); + + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ + credentials: QwenCredentials | null; + }>(tokenManager, 'memoryCache'); + memoryCache.credentials = credentials; + + expect(tokenManager.getCurrentCredentials()).toEqual(credentials); + }); + + it('should return null when no credentials are cached', () => { + tokenManager.clearCache(); + + expect(tokenManager.getCurrentCredentials()).toBeNull(); + }); + }); + + describe('Refresh Status', () => { + it('should return false when no refresh is in progress', () => { + expect(tokenManager.isRefreshInProgress()).toBe(false); + }); + + it('should return true when refresh is in progress', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + + // Clear cache to ensure refresh is triggered + tokenManager.clearCache(); + + // Mock stat for file check to fail (no file initially) + mockFs.stat.mockRejectedValueOnce( + Object.assign(new Error('ENOENT'), { code: 'ENOENT' }), + ); + + // Create a delayed refresh response + let resolveRefresh: (value: TokenRefreshData) => void; + const refreshPromise = new Promise((resolve) => { + resolveRefresh = resolve; + }); + + mockClient.refreshAccessToken = vi.fn().mockReturnValue(refreshPromise); + + // Mock file operations for lock and save + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + + // Start refresh + const refreshOperation = tokenManager.getValidCredentials(mockClient); + + // Wait a tick to ensure the refresh promise is set + await new Promise((resolve) => setImmediate(resolve)); + + expect(tokenManager.isRefreshInProgress()).toBe(true); + + // Complete refresh + resolveRefresh!(createSuccessfulRefreshResponse()); + await refreshOperation; + + expect(tokenManager.isRefreshInProgress()).toBe(false); + }); + }); + + describe('Debug Info', () => { + it('should return complete debug information', () => { + const credentials = createValidCredentials(); + + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ + credentials: QwenCredentials | null; + }>(tokenManager, 'memoryCache'); + memoryCache.credentials = credentials; + + const debugInfo = tokenManager.getDebugInfo(); + + expect(debugInfo).toHaveProperty('hasCredentials', true); + expect(debugInfo).toHaveProperty('credentialsExpired', false); + expect(debugInfo).toHaveProperty('isRefreshing', false); + expect(debugInfo).toHaveProperty('cacheAge'); + expect(typeof debugInfo.cacheAge).toBe('number'); + }); + + it('should indicate expired credentials in debug info', () => { + const expiredCredentials = createExpiredCredentials(); + + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ + credentials: QwenCredentials | null; + }>(tokenManager, 'memoryCache'); + memoryCache.credentials = expiredCredentials; + + const debugInfo = tokenManager.getDebugInfo(); + + expect(debugInfo.hasCredentials).toBe(true); + expect(debugInfo.credentialsExpired).toBe(true); + }); + + it('should indicate no credentials in debug info', () => { + tokenManager.clearCache(); + + const debugInfo = tokenManager.getDebugInfo(); + + expect(debugInfo.hasCredentials).toBe(false); + expect(debugInfo.credentialsExpired).toBe(false); + }); + }); + + describe('Error Handling', () => { + it('should create TokenManagerError with correct type and message', () => { + const error = new TokenManagerError( + TokenError.REFRESH_FAILED, + 'Token refresh failed', + new Error('Original error'), + ); + + expect(error).toBeInstanceOf(Error); + expect(error).toBeInstanceOf(TokenManagerError); + expect(error.type).toBe(TokenError.REFRESH_FAILED); + expect(error.message).toBe('Token refresh failed'); + expect(error.name).toBe('TokenManagerError'); + expect(error.originalError).toBeInstanceOf(Error); + }); + + it('should handle file access errors gracefully', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + + // Mock file stat to throw access error + const accessError = new Error( + 'Permission denied', + ) as NodeJS.ErrnoException; + accessError.code = 'EACCES'; + mockFs.stat.mockRejectedValue(accessError); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow(TokenManagerError); + }); + + it('should handle missing file gracefully', async () => { + const mockClient = createMockQwenClient(); + const validCredentials = createValidCredentials(); + + // Mock file stat to throw file not found error + const notFoundError = new Error( + 'File not found', + ) as NodeJS.ErrnoException; + notFoundError.code = 'ENOENT'; + mockFs.stat.mockRejectedValue(notFoundError); + + // Set valid credentials in cache + const memoryCache = getPrivateProperty<{ + credentials: QwenCredentials | null; + }>(tokenManager, 'memoryCache'); + memoryCache.credentials = validCredentials; + + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result).toEqual(validCredentials); + }); + + it('should handle lock timeout scenarios', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + + // Configure shorter timeouts for testing + tokenManager.setLockConfig({ + maxAttempts: 3, + attemptInterval: 50, + }); + + // Mock stat for file check to pass (no file initially) + mockFs.stat.mockRejectedValueOnce( + Object.assign(new Error('ENOENT'), { code: 'ENOENT' }), + ); + + // Mock writeFile to always throw EEXIST for lock file writes (flag: 'wx') + // but succeed for regular file writes + const lockError = new Error('File exists') as NodeJS.ErrnoException; + lockError.code = 'EEXIST'; + + mockFs.writeFile.mockImplementation((path, data, options) => { + if (typeof options === 'object' && options?.flag === 'wx') { + return Promise.reject(lockError); + } + return Promise.resolve(undefined); + }); + + // Mock stat to return recent lock file (not stale) when checking lock age + mockFs.stat.mockResolvedValue({ mtimeMs: Date.now() } as Stats); + + // Mock unlink to simulate lock file removal attempts + mockFs.unlink.mockResolvedValue(undefined); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow(TokenManagerError); + }, 500); // 500ms timeout for lock test (3 attempts × 50ms = ~150ms + buffer) + + it('should handle refresh response without access token', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const invalidResponse = { + token_type: 'Bearer', + expires_in: 3600, + // access_token is missing, so we use undefined explicitly + access_token: undefined, + } as Partial; + + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(invalidResponse); + + // Mock stat for file check to pass (no file initially) + mockFs.stat.mockRejectedValueOnce( + Object.assign(new Error('ENOENT'), { code: 'ENOENT' }), + ); + + // Mock file operations for lock acquisition + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + // Clear cache to force refresh + tokenManager.clearCache(); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow(TokenManagerError); + + await expect( + tokenManager.getValidCredentials(mockClient), + ).rejects.toThrow('no token returned'); + }); + }); + + describe('File System Operations', () => { + it('should handle file reload failures gracefully', async () => { + const mockClient = createMockQwenClient(); + + // Mock successful refresh for when cache is cleared + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(createSuccessfulRefreshResponse()); + + // Mock file operations + mockFs.stat + .mockResolvedValueOnce({ mtimeMs: 2000 } as Stats) // For checkAndReloadIfNeeded + .mockResolvedValue({ mtimeMs: 1000 } as Stats); // For later operations + mockFs.readFile.mockRejectedValue(new Error('Read failed')); + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + // Set initial cache state to trigger reload + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ fileModTime: number }>( + tokenManager, + 'memoryCache', + ); + memoryCache.fileModTime = 1000; + + // Should not throw error, should refresh and get new credentials + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result).toBeDefined(); + expect(result.access_token).toBe('fresh_access_token'); + }); + + it('should handle invalid JSON in credentials file', async () => { + const mockClient = createMockQwenClient(); + + // Mock successful refresh for when cache is cleared + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(createSuccessfulRefreshResponse()); + + // Mock file operations with invalid JSON + mockFs.stat + .mockResolvedValueOnce({ mtimeMs: 2000 } as Stats) // For checkAndReloadIfNeeded + .mockResolvedValue({ mtimeMs: 1000 } as Stats); // For later operations + mockFs.readFile.mockResolvedValue('invalid json content'); + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + // Set initial cache state to trigger reload + tokenManager.clearCache(); + const memoryCache = getPrivateProperty<{ fileModTime: number }>( + tokenManager, + 'memoryCache', + ); + memoryCache.fileModTime = 1000; + + // Should handle JSON parse error gracefully, then refresh and get new credentials + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result).toBeDefined(); + expect(result.access_token).toBe('fresh_access_token'); + }); + + it('should handle directory creation during save', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const refreshResponse = createSuccessfulRefreshResponse(); + + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(refreshResponse); + + // Mock file operations + mockFs.stat.mockResolvedValue({ mtimeMs: 1000 } as Stats); + mockFs.writeFile.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + await tokenManager.getValidCredentials(mockClient); + + expect(mockFs.mkdir).toHaveBeenCalledWith(expect.any(String), { + recursive: true, + mode: 0o700, + }); + expect(mockFs.writeFile).toHaveBeenCalled(); + }); + }); + + describe('Lock File Management', () => { + it('should clean up lock file during process cleanup', () => { + // Create a new instance to trigger cleanup handler registration + SharedTokenManager.getInstance(); + + // Access the private cleanup method for testing + const cleanupHandlers = process.listeners('exit'); + const cleanup = cleanupHandlers[cleanupHandlers.length - 1] as () => void; + + // Should not throw when lock file doesn't exist + expect(() => cleanup()).not.toThrow(); + expect(mockUnlinkSync).toHaveBeenCalled(); + }); + + it('should handle stale lock cleanup', async () => { + const mockClient = createMockQwenClient(createExpiredCredentials()); + const refreshResponse = createSuccessfulRefreshResponse(); + + mockClient.refreshAccessToken = vi + .fn() + .mockResolvedValue(refreshResponse); + + // First writeFile call throws EEXIST (lock exists) + // Second writeFile call succeeds (after stale lock cleanup) + const lockError = new Error('File exists') as NodeJS.ErrnoException; + lockError.code = 'EEXIST'; + mockFs.writeFile + .mockRejectedValueOnce(lockError) + .mockResolvedValue(undefined); + + // Mock stat to return stale lock (old timestamp) + mockFs.stat + .mockResolvedValueOnce({ mtimeMs: Date.now() - 20000 } as Stats) // Stale lock + .mockResolvedValueOnce({ mtimeMs: 1000 } as Stats); // Credentials file + + // Mock unlink to succeed + mockFs.unlink.mockResolvedValue(undefined); + mockFs.mkdir.mockResolvedValue(undefined); + + const result = await tokenManager.getValidCredentials(mockClient); + + expect(result.access_token).toBe(refreshResponse.access_token); + expect(mockFs.unlink).toHaveBeenCalled(); // Stale lock removed + }); + }); +}); diff --git a/packages/core/src/qwen/sharedTokenManager.ts b/packages/core/src/qwen/sharedTokenManager.ts new file mode 100644 index 000000000..3c950cd62 --- /dev/null +++ b/packages/core/src/qwen/sharedTokenManager.ts @@ -0,0 +1,662 @@ +/** + * @license + * Copyright 2025 Qwen + * SPDX-License-Identifier: Apache-2.0 + */ + +import path from 'node:path'; +import { promises as fs, unlinkSync } from 'node:fs'; +import * as os from 'os'; +import { randomUUID } from 'node:crypto'; + +import { + IQwenOAuth2Client, + type QwenCredentials, + type TokenRefreshData, + type ErrorData, + isErrorResponse, +} from './qwenOAuth2.js'; + +// File System Configuration +const QWEN_DIR = '.qwen'; +const QWEN_CREDENTIAL_FILENAME = 'oauth_creds.json'; +const QWEN_LOCK_FILENAME = 'oauth_creds.lock'; + +// Token and Cache Configuration +const TOKEN_REFRESH_BUFFER_MS = 30 * 1000; // 30 seconds +const LOCK_TIMEOUT_MS = 10000; // 10 seconds lock timeout +const CACHE_CHECK_INTERVAL_MS = 1000; // 1 second cache check interval + +// Lock acquisition configuration (can be overridden for testing) +interface LockConfig { + maxAttempts: number; + attemptInterval: number; +} + +const DEFAULT_LOCK_CONFIG: LockConfig = { + maxAttempts: 50, + attemptInterval: 200, +}; + +/** + * Token manager error types for better error classification + */ +export enum TokenError { + REFRESH_FAILED = 'REFRESH_FAILED', + NO_REFRESH_TOKEN = 'NO_REFRESH_TOKEN', + LOCK_TIMEOUT = 'LOCK_TIMEOUT', + FILE_ACCESS_ERROR = 'FILE_ACCESS_ERROR', + NETWORK_ERROR = 'NETWORK_ERROR', +} + +/** + * Custom error class for token manager operations + */ +export class TokenManagerError extends Error { + constructor( + public type: TokenError, + message: string, + public originalError?: unknown, + ) { + super(message); + this.name = 'TokenManagerError'; + } +} + +/** + * Interface for the memory cache state + */ +interface MemoryCache { + credentials: QwenCredentials | null; + fileModTime: number; + lastCheck: number; +} + +/** + * Validates that the given data is a valid QwenCredentials object + * + * @param data - The data to validate + * @returns The validated credentials object + * @throws Error if the data is invalid + */ +function validateCredentials(data: unknown): QwenCredentials { + if (!data || typeof data !== 'object') { + throw new Error('Invalid credentials format'); + } + + const creds = data as Partial; + const requiredFields = [ + 'access_token', + 'refresh_token', + 'token_type', + ] as const; + + // Check required string fields + for (const field of requiredFields) { + if (!creds[field] || typeof creds[field] !== 'string') { + throw new Error(`Invalid credentials: missing ${field}`); + } + } + + // Check expiry_date + if (!creds.expiry_date || typeof creds.expiry_date !== 'number') { + throw new Error('Invalid credentials: missing expiry_date'); + } + + return creds as QwenCredentials; +} + +/** + * Manages OAuth tokens across multiple processes using file-based caching and locking + */ +export class SharedTokenManager { + private static instance: SharedTokenManager | null = null; + + /** + * In-memory cache for credentials and file state tracking + */ + private memoryCache: MemoryCache = { + credentials: null, + fileModTime: 0, + lastCheck: 0, + }; + + /** + * Promise tracking any ongoing token refresh operation + */ + private refreshPromise: Promise | null = null; + + /** + * Whether cleanup handlers have been registered + */ + private cleanupHandlersRegistered = false; + + /** + * Reference to cleanup functions for proper removal + */ + private cleanupFunction: (() => void) | null = null; + + /** + * Lock configuration for testing purposes + */ + private lockConfig: LockConfig = DEFAULT_LOCK_CONFIG; + + /** + * Private constructor for singleton pattern + */ + private constructor() { + this.registerCleanupHandlers(); + } + + /** + * Get the singleton instance + * @returns The shared token manager instance + */ + static getInstance(): SharedTokenManager { + if (!SharedTokenManager.instance) { + SharedTokenManager.instance = new SharedTokenManager(); + } + return SharedTokenManager.instance; + } + + /** + * Set up handlers to clean up lock files when the process exits + */ + private registerCleanupHandlers(): void { + if (this.cleanupHandlersRegistered) { + return; + } + + this.cleanupFunction = () => { + try { + const lockPath = this.getLockFilePath(); + // Use synchronous unlink for process exit handlers + unlinkSync(lockPath); + } catch (_error) { + // Ignore cleanup errors - lock file might not exist or already be cleaned up + } + }; + + process.on('exit', this.cleanupFunction); + process.on('SIGINT', this.cleanupFunction); + process.on('SIGTERM', this.cleanupFunction); + process.on('uncaughtException', this.cleanupFunction); + process.on('unhandledRejection', this.cleanupFunction); + + this.cleanupHandlersRegistered = true; + } + + /** + * Get valid OAuth credentials, refreshing them if necessary + * + * @param qwenClient - The OAuth2 client instance + * @param forceRefresh - If true, refresh token even if current one is still valid + * @returns Promise resolving to valid credentials + * @throws TokenManagerError if unable to obtain valid credentials + */ + async getValidCredentials( + qwenClient: IQwenOAuth2Client, + forceRefresh = false, + ): Promise { + try { + // Check if credentials file has been updated by other sessions + await this.checkAndReloadIfNeeded(); + + // Return valid cached credentials if available (unless force refresh is requested) + if ( + !forceRefresh && + this.memoryCache.credentials && + this.isTokenValid(this.memoryCache.credentials) + ) { + return this.memoryCache.credentials; + } + + // If refresh is already in progress, wait for it to complete + if (this.refreshPromise) { + return this.refreshPromise; + } + + // Start new refresh operation with distributed locking + this.refreshPromise = this.performTokenRefresh(qwenClient, forceRefresh); + + try { + const credentials = await this.refreshPromise; + return credentials; + } catch (error) { + // Ensure refreshPromise is cleared on error before re-throwing + this.refreshPromise = null; + throw error; + } finally { + this.refreshPromise = null; + } + } catch (error) { + // Convert generic errors to TokenManagerError for better error handling + if (error instanceof TokenManagerError) { + throw error; + } + + throw new TokenManagerError( + TokenError.REFRESH_FAILED, + `Failed to get valid credentials: ${error instanceof Error ? error.message : String(error)}`, + error, + ); + } + } + + /** + * Check if the credentials file was updated by another process and reload if so + */ + private async checkAndReloadIfNeeded(): Promise { + const now = Date.now(); + + // Limit check frequency to avoid excessive disk I/O + if (now - this.memoryCache.lastCheck < CACHE_CHECK_INTERVAL_MS) { + return; + } + + this.memoryCache.lastCheck = now; + + try { + const filePath = this.getCredentialFilePath(); + const stats = await fs.stat(filePath); + const fileModTime = stats.mtimeMs; + + // Reload credentials if file has been modified since last cache + if (fileModTime > this.memoryCache.fileModTime) { + await this.reloadCredentialsFromFile(); + this.memoryCache.fileModTime = fileModTime; + } + } catch (error) { + // Handle file access errors + if ( + error instanceof Error && + 'code' in error && + error.code !== 'ENOENT' + ) { + // Clear cache for non-missing file errors + this.memoryCache.credentials = null; + this.memoryCache.fileModTime = 0; + + throw new TokenManagerError( + TokenError.FILE_ACCESS_ERROR, + `Failed to access credentials file: ${error.message}`, + error, + ); + } + + // For missing files (ENOENT), just reset file modification time + // but keep existing valid credentials in memory if they exist + this.memoryCache.fileModTime = 0; + } + } + + /** + * Load credentials from the file system into memory cache + */ + private async reloadCredentialsFromFile(): Promise { + try { + const filePath = this.getCredentialFilePath(); + const content = await fs.readFile(filePath, 'utf-8'); + const parsedData = JSON.parse(content); + const credentials = validateCredentials(parsedData); + this.memoryCache.credentials = credentials; + } catch (error) { + // Log validation errors for debugging but don't throw + if ( + error instanceof Error && + error.message.includes('Invalid credentials') + ) { + console.warn(`Failed to validate credentials file: ${error.message}`); + } + this.memoryCache.credentials = null; + } + } + + /** + * Refresh the OAuth token using file locking to prevent concurrent refreshes + * + * @param qwenClient - The OAuth2 client instance + * @param forceRefresh - If true, skip checking if token is already valid after getting lock + * @returns Promise resolving to refreshed credentials + * @throws TokenManagerError if refresh fails or lock cannot be acquired + */ + private async performTokenRefresh( + qwenClient: IQwenOAuth2Client, + forceRefresh = false, + ): Promise { + const lockPath = this.getLockFilePath(); + + try { + // Check if we have a refresh token before attempting refresh + const currentCredentials = qwenClient.getCredentials(); + if (!currentCredentials.refresh_token) { + throw new TokenManagerError( + TokenError.NO_REFRESH_TOKEN, + 'No refresh token available for token refresh', + ); + } + + // Acquire distributed file lock + await this.acquireLock(lockPath); + + // Double-check if another process already refreshed the token (unless force refresh is requested) + await this.checkAndReloadIfNeeded(); + + // Use refreshed credentials if they're now valid (unless force refresh is requested) + if ( + !forceRefresh && + this.memoryCache.credentials && + this.isTokenValid(this.memoryCache.credentials) + ) { + qwenClient.setCredentials(this.memoryCache.credentials); + return this.memoryCache.credentials; + } + + // Perform the actual token refresh + const response = await qwenClient.refreshAccessToken(); + + if (!response || isErrorResponse(response)) { + const errorData = response as ErrorData; + throw new TokenManagerError( + TokenError.REFRESH_FAILED, + `Token refresh failed: ${errorData?.error || 'Unknown error'} - ${errorData?.error_description || 'No details provided'}`, + ); + } + + const tokenData = response as TokenRefreshData; + + if (!tokenData.access_token) { + throw new TokenManagerError( + TokenError.REFRESH_FAILED, + 'Failed to refresh access token: no token returned', + ); + } + + // Create updated credentials object + const credentials: QwenCredentials = { + access_token: tokenData.access_token, + token_type: tokenData.token_type, + refresh_token: + tokenData.refresh_token || currentCredentials.refresh_token, + resource_url: tokenData.resource_url, + expiry_date: Date.now() + tokenData.expires_in * 1000, + }; + + // Update memory cache and client credentials + this.memoryCache.credentials = credentials; + qwenClient.setCredentials(credentials); + + // Persist to file and update modification time + await this.saveCredentialsToFile(credentials); + + return credentials; + } catch (error) { + if (error instanceof TokenManagerError) { + throw error; + } + + // Handle network-related errors + if ( + error instanceof Error && + (error.message.includes('fetch') || + error.message.includes('network') || + error.message.includes('timeout')) + ) { + throw new TokenManagerError( + TokenError.NETWORK_ERROR, + `Network error during token refresh: ${error.message}`, + error, + ); + } + + throw new TokenManagerError( + TokenError.REFRESH_FAILED, + `Unexpected error during token refresh: ${error instanceof Error ? error.message : String(error)}`, + error, + ); + } finally { + // Always release the file lock + await this.releaseLock(lockPath); + } + } + + /** + * Save credentials to file and update the cached file modification time + * + * @param credentials - The credentials to save + */ + private async saveCredentialsToFile( + credentials: QwenCredentials, + ): Promise { + const filePath = this.getCredentialFilePath(); + const dirPath = path.dirname(filePath); + + // Create directory with restricted permissions + try { + await fs.mkdir(dirPath, { recursive: true, mode: 0o700 }); + } catch (error) { + throw new TokenManagerError( + TokenError.FILE_ACCESS_ERROR, + `Failed to create credentials directory: ${error instanceof Error ? error.message : String(error)}`, + error, + ); + } + + const credString = JSON.stringify(credentials, null, 2); + + try { + // Write file with restricted permissions (owner read/write only) + await fs.writeFile(filePath, credString, { mode: 0o600 }); + } catch (error) { + throw new TokenManagerError( + TokenError.FILE_ACCESS_ERROR, + `Failed to write credentials file: ${error instanceof Error ? error.message : String(error)}`, + error, + ); + } + + // Update cached file modification time to avoid unnecessary reloads + try { + const stats = await fs.stat(filePath); + this.memoryCache.fileModTime = stats.mtimeMs; + } catch (error) { + // Non-fatal error, just log it + console.warn( + `Failed to update file modification time: ${error instanceof Error ? error.message : String(error)}`, + ); + } + } + + /** + * Check if the token is valid and not expired + * + * @param credentials - The credentials to validate + * @returns true if token is valid and not expired, false otherwise + */ + private isTokenValid(credentials: QwenCredentials): boolean { + if (!credentials.expiry_date || !credentials.access_token) { + return false; + } + return Date.now() < credentials.expiry_date - TOKEN_REFRESH_BUFFER_MS; + } + + /** + * Get the full path to the credentials file + * + * @returns The absolute path to the credentials file + */ + private getCredentialFilePath(): string { + return path.join(os.homedir(), QWEN_DIR, QWEN_CREDENTIAL_FILENAME); + } + + /** + * Get the full path to the lock file + * + * @returns The absolute path to the lock file + */ + private getLockFilePath(): string { + return path.join(os.homedir(), QWEN_DIR, QWEN_LOCK_FILENAME); + } + + /** + * Acquire a file lock to prevent other processes from refreshing tokens simultaneously + * + * @param lockPath - Path to the lock file + * @throws TokenManagerError if lock cannot be acquired within timeout period + */ + private async acquireLock(lockPath: string): Promise { + const { maxAttempts, attemptInterval } = this.lockConfig; + const lockId = randomUUID(); // Use random UUID instead of PID for security + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + try { + // Attempt to create lock file atomically (exclusive mode) + await fs.writeFile(lockPath, lockId, { flag: 'wx' }); + return; // Successfully acquired lock + } catch (error: unknown) { + if ((error as NodeJS.ErrnoException).code === 'EEXIST') { + // Lock file already exists, check if it's stale + try { + const stats = await fs.stat(lockPath); + const lockAge = Date.now() - stats.mtimeMs; + + // Remove stale locks that exceed timeout + if (lockAge > LOCK_TIMEOUT_MS) { + try { + await fs.unlink(lockPath); + console.warn( + `Removed stale lock file: ${lockPath} (age: ${lockAge}ms)`, + ); + continue; // Retry lock acquisition + } catch (unlinkError) { + // Log the error but continue trying - another process might have removed it + console.warn( + `Failed to remove stale lock file ${lockPath}: ${unlinkError instanceof Error ? unlinkError.message : String(unlinkError)}`, + ); + // Still continue - the lock might have been removed by another process + } + } + } catch (statError) { + // Can't stat lock file, it might have been removed, continue trying + console.warn( + `Failed to stat lock file ${lockPath}: ${statError instanceof Error ? statError.message : String(statError)}`, + ); + } + + // Wait before retrying + await new Promise((resolve) => setTimeout(resolve, attemptInterval)); + } else { + throw new TokenManagerError( + TokenError.FILE_ACCESS_ERROR, + `Failed to create lock file: ${error instanceof Error ? error.message : String(error)}`, + error, + ); + } + } + } + + throw new TokenManagerError( + TokenError.LOCK_TIMEOUT, + 'Failed to acquire file lock for token refresh: timeout exceeded', + ); + } + + /** + * Release the file lock + * + * @param lockPath - Path to the lock file + */ + private async releaseLock(lockPath: string): Promise { + try { + await fs.unlink(lockPath); + } catch (error) { + // Lock file might already be removed by another process or timeout cleanup + // This is not an error condition, but log for debugging + if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { + console.warn( + `Failed to release lock file ${lockPath}: ${error instanceof Error ? error.message : String(error)}`, + ); + } + } + } + + /** + * Clear all cached data and reset the manager to initial state + */ + clearCache(): void { + this.memoryCache = { + credentials: null, + fileModTime: 0, + lastCheck: 0, + }; + this.refreshPromise = null; + } + + /** + * Get the current cached credentials (may be expired) + * + * @returns The currently cached credentials or null + */ + getCurrentCredentials(): QwenCredentials | null { + return this.memoryCache.credentials; + } + + /** + * Check if there's an ongoing refresh operation + * + * @returns true if refresh is in progress, false otherwise + */ + isRefreshInProgress(): boolean { + return this.refreshPromise !== null; + } + + /** + * Set lock configuration for testing purposes + * @param config - Lock configuration + */ + setLockConfig(config: Partial): void { + this.lockConfig = { ...DEFAULT_LOCK_CONFIG, ...config }; + } + + /** + * Clean up event listeners (primarily for testing) + */ + cleanup(): void { + if (this.cleanupFunction && this.cleanupHandlersRegistered) { + this.cleanupFunction(); + + process.removeListener('exit', this.cleanupFunction); + process.removeListener('SIGINT', this.cleanupFunction); + process.removeListener('SIGTERM', this.cleanupFunction); + process.removeListener('uncaughtException', this.cleanupFunction); + process.removeListener('unhandledRejection', this.cleanupFunction); + + this.cleanupHandlersRegistered = false; + this.cleanupFunction = null; + } + } + + /** + * Get a summary of the current state for debugging + * + * @returns Object containing current state information + */ + getDebugInfo(): { + hasCredentials: boolean; + credentialsExpired: boolean; + isRefreshing: boolean; + cacheAge: number; + } { + const hasCredentials = !!this.memoryCache.credentials; + const credentialsExpired = hasCredentials + ? !this.isTokenValid(this.memoryCache.credentials!) + : false; + + return { + hasCredentials, + credentialsExpired, + isRefreshing: this.isRefreshInProgress(), + cacheAge: Date.now() - this.memoryCache.lastCheck, + }; + } +} From 347e606366b818e1fd574f07974535d067c347ce Mon Sep 17 00:00:00 2001 From: Mingholy Date: Wed, 27 Aug 2025 15:23:21 +0800 Subject: [PATCH 13/16] fix: ambiguous literals (#461) --- docs/cli/configuration.md | 2 +- docs/integration-tests.md | 2 +- packages/cli/src/config/config.ts | 2 +- packages/cli/src/ui/IdeIntegrationNudge.tsx | 2 +- packages/cli/src/ui/commands/ideCommand.ts | 2 +- packages/cli/src/ui/commands/mcpCommand.test.ts | 2 +- packages/cli/src/ui/commands/mcpCommand.ts | 3 ++- scripts/build_sandbox.js | 2 +- scripts/create_alias.sh | 8 ++++---- 9 files changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index 5e8c90029..78147c3f2 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -438,7 +438,7 @@ Arguments passed directly when running the CLI can override other configurations - `auto_edit`: Automatically approve edit tools (replace, write_file) while prompting for others - `yolo`: Automatically approve all tool calls (equivalent to `--yolo`) - Cannot be used together with `--yolo`. Use `--approval-mode=yolo` instead of `--yolo` for the new unified approach. - - Example: `gemini --approval-mode auto_edit` + - Example: `qwen --approval-mode auto_edit` - **`--telemetry`**: - Enables [telemetry](../telemetry.md). - **`--telemetry-target`**: diff --git a/docs/integration-tests.md b/docs/integration-tests.md index 90e65a9a2..00c91fe16 100644 --- a/docs/integration-tests.md +++ b/docs/integration-tests.md @@ -89,7 +89,7 @@ The verbose output is formatted to clearly identify the source of the logs: ``` --- TEST: : --- -... output from the gemini command ... +... output from the qwen command ... --- END TEST: : --- ``` diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index d929747ed..4431d7f83 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -218,7 +218,7 @@ export async function parseArguments(): Promise { .option('proxy', { type: 'string', description: - 'Proxy for gemini client, like schema://user:password@host:port', + 'Proxy for qwen client, like schema://user:password@host:port', }) .option('include-directories', { type: 'array', diff --git a/packages/cli/src/ui/IdeIntegrationNudge.tsx b/packages/cli/src/ui/IdeIntegrationNudge.tsx index 7f42c6004..dd58957ab 100644 --- a/packages/cli/src/ui/IdeIntegrationNudge.tsx +++ b/packages/cli/src/ui/IdeIntegrationNudge.tsx @@ -88,7 +88,7 @@ export function IdeIntegrationNudge({ {'> '} - {`Do you want to connect ${ideName ?? 'your'} editor to Gemini CLI?`} + {`Do you want to connect ${ideName ?? 'your'} editor to Qwen Code?`} {installText} diff --git a/packages/cli/src/ui/commands/ideCommand.ts b/packages/cli/src/ui/commands/ideCommand.ts index e18ab12dd..5172f1483 100644 --- a/packages/cli/src/ui/commands/ideCommand.ts +++ b/packages/cli/src/ui/commands/ideCommand.ts @@ -130,7 +130,7 @@ export const ideCommand = (config: Config | null): SlashCommand | null => { ({ type: 'message', messageType: 'error', - content: `IDE integration is not supported in your current environment. To use this feature, run Gemini CLI in one of these supported IDEs: ${Object.values( + content: `IDE integration is not supported in your current environment. To use this feature, run Qwen Code in one of these supported IDEs: ${Object.values( DetectedIde, ) .map((ide) => getIdeInfo(ide).displayName) diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 42ee36123..7c2dc48fc 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -146,7 +146,7 @@ describe('mcpCommand', () => { type: 'message', messageType: 'info', content: - 'No MCP servers configured. Please view MCP documentation in your browser: https://goo.gle/gemini-cli-docs-mcp or use the cli /docs command', + 'No MCP servers configured. Please view MCP documentation in your browser: https://qwenlm.github.io/qwen-code-docs/en/tools/mcp-server/#how-to-set-up-your-mcp-server or use the cli /docs command', }); }); }); diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index db537b7b5..3c3fb4deb 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -58,7 +58,8 @@ const getMcpStatus = async ( const blockedMcpServers = config.getBlockedMcpServers() || []; if (serverNames.length === 0 && blockedMcpServers.length === 0) { - const docsUrl = 'https://goo.gle/gemini-cli-docs-mcp'; + const docsUrl = + 'https://qwenlm.github.io/qwen-code-docs/en/tools/mcp-server/#how-to-set-up-your-mcp-server'; return { type: 'message', messageType: 'info', diff --git a/scripts/build_sandbox.js b/scripts/build_sandbox.js index 5cb4cd612..eb24de26b 100644 --- a/scripts/build_sandbox.js +++ b/scripts/build_sandbox.js @@ -136,7 +136,7 @@ function buildImage(imageName, dockerfile) { if (isWindows) { // PowerShell doesn't support <() process substitution. // Create a temporary auth file that we will clean up after. - tempAuthFile = join(os.tmpdir(), `gemini-auth-${Date.now()}.json`); + tempAuthFile = join(os.tmpdir(), `qwen-auth-${Date.now()}.json`); writeFileSync(tempAuthFile, '{}'); buildCommandArgs = `--authfile="${tempAuthFile}"`; } else { diff --git a/scripts/create_alias.sh b/scripts/create_alias.sh index ecb01bb32..0a6b8363a 100755 --- a/scripts/create_alias.sh +++ b/scripts/create_alias.sh @@ -5,7 +5,7 @@ set -euo pipefail # Determine the project directory PROJECT_DIR=$(cd "$(dirname "$0")/.." && pwd) -ALIAS_COMMAND="alias gemini='node "${PROJECT_DIR}/scripts/start.js"'" +ALIAS_COMMAND="alias qwen='node "${PROJECT_DIR}/scripts/start.js"'" # Detect shell and set config file path if [[ "${SHELL}" == *"/bash" ]]; then @@ -22,8 +22,8 @@ echo " ${ALIAS_COMMAND}" echo "" # Check if the alias already exists -if grep -q "alias gemini=" "${CONFIG_FILE}"; then - echo "A 'gemini' alias already exists in ${CONFIG_FILE}. No changes were made." +if grep -q "alias qwen=" "${CONFIG_FILE}"; then + echo "A 'qwen' alias already exists in ${CONFIG_FILE}. No changes were made." exit 0 fi @@ -33,7 +33,7 @@ if [[ "${REPLY}" =~ ^[Yy]$ ]]; then echo "${ALIAS_COMMAND}" >> "${CONFIG_FILE}" echo "" echo "Alias added to ${CONFIG_FILE}." - echo "Please run 'source ${CONFIG_FILE}' or open a new terminal to use the 'gemini' command." + echo "Please run 'source ${CONFIG_FILE}' or open a new terminal to use the 'qwen' command." else echo "Aborted. No changes were made." fi From 600c58bbcb6b857a980ef574fbbdd26a576525ed Mon Sep 17 00:00:00 2001 From: tanzhenxin Date: Wed, 27 Aug 2025 17:32:57 +0800 Subject: [PATCH 14/16] =?UTF-8?q?=F0=9F=94=A7=20Miscellaneous=20Improvemen?= =?UTF-8?q?ts=20and=20Refactoring=20(#466)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gemini-automated-issue-triage.yml | 5 +- CONTRIBUTING.md | 2 +- ROADMAP.md => ROADMAP.gemini.md | 0 packages/cli/src/config/config.ts | 1 + .../cli/src/ui/commands/directoryCommand.tsx | 55 ++- .../components/ContextSummaryDisplay.test.tsx | 6 +- packages/core/src/config/config.ts | 10 + .../src/core/openaiContentGenerator.test.ts | 6 +- .../core/src/core/openaiContentGenerator.ts | 38 +- .../telemetry/integration.test.circular.ts | 45 +- packages/core/src/telemetry/loggers.ts | 6 +- .../src/telemetry/qwen-logger/event-types.ts | 1 + .../telemetry/qwen-logger/qwen-logger.test.ts | 407 ++++++++++++++++++ .../src/telemetry/qwen-logger/qwen-logger.ts | 223 ++++++++-- packages/core/src/utils/bfsFileSearch.test.ts | 6 +- packages/core/src/utils/memoryDiscovery.ts | 25 +- 16 files changed, 755 insertions(+), 81 deletions(-) rename ROADMAP.md => ROADMAP.gemini.md (100%) create mode 100644 packages/core/src/telemetry/qwen-logger/qwen-logger.test.ts diff --git a/.github/workflows/gemini-automated-issue-triage.yml b/.github/workflows/gemini-automated-issue-triage.yml index 96d71b7b1..3471a47b1 100644 --- a/.github/workflows/gemini-automated-issue-triage.yml +++ b/.github/workflows/gemini-automated-issue-triage.yml @@ -52,10 +52,7 @@ jobs: { "maxSessionTurns": 25, "coreTools": [ - "run_shell_command(echo)", - "run_shell_command(gh label list)", - "run_shell_command(gh issue edit)", - "run_shell_command(gh issue list)" + "run_shell_command" ], "sandbox": false } diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6c934f235..40c91a8f7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -209,7 +209,7 @@ npm run lint ### Coding Conventions - Please adhere to the coding style, patterns, and conventions used throughout the existing codebase. -- Consult [GEMINI.md](https://github.com/google-gemini/gemini-cli/blob/main/GEMINI.md) (typically found in the project root) for specific instructions related to AI-assisted development, including conventions for React, comments, and Git usage. +- Consult [QWEN.md](https://github.com/QwenLM/qwen-code/blob/main/QWEN.md) (typically found in the project root) for specific instructions related to AI-assisted development, including conventions for React, comments, and Git usage. - **Imports:** Pay special attention to import paths. The project uses ESLint to enforce restrictions on relative imports between packages. ### Project Structure diff --git a/ROADMAP.md b/ROADMAP.gemini.md similarity index 100% rename from ROADMAP.md rename to ROADMAP.gemini.md diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 4431d7f83..a16ceb0d9 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -577,6 +577,7 @@ export async function loadCliConfig( 'SYSTEM_TEMPLATE:{"name":"qwen3_coder","params":{"is_git_repository":{RUNTIME_VARS_IS_GIT_REPO},"sandbox":"{RUNTIME_VARS_SANDBOX}"}}', }, ]) as ConfigParameters['systemPromptMappings'], + authType: settings.selectedAuthType, contentGenerator: settings.contentGenerator, cliVersion, tavilyApiKey: diff --git a/packages/cli/src/ui/commands/directoryCommand.tsx b/packages/cli/src/ui/commands/directoryCommand.tsx index 7aa9e289d..5de976b54 100644 --- a/packages/cli/src/ui/commands/directoryCommand.tsx +++ b/packages/cli/src/ui/commands/directoryCommand.tsx @@ -91,35 +91,34 @@ export const directoryCommand: SlashCommand = { } } - try { - if (config.shouldLoadMemoryFromIncludeDirectories()) { - const { memoryContent, fileCount } = - await loadServerHierarchicalMemory( - config.getWorkingDir(), - [ - ...config.getWorkspaceContext().getDirectories(), - ...pathsToAdd, - ], - config.getDebugMode(), - config.getFileService(), - config.getExtensionContextFilePaths(), - context.services.settings.merged.memoryImportFormat || 'tree', // Use setting or default to 'tree' - config.getFileFilteringOptions(), - context.services.settings.merged.memoryDiscoveryMaxDirs, - ); - config.setUserMemory(memoryContent); - config.setGeminiMdFileCount(fileCount); - context.ui.setGeminiMdFileCount(fileCount); + if (added.length > 0) { + try { + if (config.shouldLoadMemoryFromIncludeDirectories()) { + const { memoryContent, fileCount } = + await loadServerHierarchicalMemory( + config.getWorkingDir(), + [...config.getWorkspaceContext().getDirectories()], + config.getDebugMode(), + config.getFileService(), + config.getExtensionContextFilePaths(), + context.services.settings.merged.memoryImportFormat || 'tree', // Use setting or default to 'tree' + config.getFileFilteringOptions(), + context.services.settings.merged.memoryDiscoveryMaxDirs, + ); + config.setUserMemory(memoryContent); + config.setGeminiMdFileCount(fileCount); + context.ui.setGeminiMdFileCount(fileCount); + } + addItem( + { + type: MessageType.INFO, + text: `Successfully added memory files from the following directories if there are:\n- ${added.join('\n- ')}`, + }, + Date.now(), + ); + } catch (error) { + errors.push(`Error refreshing memory: ${(error as Error).message}`); } - addItem( - { - type: MessageType.INFO, - text: `Successfully added GEMINI.md files from the following directories if there are:\n- ${added.join('\n- ')}`, - }, - Date.now(), - ); - } catch (error) { - errors.push(`Error refreshing memory: ${(error as Error).message}`); } if (added.length > 0) { diff --git a/packages/cli/src/ui/components/ContextSummaryDisplay.test.tsx b/packages/cli/src/ui/components/ContextSummaryDisplay.test.tsx index d70bb4ca5..13a9673d9 100644 --- a/packages/cli/src/ui/components/ContextSummaryDisplay.test.tsx +++ b/packages/cli/src/ui/components/ContextSummaryDisplay.test.tsx @@ -27,7 +27,7 @@ const renderWithWidth = ( describe('', () => { const baseProps = { geminiMdFileCount: 1, - contextFileNames: ['GEMINI.md'], + contextFileNames: ['QWEN.md'], mcpServers: { 'test-server': { command: 'test' } }, showToolDescriptions: false, ideContext: { @@ -41,7 +41,7 @@ describe('', () => { const { lastFrame } = renderWithWidth(120, baseProps); const output = lastFrame(); expect(output).toContain( - 'Using: 1 open file (ctrl+e to view) | 1 GEMINI.md file | 1 MCP server (ctrl+t to view)', + 'Using: 1 open file (ctrl+e to view) | 1 QWEN.md file | 1 MCP server (ctrl+t to view)', ); // Check for absence of newlines expect(output.includes('\n')).toBe(false); @@ -53,7 +53,7 @@ describe('', () => { const expectedLines = [ 'Using:', ' - 1 open file (ctrl+e to view)', - ' - 1 GEMINI.md file', + ' - 1 QWEN.md file', ' - 1 MCP server (ctrl+t to view)', ]; const actualLines = output.split('\n'); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index d09c24e67..45d39b439 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -208,6 +208,7 @@ export interface ConfigParameters { modelNames: string[]; template: string; }>; + authType?: AuthType; contentGenerator?: { timeout?: number; maxRetries?: number; @@ -288,6 +289,7 @@ export class Config { private readonly summarizeToolOutput: | Record | undefined; + private authType?: AuthType; private readonly enableOpenAILogging: boolean; private readonly contentGenerator?: { timeout?: number; @@ -368,6 +370,7 @@ export class Config { this.ideMode = params.ideMode ?? false; this.ideClient = IdeClient.getInstance(); this.systemPromptMappings = params.systemPromptMappings; + this.authType = params.authType; this.enableOpenAILogging = params.enableOpenAILogging ?? false; this.contentGenerator = params.contentGenerator; this.cliVersion = params.cliVersion; @@ -451,6 +454,8 @@ export class Config { // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; + + this.authType = authMethod; } getSessionId(): string { @@ -545,6 +550,7 @@ export class Config { getDebugMode(): boolean { return this.debugMode; } + getQuestion(): string | undefined { return this.question; } @@ -763,6 +769,10 @@ export class Config { } } + getAuthType(): AuthType | undefined { + return this.authType; + } + getEnableOpenAILogging(): boolean { return this.enableOpenAILogging; } diff --git a/packages/core/src/core/openaiContentGenerator.test.ts b/packages/core/src/core/openaiContentGenerator.test.ts index 8d03f0ae0..b20c9dc25 100644 --- a/packages/core/src/core/openaiContentGenerator.test.ts +++ b/packages/core/src/core/openaiContentGenerator.test.ts @@ -3410,7 +3410,10 @@ describe('OpenAIContentGenerator', () => { model: 'qwen-turbo', }; - await dashscopeGenerator.generateContent(request, 'dashscope-prompt-id'); + await dashscopeGenerator.generateContentStream( + request, + 'dashscope-prompt-id', + ); // Should include cache control in last message expect(mockOpenAIClient.chat.completions.create).toHaveBeenCalledWith( @@ -3422,7 +3425,6 @@ describe('OpenAIContentGenerator', () => { expect.objectContaining({ type: 'text', text: 'Hello, how are you?', - cache_control: { type: 'ephemeral' }, }), ]), }), diff --git a/packages/core/src/core/openaiContentGenerator.ts b/packages/core/src/core/openaiContentGenerator.ts index be681b0a5..94616a2c9 100644 --- a/packages/core/src/core/openaiContentGenerator.ts +++ b/packages/core/src/core/openaiContentGenerator.ts @@ -130,6 +130,7 @@ export class OpenAIContentGenerator implements ContentGenerator { ? { 'X-DashScope-CacheControl': 'enable', 'X-DashScope-UserAgent': userAgent, + 'X-DashScope-AuthType': contentGeneratorConfig.authType, } : {}), }; @@ -235,8 +236,18 @@ export class OpenAIContentGenerator implements ContentGenerator { private async buildCreateParams( request: GenerateContentParameters, userPromptId: string, + streaming: boolean = false, ): Promise[0]> { - const messages = this.convertToOpenAIFormat(request); + let messages = this.convertToOpenAIFormat(request); + + // Add cache control to system and last messages for DashScope providers + // Only add cache control to system message for non-streaming requests + if (this.isDashScopeProvider()) { + messages = this.addDashScopeCacheControl( + messages, + streaming ? 'both' : 'system', + ); + } // Build sampling parameters with clear priority: // 1. Request-level parameters (highest priority) @@ -259,6 +270,11 @@ export class OpenAIContentGenerator implements ContentGenerator { ); } + if (streaming) { + createParams.stream = true; + createParams.stream_options = { include_usage: true }; + } + return createParams; } @@ -267,7 +283,11 @@ export class OpenAIContentGenerator implements ContentGenerator { userPromptId: string, ): Promise { const startTime = Date.now(); - const createParams = await this.buildCreateParams(request, userPromptId); + const createParams = await this.buildCreateParams( + request, + userPromptId, + false, + ); try { const completion = (await this.client.chat.completions.create( @@ -358,10 +378,11 @@ export class OpenAIContentGenerator implements ContentGenerator { userPromptId: string, ): Promise> { const startTime = Date.now(); - const createParams = await this.buildCreateParams(request, userPromptId); - - createParams.stream = true; - createParams.stream_options = { include_usage: true }; + const createParams = await this.buildCreateParams( + request, + userPromptId, + true, + ); try { const stream = (await this.client.chat.completions.create( @@ -942,14 +963,13 @@ export class OpenAIContentGenerator implements ContentGenerator { const mergedMessages = this.mergeConsecutiveAssistantMessages(cleanedMessages); - // Add cache control to system and last messages for DashScope providers - return this.addCacheControlFlag(mergedMessages, 'both'); + return mergedMessages; } /** * Add cache control flag to specified message(s) for DashScope providers */ - private addCacheControlFlag( + private addDashScopeCacheControl( messages: OpenAI.Chat.ChatCompletionMessageParam[], target: 'system' | 'last' | 'both' = 'both', ): OpenAI.Chat.ChatCompletionMessageParam[] { diff --git a/packages/core/src/telemetry/integration.test.circular.ts b/packages/core/src/telemetry/integration.test.circular.ts index 614f5e02f..17beccd5f 100644 --- a/packages/core/src/telemetry/integration.test.circular.ts +++ b/packages/core/src/telemetry/integration.test.circular.ts @@ -8,12 +8,23 @@ * Integration test to verify circular reference handling with proxy agents */ -import { describe, it, expect } from 'vitest'; +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { QwenLogger } from './qwen-logger/qwen-logger.js'; import { RumEvent } from './qwen-logger/event-types.js'; import { Config } from '../config/config.js'; describe('Circular Reference Integration Test', () => { + beforeEach(() => { + // Clear singleton instance before each test + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (QwenLogger as any).instance = undefined; + }); + + afterEach(() => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (QwenLogger as any).instance = undefined; + }); + it('should handle HttpsProxyAgent-like circular references in qwen logging', () => { // Create a mock config with proxy const mockConfig = { @@ -64,4 +75,36 @@ describe('Circular Reference Integration Test', () => { logger?.enqueueLogEvent(problematicEvent); }).not.toThrow(); }); + + it('should handle event overflow without memory leaks', () => { + const mockConfig = { + getTelemetryEnabled: () => true, + getUsageStatisticsEnabled: () => true, + getSessionId: () => 'test-session', + getDebugMode: () => true, + } as unknown as Config; + + const logger = QwenLogger.getInstance(mockConfig); + + // Add more events than the maximum capacity + for (let i = 0; i < 1100; i++) { + logger?.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: `overflow-test-${i}`, + }); + } + + // Logger should still be functional + expect(logger).toBeDefined(); + expect(() => { + logger?.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: 'final-test', + }); + }).not.toThrow(); + }); }); diff --git a/packages/core/src/telemetry/loggers.ts b/packages/core/src/telemetry/loggers.ts index c887f1644..7c2f25ae5 100644 --- a/packages/core/src/telemetry/loggers.ts +++ b/packages/core/src/telemetry/loggers.ts @@ -8,7 +8,6 @@ import { LogAttributes, LogRecord, logs } from '@opentelemetry/api-logs'; import { SemanticAttributes } from '@opentelemetry/semantic-conventions'; import { Config } from '../config/config.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; -import { ClearcutLogger } from './clearcut-logger/clearcut-logger.js'; import { EVENT_API_ERROR, EVENT_API_REQUEST, @@ -150,7 +149,7 @@ export function logToolCall(config: Config, event: ToolCallEvent): void { } export function logApiRequest(config: Config, event: ApiRequestEvent): void { - QwenLogger.getInstance(config)?.logApiRequestEvent(event); + // QwenLogger.getInstance(config)?.logApiRequestEvent(event); if (!isTelemetrySdkInitialized()) return; const attributes: LogAttributes = { @@ -364,6 +363,7 @@ export function logIdeConnection( config: Config, event: IdeConnectionEvent, ): void { + QwenLogger.getInstance(config)?.logIdeConnectionEvent(event); if (!isTelemetrySdkInitialized()) return; const attributes: LogAttributes = { @@ -384,7 +384,7 @@ export function logKittySequenceOverflow( config: Config, event: KittySequenceOverflowEvent, ): void { - ClearcutLogger.getInstance(config)?.logKittySequenceOverflowEvent(event); + QwenLogger.getInstance(config)?.logKittySequenceOverflowEvent(event); if (!isTelemetrySdkInitialized()) return; const attributes: LogAttributes = { ...getCommonAttributes(config), diff --git a/packages/core/src/telemetry/qwen-logger/event-types.ts b/packages/core/src/telemetry/qwen-logger/event-types.ts index 1549d2ba8..f81fb7121 100644 --- a/packages/core/src/telemetry/qwen-logger/event-types.ts +++ b/packages/core/src/telemetry/qwen-logger/event-types.ts @@ -79,5 +79,6 @@ export interface RumPayload { session: RumSession; view: RumView; events: RumEvent[]; + properties?: Record; _v: string; } diff --git a/packages/core/src/telemetry/qwen-logger/qwen-logger.test.ts b/packages/core/src/telemetry/qwen-logger/qwen-logger.test.ts new file mode 100644 index 000000000..7ae66ebce --- /dev/null +++ b/packages/core/src/telemetry/qwen-logger/qwen-logger.test.ts @@ -0,0 +1,407 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + afterAll, +} from 'vitest'; +import { QwenLogger, TEST_ONLY } from './qwen-logger.js'; +import { Config } from '../../config/config.js'; +import { + StartSessionEvent, + EndSessionEvent, + IdeConnectionEvent, + KittySequenceOverflowEvent, + IdeConnectionType, +} from '../types.js'; +import { RumEvent } from './event-types.js'; + +// Mock dependencies +vi.mock('../../utils/user_id.js', () => ({ + getInstallationId: vi.fn(() => 'test-installation-id'), +})); + +vi.mock('../../utils/safeJsonStringify.js', () => ({ + safeJsonStringify: vi.fn((obj) => JSON.stringify(obj)), +})); + +// Mock https module +vi.mock('https', () => ({ + request: vi.fn(), +})); + +const makeFakeConfig = (overrides: Partial = {}): Config => { + const defaults = { + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getSessionId: () => 'test-session-id', + getCliVersion: () => '1.0.0', + getProxy: () => undefined, + getContentGeneratorConfig: () => ({ authType: 'test-auth' }), + getMcpServers: () => ({}), + getModel: () => 'test-model', + getEmbeddingModel: () => 'test-embedding', + getSandbox: () => false, + getCoreTools: () => [], + getApprovalMode: () => 'auto', + getTelemetryEnabled: () => true, + getTelemetryLogPromptsEnabled: () => false, + getFileFilteringRespectGitIgnore: () => true, + ...overrides, + }; + return defaults as Config; +}; + +describe('QwenLogger', () => { + let mockConfig: Config; + + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date('2025-01-01T12:00:00.000Z')); + mockConfig = makeFakeConfig(); + // Clear singleton instance + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (QwenLogger as any).instance = undefined; + }); + + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + afterAll(() => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (QwenLogger as any).instance = undefined; + }); + + describe('getInstance', () => { + it('returns undefined when usage statistics are disabled', () => { + const config = makeFakeConfig({ getUsageStatisticsEnabled: () => false }); + const logger = QwenLogger.getInstance(config); + expect(logger).toBeUndefined(); + }); + + it('returns an instance when usage statistics are enabled', () => { + const logger = QwenLogger.getInstance(mockConfig); + expect(logger).toBeInstanceOf(QwenLogger); + }); + + it('is a singleton', () => { + const logger1 = QwenLogger.getInstance(mockConfig); + const logger2 = QwenLogger.getInstance(mockConfig); + expect(logger1).toBe(logger2); + }); + }); + + describe('event queue management', () => { + it('should handle event overflow gracefully', () => { + const debugConfig = makeFakeConfig({ getDebugMode: () => true }); + const logger = QwenLogger.getInstance(debugConfig)!; + const consoleSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => {}); + + // Fill the queue beyond capacity + for (let i = 0; i < TEST_ONLY.MAX_EVENTS + 10; i++) { + logger.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: `test-event-${i}`, + }); + } + + // Should have logged debug messages about dropping events + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'QwenLogger: Dropped old event to prevent memory leak', + ), + ); + }); + + it('should handle enqueue errors gracefully', () => { + const debugConfig = makeFakeConfig({ getDebugMode: () => true }); + const logger = QwenLogger.getInstance(debugConfig)!; + const consoleSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + // Mock the events deque to throw an error + const originalPush = logger['events'].push; + logger['events'].push = vi.fn(() => { + throw new Error('Test error'); + }); + + logger.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: 'test-event', + }); + + expect(consoleSpy).toHaveBeenCalledWith( + 'QwenLogger: Failed to enqueue log event.', + expect.any(Error), + ); + + // Restore original method + logger['events'].push = originalPush; + }); + }); + + describe('concurrent flush protection', () => { + it('should handle concurrent flush requests', () => { + const debugConfig = makeFakeConfig({ getDebugMode: () => true }); + const logger = QwenLogger.getInstance(debugConfig)!; + const consoleSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => {}); + + // Manually set the flush in progress flag to simulate concurrent access + logger['isFlushInProgress'] = true; + + // Try to flush while another flush is in progress + const result = logger.flushToRum(); + + // Should have logged about pending flush + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'QwenLogger: Flush already in progress, marking pending flush', + ), + ); + + // Should return a resolved promise + expect(result).toBeInstanceOf(Promise); + + // Reset the flag + logger['isFlushInProgress'] = false; + }); + }); + + describe('failed event retry mechanism', () => { + it('should requeue failed events with size limits', () => { + const debugConfig = makeFakeConfig({ getDebugMode: () => true }); + const logger = QwenLogger.getInstance(debugConfig)!; + const consoleSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => {}); + + const failedEvents: RumEvent[] = []; + for (let i = 0; i < TEST_ONLY.MAX_RETRY_EVENTS + 50; i++) { + failedEvents.push({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: `failed-event-${i}`, + }); + } + + // Call the private method using bracket notation + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (logger as any).requeueFailedEvents(failedEvents); + + // Should have logged about dropping events due to retry limit + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('QwenLogger: Re-queued'), + ); + }); + + it('should handle empty retry queue gracefully', () => { + const debugConfig = makeFakeConfig({ getDebugMode: () => true }); + const logger = QwenLogger.getInstance(debugConfig)!; + const consoleSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => {}); + + // Fill the queue to capacity first + for (let i = 0; i < TEST_ONLY.MAX_EVENTS; i++) { + logger.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: `event-${i}`, + }); + } + + // Try to requeue when no space is available + const failedEvents: RumEvent[] = [ + { + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: 'failed-event', + }, + ]; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (logger as any).requeueFailedEvents(failedEvents); + + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('QwenLogger: No events re-queued'), + ); + }); + }); + + describe('event handlers', () => { + it('should log IDE connection events', () => { + const logger = QwenLogger.getInstance(mockConfig)!; + const enqueueSpy = vi.spyOn(logger, 'enqueueLogEvent'); + + const event = new IdeConnectionEvent(IdeConnectionType.SESSION); + + logger.logIdeConnectionEvent(event); + + expect(enqueueSpy).toHaveBeenCalledWith( + expect.objectContaining({ + event_type: 'action', + type: 'connection', + name: 'ide_connection', + snapshots: JSON.stringify({ + connection_type: IdeConnectionType.SESSION, + }), + }), + ); + }); + + it('should log Kitty sequence overflow events', () => { + const logger = QwenLogger.getInstance(mockConfig)!; + const enqueueSpy = vi.spyOn(logger, 'enqueueLogEvent'); + + const event = new KittySequenceOverflowEvent(1024, 'truncated...'); + + logger.logKittySequenceOverflowEvent(event); + + expect(enqueueSpy).toHaveBeenCalledWith( + expect.objectContaining({ + event_type: 'exception', + type: 'overflow', + name: 'kitty_sequence_overflow', + subtype: 'kitty_sequence_overflow', + snapshots: JSON.stringify({ + sequence_length: 1024, + truncated_sequence: 'truncated...', + }), + }), + ); + }); + + it('should flush start session events immediately', async () => { + const logger = QwenLogger.getInstance(mockConfig)!; + const flushSpy = vi.spyOn(logger, 'flushToRum').mockResolvedValue({}); + + const testConfig = makeFakeConfig({ + getModel: () => 'test-model', + getEmbeddingModel: () => 'test-embedding', + }); + const event = new StartSessionEvent(testConfig); + + logger.logStartSessionEvent(event); + + expect(flushSpy).toHaveBeenCalled(); + }); + + it('should flush end session events immediately', async () => { + const logger = QwenLogger.getInstance(mockConfig)!; + const flushSpy = vi.spyOn(logger, 'flushToRum').mockResolvedValue({}); + + const event = new EndSessionEvent(mockConfig); + + logger.logEndSessionEvent(event); + + expect(flushSpy).toHaveBeenCalled(); + }); + }); + + describe('flush timing', () => { + it('should not flush if interval has not passed', () => { + const logger = QwenLogger.getInstance(mockConfig)!; + const flushSpy = vi.spyOn(logger, 'flushToRum'); + + // Add an event and try to flush immediately + logger.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: 'test-event', + }); + + logger.flushIfNeeded(); + + expect(flushSpy).not.toHaveBeenCalled(); + }); + + it('should flush when interval has passed', () => { + const logger = QwenLogger.getInstance(mockConfig)!; + const flushSpy = vi.spyOn(logger, 'flushToRum').mockResolvedValue({}); + + // Add an event + logger.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: 'test-event', + }); + + // Advance time beyond flush interval + vi.advanceTimersByTime(TEST_ONLY.FLUSH_INTERVAL_MS + 1000); + + logger.flushIfNeeded(); + + expect(flushSpy).toHaveBeenCalled(); + }); + }); + + describe('error handling', () => { + it('should handle flush errors gracefully with debug mode', async () => { + const debugConfig = makeFakeConfig({ getDebugMode: () => true }); + const logger = QwenLogger.getInstance(debugConfig)!; + const consoleSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => {}); + + // Add an event first + logger.enqueueLogEvent({ + timestamp: Date.now(), + event_type: 'action', + type: 'test', + name: 'test-event', + }); + + // Mock flushToRum to throw an error + const originalFlush = logger.flushToRum.bind(logger); + logger.flushToRum = vi.fn().mockRejectedValue(new Error('Network error')); + + // Advance time to trigger flush + vi.advanceTimersByTime(TEST_ONLY.FLUSH_INTERVAL_MS + 1000); + + logger.flushIfNeeded(); + + // Wait for async operations + await vi.runAllTimersAsync(); + + expect(consoleSpy).toHaveBeenCalledWith( + 'Error flushing to RUM:', + expect.any(Error), + ); + + // Restore original method + logger.flushToRum = originalFlush; + }); + }); + + describe('constants export', () => { + it('should export test constants', () => { + expect(TEST_ONLY.MAX_EVENTS).toBe(1000); + expect(TEST_ONLY.MAX_RETRY_EVENTS).toBe(100); + expect(TEST_ONLY.FLUSH_INTERVAL_MS).toBe(60000); + }); + }); +}); diff --git a/packages/core/src/telemetry/qwen-logger/qwen-logger.ts b/packages/core/src/telemetry/qwen-logger/qwen-logger.ts index 6e84fe5ac..2b3d5fb7d 100644 --- a/packages/core/src/telemetry/qwen-logger/qwen-logger.ts +++ b/packages/core/src/telemetry/qwen-logger/qwen-logger.ts @@ -7,7 +7,6 @@ import { Buffer } from 'buffer'; import * as https from 'https'; import { HttpsProxyAgent } from 'https-proxy-agent'; -import { randomUUID } from 'crypto'; import { StartSessionEvent, @@ -22,6 +21,8 @@ import { NextSpeakerCheckEvent, SlashCommandEvent, MalformedJsonResponseEvent, + IdeConnectionEvent, + KittySequenceOverflowEvent, } from '../types.js'; import { RumEvent, @@ -31,12 +32,12 @@ import { RumExceptionEvent, RumPayload, } from './event-types.js'; -// Removed unused EventMetadataKey import import { Config } from '../../config/config.js'; import { safeJsonStringify } from '../../utils/safeJsonStringify.js'; -// Removed unused import import { HttpError, retryWithBackoff } from '../../utils/retry.js'; import { getInstallationId } from '../../utils/user_id.js'; +import { FixedDeque } from 'mnemonist'; +import { AuthType } from '../../core/contentGenerator.js'; // Usage statistics collection endpoint const USAGE_STATS_HOSTNAME = 'gb4w8c3ygj-default-sea.rum.aliyuncs.com'; @@ -44,6 +45,23 @@ const USAGE_STATS_PATH = '/'; const RUN_APP_ID = 'gb4w8c3ygj@851d5d500f08f92'; +/** + * Interval in which buffered events are sent to RUM. + */ +const FLUSH_INTERVAL_MS = 1000 * 60; + +/** + * Maximum amount of events to keep in memory. Events added after this amount + * are dropped until the next flush to RUM, which happens periodically as + * defined by {@link FLUSH_INTERVAL_MS}. + */ +const MAX_EVENTS = 1000; + +/** + * Maximum events to retry after a failed RUM flush + */ +const MAX_RETRY_EVENTS = 100; + export interface LogResponse { nextRequestWaitMs?: number; } @@ -53,23 +71,42 @@ export interface LogResponse { export class QwenLogger { private static instance: QwenLogger; private config?: Config; - private readonly events: RumEvent[] = []; - private last_flush_time: number = Date.now(); - private flush_interval_ms: number = 1000 * 60; // Wait at least a minute before flushing events. + + /** + * Queue of pending events that need to be flushed to the server. New events + * are added to this queue and then flushed on demand (via `flushToRum`) + */ + private readonly events: FixedDeque; + + /** + * The last time that the events were successfully flushed to the server. + */ + private lastFlushTime: number = Date.now(); + private userId: string; private sessionId: string; - private viewId: string; + + /** + * The value is true when there is a pending flush happening. This prevents + * concurrent flush operations. + */ private isFlushInProgress: boolean = false; + + /** + * This value is true when a flush was requested during an ongoing flush. + */ + private pendingFlush: boolean = false; + private isShutdown: boolean = false; private constructor(config?: Config) { this.config = config; + this.events = new FixedDeque(Array, MAX_EVENTS); this.userId = this.generateUserId(); this.sessionId = typeof this.config?.getSessionId === 'function' ? this.config.getSessionId() : ''; - this.viewId = randomUUID(); } private generateUserId(): string { @@ -92,7 +129,26 @@ export class QwenLogger { } enqueueLogEvent(event: RumEvent): void { - this.events.push(event); + try { + // Manually handle overflow for FixedDeque, which throws when full. + const wasAtCapacity = this.events.size >= MAX_EVENTS; + + if (wasAtCapacity) { + this.events.shift(); // Evict oldest element to make space. + } + + this.events.push(event); + + if (wasAtCapacity && this.config?.getDebugMode()) { + console.debug( + `QwenLogger: Dropped old event to prevent memory leak (queue size: ${this.events.size})`, + ); + } + } catch (error) { + if (this.config?.getDebugMode()) { + console.error('QwenLogger: Failed to enqueue log event.', error); + } + } } createRumEvent( @@ -143,6 +199,7 @@ export class QwenLogger { } async createRumPayload(): Promise { + const authType = this.config?.getAuthType(); const version = this.config?.getCliVersion() || 'unknown'; return { @@ -159,40 +216,59 @@ export class QwenLogger { id: this.sessionId, }, view: { - id: this.viewId, + id: this.sessionId, name: 'qwen-code-cli', }, - events: [...this.events], + + events: this.events.toArray() as RumEvent[], + properties: { + auth_type: authType, + model: this.config?.getModel(), + base_url: + authType === AuthType.USE_OPENAI ? process.env.OPENAI_BASE_URL : '', + }, _v: `qwen-code@${version}`, }; } flushIfNeeded(): void { - if (Date.now() - this.last_flush_time < this.flush_interval_ms) { - return; - } - - // Prevent concurrent flush operations - if (this.isFlushInProgress) { + if (Date.now() - this.lastFlushTime < FLUSH_INTERVAL_MS) { return; } this.flushToRum().catch((error) => { - console.debug('Error flushing to RUM:', error); + if (this.config?.getDebugMode()) { + console.debug('Error flushing to RUM:', error); + } }); } async flushToRum(): Promise { + if (this.isFlushInProgress) { + if (this.config?.getDebugMode()) { + console.debug( + 'QwenLogger: Flush already in progress, marking pending flush.', + ); + } + this.pendingFlush = true; + return Promise.resolve({}); + } + this.isFlushInProgress = true; + if (this.config?.getDebugMode()) { console.log('Flushing log events to RUM.'); } - if (this.events.length === 0) { + if (this.events.size === 0) { + this.isFlushInProgress = false; return {}; } - this.isFlushInProgress = true; + const eventsToSend = this.events.toArray() as RumEvent[]; + this.events.clear(); const rumPayload = await this.createRumPayload(); + // Override events with the ones we're sending + rumPayload.events = eventsToSend; const flushFn = () => new Promise((resolve, reject) => { const body = safeJsonStringify(rumPayload); @@ -246,16 +322,29 @@ export class QwenLogger { }, }); - this.events.splice(0, this.events.length); - this.last_flush_time = Date.now(); + this.lastFlushTime = Date.now(); return {}; } catch (error) { if (this.config?.getDebugMode()) { console.error('RUM flush failed after multiple retries.', error); } + + // Re-queue failed events for retry + this.requeueFailedEvents(eventsToSend); return {}; } finally { this.isFlushInProgress = false; + + // If a flush was requested while we were flushing, flush again + if (this.pendingFlush) { + this.pendingFlush = false; + // Fire and forget the pending flush + this.flushToRum().catch((error) => { + if (this.config?.getDebugMode()) { + console.debug('Error in pending flush to RUM:', error); + } + }); + } } } @@ -282,7 +371,9 @@ export class QwenLogger { // Flush start event immediately this.enqueueLogEvent(applicationEvent); this.flushToRum().catch((error: unknown) => { - console.debug('Error flushing to RUM:', error); + if (this.config?.getDebugMode()) { + console.debug('Error flushing to RUM:', error); + } }); } @@ -451,13 +542,41 @@ export class QwenLogger { this.flushIfNeeded(); } + logIdeConnectionEvent(event: IdeConnectionEvent): void { + const rumEvent = this.createActionEvent('connection', 'ide_connection', { + snapshots: JSON.stringify({ connection_type: event.connection_type }), + }); + + this.enqueueLogEvent(rumEvent); + this.flushIfNeeded(); + } + + logKittySequenceOverflowEvent(event: KittySequenceOverflowEvent): void { + const rumEvent = this.createExceptionEvent( + 'overflow', + 'kitty_sequence_overflow', + { + subtype: 'kitty_sequence_overflow', + snapshots: JSON.stringify({ + sequence_length: event.sequence_length, + truncated_sequence: event.truncated_sequence, + }), + }, + ); + + this.enqueueLogEvent(rumEvent); + this.flushIfNeeded(); + } + logEndSessionEvent(_event: EndSessionEvent): void { const applicationEvent = this.createViewEvent('session', 'session_end', {}); // Flush immediately on session end. this.enqueueLogEvent(applicationEvent); this.flushToRum().catch((error: unknown) => { - console.debug('Error flushing to RUM:', error); + if (this.config?.getDebugMode()) { + console.debug('Error flushing to RUM:', error); + } }); } @@ -480,4 +599,60 @@ export class QwenLogger { const event = new EndSessionEvent(this.config); this.logEndSessionEvent(event); } + + private requeueFailedEvents(eventsToSend: RumEvent[]): void { + // Add the events back to the front of the queue to be retried, but limit retry queue size + const eventsToRetry = eventsToSend.slice(-MAX_RETRY_EVENTS); // Keep only the most recent events + + // Log a warning if we're dropping events + if (eventsToSend.length > MAX_RETRY_EVENTS && this.config?.getDebugMode()) { + console.warn( + `QwenLogger: Dropping ${ + eventsToSend.length - MAX_RETRY_EVENTS + } events due to retry queue limit. Total events: ${ + eventsToSend.length + }, keeping: ${MAX_RETRY_EVENTS}`, + ); + } + + // Determine how many events can be re-queued + const availableSpace = MAX_EVENTS - this.events.size; + const numEventsToRequeue = Math.min(eventsToRetry.length, availableSpace); + + if (numEventsToRequeue === 0) { + if (this.config?.getDebugMode()) { + console.debug( + `QwenLogger: No events re-queued (queue size: ${this.events.size})`, + ); + } + return; + } + + // Get the most recent events to re-queue + const eventsToRequeue = eventsToRetry.slice( + eventsToRetry.length - numEventsToRequeue, + ); + + // Prepend events to the front of the deque to be retried first. + // We iterate backwards to maintain the original order of the failed events. + for (let i = eventsToRequeue.length - 1; i >= 0; i--) { + this.events.unshift(eventsToRequeue[i]); + } + // Clear any potential overflow + while (this.events.size > MAX_EVENTS) { + this.events.pop(); + } + + if (this.config?.getDebugMode()) { + console.debug( + `QwenLogger: Re-queued ${numEventsToRequeue} events for retry (queue size: ${this.events.size})`, + ); + } + } } + +export const TEST_ONLY = { + MAX_RETRY_EVENTS, + MAX_EVENTS, + FLUSH_INTERVAL_MS, +}; diff --git a/packages/core/src/utils/bfsFileSearch.test.ts b/packages/core/src/utils/bfsFileSearch.test.ts index f9d76e386..b47c41721 100644 --- a/packages/core/src/utils/bfsFileSearch.test.ts +++ b/packages/core/src/utils/bfsFileSearch.test.ts @@ -210,16 +210,16 @@ describe('bfsFileSearch', () => { for (let i = 0; i < numTargetDirs; i++) { // Add target files in some directories fileCreationPromises.push( - createTestFile('content', `dir${i}`, 'GEMINI.md'), + createTestFile('content', `dir${i}`, 'QWEN.md'), ); fileCreationPromises.push( - createTestFile('content', `dir${i}`, 'subdir1', 'GEMINI.md'), + createTestFile('content', `dir${i}`, 'subdir1', 'QWEN.md'), ); } const expectedFiles = await Promise.all(fileCreationPromises); const result = await bfsFileSearch(testRootDir, { - fileName: 'GEMINI.md', + fileName: 'QWEN.md', // Provide a generous maxDirs limit to ensure it doesn't prematurely stop // in this large test case. Total dirs created is 200. maxDirs: 250, diff --git a/packages/core/src/utils/memoryDiscovery.ts b/packages/core/src/utils/memoryDiscovery.ts index 0a2989a9a..7433b868e 100644 --- a/packages/core/src/utils/memoryDiscovery.ts +++ b/packages/core/src/utils/memoryDiscovery.ts @@ -143,9 +143,28 @@ async function getGeminiMdFilePathsInternalForEachDir( // It's okay if it's not found. } - // FIX: Only perform the workspace search (upward and downward scans) - // if a valid currentWorkingDirectory is provided. - if (dir) { + // Handle the case where we're in the home directory (dir is empty string or home path) + const resolvedDir = dir ? path.resolve(dir) : resolvedHome; + const isHomeDirectory = resolvedDir === resolvedHome; + + if (isHomeDirectory) { + // For home directory, only check for QWEN.md directly in the home directory + const homeContextPath = path.join(resolvedHome, geminiMdFilename); + try { + await fs.access(homeContextPath, fsSync.constants.R_OK); + if (homeContextPath !== globalMemoryPath) { + allPaths.add(homeContextPath); + if (debugMode) + logger.debug( + `Found readable home ${geminiMdFilename}: ${homeContextPath}`, + ); + } + } catch { + // Not found, which is okay + } + } else if (dir) { + // FIX: Only perform the workspace search (upward and downward scans) + // if a valid currentWorkingDirectory is provided and it's not the home directory. const resolvedCwd = path.resolve(dir); if (debugMode) logger.debug( From 4463107af30393c3664a33d38806d725bba7418b Mon Sep 17 00:00:00 2001 From: "mingholy.lmh" Date: Wed, 27 Aug 2025 18:25:44 +0800 Subject: [PATCH 15/16] chore: bump version to 0.0.9 --- CHANGELOG.md | 22 ++++++++++++++++++++++ docs/deployment.md | 2 +- package-lock.json | 12 ++++++------ package.json | 4 ++-- packages/cli/package.json | 4 ++-- packages/core/package.json | 2 +- packages/test-utils/package.json | 2 +- packages/vscode-ide-companion/package.json | 2 +- 8 files changed, 36 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0b33eda1..11523eb59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## 0.0.9 + +- Synced upstream `gemini-cli` to v0.1.21. +- Fixed token synchronization among multiple Qwen sessions. +- Improved tool execution with early stop on invalid tool calls. +- Added explicit `is_background` parameter for shell tool. +- Enhanced memory management with sub-commands to switch between project and global memory operations. +- Renamed `GEMINI_DIR` to `QWEN_DIR` for better branding consistency. +- Added support for Qwen Markdown selection. +- Fixed parallel tool usage and improved tool reliability. +- Upgraded integration tests to use Vitest framework. +- Enhanced VS Code IDE integration with launch configurations. +- Added terminal setup command for Shift+Enter and Ctrl+Enter support. +- Fixed GitHub Workflows configuration issues. +- Improved settings directory and command descriptions. +- Fixed locale handling in yargs configuration. +- Added support for `trustedFolders.json` configuration file. +- Enhanced cross-platform compatibility for sandbox build scripts. +- Improved error handling and fixed ambiguous literals. +- Updated documentation links and added IDE integration documentation. +- Miscellaneous improvements and bug fixes. + ## 0.0.8 - Synced upstream `gemini-cli` to v0.1.19. diff --git a/docs/deployment.md b/docs/deployment.md index 1154605d9..28d7a2209 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -41,7 +41,7 @@ For security and isolation, Qwen Code can be run inside a container. This is the You can run the published sandbox image directly. This is useful for environments where you only have Docker and want to run the CLI. ```bash # Run the published sandbox image - docker run --rm -it ghcr.io/qwenlm/qwen-code:0.0.8 + docker run --rm -it ghcr.io/qwenlm/qwen-code:0.0.9 ``` - **Using the `--sandbox` flag:** If you have Qwen Code installed locally (using the standard installation described above), you can instruct it to run inside the sandbox container. diff --git a/package-lock.json b/package-lock.json index 7c14c7d97..3d83ebc97 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@qwen-code/qwen-code", - "version": "0.0.8", + "version": "0.0.9", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@qwen-code/qwen-code", - "version": "0.0.8", + "version": "0.0.9", "workspaces": [ "packages/*" ], @@ -12336,7 +12336,7 @@ }, "packages/cli": { "name": "@qwen-code/qwen-code", - "version": "0.0.8", + "version": "0.0.9", "dependencies": { "@google/genai": "1.9.0", "@iarna/toml": "^2.2.5", @@ -12520,7 +12520,7 @@ }, "packages/core": { "name": "@qwen-code/qwen-code-core", - "version": "0.0.8", + "version": "0.0.9", "dependencies": { "@google/genai": "1.13.0", "@modelcontextprotocol/sdk": "^1.11.0", @@ -12671,7 +12671,7 @@ }, "packages/test-utils": { "name": "@qwen-code/qwen-code-test-utils", - "version": "0.0.8", + "version": "0.0.9", "license": "Apache-2.0", "devDependencies": { "typescript": "^5.3.3" @@ -12682,7 +12682,7 @@ }, "packages/vscode-ide-companion": { "name": "qwen-code-vscode-ide-companion", - "version": "0.0.8", + "version": "0.0.9", "license": "LICENSE", "dependencies": { "@modelcontextprotocol/sdk": "^1.15.1", diff --git a/package.json b/package.json index a87c8e7e7..e90a2c3e7 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@qwen-code/qwen-code", - "version": "0.0.8", + "version": "0.0.9", "engines": { "node": ">=20.0.0" }, @@ -13,7 +13,7 @@ "url": "git+https://github.com/QwenLM/qwen-code.git" }, "config": { - "sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.8" + "sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.9" }, "scripts": { "start": "node scripts/start.js", diff --git a/packages/cli/package.json b/packages/cli/package.json index 1edd2e978..23a9ddb86 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -1,6 +1,6 @@ { "name": "@qwen-code/qwen-code", - "version": "0.0.8", + "version": "0.0.9", "description": "Qwen Code", "repository": { "type": "git", @@ -25,7 +25,7 @@ "dist" ], "config": { - "sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.8" + "sandboxImageUri": "ghcr.io/qwenlm/qwen-code:0.0.9" }, "dependencies": { "@google/genai": "1.9.0", diff --git a/packages/core/package.json b/packages/core/package.json index 3415d747f..d3dc28465 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@qwen-code/qwen-code-core", - "version": "0.0.8", + "version": "0.0.9", "description": "Qwen Code Core", "repository": { "type": "git", diff --git a/packages/test-utils/package.json b/packages/test-utils/package.json index d8afeae19..8f8dbf2ae 100644 --- a/packages/test-utils/package.json +++ b/packages/test-utils/package.json @@ -1,6 +1,6 @@ { "name": "@qwen-code/qwen-code-test-utils", - "version": "0.0.8", + "version": "0.0.9", "private": true, "main": "src/index.ts", "license": "Apache-2.0", diff --git a/packages/vscode-ide-companion/package.json b/packages/vscode-ide-companion/package.json index 3fbb43a37..8d826d28f 100644 --- a/packages/vscode-ide-companion/package.json +++ b/packages/vscode-ide-companion/package.json @@ -2,7 +2,7 @@ "name": "qwen-code-vscode-ide-companion", "displayName": "Qwen Code Companion", "description": "Enable Qwen Code with direct access to your VS Code workspace.", - "version": "0.0.8", + "version": "0.0.9", "publisher": "qwenlm", "icon": "assets/icon.png", "repository": { From d6da3d444828d127f0c47b9361b347413dfd60e8 Mon Sep 17 00:00:00 2001 From: ajiwo Date: Mon, 25 Aug 2025 13:32:57 +0000 Subject: [PATCH 16/16] feat(tools): Include the new content after edits Introduces a new configuration setting, `readAfterEdit`, which is enabled by default. When this setting is active, the `edit` tool will automatically append the full content of a file to its response message (`llmContent`) after a successful modification or creation. This provides the AI with immediate context of the changes, improving its awareness of the file's current state and reducing the need for a subsequent `read_file` call. Co-authored-by: Qwen-Coder --- docs/cli/configuration.md | 4 + docs/tools/file-system.md | 1 + packages/cli/src/config/config.ts | 1 + packages/cli/src/config/settingsSchema.ts | 10 + packages/core/src/config/config.ts | 7 + packages/core/src/tools/edit.test.ts | 287 ++++++++++++++++++++++ packages/core/src/tools/edit.ts | 7 +- 7 files changed, 316 insertions(+), 1 deletion(-) diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index 78147c3f2..044267e51 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -272,6 +272,10 @@ In addition to a project settings file, a project's `.qwen` directory can contai - **Description:** API key for Tavily web search service. Required to enable the `web_search` tool functionality. If not configured, the web search tool will be disabled and skipped. - **Default:** `undefined` (web search disabled) - **Example:** `"tavilyApiKey": "tvly-your-api-key-here"` +- **`readAfterEdit`** (boolean): + - **Description:** Automatically read file content after editing to provide context to the AI. When enabled, the content of a file is included in the LLM response after successful edit operations, enhancing the AI's awareness of the changes made. + - **Default:** `true` + - **Example:** `"readAfterEdit": false` - **`chatCompression`** (object): - **Description:** Controls the settings for chat history compression, both automatic and when manually invoked through the /compress command. diff --git a/docs/tools/file-system.md b/docs/tools/file-system.md index 45c1eaa7b..0181614c8 100644 --- a/docs/tools/file-system.md +++ b/docs/tools/file-system.md @@ -167,6 +167,7 @@ search_file_content(pattern="function", include="*.js", maxResults=10) - `old_string` is found multiple times, and the self-correction mechanism cannot resolve it to a single, unambiguous match. - **Output (`llmContent`):** - On success: `Successfully modified file: /path/to/file.txt (1 replacements).` or `Created new file: /path/to/new_file.txt with provided content.` + - When the `readAfterEdit` configuration is enabled (default), the updated file content is also included in the response to provide context to the AI. - On failure: An error message explaining the reason (e.g., `Failed to edit, 0 occurrences found...`, `Failed to edit, expected 1 occurrences but found 2...`). - **Confirmation:** Yes. Shows a diff of the proposed changes and asks for user approval before writing to the file. diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index a16ceb0d9..808f4355c 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -585,6 +585,7 @@ export async function loadCliConfig( chatCompression: settings.chatCompression, folderTrustFeature, folderTrust, + readAfterEdit: settings.readAfterEdit ?? true, interactive, trustedFolder, }); diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 4a21ebe5b..bf3a31975 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -540,6 +540,16 @@ export const SETTINGS_SCHEMA = { description: 'The API key for the Tavily API.', showInDialog: false, }, + readAfterEdit: { + type: 'boolean', + label: 'Read After Edit', + category: 'Tools', + requiresRestart: false, + default: true, + description: + 'Automatically read file content after editing to provide context to the AI.', + showInDialog: true, + }, } as const; type InferSettings = { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 45d39b439..156cdd750 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -223,6 +223,7 @@ export interface ConfigParameters { chatCompression?: ChatCompressionSettings; interactive?: boolean; trustedFolder?: boolean; + readAfterEdit?: boolean; } export class Config { @@ -303,6 +304,7 @@ export class Config { private readonly chatCompression: ChatCompressionSettings | undefined; private readonly interactive: boolean; private readonly trustedFolder: boolean | undefined; + private readonly readAfterEdit: boolean; private initialized: boolean = false; constructor(params: ConfigParameters) { @@ -380,6 +382,7 @@ export class Config { this.chatCompression = params.chatCompression; this.interactive = params.interactive ?? false; this.trustedFolder = params.trustedFolder; + this.readAfterEdit = params.readAfterEdit ?? true; // Web search this.tavilyApiKey = params.tavilyApiKey; @@ -813,6 +816,10 @@ export class Config { return this.interactive; } + getReadAfterEdit(): boolean { + return this.readAfterEdit; + } + async getGitService(): Promise { if (!this.gitService) { this.gitService = new GitService(this.targetDir); diff --git a/packages/core/src/tools/edit.test.ts b/packages/core/src/tools/edit.test.ts index b2e31fdda..4c5058c84 100644 --- a/packages/core/src/tools/edit.test.ts +++ b/packages/core/src/tools/edit.test.ts @@ -81,6 +81,7 @@ describe('EditTool', () => { getGeminiMdFileCount: () => 0, setGeminiMdFileCount: vi.fn(), getToolRegistry: () => ({}) as any, // Minimal mock for ToolRegistry + getReadAfterEdit: () => vi.fn().mockReturnValue(true), } as unknown as Config; // Reset mocks before each test @@ -847,3 +848,289 @@ describe('EditTool', () => { }); }); }); + +describe('EditTool - readAfterEdit', () => { + let tool: EditTool; + let tempDir: string; + let rootDir: string; + let mockConfig: Config; + let geminiClient: any; + + beforeEach(() => { + vi.restoreAllMocks(); + tempDir = fs.mkdtempSync( + path.join(os.tmpdir(), 'edit-tool-readafteredit-test-'), + ); + rootDir = path.join(tempDir, 'root'); + fs.mkdirSync(rootDir); + + geminiClient = { + generateJson: mockGenerateJson, + }; + + mockConfig = { + getGeminiClient: vi.fn().mockReturnValue(geminiClient), + getTargetDir: () => rootDir, + getApprovalMode: vi.fn(), + getWorkspaceContext: () => createMockWorkspaceContext(rootDir), + getReadAfterEdit: vi.fn().mockReturnValue(true), // Default to true for these tests + } as unknown as Config; + + (mockConfig.getApprovalMode as Mock).mockClear(); + (mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.DEFAULT); + + mockEnsureCorrectEdit.mockReset(); + mockEnsureCorrectEdit.mockImplementation( + async (_, currentContent, params) => { + let occurrences = 0; + if (params.old_string && currentContent) { + let index = currentContent.indexOf(params.old_string); + while (index !== -1) { + occurrences++; + index = currentContent.indexOf(params.old_string, index + 1); + } + } else if (params.old_string === '') { + occurrences = 0; + } + return Promise.resolve({ params, occurrences }); + }, + ); + + mockGenerateJson.mockReset(); + mockGenerateJson.mockImplementation(async () => Promise.resolve({})); + + tool = new EditTool(mockConfig); + }); + + afterEach(() => { + fs.rmSync(tempDir, { recursive: true, force: true }); + }); + + describe('readAfterEdit enabled', () => { + beforeEach(() => { + (mockConfig.getReadAfterEdit as Mock).mockReturnValue(true); + }); + + it('should include file content in llmContent after successful edit', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'This is the original content.'; + const newContent = 'This is the modified content.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'original', + new_string: 'modified', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).toContain(newContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(newContent); + }); + + it('should include file content in llmContent when creating a new file', async () => { + const newFileName = 'new_file.txt'; + const newFilePath = path.join(rootDir, newFileName); + const fileContent = 'Content for the new file.'; + + const params: EditToolParams = { + file_path: newFilePath, + old_string: '', + new_string: fileContent, + }; + + (mockConfig.getApprovalMode as Mock).mockReturnValueOnce( + ApprovalMode.AUTO_EDIT, + ); + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Created new file/); + expect(result.llmContent).toContain(fileContent); + expect(fs.existsSync(newFilePath)).toBe(true); + expect(fs.readFileSync(newFilePath, 'utf8')).toBe(fileContent); + }); + + it('should include file content in llmContent when replacing multiple occurrences', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'old text old text old text'; + const expectedContent = 'new text new text new text'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + expected_replacements: 3, + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).toContain(expectedContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(expectedContent); + }); + + it('should include file content even when user modified the new_string', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'This is some old text.'; + const newContent = 'This is some new text.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + modified_by_user: true, + }; + + (mockConfig.getApprovalMode as Mock).mockReturnValueOnce( + ApprovalMode.AUTO_EDIT, + ); + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch( + /User modified the `new_string` content/, + ); + expect(result.llmContent).toContain(newContent); + }); + }); + + describe('readAfterEdit disabled', () => { + beforeEach(() => { + (mockConfig.getReadAfterEdit as Mock).mockReturnValue(false); + }); + + it('should NOT include file content in llmContent after successful edit when disabled', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'This is the original content.'; + const newContent = 'This is the modified content.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'original', + new_string: 'modified', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).not.toContain(newContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(newContent); + }); + + it('should NOT include file content when creating a new file and feature is disabled', async () => { + const newFileName = 'new_file.txt'; + const newFilePath = path.join(rootDir, newFileName); + const fileContent = 'Content for the new file.'; + + const params: EditToolParams = { + file_path: newFilePath, + old_string: '', + new_string: fileContent, + }; + + (mockConfig.getApprovalMode as Mock).mockReturnValueOnce( + ApprovalMode.AUTO_EDIT, + ); + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Created new file/); + expect(result.llmContent).not.toContain(fileContent); + expect(fs.existsSync(newFilePath)).toBe(true); + expect(fs.readFileSync(newFilePath, 'utf8')).toBe(fileContent); + }); + + it('should NOT include file content when replacing multiple occurrences and feature is disabled', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'old text old text old text'; + const expectedContent = 'new text new text new text'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + expected_replacements: 3, + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/Successfully modified file/); + expect(result.llmContent).not.toContain(expectedContent); + expect(fs.readFileSync(filePath, 'utf8')).toBe(expectedContent); + }); + }); + + describe('Error cases with readAfterEdit', () => { + beforeEach(() => { + (mockConfig.getReadAfterEdit as Mock).mockReturnValue(true); + }); + + it('should not include file content in llmContent when edit fails', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const initialContent = 'Some content.'; + + fs.writeFileSync(filePath, initialContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: 'nonexistent', + new_string: 'replacement', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch( + /0 occurrences found for old_string in/, + ); + expect(result.llmContent).not.toContain(initialContent); // Should not include file content on error + expect(fs.readFileSync(filePath, 'utf8')).toBe(initialContent); // File should be unchanged + }); + + it('should not include file content in llmContent when file already exists during creation', async () => { + const testFile = 'test.txt'; + const filePath = path.join(rootDir, testFile); + const existingContent = 'Existing content'; + + fs.writeFileSync(filePath, existingContent, 'utf8'); + + const params: EditToolParams = { + file_path: filePath, + old_string: '', + new_string: 'new content', + }; + + const invocation = tool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(result.llmContent).toMatch(/File already exists, cannot create/); + expect(result.llmContent).not.toContain(existingContent); // Should not include file content on error + expect(fs.readFileSync(filePath, 'utf8')).toBe(existingContent); // File should be unchanged + }); + }); +}); diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 8d90dfe45..bbbe34de2 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -384,8 +384,13 @@ class EditToolInvocation implements ToolInvocation { ); } + let llmContent = llmSuccessMessageParts.join(' '); + if (this.config.getReadAfterEdit()) { + llmContent += `\n${editData.newContent}`; + } + return { - llmContent: llmSuccessMessageParts.join(' '), + llmContent, returnDisplay: displayResult, }; } catch (error) {