From 8423f2155fe385b1c9b13d26fafb9e42389f50c5 Mon Sep 17 00:00:00 2001 From: Ahmed Afifi Date: Thu, 3 Jul 2025 04:01:01 +0300 Subject: [PATCH] Fix DispatchProxy Generator --- .../Reflection/DispatchProxyGenerator.cs | 47 ++++++-- .../tests/DispatchProxyTests.cs | 113 ++++++++++++++++++ ...stem.Reflection.DispatchProxy.Tests.csproj | 4 +- 3 files changed, 153 insertions(+), 11 deletions(-) diff --git a/src/libraries/System.Reflection.DispatchProxy/src/System/Reflection/DispatchProxyGenerator.cs b/src/libraries/System.Reflection.DispatchProxy/src/System/Reflection/DispatchProxyGenerator.cs index ac765e8be34ffd..57259a3db2b36d 100644 --- a/src/libraries/System.Reflection.DispatchProxy/src/System/Reflection/DispatchProxyGenerator.cs +++ b/src/libraries/System.Reflection.DispatchProxy/src/System/Reflection/DispatchProxyGenerator.cs @@ -45,9 +45,15 @@ internal static class DispatchProxyGenerator // It is the first field in the class and the first ctor parameter. private const int MethodInfosFieldAndCtorParameterIndex = 0; - // We group AssemblyBuilders by the ALC of the base type's assembly. - // This allows us to granularly unload generated proxy types. - private static readonly ConditionalWeakTable s_alcProxyAssemblyMap = new(); + // We group AssemblyBuilders by two levels of AssemblyLoadContext: + // 1. The ALC of the interface type's assembly (outer key). + // 2. The ALC of the base type's assembly (inner key). + // This ensures that proxy types are generated in the correct scope and that + // they can be properly unloaded when either AssemblyLoadContext is collectible. + private static readonly ConditionalWeakTable> s_alcProxyAssemblyMap = new(); + + private static ProxyAssembly? s_defaultProxyAssembly; + private static readonly MethodInfo s_dispatchProxyInvokeMethod = typeof(DispatchProxy).GetMethod("Invoke", BindingFlags.NonPublic | BindingFlags.Instance)!; private static readonly MethodInfo s_getTypeFromHandleMethod = typeof(Type).GetMethod("GetTypeFromHandle", new Type[] { typeof(RuntimeTypeHandle) })!; private static readonly MethodInfo s_makeGenericMethodMethod = GetGenericMethodMethodInfo(); @@ -68,12 +74,33 @@ internal static object CreateProxyInstance( Debug.Assert(baseType != null); Debug.Assert(interfaceType != null); - AssemblyLoadContext? alc = AssemblyLoadContext.GetLoadContext(baseType.Assembly); - Debug.Assert(alc != null); + AssemblyLoadContext? alcBaseType = AssemblyLoadContext.GetLoadContext(baseType.Assembly); + Debug.Assert(alcBaseType != null); + + AssemblyLoadContext? alcInterfaceType = AssemblyLoadContext.GetLoadContext(interfaceType.Assembly); + Debug.Assert(alcInterfaceType != null); + + ProxyAssembly proxyAssembly; + using (alcInterfaceType.EnterContextualReflection()) + { + if (alcBaseType == AssemblyLoadContext.Default && alcInterfaceType == AssemblyLoadContext.Default) + { + if (s_defaultProxyAssembly == null) + { + Interlocked.CompareExchange(ref s_defaultProxyAssembly, new ProxyAssembly(AssemblyLoadContext.Default), null); + } - ProxyAssembly proxyAssembly = s_alcProxyAssemblyMap.GetOrAdd(alc, static x => new ProxyAssembly(x)); - GeneratedTypeInfo proxiedType = proxyAssembly.GetProxyType(baseType, interfaceType, interfaceParameter, proxyParameter); - return Activator.CreateInstance(proxiedType.GeneratedType, new object[] { proxiedType.MethodInfos })!; + proxyAssembly = s_defaultProxyAssembly; + } + else + { + var secondLevelMap = s_alcProxyAssemblyMap.GetValue(alcInterfaceType, static x => new ConditionalWeakTable()); + proxyAssembly = secondLevelMap.GetValue(alcBaseType, static x => new ProxyAssembly(x)); + } + + GeneratedTypeInfo proxiedType = proxyAssembly.GetProxyType(baseType, interfaceType, interfaceParameter, proxyParameter); + return Activator.CreateInstance(proxiedType.GeneratedType, new object[] { proxiedType.MethodInfos })!; + } } private sealed class GeneratedTypeInfo @@ -119,6 +146,8 @@ private sealed class ProxyAssembly public ProxyAssembly(AssemblyLoadContext alc) { string name; + var currentAlc = AssemblyLoadContext.CurrentContextualReflectionContext ?? AssemblyLoadContext.Default; + if (alc == AssemblyLoadContext.Default) { name = "ProxyBuilder"; @@ -130,7 +159,7 @@ public ProxyAssembly(AssemblyLoadContext alc) } AssemblyBuilderAccess builderAccess = - alc.IsCollectible ? AssemblyBuilderAccess.RunAndCollect : AssemblyBuilderAccess.Run; + alc.IsCollectible || currentAlc.IsCollectible ? AssemblyBuilderAccess.RunAndCollect : AssemblyBuilderAccess.Run; _ab = AssemblyBuilder.DefineDynamicAssembly(new AssemblyName(name), builderAccess); _mb = _ab.DefineDynamicModule("testmod"); } diff --git a/src/libraries/System.Reflection.DispatchProxy/tests/DispatchProxyTests.cs b/src/libraries/System.Reflection.DispatchProxy/tests/DispatchProxyTests.cs index 4cc5a964235102..290356218e7710 100644 --- a/src/libraries/System.Reflection.DispatchProxy/tests/DispatchProxyTests.cs +++ b/src/libraries/System.Reflection.DispatchProxy/tests/DispatchProxyTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Reflection; using System.Reflection.Emit; @@ -856,6 +857,73 @@ public static void Test_Multiple_AssemblyLoadContextsWithBadName() Assert.True((bool)method.Invoke(null, null)); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public static void Verify_Correct_Interface_Using_Cached_ALCs(bool collectable) + { + var asmBytes = EmitITestInterface(); + + // Write the compiled assembly to a temporary file + var tempAssemblyPath = Path.Combine(Path.GetTempPath(), $"SharedInterface_{Guid.NewGuid():N}.dll"); + File.WriteAllBytes(tempAssemblyPath, asmBytes); + + var alc1 = new AssemblyLoadContext("alc1", collectable); + var alc2 = new AssemblyLoadContext("alc2", collectable); + + try + { + var a1 = alc1.LoadFromAssemblyPath(tempAssemblyPath); + var a2 = alc2.LoadFromAssemblyPath(tempAssemblyPath); + + var interface1 = a1.GetType("Shared.ITest") ?? throw new Exception("interface1 not found"); + var interface2 = a2.GetType("Shared.ITest") ?? throw new Exception("interface2 not found"); + + Assert.NotEqual(interface1, interface2); // types from different ALCs must not be equal + + // 1) Create a proxy for interface1. This populates DispatchProxy's internal cache. + var proxy1 = DispatchProxy.Create(interface1, typeof(ForwardingDispatchProxy)); + Assert.NotNull(proxy1); + + // 2) Now create a proxy for interface2 WITHOUT entering contextual reflection. + // According to the bug, this can produce a proxy that implements interface1 + // (from the other ALC) rather than interface2. + var proxy2 = DispatchProxy.Create(interface2, typeof(ForwardingDispatchProxy)); + Assert.NotNull(proxy2); + + // Collect interfaces implemented by proxy2's runtime type. + var implemented = proxy2.GetType().GetInterfaces(); + + // Assert: proxy should not be castable across ALC boundaries + Assert.Throws(() => + { + var _ = Convert.ChangeType(proxy1, interface2); + }); + + // We expect the created proxy to implement interface2. On affected runtimes it will + // implement interface1 instead and this assertion will fail. + Assert.Contains(interface2, implemented); + + // For additional clarity, make the negative assertion that it should not be the other. + Assert.DoesNotContain(interface1, implemented.Where(t => t.Assembly == interface1.Assembly && t.FullName == interface1.FullName)); + } + finally + { + try + { + if (File.Exists(tempAssemblyPath)) + File.Delete(tempAssemblyPath); + + if (collectable) + { + alc1.Unload(); + alc2.Unload(); + } + } + catch { } + } + } + internal static bool Demo() { TestType_IHelloService proxy = DispatchProxy.Create(); @@ -872,5 +940,50 @@ private static TInterface CreateHelper(bool useGenericCreate return (TInterface)DispatchProxy.Create(typeof(TInterface), typeof(TProxy)); } + + // Compile a small assembly in-memory that contains a single public interface Shared.ITest + // with two abstract methods: void DoSomething() and int GetValue() + private static byte[] EmitITestInterface() + { + // Define a new assembly + var assemblyName = new AssemblyName("SharedInterfaceAssembly"); + var pab = new PersistedAssemblyBuilder(assemblyName, typeof(object).Assembly); + + // Define a dynamic module + var moduleBuilder = pab.DefineDynamicModule("MainModule"); + + // Define public interface Shared.ITest + var tb = moduleBuilder.DefineType("Shared.ITest", TypeAttributes.Public | TypeAttributes.Interface | TypeAttributes.Abstract); + + // Add methods + tb.DefineMethod("DoSomething", MethodAttributes.Public | MethodAttributes.Abstract | MethodAttributes.Virtual, typeof(void), Type.EmptyTypes); + tb.DefineMethod("GetValue", MethodAttributes.Public | MethodAttributes.Abstract | MethodAttributes.Virtual, typeof(int), Type.EmptyTypes); + + // Finalize the interface + tb.CreateType(); + + // Save into memory as a portable executable (DLL) + using var peStream = new MemoryStream(); + pab.Save(peStream); + return peStream.ToArray(); + } + + // ForwardingDispatchProxy is a custom proxy base for our tests. + public class ForwardingDispatchProxy : DispatchProxy + { + protected override object? Invoke(MethodInfo? targetMethod, object?[]? args) + { + if (targetMethod == null) + throw new ArgumentNullException(nameof(targetMethod)); + + if (targetMethod.ReturnType == typeof(void)) + return null!; + + if (targetMethod.ReturnType.IsValueType) + return Activator.CreateInstance(targetMethod.ReturnType)!; + + return null; + } + } } } diff --git a/src/libraries/System.Reflection.DispatchProxy/tests/System.Reflection.DispatchProxy.Tests.csproj b/src/libraries/System.Reflection.DispatchProxy/tests/System.Reflection.DispatchProxy.Tests.csproj index 3603f510b4d79f..e88b21866d2c63 100644 --- a/src/libraries/System.Reflection.DispatchProxy/tests/System.Reflection.DispatchProxy.Tests.csproj +++ b/src/libraries/System.Reflection.DispatchProxy/tests/System.Reflection.DispatchProxy.Tests.csproj @@ -1,5 +1,5 @@ - + $(NetCoreAppCurrent) @@ -12,5 +12,5 @@ - + \ No newline at end of file