Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
<SharedGUID>b591423d-f92d-4e00-b0eb-615c9853506c</SharedGUID>
</PropertyGroup>
<PropertyGroup Label="Configuration">
<Import_RootNamespace>InterfaceStubGenerator.Shared</Import_RootNamespace>
<Import_RootNamespace>Refit.Generator</Import_RootNamespace>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildThisFileDirectory)RefitClientModel.cs" />
<Compile Include="$(MSBuildThisFileDirectory)InterfaceStubGenerator.cs" />
<Compile Include="$(MSBuildThisFileDirectory)ITypeSymbolExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)RefitMetadata.cs" />
</ItemGroup>
</Project>
134 changes: 27 additions & 107 deletions InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ ImmutableArray<InterfaceDeclarationSyntax> candidateInterfaces
return;
}

var refitMetadata = new RefitMetadata(disposableInterfaceSymbol, httpMethodBaseAttributeSymbol);

// Check the candidates and keep the ones we're actually interested in

#pragma warning disable RS1024 // Compare symbols correctly
Expand All @@ -134,7 +136,7 @@ ImmutableArray<InterfaceDeclarationSyntax> candidateInterfaces
{
// Get the symbol being declared by the method
var methodSymbol = model.GetDeclaredSymbol(method);
if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol))
if (refitMetadata.IsRefitMethod(methodSymbol))
{
var isAnnotated =
compilation.Options.NullableContextOptions
Expand Down Expand Up @@ -170,9 +172,9 @@ ImmutableArray<InterfaceDeclarationSyntax> candidateInterfaces
continue;

// The interface has no refit methods, but its base interfaces might
var hasDerivedRefit = ifaceSymbol
.AllInterfaces.SelectMany(i => i.GetMembers().OfType<IMethodSymbol>())
.Any(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol));
var hasDerivedRefit = ifaceSymbol.AllInterfaces
.SelectMany(i => i.GetMembers().OfType<IMethodSymbol>())
.Any(refitMetadata.IsRefitMethod);

if (hasDerivedRefit)
{
Expand Down Expand Up @@ -276,20 +278,15 @@ public static void Initialize()
// each group is keyed by the Interface INamedTypeSymbol and contains the members
// with a refit attribute on them. Types may contain other members, without the attribute, which we'll
// need to check for and error out on

var classSource = ProcessInterface(
context,
reportDiagnostic,
group.Key,
group.Value,
preserveAttributeSymbol,
disposableInterfaceSymbol,
httpMethodBaseAttributeSymbol,
supportsNullable,
interfaceToNullableEnabledMap[group.Key]
);

var keyName = group.Key.Name;
var model = new RefitClientModel(group.Key, group.Value, refitMetadata);
var classSource = ProcessInterface(context,
reportDiagnostic,
model,
preserveAttributeSymbol,
supportsNullable,
interfaceToNullableEnabledMap[model.RefitInterface]);

var keyName = model.FileName;
int value;
while (keyCount.TryGetValue(keyName, out value))
{
Expand All @@ -304,39 +301,14 @@ public static void Initialize()
static string ProcessInterface<TContext>(
TContext context,
Action<TContext, Diagnostic> reportDiagnostic,
INamedTypeSymbol interfaceSymbol,
List<IMethodSymbol> refitMethods,
RefitClientModel interfaceModel,
ISymbol preserveAttributeSymbol,
ISymbol disposableInterfaceSymbol,
INamedTypeSymbol httpMethodBaseAttributeSymbol,
bool supportsNullable,
bool nullableEnabled
)
{
// Get the class name with the type parameters, then remove the namespace
var className = interfaceSymbol.ToDisplayString();
var lastDot = className.LastIndexOf('.');
if (lastDot > 0)
{
className = className.Substring(lastDot + 1);
}
var classDeclaration = $"{interfaceSymbol.ContainingType?.Name}{className}";

// Get the class name itself
var classSuffix = $"{interfaceSymbol.ContainingType?.Name}{interfaceSymbol.Name}";
var ns = interfaceSymbol.ContainingNamespace?.ToDisplayString();

// if it's the global namespace, our lookup rules say it should be the same as the class name
if (
interfaceSymbol.ContainingNamespace != null
&& interfaceSymbol.ContainingNamespace.IsGlobalNamespace
)
{
ns = string.Empty;
}

// Remove dots
ns = ns!.Replace(".", "");
INamedTypeSymbol interfaceSymbol = interfaceModel.RefitInterface;
List<IMethodSymbol> refitMethods = interfaceModel.RefitMethods;

// See what the nullable context is

Expand Down Expand Up @@ -371,58 +343,22 @@ partial class Generated
[{preserveAttributeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}]
[global::System.Reflection.Obfuscation(Exclude=true)]
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
partial class {ns}{classDeclaration}
: {interfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}{GenerateConstraints(interfaceSymbol.TypeParameters, false)}
partial class {interfaceModel.NamespacePrefix}{interfaceModel.ClassDeclaration}
: {interfaceModel.BaseInterfaceDeclaration}{GenerateConstraints(interfaceSymbol.TypeParameters, false)}

{{
/// <inheritdoc />
public global::System.Net.Http.HttpClient Client {{ get; }}
readonly global::Refit.IRequestBuilder requestBuilder;

/// <inheritdoc />
public {ns}{classSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder)
public {interfaceModel.NamespacePrefix}{interfaceModel.ClassSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder)
{{
Client = client;
this.requestBuilder = requestBuilder;
}}

"
);
// Get any other methods on the refit interfaces. We'll need to generate something for them and warn
var nonRefitMethods = interfaceSymbol
.GetMembers()
.OfType<IMethodSymbol>()
.Except(refitMethods, SymbolEqualityComparer.Default)
.Cast<IMethodSymbol>()
.ToList();

// get methods for all inherited
var derivedMethods = interfaceSymbol
.AllInterfaces.SelectMany(i => i.GetMembers().OfType<IMethodSymbol>())
.ToList();

// Look for disposable
var disposeMethod = derivedMethods.Find(
m =>
m.ContainingType?.Equals(
disposableInterfaceSymbol,
SymbolEqualityComparer.Default
) == true
);
if (disposeMethod != null)
{
//remove it from the derived methods list so we don't process it with the rest
derivedMethods.Remove(disposeMethod);
}

// Pull out the refit methods from the derived types
var derivedRefitMethods = derivedMethods
.Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol))
.ToList();
var derivedNonRefitMethods = derivedMethods
.Except(derivedMethods, SymbolEqualityComparer.Default)
.Cast<IMethodSymbol>()
.ToList();
");

var memberNames = new HashSet<string>(interfaceSymbol.GetMembers().Select(x => x.Name));

Expand All @@ -432,30 +368,22 @@ partial class {ns}{classDeclaration}
ProcessRefitMethod(source, method, true, memberNames);
}

foreach (var method in refitMethods.Concat(derivedRefitMethods))
foreach (var method in interfaceModel.AllRefitMethods)
{
ProcessRefitMethod(source, method, false, memberNames);
}

// Handle non-refit Methods that aren't static or properties or have a method body
foreach (var method in nonRefitMethods.Concat(derivedNonRefitMethods))
foreach (var method in interfaceModel.NonRefitMethods)
{
if (
method.IsStatic
|| method.MethodKind == MethodKind.PropertyGet
|| method.MethodKind == MethodKind.PropertySet
|| !method.IsAbstract
) // If an interface method has a body, it won't be abstract
continue;

ProcessNonRefitMethod(context, reportDiagnostic, source, method);
}

// Handle Dispose
if (disposeMethod != null)
if (interfaceModel.DisposeMethod != null)
{
ProcessDisposableMethod(source, disposeMethod);
}
ProcessDisposableMethod(source, interfaceModel.DisposeMethod);
}

source.Append(
@"
Expand Down Expand Up @@ -779,14 +707,6 @@ static string UniqueName(string name, HashSet<string> methodNames)
return candidateName;
}

static bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttibute)
{
return methodSymbol
?.GetAttributes()
.Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttibute) == true)
== true;
}

#if ROSLYN_4

/// <summary>
Expand Down
113 changes: 113 additions & 0 deletions InterfaceStubGenerator.Shared/RefitClientModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using System.Collections.Generic;
using System.Linq;

using Microsoft.CodeAnalysis;

namespace Refit.Generator;

internal class RefitClientModel
{
readonly RefitMetadata refitMetadata;

public RefitClientModel(INamedTypeSymbol refitInterface, List<IMethodSymbol> refitMethods, RefitMetadata refitMetadata)
{
RefitInterface = refitInterface;
RefitMethods = refitMethods;
this.refitMetadata = refitMetadata;

// Get any other methods on the refit interfaces. We'll need to generate something for them and warn
var nonRefitMethods = refitInterface
.GetMembers()
.OfType<IMethodSymbol>()
.Except(refitMethods, SymbolEqualityComparer.Default)
.Cast<IMethodSymbol>()
.ToList();

// get methods for all inherited
var derivedMethods = refitInterface
.AllInterfaces.SelectMany(i => i.GetMembers().OfType<IMethodSymbol>())
.ToList();

// Look for disposable
DisposeMethod = derivedMethods.Find(
m =>
m.ContainingType?.Equals(
refitMetadata.DisposableInterfaceSymbol,
SymbolEqualityComparer.Default
) == true
);
if (DisposeMethod != null)
{
//remove it from the derived methods list so we don't process it with the rest
derivedMethods.Remove(DisposeMethod);
}

// Pull out the refit methods from the derived types
var derivedRefitMethods = derivedMethods.Where(refitMetadata.IsRefitMethod).ToList();
var derivedNonRefitMethods = derivedMethods.Except(derivedMethods, SymbolEqualityComparer.Default).Cast<IMethodSymbol>().ToList();

AllRefitMethods = refitMethods.Concat(derivedRefitMethods);
NonRefitMethods = nonRefitMethods.Concat(derivedNonRefitMethods)
.Where(static method =>
{
return !(method.IsStatic ||
method.MethodKind == MethodKind.PropertyGet ||
method.MethodKind == MethodKind.PropertySet ||
!method.IsAbstract);
});
}

public INamedTypeSymbol RefitInterface { get; }
public List<IMethodSymbol> RefitMethods { get; }
public IEnumerable<IMethodSymbol> AllRefitMethods { get; }
public IEnumerable<IMethodSymbol> NonRefitMethods { get; }

public string FileName => RefitInterface.Name;

public string ClassDeclaration
{
get
{
// Get the class name with the type parameters, then remove the namespace
var className = RefitInterface.ToDisplayString();
var lastDot = className.LastIndexOf('.');
if (lastDot > 0)
{
className = className.Substring(lastDot + 1);
}
var classDeclaration = $"{RefitInterface.ContainingType?.Name}{className}";
return classDeclaration;
}
}

public string ClassSuffix
{
get
{
// Get the class name itself
var classSuffix = $"{RefitInterface.ContainingType?.Name}{RefitInterface.Name}";
return classSuffix;
}
}

public string NamespacePrefix
{
get
{
var ns = RefitInterface.ContainingNamespace?.ToDisplayString();

// if it's the global namespace, our lookup rules say it should be the same as the class name
if (RefitInterface.ContainingNamespace != null && RefitInterface.ContainingNamespace.IsGlobalNamespace)
{
return string.Empty;
}

// Remove dots
ns = ns!.Replace(".", "");
return ns;
}
}
public string BaseInterfaceDeclaration => $"{RefitInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}";

public IMethodSymbol DisposeMethod { get; }
}
22 changes: 22 additions & 0 deletions InterfaceStubGenerator.Shared/RefitMetadata.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Linq;

using Microsoft.CodeAnalysis;

namespace Refit.Generator;

internal class RefitMetadata
{
public RefitMetadata(INamedTypeSymbol? disposableInterfaceSymbol, INamedTypeSymbol httpMethodBaseAttributeSymbol)
{
DisposableInterfaceSymbol = disposableInterfaceSymbol;
HttpMethodBaseAttributeSymbol = httpMethodBaseAttributeSymbol;
}

public INamedTypeSymbol? DisposableInterfaceSymbol { get; }
public INamedTypeSymbol HttpMethodBaseAttributeSymbol { get; }

public bool IsRefitMethod(IMethodSymbol? methodSymbol)
{
return methodSymbol?.GetAttributes().Any(ad => ad.AttributeClass?.InheritsFromOrEquals(HttpMethodBaseAttributeSymbol) == true) == true;
}
}