Skip to content

Commit 820821f

Browse files
authored
Tokens can be cached beyond the lifetime of the (http) transport. (#834)
1 parent e6db415 commit 820821f

File tree

12 files changed

+538
-373
lines changed

12 files changed

+538
-373
lines changed

src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions
8686
/// </para>
8787
/// </remarks>
8888
public IDictionary<string, string> AdditionalAuthorizationParameters { get; set; } = new Dictionary<string, string>();
89+
90+
/// <summary>
91+
/// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport.
92+
/// If none is provided, tokens will be cached with the transport.
93+
/// </summary>
94+
public ITokenCache? TokenCache { get; set; }
8995
}

src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
namespace ModelContextProtocol.Authentication;
1414

1515
/// <summary>
16-
/// A generic implementation of an OAuth authorization provider for MCP. This does not do any advanced token
17-
/// protection or caching - it acquires a token and server metadata and holds it in memory.
18-
/// This is suitable for demonstration and development purposes.
16+
/// A generic implementation of an OAuth authorization provider.
1917
/// </summary>
2018
internal sealed partial class ClientOAuthProvider
2119
{
@@ -24,6 +22,8 @@ internal sealed partial class ClientOAuthProvider
2422
/// </summary>
2523
private const string BearerScheme = "Bearer";
2624

25+
private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"];
26+
2727
private readonly Uri _serverUrl;
2828
private readonly Uri _redirectUri;
2929
private readonly string[]? _scopes;
@@ -43,7 +43,7 @@ internal sealed partial class ClientOAuthProvider
4343
private string? _clientId;
4444
private string? _clientSecret;
4545

46-
private TokenContainer? _token;
46+
private ITokenCache _tokenCache;
4747
private AuthorizationServerMetadata? _authServerMetadata;
4848

4949
/// <summary>
@@ -57,11 +57,11 @@ internal sealed partial class ClientOAuthProvider
5757
public ClientOAuthProvider(
5858
Uri serverUrl,
5959
ClientOAuthOptions options,
60-
HttpClient? httpClient = null,
60+
HttpClient httpClient,
6161
ILoggerFactory? loggerFactory = null)
6262
{
6363
_serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl));
64-
_httpClient = httpClient ?? new HttpClient();
64+
_httpClient = httpClient;
6565
_logger = (ILogger?)loggerFactory?.CreateLogger<ClientOAuthProvider>() ?? NullLogger.Instance;
6666

6767
if (options is null)
@@ -85,6 +85,7 @@ public ClientOAuthProvider(
8585
_dcrClientUri = options.DynamicClientRegistration?.ClientUri;
8686
_dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken;
8787
_dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate;
88+
_tokenCache = options.TokenCache ?? new InMemoryTokenCache();
8889
}
8990

9091
/// <summary>
@@ -138,20 +139,21 @@ public ClientOAuthProvider(
138139
{
139140
ThrowIfNotBearerScheme(scheme);
140141

142+
var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false);
143+
141144
// Return the token if it's valid
142-
if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
145+
if (tokens is not null && !tokens.IsExpired)
143146
{
144-
return _token.AccessToken;
147+
return tokens.AccessToken;
145148
}
146149

147-
// Try to refresh the token if we have a refresh token
148-
if (_token?.RefreshToken != null && _authServerMetadata != null)
150+
// Try to refresh the access token if it is invalid and we have a refresh token.
151+
if (tokens?.RefreshToken != null && _authServerMetadata != null)
149152
{
150-
var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
151-
if (newToken != null)
153+
var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
154+
if (newTokens is not null)
152155
{
153-
_token = newToken;
154-
return _token.AccessToken;
156+
return newTokens.AccessToken;
155157
}
156158
}
157159

@@ -174,12 +176,7 @@ public async Task HandleUnauthorizedResponseAsync(
174176
HttpResponseMessage response,
175177
CancellationToken cancellationToken = default)
176178
{
177-
// This provider only supports Bearer scheme
178-
if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase))
179-
{
180-
throw new InvalidOperationException("This credential provider only supports the Bearer scheme");
181-
}
182-
179+
ThrowIfNotBearerScheme(scheme);
183180
await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false);
184181
}
185182

@@ -223,26 +220,29 @@ private async Task PerformOAuthAuthorizationAsync(
223220
// Store auth server metadata for future refresh operations
224221
_authServerMetadata = authServerMetadata;
225222

223+
// The existing access token must be invalid to have resulted in a 401 response, but refresh might still work.
224+
if (await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { RefreshToken: {} refreshToken })
225+
{
226+
var refreshedTokens = await RefreshTokenAsync(refreshToken, protectedResourceMetadata.Resource, authServerMetadata, cancellationToken).ConfigureAwait(false);
227+
if (refreshedTokens is not null)
228+
{
229+
// A non-null result indicates the refresh succeeded and the new tokens have been stored.
230+
return;
231+
}
232+
}
233+
226234
// Perform dynamic client registration if needed
227235
if (string.IsNullOrEmpty(_clientId))
228236
{
229237
await PerformDynamicClientRegistrationAsync(authServerMetadata, cancellationToken).ConfigureAwait(false);
230238
}
231239

232240
// Perform the OAuth flow
233-
var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);
234-
235-
if (token is null)
236-
{
237-
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token.");
238-
}
241+
await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);
239242

240-
_token = token;
241243
LogOAuthAuthorizationCompleted();
242244
}
243245

244-
private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"];
245-
246246
private async Task<AuthorizationServerMetadata> GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken)
247247
{
248248
if (authServerUri.OriginalString.Length == 0 ||
@@ -298,7 +298,7 @@ private async Task<AuthorizationServerMetadata> GetAuthServerMetadataAsync(Uri a
298298
throw new McpException($"Failed to find .well-known/openid-configuration or .well-known/oauth-authorization-server metadata for authorization server: '{authServerUri}'");
299299
}
300300

301-
private async Task<TokenContainer> RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken)
301+
private async Task<TokenContainer?> RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken)
302302
{
303303
var requestContent = new FormUrlEncodedContent(new Dictionary<string, string>
304304
{
@@ -314,10 +314,17 @@ private async Task<TokenContainer> RefreshTokenAsync(string refreshToken, Uri re
314314
Content = requestContent
315315
};
316316

317-
return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false);
317+
using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
318+
319+
if (!httpResponse.IsSuccessStatusCode)
320+
{
321+
return null;
322+
}
323+
324+
return await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
318325
}
319326

320-
private async Task<TokenContainer?> InitiateAuthorizationCodeFlowAsync(
327+
private async Task InitiateAuthorizationCodeFlowAsync(
321328
ProtectedResourceMetadata protectedResourceMetadata,
322329
AuthorizationServerMetadata authServerMetadata,
323330
CancellationToken cancellationToken)
@@ -330,10 +337,10 @@ private async Task<TokenContainer> RefreshTokenAsync(string refreshToken, Uri re
330337

331338
if (string.IsNullOrEmpty(authCode))
332339
{
333-
return null;
340+
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty authorization code.");
334341
}
335342

336-
return await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false);
343+
await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false);
337344
}
338345

339346
private Uri BuildAuthorizationUrl(
@@ -377,7 +384,7 @@ private Uri BuildAuthorizationUrl(
377384
return uriBuilder.Uri;
378385
}
379386

380-
private async Task<TokenContainer> ExchangeCodeForTokenAsync(
387+
private async Task ExchangeCodeForTokenAsync(
381388
ProtectedResourceMetadata protectedResourceMetadata,
382389
AuthorizationServerMetadata authServerMetadata,
383390
string authorizationCode,
@@ -400,24 +407,39 @@ private async Task<TokenContainer> ExchangeCodeForTokenAsync(
400407
Content = requestContent
401408
};
402409

403-
return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false);
404-
}
405-
406-
private async Task<TokenContainer> FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken)
407-
{
408410
using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
409411
httpResponse.EnsureSuccessStatusCode();
412+
await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
413+
}
410414

411-
using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
412-
var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false);
415+
private async Task<TokenContainer> HandleSuccessfulTokenResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken)
416+
{
417+
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
418+
var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false);
413419

414420
if (tokenResponse is null)
415421
{
416-
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response.");
422+
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an empty response.");
423+
}
424+
425+
if (tokenResponse.TokenType is null || !string.Equals(tokenResponse.TokenType, BearerScheme, StringComparison.OrdinalIgnoreCase))
426+
{
427+
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an unsupported token type: '{tokenResponse.TokenType ?? "<null>"}'. Only 'Bearer' tokens are supported.");
417428
}
418429

419-
tokenResponse.ObtainedAt = DateTimeOffset.UtcNow;
420-
return tokenResponse;
430+
TokenContainer tokens = new()
431+
{
432+
AccessToken = tokenResponse.AccessToken,
433+
RefreshToken = tokenResponse.RefreshToken,
434+
ExpiresIn = tokenResponse.ExpiresIn,
435+
TokenType = tokenResponse.TokenType,
436+
Scope = tokenResponse.Scope,
437+
ObtainedAt = DateTimeOffset.UtcNow,
438+
};
439+
440+
await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false);
441+
442+
return tokens;
421443
}
422444

423445
/// <summary>
@@ -581,7 +603,7 @@ private async Task<ProtectedResourceMetadata> ExtractProtectedResourceMetadata(H
581603
string? resourceMetadataUrl = null;
582604
foreach (var header in response.Headers.WwwAuthenticate)
583605
{
584-
if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter))
606+
if (string.Equals(header.Scheme, BearerScheme, StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter))
585607
{
586608
resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata");
587609
if (resourceMetadataUrl != null)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
namespace ModelContextProtocol.Authentication;
2+
3+
/// <summary>
4+
/// Allows the client to cache access tokens beyond the lifetime of the transport.
5+
/// </summary>
6+
public interface ITokenCache
7+
{
8+
/// <summary>
9+
/// Cache the token. After a new access token is acquired, this method is invoked to store it.
10+
/// </summary>
11+
ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken);
12+
13+
/// <summary>
14+
/// Get the cached token. This method is invoked for every request.
15+
/// </summary>
16+
ValueTask<TokenContainer?> GetTokensAsync(CancellationToken cancellationToken);
17+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
namespace ModelContextProtocol.Authentication;
3+
4+
/// <summary>
5+
/// Caches the token in-memory within this instance.
6+
/// </summary>
7+
internal class InMemoryTokenCache : ITokenCache
8+
{
9+
private TokenContainer? _tokens;
10+
11+
/// <summary>
12+
/// Cache the token.
13+
/// </summary>
14+
public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken)
15+
{
16+
_tokens = tokens;
17+
return default;
18+
}
19+
20+
/// <summary>
21+
/// Get the cached token.
22+
/// </summary>
23+
public ValueTask<TokenContainer?> GetTokensAsync(CancellationToken cancellationToken)
24+
{
25+
return new ValueTask<TokenContainer?>(_tokens);
26+
}
27+
}
Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,39 @@
1-
using System.Text.Json.Serialization;
2-
31
namespace ModelContextProtocol.Authentication;
42

53
/// <summary>
6-
/// Represents a token response from the OAuth server.
4+
/// Represents a cacheable combination of tokens ready to be used for authentication.
75
/// </summary>
8-
internal sealed class TokenContainer
6+
public sealed class TokenContainer
97
{
8+
/// <summary>
9+
/// Gets or sets the token type (typically "Bearer").
10+
/// </summary>
11+
public required string TokenType { get; set; }
12+
1013
/// <summary>
1114
/// Gets or sets the access token.
1215
/// </summary>
13-
[JsonPropertyName("access_token")]
14-
public string AccessToken { get; set; } = string.Empty;
16+
public required string AccessToken { get; set; }
1517

1618
/// <summary>
1719
/// Gets or sets the refresh token.
1820
/// </summary>
19-
[JsonPropertyName("refresh_token")]
2021
public string? RefreshToken { get; set; }
2122

2223
/// <summary>
2324
/// Gets or sets the number of seconds until the access token expires.
2425
/// </summary>
25-
[JsonPropertyName("expires_in")]
26-
public int ExpiresIn { get; set; }
27-
28-
/// <summary>
29-
/// Gets or sets the extended expiration time in seconds.
30-
/// </summary>
31-
[JsonPropertyName("ext_expires_in")]
32-
public int ExtExpiresIn { get; set; }
33-
34-
/// <summary>
35-
/// Gets or sets the token type (typically "Bearer").
36-
/// </summary>
37-
[JsonPropertyName("token_type")]
38-
public string TokenType { get; set; } = string.Empty;
26+
public int? ExpiresIn { get; set; }
3927

4028
/// <summary>
4129
/// Gets or sets the scope of the access token.
4230
/// </summary>
43-
[JsonPropertyName("scope")]
44-
public string Scope { get; set; } = string.Empty;
31+
public string? Scope { get; set; }
4532

4633
/// <summary>
4734
/// Gets or sets the timestamp when the token was obtained.
4835
/// </summary>
49-
[JsonIgnore]
50-
public DateTimeOffset ObtainedAt { get; set; }
36+
public required DateTimeOffset ObtainedAt { get; set; }
5137

52-
/// <summary>
53-
/// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn.
54-
/// </summary>
55-
[JsonIgnore]
56-
public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn);
38+
internal bool IsExpired => ExpiresIn is not null && DateTimeOffset.UtcNow >= ObtainedAt.AddSeconds(ExpiresIn.Value);
5739
}

0 commit comments

Comments
 (0)