Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eng/packages/General.props
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<PackageVersion Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
<PackageVersion Include="System.ComponentModel.Annotations" Version="5.0.0" />
<PackageVersion Include="System.Memory" Version="4.5.5" />
<PackageVersion Include="System.Numerics.Tensors" Version="$(SystemNumericsTensorsVersion)" />
<PackageVersion Include="System.Private.Uri" Version="4.3.2" />
<PackageVersion Include="System.Runtime.Caching" Version="$(SystemRuntimeCachingVersion)" />
<PackageVersion Include="System.Runtime.CompilerServices.Unsafe" Version="6.1.0" />
Expand Down
2 changes: 1 addition & 1 deletion eng/packages/TestOnly.props
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<PackageVersion Include="Polly.Testing" Version="8.4.2" />
<PackageVersion Include="StrongNamer" Version="0.2.5" />
<PackageVersion Include="System.Configuration.ConfigurationManager" Version="$(SystemConfigurationConfigurationManagerVersion)" />
<PackageVersion Include="System.Numerics.Tensors" Version="$(SystemNumericsTensorsVersion)" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="Verify.Xunit" Version="28.15.0" />
<PackageVersion Include="Xunit.Combinatorial" Version="1.6.24" />
<PackageVersion Include="xunit.extensibility.execution" Version="$(XUnitVersion)" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Extensions.AI;

/// <summary>
/// Represents a strategy capable of selecting a reduced set of tools for a chat request.
/// </summary>
/// <remarks>
/// A tool reduction strategy is invoked prior to sending a request to an underlying <see cref="IChatClient"/>,
/// enabling scenarios where a large tool catalog must be trimmed to fit provider limits or to improve model
/// tool selection quality.
/// <para>
/// The implementation should return a non-<see langword="null"/> enumerable. Returning the original
/// <see cref="ChatOptions.Tools"/> instance indicates no change. Returning a different enumerable indicates
/// the caller may replace the existing tool list.
/// </para>
/// </remarks>
[Experimental("MEAI001")]
public interface IToolReductionStrategy
{
/// <summary>
/// Selects the tools that should be included for a specific request.
/// </summary>
/// <param name="messages">The chat messages for the request. This is an <see cref="IEnumerable{T}"/> to avoid premature materialization.</param>
/// <param name="options">The chat options for the request (may be <see langword="null"/>).</param>
/// <param name="cancellationToken">A token to observe cancellation.</param>
/// <returns>
/// A (possibly reduced) enumerable of <see cref="AITool"/> instances. Must never be <see langword="null"/>.
/// Returning the same instance referenced by <paramref name="options"/>.<see cref="ChatOptions.Tools"/> signals no change.
/// </returns>
Task<IEnumerable<AITool>> SelectToolsForRequestAsync(
IEnumerable<ChatMessage> messages,
ChatOptions? options,
CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="System.Numerics.Tensors" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

/// <summary>Extension methods for adding tool reduction middleware to a chat client pipeline.</summary>
[Experimental("MEAI001")]
public static class ChatClientBuilderToolReductionExtensions
{
/// <summary>
/// Adds tool reduction to the chat client pipeline using the specified <paramref name="strategy"/>.
/// </summary>
/// <param name="builder">The chat client builder.</param>
/// <param name="strategy">The reduction strategy.</param>
/// <returns>The original builder for chaining.</returns>
/// <exception cref="ArgumentNullException">If <paramref name="builder"/> or <paramref name="strategy"/> is <see langword="null"/>.</exception>
/// <remarks>
/// This should typically appear in the pipeline before function invocation middleware so that only the reduced tools
/// are exposed to the underlying provider.
/// </remarks>
public static ChatClientBuilder UseToolReduction(this ChatClientBuilder builder, IToolReductionStrategy strategy)
{
_ = Throw.IfNull(builder);
_ = Throw.IfNull(strategy);

return builder.Use(inner => new ToolReducingChatClient(inner, strategy));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Numerics.Tensors;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;

#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14

/// <summary>
/// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context.
/// </summary>
/// <remarks>
/// The strategy embeds each tool (name + description by default) once (cached) and embeds the current
/// conversation content each request. It then selects the top <c>toolLimit</c> tools by similarity.
/// </remarks>
[Experimental("MEAI001")]
public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy
{
private readonly ConditionalWeakTable<AITool, Embedding<float>> _toolEmbeddingsCache = new();
private readonly IEmbeddingGenerator<string, Embedding<float>> _embeddingGenerator;
private readonly int _toolLimit;

private Func<AITool, string> _toolEmbeddingTextFactory = static t =>
{
if (string.IsNullOrWhiteSpace(t.Name))
{
return t.Description;
}

if (string.IsNullOrWhiteSpace(t.Description))
{
return t.Name;
}

return t.Name + "\n" + t.Description;
};

private Func<IEnumerable<ChatMessage>, string> _messagesEmbeddingTextFactory = static messages =>
{
var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s));
return string.Join("\n", messageTexts);
};

private Func<ReadOnlyMemory<float>, ReadOnlyMemory<float>, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span);

/// <summary>
/// Initializes a new instance of the <see cref="EmbeddingToolReductionStrategy"/> class.
/// </summary>
/// <param name="embeddingGenerator">Embedding generator used to produce embeddings.</param>
/// <param name="toolLimit">Maximum number of tools to return. Must be greater than zero.</param>
public EmbeddingToolReductionStrategy(
IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator,
int toolLimit)
{
_embeddingGenerator = Throw.IfNull(embeddingGenerator);
_toolLimit = Throw.IfLessThanOrEqual(toolLimit, min: 0);
}

/// <summary>
/// Gets or sets a delegate used to produce the text to embed for a tool.
/// Defaults to: <c>Name + "\n" + Description</c> (omitting empty parts).
/// </summary>
public Func<AITool, string> ToolEmbeddingTextFactory
{
get => _toolEmbeddingTextFactory;
set => _toolEmbeddingTextFactory = Throw.IfNull(value);
}

/// <summary>
/// Gets or sets the factory function used to generate a single text string from a collection of chat messages for
/// embedding purposes.
/// </summary>
public Func<IEnumerable<ChatMessage>, string> MessagesEmbeddingTextFactory
{
get => _messagesEmbeddingTextFactory;
set => _messagesEmbeddingTextFactory = Throw.IfNull(value);
}

/// <summary>
/// Gets or sets a similarity function applied to (query, tool) embedding vectors. Defaults to cosine similarity.
/// </summary>
public Func<ReadOnlyMemory<float>, ReadOnlyMemory<float>, float> Similarity
{
get => _similarity;
set => _similarity = Throw.IfNull(value);
}

/// <summary>
/// Gets or sets a value indicating whether tool embeddings are cached. Defaults to <see langword="true"/>.
/// </summary>
public bool EnableEmbeddingCaching { get; set; } = true;

/// <summary>
/// Gets or sets a value indicating whether to preserve original ordering of selected tools.
/// If <see langword="false"/> (default), tools are ordered by descending similarity.
/// If <see langword="true"/>, the top-N tools by similarity are re-emitted in their original order.
/// </summary>
public bool PreserveOriginalOrdering { get; set; }

/// <inheritdoc />
public async Task<IEnumerable<AITool>> SelectToolsForRequestAsync(
IEnumerable<ChatMessage> messages,
ChatOptions? options,
CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);

if (options?.Tools is not { Count: > 0 } tools)
{
return options?.Tools ?? [];
}

Debug.Assert(_toolLimit > 0, "Expected the tool count limit to be greater than zero.");

if (tools.Count <= _toolLimit)
{
// No reduction necessary.
return tools;
}

// Build query text from recent messages.
var queryText = MessagesEmbeddingTextFactory(messages);
if (string.IsNullOrWhiteSpace(queryText))
{
// We couldn't build a meaningful query, likely because the message list was empty.
// We'll just return a truncated list of tools.
return tools.Take(_toolLimit);
}

// Ensure embeddings for any uncached tools are generated in a batch.
var toolEmbeddings = await GetToolEmbeddingsAsync(tools, cancellationToken).ConfigureAwait(false);

// Generate the query embedding.
var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false);
var queryVector = queryEmbedding.Vector;

// Compute rankings.
var ranked = tools
.Zip(toolEmbeddings, static (tool, embedding) => (Tool: tool, Embedding: embedding))
.Select((t, i) => (t.Tool, Index: i, Score: Similarity(queryVector, t.Embedding.Vector)))
.OrderByDescending(t => t.Score)
.Take(_toolLimit);

if (PreserveOriginalOrdering)
{
ranked = ranked.OrderBy(t => t.Index);
}

return ranked.Select(t => t.Tool);
}

private async Task<IReadOnlyList<Embedding<float>>> GetToolEmbeddingsAsync(IList<AITool> tools, CancellationToken cancellationToken)
{
if (!EnableEmbeddingCaching)
{
// Embed all tools in one batch; do not store in cache.
return await ComputeEmbeddingsAsync(tools.Select(t => ToolEmbeddingTextFactory(t)), expectedCount: tools.Count);
}

var result = new Embedding<float>[tools.Count];
var cacheMisses = new List<(AITool Tool, int Index)>(tools.Count);

for (var i = 0; i < tools.Count; i++)
{
if (_toolEmbeddingsCache.TryGetValue(tools[i], out var embedding))
{
result[i] = embedding;
}
else
{
cacheMisses.Add((tools[i], i));
}
}

if (cacheMisses.Count == 0)
{
return result;
}

var uncachedEmbeddings = await ComputeEmbeddingsAsync(cacheMisses.Select(t => ToolEmbeddingTextFactory(t.Tool)), expectedCount: cacheMisses.Count);

for (var i = 0; i < cacheMisses.Count; i++)
{
var embedding = uncachedEmbeddings[i];
result[cacheMisses[i].Index] = embedding;
_toolEmbeddingsCache.Add(cacheMisses[i].Tool, embedding);
}

return result;

async ValueTask<GeneratedEmbeddings<Embedding<float>>> ComputeEmbeddingsAsync(IEnumerable<string> texts, int expectedCount)
{
var embeddings = await _embeddingGenerator.GenerateAsync(texts, cancellationToken: cancellationToken).ConfigureAwait(false);
if (embeddings.Count != expectedCount)
{
Throw.InvalidOperationException($"Expected {expectedCount} embeddings, got {embeddings.Count}.");
}

return embeddings;
}
}
}
Loading
Loading