Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public MockServerManagedIdentityProvider(string testName)

public bool EnableMIChecking { get; set; }

public Guid GetServerApplicationId(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true)
public Task<ServerApplicationIdentity> GetServerApplicationIdentityAsync(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true)
{
return Guid.Empty;
return Task.FromResult(new ServerApplicationIdentity(Guid.Empty, Guid.Empty));
}

public LocalServerType GetServerType(IEcsManagement ecsManagement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using Microsoft.Azure.Commands.StorageSync.Common;
using Microsoft.Azure.Commands.StorageSync.Common.Extensions;
using Microsoft.Azure.Commands.StorageSync.InternalObjects;
using Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity;
using Microsoft.Azure.Management.StorageSync.Models;
using Newtonsoft.Json;
using System;
Expand Down Expand Up @@ -372,17 +373,18 @@ private bool TryCreateDirectory(string monitoringDataPath, out DirectoryInfo dir
return false;
}

public override Guid? GetApplicationIdOrNull()
public override ServerApplicationIdentity GetServerApplicationIdentityOrNull()
{
if(TestName == "TestNewRegisteredServerWithIdentityError")
var testTenantGuid = new Guid("0483643a-cb2f-462a-bc27-1a270e5bdc0a");
if (TestName == "TestNewRegisteredServerWithIdentityError")
{
return null;
}
if(TestName == "TestPatchRegisteredServer")
{
return new Guid("3b990c8b-9607-4c2a-8b04-1d41985facca");
return new ServerApplicationIdentity(new Guid("3b990c8b-9607-4c2a-8b04-1d41985facca"), testTenantGuid);
}
return Guid.NewGuid();
return new ServerApplicationIdentity(Guid.NewGuid(), testTenantGuid);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ public MockSyncServerRegistrationClientBase(IEcsManagement ecsManagementInteropC
/// <summary>
/// This function will return the application id of the server if it is available.
/// </summary>
/// <returns>Application Id or null</returns>
public abstract Guid? GetApplicationIdOrNull();
/// <returns>ServerApplicationIdentity or null</returns>
public abstract ServerApplicationIdentity GetServerApplicationIdentityOrNull();

/// <summary>
/// Validate sync server registration.
Expand Down Expand Up @@ -146,6 +146,7 @@ public void Dispose()
/// 4. Get ClusterInfo
/// 5. Populate RegistrationServerResource
/// </summary>
/// <param name="storageSyncServiceTenantId">Storage Sync Service Tenant Id</param>
/// <param name="managementEndpointUri">Management endpoint Uri</param>
/// <param name="subscriptionId">Subscription Id</param>
/// <param name="storageSyncServiceName">Storage Sync Service Name</param>
Expand All @@ -162,6 +163,7 @@ public void Dispose()
/// </exception>
/// <exception cref="ServerRegistrationException"></exception>
public RegisteredServer Register(
string storageSyncServiceTenantId,
Uri managementEndpointUri,
Guid subscriptionId,
string storageSyncServiceName,
Expand All @@ -176,7 +178,19 @@ public RegisteredServer Register(
bool assignIdentity)
{
// Get ApplicationId
Guid? applicationId = assignIdentity ? GetApplicationIdOrNull() : null;
ServerApplicationIdentity serverApplicationIdentity = assignIdentity ? GetServerApplicationIdentityOrNull() : null;
// Discover the server type , Get the application id,
Guid? applicationId = serverApplicationIdentity?.ApplicationId;

if (serverApplicationIdentity != null && serverApplicationIdentity.TenantId != Guid.Empty)
{
// Check that tenants match
if (!string.Equals(storageSyncServiceTenantId, serverApplicationIdentity.TenantId.ToString(), StringComparison.OrdinalIgnoreCase))
{
throw new ServerRegistrationException(
$"Cross-tenant registration is not allowed. The server belongs to tenant '{serverApplicationIdentity.TenantId}' but the Storage Sync Service is in tenant '{storageSyncServiceTenantId}'.");
}
}

#pragma warning disable CA1416 // Validate platform compatibility
//RegistryUtility.WriteValue(StorageSyncConstants.ServerAuthRegistryKeyName,
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/StorageSync/StorageSync/ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
- Additional information about change #1
-->
## Upcoming Release

* Fixed security bug in checking tenant id for MI server registration
## Version 2.5.1
* Fixed security bug in token acquisition for MI server registration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using System.Management;
using System.Management.Automation;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace Commands.StorageSync.Interop.Clients
{
Expand Down Expand Up @@ -406,13 +407,13 @@ private bool TryCreateDirectory(string monitoringDataPath, out DirectoryInfo dir
/// This function will get the application id of the server if identity is available.
/// </summary>
/// <returns>Application id or null.</returns>
public override Guid? GetApplicationIdOrNull()
public async override Task<ServerApplicationIdentity> GetServerApplicationIdentityOrNull()
{
LocalServerType localServerType = this.ServerManagedIdentityProvider.GetServerType(this.EcsManagementInteropClient);

if(localServerType != LocalServerType.HybridServer)
if (localServerType != LocalServerType.HybridServer)
{
return this.ServerManagedIdentityProvider.GetServerApplicationId(localServerType, throwIfNotFound: true, validateSystemAssignedManagedIdentity: true);
return await this.ServerManagedIdentityProvider.GetServerApplicationIdentityAsync(localServerType, throwIfNotFound: true, validateSystemAssignedManagedIdentity: true);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public abstract ServerRegistrationData Setup(
/// This function will return the application id of the server if it is available.
/// </summary>
/// <returns>Application Id or null</returns>
public abstract Guid? GetApplicationIdOrNull();
public abstract Task<ServerApplicationIdentity> GetServerApplicationIdentityOrNull();

/// <summary>
/// Dispose method for cleaning Interop client object.
Expand All @@ -146,6 +146,7 @@ public void Dispose()
/// 3. Calls RegisterOnline callback to make ARM call (from caller context)
/// 4. Persists registered server resource from cloud to local FileSyncSvc service
/// </summary>
/// <param name="storageSyncServiceTenantId">Storage Sync Service TenantId</param>
/// <param name="managementEndpointUri">Management endpoint Uri</param>
/// <param name="subscriptionId">Subscription Id</param>
/// <param name="storageSyncServiceName">Storage Sync Service Name</param>
Expand All @@ -160,6 +161,7 @@ public void Dispose()
/// <param name="assignIdentity">Assign Identity</param>
/// <returns>Registered Server Resource</returns>
public RegisteredServer Register(
string storageSyncServiceTenantId,
Uri managementEndpointUri,
Guid subscriptionId,
string storageSyncServiceName,
Expand All @@ -174,7 +176,17 @@ public RegisteredServer Register(
bool assignIdentity)
{
// Discover the server type , Get the application id,
Guid? applicationId = assignIdentity ? GetApplicationIdOrNull() : null;
ServerApplicationIdentity serverApplicationIdentity = assignIdentity ? GetServerApplicationIdentityOrNull().GetAwaiter().GetResult() : null;
Guid? applicationId = serverApplicationIdentity?.ApplicationId;

if (serverApplicationIdentity != null && serverApplicationIdentity.TenantId != Guid.Empty)
{
// Check that tenants match
if (!string.Equals(storageSyncServiceTenantId, serverApplicationIdentity.TenantId.ToString(), StringComparison.OrdinalIgnoreCase))
{
throw new ServerRegistrationException(ServerRegistrationErrorCode.ServerAndSyncServiceTenantMismatched);
}
}

// Set the registry key for ServerAuthType
RegistryUtility.WriteValue(StorageSyncConstants.ServerAuthRegistryKeyName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ public enum ServerRegistrationErrorCode
/// <summary>
/// Monitoring Service Endpoint Invalid or Not Set
/// </summary>
MonitoringServiceEndpointInvalidOrNotSet
MonitoringServiceEndpointInvalidOrNotSet,

/// <summary>
/// Server and Sync Service Tenant Mismatched.
/// </summary>
ServerAndSyncServiceTenantMismatched

}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Commands.StorageSync.Interop.Interfaces;
using Microsoft.Azure.Commands.StorageSync.Interop.Enums;
using System;
using System.Threading.Tasks;

namespace Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity
{
Expand All @@ -20,14 +21,11 @@ public interface IServerManagedIdentityProvider
LocalServerType GetServerType(IEcsManagement ecsManagement);

/// <summary>
/// Gets the server's application id by trying to get a token and parsing for the oid
/// We choose to get the applicationId from the token rather than making a Get call on the resource
/// because we don't know the permissions the user has on the resource
/// Gets the server's application identity (application ID and tenant ID) asynchronously by trying to get a token from the Arc/Azure IMDS endpoint and parsing for the oid and tenant ID.
/// </summary>
/// <param name="serverType">ServerType: Hybrid or Azure</param>
/// <param name="throwIfNotFound">Whether to throw an exception if an Application ID is not available</param>
/// <param name="validateSystemAssignedManagedIdentity">Whether to validate that the Application Id belongs to a System-Assigned Managed Identity</param>
/// <returns>Server's applicationId if it's available, Guid.Empty otherwise</returns>
Guid GetServerApplicationId(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true);
Task<ServerApplicationIdentity> GetServerApplicationIdentityAsync(LocalServerType serverType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public interface ISyncServerRegistration : IDisposable
/// 2. Sets up ServerRegistrationData
/// 3. Calls RegisterOnline callback to make ARM call (from caller context)
/// 4. Persists registered server resource from cloud to local FileSyncSvc service
/// <param name="storageSyncServiceTenantId">Storage Sync Service TenantId</param>
/// <param name="managementEndpointUri">Management endpoint Uri</param>
/// <param name="subscriptionId">Subscription Id</param>
/// <param name="storageSyncServiceName">Storage Sync Service Name</param>
Expand All @@ -47,6 +48,7 @@ public interface ISyncServerRegistration : IDisposable
/// <returns>Registered Server Resource</returns>
/// </summary>
RegisteredServer Register(
string storageSyncServiceTenantId,
Uri managementEndpointUri,
Guid subscriptionId,
string storageSyncServiceName,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;

namespace Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity
{
/// <summary>
/// ServerApplicationIdentity represents the server's application identity with application ID and tenant ID.
/// </summary>
public class ServerApplicationIdentity
{
public Guid ApplicationId { get; set; }
public Guid TenantId { get; set; }

public ServerApplicationIdentity(Guid applicationId, Guid tenantId)
{
ApplicationId = applicationId;
TenantId = tenantId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace Microsoft.Azure.Commands.StorageSync.Interop.ManagedIdentity
/// </summary>
public class ServerManagedIdentityProvider : IServerManagedIdentityProvider
{
private const string ResourceStorageUri = "https://storage.azure.com/";

public bool EnableMIChecking { get; set; }

public Action<string, EventLevel> TraceLog { get; private set; }
Expand Down Expand Up @@ -47,19 +49,20 @@ public LocalServerType GetServerType(IEcsManagement ecsManagement)
/// </summary>
/// <param name="localServerType">ServerType: Hybrid or Azure</param>
/// <param name="throwIfNotFound">Whether to throw an exception if an Application ID is not available</param>
/// <param name="validateSAMI">Whether to validate that the Application Id belongs to a System-Assigned Managed Identity</param>
/// <param name="validateSystemAssignedManagedIdentity">Whether to validate that the Application Id belongs to a System-Assigned Managed Identity</param>
/// <returns>Server's applicationId if it's available, Guid.Empty otherwise</returns>
public Guid GetServerApplicationId(LocalServerType localServerType, bool throwIfNotFound = true, bool validateSAMI = true)
public async Task<ServerApplicationIdentity> GetServerApplicationIdentityAsync(LocalServerType localServerType, bool throwIfNotFound = true, bool validateSystemAssignedManagedIdentity = true)
{
var applicationId = Guid.Empty;
Guid applicationId = Guid.Empty;
Guid tenantId = Guid.Empty;

if (EnableMIChecking)
{
try
{
if (localServerType == LocalServerType.HybridServer)
{
return applicationId;
return new ServerApplicationIdentity(applicationId, tenantId);
}

// We need to use the https://storage.azure.com resource, as this provides us the x-ms-rid header to use for validation.
Expand All @@ -68,14 +71,12 @@ public Guid GetServerApplicationId(LocalServerType localServerType, bool throwIf
// When we cache token in ServerManagedIdentityTokenProvider, it will use ProtectedMemory to encrypt/decrypt the token,
// and this GetServerApplicationId can be triggered from server registration using PowerShell Core which causes that issue.
// So this is another reason we need to get the token from IMDS endpoint directly via ServerManagedIdentityUtils, not ServerManagedIdentityTokenProvider.
ServerManagedIdentityTokenResponse tokenResponse;

tokenResponse = ServerManagedIdentityUtils.GetManagedIdentityTokenResponseAsync(resource: "https://storage.azure.com/").GetAwaiter().GetResult();

ServerManagedIdentityTokenResponse tokenResponse = await ServerManagedIdentityUtils.GetManagedIdentityTokenResponseAsync(resource: ResourceStorageUri);

var token = tokenResponse.AccessToken;

applicationId = ServerManagedIdentityTokenHelper.GetTokenOid(token);
tenantId = ServerManagedIdentityTokenHelper.GetTokenTenantId(token);
}
catch (Exception)
{
Expand All @@ -90,7 +91,7 @@ public Guid GetServerApplicationId(LocalServerType localServerType, bool throwIf
TraceLog($"{nameof(EnableMIChecking)} is off.", EventLevel.Informational);
}

return applicationId;
return new ServerApplicationIdentity(applicationId, tenantId);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ public static class ServerManagedIdentityTokenHelper
private const string UserAssignedManagedIdentityResourceType = "Microsoft.ManagedIdentity/userAssignedIdentities";

/// <summary>
/// Gets the oid claim from the token payload
/// Gets the Oid claim from the token payload
/// </summary>
/// <param name="token"> the access token </param>
/// <returns> true, oid if successfully parsed, else return false, guid.empty </returns>
/// <param name="token">The access token </param>
/// <returns> The Oid as a Guid if successfully parsed, otherwise throws an exception </returns>
public static Guid GetTokenOid(string token)
{
// try to deserialize the json string to aadtoken object
Expand All @@ -25,6 +25,22 @@ public static Guid GetTokenOid(string token)
// parse the oid string to guid object
return Guid.Parse(aadToken?.Oid);
}
/// <summary>
/// Gets the tenantId claim from the token payload
/// </summary>
/// <param name="token"> The access token </param>
/// <returns> The tenantId as a Guid if successfully parsed, otherwise throws an exception </returns>
public static Guid GetTokenTenantId(string token)
{
// try to deserialize the json string to aadtoken object
var aadToken = TryGetAadTokenFromAccessTokenString(token);

if(!Guid.TryParse(aadToken.TenantId, out Guid tenantId))
{
throw new ArgumentException("Token TenantId is invalid");
}
return tenantId;
}

/// <summary>
/// Try to get the Managed Identity type based on the given token response.
Expand Down Expand Up @@ -106,12 +122,17 @@ public class AadToken

[JsonProperty(PropertyName = ManagedIdentityClaimNames.ManagedIdentityResourceId)]
public string MIResourceId { get; set; }

[JsonProperty(PropertyName = ManagedIdentityClaimNames.TenantId)]
public string TenantId { get; set; }
}

public static class ManagedIdentityClaimNames
{
public const string Oid = "oid";

public const string ManagedIdentityResourceId = "xms_mirid";

public const string TenantId = "tid";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ private static bool IsSecretFilePathValid(string secretFilePath)

// Expected form: %ProgramData%\AzureConnectedMachineAgent\Tokens\<guid>.key
var programData = Environment.GetEnvironmentVariable("ProgramData");

if (string.IsNullOrEmpty(programData))
{
// If ProgramData is not found, try to manually construct it using SystemDrive
Expand Down Expand Up @@ -410,4 +410,4 @@ private static LocalServerType GetLocalServerTypeFromRegistry()
return LocalServerType.HybridServer;
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,7 @@
<data name="AgentMI_ProgramDataNotFoundError" xml:space="preserve">
<value>GetEnvironmentVariable failed to find 'ProgramData'</value>
</data>
<data name="MissingAzureContextTenantId" xml:space="preserve">
<value>The given azure context does not have tenant id.</value>
</data>
</root>
Loading