diff --git a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.SzArray.cs b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.SzArray.cs index 6a9a84bb4..a9ac01849 100644 --- a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.SzArray.cs +++ b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.SzArray.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; +using System.Collections.Generic; using System.Runtime.InteropServices; using AsmResolver.DotNet; using AsmResolver.DotNet.Code.Cil; @@ -10,6 +12,7 @@ using WindowsRuntime.InteropGenerator.Factories; using WindowsRuntime.InteropGenerator.Generation; using WindowsRuntime.InteropGenerator.Helpers; +using WindowsRuntime.InteropGenerator.Models; using WindowsRuntime.InteropGenerator.References; using WindowsRuntime.InteropGenerator.Resolvers; using static AsmResolver.PE.DotNet.Cil.CilOpCodes; @@ -26,6 +29,12 @@ internal partial class InteropTypeDefinitionBuilder /// public static class SzArray { + /// + /// The thread-local list to build COM interface entries. + /// + [ThreadStatic] + private static List? entriesList; + /// /// Creates a new type definition for the marshaller for some SZ array type. /// @@ -335,6 +344,7 @@ public static void ArrayImpl( /// Creates a new type definition for the implementation of the COM interface entries for some SZ array type. /// /// The for the SZ array type. + /// The vtable types implemented by . /// The instance returned by . /// The 'IID' get method for the 'IReferenceArray`1<T>' interface. /// The instance to use. @@ -342,9 +352,11 @@ public static void ArrayImpl( /// The emit state for this invocation. /// The module that will contain the type being created. /// Whether to use Windows.UI.Xaml projections. + /// The resulting interface entries type. /// The resulting implementation type. public static void InterfaceEntriesImpl( SzArrayTypeSignature arrayType, + TypeSignatureEquatableSet vtableTypes, TypeDefinition implType, MethodDefinition get_IidMethod, InteropDefinitions interopDefinitions, @@ -352,65 +364,55 @@ public static void InterfaceEntriesImpl( InteropGeneratorEmitState emitState, ModuleDefinition module, bool useWindowsUIXamlProjections, + out TypeDefinition interfaceEntriesType, out TypeDefinition interfaceEntriesImplType) { - var listImpl = InteropImplTypeResolver.GetCustomMappedOrManuallyProjectedTypeImpl( - type: interopReferences.IList.ToReferenceTypeSignature(), - interopReferences: interopReferences, - useWindowsUIXamlProjections: useWindowsUIXamlProjections); + // Reuse the same list, to minimize allocations (same as for user-defined types) + List entriesList = SzArray.entriesList ??= []; - var enumerableImpl = InteropImplTypeResolver.GetCustomMappedOrManuallyProjectedTypeImpl( - type: interopReferences.IEnumerable.ToReferenceTypeSignature(), - interopReferences: interopReferences, - useWindowsUIXamlProjections: useWindowsUIXamlProjections); - - var list1Impl = InteropImplTypeResolver.GetGenericInstanceTypeImpl( - type: interopReferences.IList1.MakeGenericReferenceType(arrayType.BaseType), - interopDefinitions: interopDefinitions, - interopReferences: interopReferences, - emitState: emitState); + // It's not guaranteed that the list is empty, so we must always reset it first + entriesList.Clear(); - var enumerable1Impl = InteropImplTypeResolver.GetGenericInstanceTypeImpl( - type: interopReferences.IEnumerable1.MakeGenericReferenceType(arrayType.BaseType), - interopDefinitions: interopDefinitions, - interopReferences: interopReferences, - emitState: emitState); + // Add the entry for the 'IReferenceArray' implementation first + entriesList.Add(InteropInterfaceEntriesResolver.Create(get_IidMethod, implType.GetMethod("get_Vtable"u8))); - var readOnlyList1Impl = InteropImplTypeResolver.GetGenericInstanceTypeImpl( - type: interopReferences.IReadOnlyList1.MakeGenericReferenceType(arrayType.BaseType), + // Add all entries for explicitly implemented interfaces + entriesList.AddRange(InteropInterfaceEntriesResolver.EnumerateMetadataInterfaceEntries( + vtableTypes: vtableTypes, interopDefinitions: interopDefinitions, interopReferences: interopReferences, - emitState: emitState); + emitState: emitState, + module: module, + useWindowsUIXamlProjections: useWindowsUIXamlProjections)); var propertyValueImpl = InteropImplTypeResolver.GetSzArrayTypeImpl(arrayType, interopReferences); + // Add the entry for the 'IPropertyValue' implementation as well + entriesList.Add(InteropInterfaceEntriesResolver.Create(propertyValueImpl.get_IID, propertyValueImpl.get_Vtable)); + + // Add the built-in native interfaces at the end + entriesList.AddRange(InteropInterfaceEntriesResolver.EnumerateNativeInterfaceEntries( + vtableTypes: vtableTypes, + interopReferences: interopReferences)); + + // Get or create the interface entries type for this SZ array type (we reuse them based on number of entries) + interfaceEntriesType = interopDefinitions.SzArrayInterfaceEntries(entriesList.Count); + InteropTypeDefinitionBuilder.InterfaceEntriesImpl( ns: InteropUtf8NameFactory.TypeNamespace(arrayType), name: InteropUtf8NameFactory.TypeName(arrayType, "InterfaceEntriesImpl"), - entriesFieldType: interopDefinitions.IReferenceArrayInterfaceEntries, + entriesFieldType: interfaceEntriesType, interopReferences: interopReferences, module: module, implType: out interfaceEntriesImplType, - implTypes: [ - (get_IidMethod, implType.GetMethod("get_Vtable"u8)), - (listImpl.get_IID, listImpl.get_Vtable), - (enumerableImpl.get_IID, enumerableImpl.get_Vtable), - (list1Impl.get_IID, list1Impl.get_Vtable), - (enumerable1Impl.get_IID, enumerable1Impl.get_Vtable), - (readOnlyList1Impl.get_IID, readOnlyList1Impl.get_Vtable), - (propertyValueImpl.get_IID, propertyValueImpl.get_Vtable), - (interopReferences.WellKnownInterfaceIIDsget_IID_IStringable, interopReferences.IStringableImplget_Vtable), - (interopReferences.WellKnownInterfaceIIDsget_IID_IWeakReferenceSource, interopReferences.IWeakReferenceSourceImplget_Vtable), - (interopReferences.WellKnownInterfaceIIDsget_IID_IMarshal, interopReferences.IMarshalImplget_Vtable), - (interopReferences.WellKnownInterfaceIIDsget_IID_IAgileObject, interopReferences.IAgileObjectImplget_Vtable), - (interopReferences.WellKnownInterfaceIIDsget_IID_IInspectable, interopReferences.IInspectableImplget_Vtable), - (interopReferences.WellKnownInterfaceIIDsget_IID_IUnknown, interopReferences.IUnknownImplget_Vtable)]); + implTypes: CollectionsMarshal.AsSpan(entriesList)); } /// /// Creates a new type definition for the marshaller attribute for some SZ array type. /// /// The for the SZ array type. + /// The for the interface entries type returned by . /// The instance returned by . /// The instance returned by . /// The 'IID' get method for the 'IReferenceArray`1<T>' interface. @@ -420,6 +422,7 @@ public static void InterfaceEntriesImpl( /// The resulting marshaller type. public static void ComWrappersMarshallerAttribute( SzArrayTypeSignature arrayType, + TypeDefinition arrayInterfaceEntriesType, TypeDefinition arrayInterfaceEntriesImplType, TypeDefinition arrayComWrappersCallbackType, MethodDefinition get_IidMethod, @@ -482,7 +485,7 @@ public static void ComWrappersMarshallerAttribute( CilInstructions = { { Ldarg_1 }, - { CilInstruction.CreateLdcI4(interopDefinitions.IReferenceArrayInterfaceEntries.Fields.Count) }, + { CilInstruction.CreateLdcI4(arrayInterfaceEntriesType.Fields.Count) }, { Stind_I4 }, { Call, arrayInterfaceEntriesImplType.GetMethod("get_Vtables"u8) }, { Ret } diff --git a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.UserDefinedType.cs b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.UserDefinedType.cs index 975326f97..3d993ba4f 100644 --- a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.UserDefinedType.cs +++ b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.UserDefinedType.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using AsmResolver.DotNet; -using AsmResolver.DotNet.Code.Cil; using AsmResolver.DotNet.Signatures; using AsmResolver.PE.DotNet.Cil; using AsmResolver.PE.DotNet.Metadata.Tables; @@ -28,131 +27,62 @@ internal partial class InteropTypeDefinitionBuilder /// public static class UserDefinedType { - /// - /// The number of default, always present COM interface entries. - /// We always append some default slots to all user-defined types. - /// - private const int NumberOfDefaultComInterfaceEntries = 6; - /// /// The thread-local list to build COM interface entries. /// [ThreadStatic] - private static List? entriesList; + private static List? entriesList; /// /// Creates a new type definition for the implementation of the COM interface entries for a user-defined type. /// /// The for the user-defined type. /// The vtable types implemented by . - /// The arguments for this invocation. /// The instance to use. /// The instance to use. /// The emit state for this invocation. /// The module that will contain the type being created. + /// Whether to use Windows.UI.Xaml projections. + /// The resulting interface entries type. /// The resulting implementation type. public static void InterfaceEntriesImpl( TypeSignature userDefinedType, TypeSignatureEquatableSet vtableTypes, - InteropGeneratorArgs args, InteropDefinitions interopDefinitions, InteropReferences interopReferences, InteropGeneratorEmitState emitState, ModuleDefinition module, + bool useWindowsUIXamlProjections, + out TypeDefinition interfaceEntriesType, out TypeDefinition interfaceEntriesImplType) { // Reuse the same list, to minimize allocations (since we always need to build the entries at runtime here) - List entriesList = UserDefinedType.entriesList ??= []; + List entriesList = UserDefinedType.entriesList ??= []; // It's not guaranteed that the list is empty, so we must always reset it first entriesList.Clear(); - // Track whether 'IMarshal' is explicitly implemented (so we'll skip the built-in one) - bool hasUserImplementedIMarshalInterface = false; - - // Append all entries for the type (which we share for all matching user-defined types) - foreach (TypeSignature typeSignature in vtableTypes) - { - // Handle generic types first, and then custom-mapped and manually projected types. - // These require special handling, because their ABI types are in different locations. - if (typeSignature is GenericInstanceTypeSignature genericTypeSignature) - { - (IMethodDefOrRef get_IIDMethod, IMethodDefOrRef get_VtableMethod) = InteropImplTypeResolver.GetGenericInstanceTypeImpl( - type: genericTypeSignature, - interopDefinitions: interopDefinitions, - interopReferences: interopReferences, - emitState: emitState); - - entriesList.Add(new WindowsRuntimeInterfaceEntryInfo(get_IIDMethod, get_VtableMethod)); - } - else if (typeSignature.IsCustomMappedWindowsRuntimeInterfaceType(interopReferences) || typeSignature.IsManuallyProjectedWindowsRuntimeInterfaceType(interopReferences)) - { - (IMethodDefOrRef get_IIDMethod, IMethodDefOrRef get_VtableMethod) = InteropImplTypeResolver.GetCustomMappedOrManuallyProjectedTypeImpl( - type: typeSignature, - interopReferences: interopReferences, - useWindowsUIXamlProjections: args.UseWindowsUIXamlProjections); - - entriesList.Add(new WindowsRuntimeInterfaceEntryInfo(get_IIDMethod, get_VtableMethod)); - } - else - { - // We always need to resolve the user-defined types in all cases below, so just do it once first - TypeDefinition interfaceType = typeSignature.Resolve()!; - - // For '[GeneratedComInterface]', we need to retrieve and use the generated vtable from the COM generators - if (interfaceType.IsGeneratedComInterfaceType) - { - // Ignore interfaces we can't retrieve information for (this should never happen, interfaces are filtered during discovery) - if (!interfaceType.TryGetInterfaceInformationType(interopReferences, out TypeSignature? interfaceInformationType)) - { - continue; - } - - // Get the IID of the interface (same as above, this is pre-validated) - if (!interfaceType.TryGetGuidAttribute(interopReferences, out Guid interfaceId)) - { - continue; - } - - // Track if the current interface is 'IMarshal' - hasUserImplementedIMarshalInterface |= interfaceId == WellKnownInterfaceIIDs.IID_IMarshal; - - // Add the entry from the 'InterfaceInformation' type, which contains the generated info we need - entriesList.Add(new ComInterfaceEntryInfo(interfaceInformationType)); - } - else - { - // This is the common case for all normally projected, non-generic Windows Runtime types - (IMethodDefOrRef get_IIDMethod, IMethodDefOrRef get_VtableMethod) = InteropImplTypeResolver.GetProjectedTypeImpl( - type: interfaceType, - interopReferences: interopReferences); - - entriesList.Add(new WindowsRuntimeInterfaceEntryInfo(get_IIDMethod, get_VtableMethod)); - } - } - } + // Add all entries for explicitly implemented interfaces + entriesList.AddRange(InteropInterfaceEntriesResolver.EnumerateMetadataInterfaceEntries( + vtableTypes: vtableTypes, + interopDefinitions: interopDefinitions, + interopReferences: interopReferences, + emitState: emitState, + module: module, + useWindowsUIXamlProjections: useWindowsUIXamlProjections)); - // Add the default entries after all user implementations - entriesList.AddRange( - new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IStringable, interopReferences.IStringableImplget_Vtable), - new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IWeakReferenceSource, interopReferences.IWeakReferenceSourceImplget_Vtable)); + // Add the built-in native interfaces at the end + entriesList.AddRange(InteropInterfaceEntriesResolver.EnumerateNativeInterfaceEntries( + vtableTypes: vtableTypes, + interopReferences: interopReferences)); - // Add the default 'IMarshal' entry if the user type didn't implement it explicitly - if (!hasUserImplementedIMarshalInterface) - { - entriesList.Add(new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IMarshal, interopReferences.IMarshalImplget_Vtable)); - } - - // Add the default core entries at the end ('IUnknown' in particular must always be the last one) - entriesList.AddRange( - new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IAgileObject, interopReferences.IAgileObjectImplget_Vtable), - new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IInspectable, interopReferences.IInspectableImplget_Vtable), - new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IUnknown, interopReferences.IUnknownImplget_Vtable)); + // Get or create the interface entries type for this user-defined type (we reuse them based on number of entries) + interfaceEntriesType = interopDefinitions.UserDefinedInterfaceEntries(entriesList.Count); InteropTypeDefinitionBuilder.InterfaceEntriesImpl( ns: "WindowsRuntime.Interop.UserDefinedTypes"u8, name: InteropUtf8NameFactory.TypeName(userDefinedType, "InterfaceEntriesImpl"), - entriesFieldType: interopDefinitions.UserDefinedInterfaceEntries(NumberOfDefaultComInterfaceEntries + vtableTypes.Count), + entriesFieldType: interfaceEntriesType, interopReferences: interopReferences, module: module, implType: out interfaceEntriesImplType, @@ -163,15 +93,15 @@ public static void InterfaceEntriesImpl( /// Creates a new type definition for the marshaller attribute of some user-defined type. /// /// The for the user-defined type. - /// The vtable types implemented by . - /// The instance returned by . + /// The for the interface entries type returned by . + /// The for the interface entries implementation type returned by . /// The instance to use. /// The instance to use. /// The module that will contain the type being created. /// The resulting marshaller type. public static void ComWrappersMarshallerAttribute( TypeSignature userDefinedType, - TypeSignatureEquatableSet vtableTypes, + TypeDefinition interfaceEntriesType, TypeDefinition interfaceEntriesImplType, InteropDefinitions interopDefinitions, InteropReferences interopReferences, @@ -195,9 +125,6 @@ public static void ComWrappersMarshallerAttribute( // The 'ComputeVtables' method returns the 'ComWrappers.ComInterfaceEntry*' type PointerTypeSignature computeVtablesReturnType = interopReferences.ComInterfaceEntry.Import(module).MakePointerType(); - // Retrieve the cached COM interface entries type, as we need the number of fields - TypeDefinition interfaceEntriesType = interopDefinitions.UserDefinedInterfaceEntries(NumberOfDefaultComInterfaceEntries + vtableTypes.Count); - // Define the 'ComputeVtables' method as follows: // // public static ComInterfaceEntry* ComputeVtables(out int count) @@ -336,47 +263,5 @@ public static void TypeMapAttributes( interopReferences: interopReferences, module: module); } - - /// - /// An type for Windows Runtime types. - /// - /// The value to get the interface IID. - /// The value to get the interface vtable. - private sealed class WindowsRuntimeInterfaceEntryInfo(IMethodDefOrRef get_IID, IMethodDefOrRef get_Vtable) : InterfaceEntryInfo - { - /// - public override void LoadIID(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) - { - _ = instructions.Add(Call, get_IID.Import(module)); - _ = instructions.Add(Ldobj, interopReferences.Guid.Import(module)); - } - - /// - public override void LoadVtable(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) - { - _ = instructions.Add(Call, get_Vtable.Import(module)); - } - } - - /// - /// An type for COM types. - /// - /// The InterfaceInformation type for the current interface. - private sealed class ComInterfaceEntryInfo(TypeSignature interfaceInformationType) : InterfaceEntryInfo - { - /// - public override void LoadIID(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) - { - _ = instructions.Add(Constrained, interfaceInformationType.Import(module).ToTypeDefOrRef()); - _ = instructions.Add(Call, interopReferences.IIUnknownInterfaceTypeget_Iid.Import(module)); - } - - /// - public override void LoadVtable(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) - { - _ = instructions.Add(Constrained, interfaceInformationType.Import(module).ToTypeDefOrRef()); - _ = instructions.Add(Call, interopReferences.IIUnknownInterfaceTypeget_ManagedVirtualMethodTable.Import(module)); - } - } } } \ No newline at end of file diff --git a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.cs b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.cs index 3870dcb68..3811e8b0f 100644 --- a/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.cs +++ b/src/WinRT.Interop.Generator/Builders/InteropTypeDefinitionBuilder.cs @@ -14,6 +14,7 @@ using WindowsRuntime.InteropGenerator.Generation; using WindowsRuntime.InteropGenerator.Helpers; using WindowsRuntime.InteropGenerator.References; +using WindowsRuntime.InteropGenerator.Resolvers; using static AsmResolver.PE.DotNet.Cil.CilOpCodes; #pragma warning disable IDE0061 @@ -575,7 +576,7 @@ private static void InterfaceEntriesImpl( InteropReferences interopReferences, ModuleDefinition module, out TypeDefinition implType, - params ReadOnlySpan implTypes) + params ReadOnlySpan implTypes) { InterfaceEntriesImpl( ns: ns, @@ -613,6 +614,9 @@ private static void InterfaceEntriesImpl( Action get_Vtable, out TypeDefinition implType) { + // Enforce that we can initialize all interface entries + ArgumentOutOfRangeException.ThrowIfNotEqual(implTypes.Length, entriesFieldType.Fields.Count, nameof(implTypes)); + // We're declaring an 'internal static class' type implType = new TypeDefinition( ns: ns, @@ -956,26 +960,4 @@ public static void TypeMapAttributes( module: module)); } } - - /// - /// A base type to abstract inserting interface entries information into a static constructor. - /// - private abstract class InterfaceEntryInfo - { - /// - /// Loads the IID for the interface onto the evaluation stack. - /// - /// The target . - /// The instance to use. - /// The in use. - public abstract void LoadIID(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module); - - /// - /// Loads the vtable for the interface onto the evaluation stack. - /// - /// The target . - /// The instance to use. - /// The in use. - public abstract void LoadVtable(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module); - } } \ No newline at end of file diff --git a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs index 08bdc9c5e..b493e4d9b 100644 --- a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs +++ b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.Generics.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System; using AsmResolver.DotNet; using AsmResolver.DotNet.Signatures; using WindowsRuntime.InteropGenerator.Errors; @@ -108,7 +107,7 @@ public static void TryTrackSzArrayType( } // Ignore types that are not fully resolvable (this likely means a .dll is missing) - if (!typeSignature.IsFullyResolvable(out _)) + if (!typeSignature.IsFullyResolvable(out TypeDefinition? typeDefinition)) { // Log a warning the first time we fail to resolve this SZ array in this module if (discoveryState.TrackFailedResolutionType(typeSignature, module)) @@ -119,30 +118,15 @@ public static void TryTrackSzArrayType( return; } - // Ignore array types that are not Windows Runtime types - if (!typeSignature.IsWindowsRuntimeType(interopReferences)) - { - return; - } - - // Track all SZ array types, as we'll need to emit marshalling code for them - discoveryState.TrackSzArrayType(typeSignature); - - // Each SZ array also gets a series of interfaces automatically implemented by the runtime. - // The set is fixed, so we can just hardcode those here to make sure they are also discovered. - // They will all be needed later, because CCWs for array objects will need those vtable slots. - foreach (GenericInstanceTypeSignature interfaceType in (ReadOnlySpan)[ - interopReferences.IList1.MakeGenericReferenceType(typeSignature.BaseType), - interopReferences.IEnumerable1.MakeGenericReferenceType(typeSignature.BaseType), - interopReferences.IReadOnlyList1.MakeGenericReferenceType(typeSignature.BaseType)]) - { - TryTrackWindowsRuntimeGenericInterfaceTypeInstance( - typeSignature: interfaceType, - args: args, - discoveryState: discoveryState, - interopReferences: interopReferences, - module: module); - } + // Track all SZ array types, as we'll need to emit marshalling code for them. + // This is regardless of whether their element type is a Windows Runtime type. + TryTrackExposedSzArrayType( + typeDefinition: typeDefinition, + typeSignature: typeSignature, + args: args, + discoveryState: discoveryState, + interopReferences: interopReferences, + module: module); } /// @@ -191,7 +175,7 @@ private static void TryTrackWindowsRuntimeGenericTypeInstance( module: module); // We also want to crawl base interfaces - foreach (TypeSignature interfaceSignature in typeSignature.EnumerateAllInterfaces()) + foreach (TypeSignature interfaceSignature in typeSignature.EnumerateAllInterfaces(interopReferences)) { // Filter out just constructed generic interfaces, since we only care about those here. // The non-generic ones are only useful when gathering interfaces for user-defined types. diff --git a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs index 3e5c6898d..2ecbb0ab8 100644 --- a/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs +++ b/src/WinRT.Interop.Generator/Discovery/InteropTypeDiscovery.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Linq; using AsmResolver.DotNet; using AsmResolver.DotNet.Signatures; @@ -24,6 +25,11 @@ internal static partial class InteropTypeDiscovery /// private static readonly ConcurrentBag TypeSignatureBuilderPool = []; + /// + /// A pool of instances used to validate duplicate IIDs. + /// + private static readonly ConcurrentBag> IidHashSetPool = []; + /// /// Tries to track a given composable Windows Runtime type. /// @@ -84,7 +90,13 @@ public static void TryTrackExposedUserDefinedType( ModuleDefinition module) { // Ignore types that should explicitly be excluded - if (TypeExclusions.IsExcluded(typeDefinition, interopReferences)) + if (TypeExclusions.IsExcluded(typeSignature, interopReferences)) + { + return; + } + + // Ignore SZ array types, we can't handle them from here and they have dedicated logic below + if (typeSignature is SzArrayTypeSignature) { return; } @@ -150,6 +162,15 @@ public static void TryTrackExposedUserDefinedType( // Since we're reusing the builder for all types, make sure to clear it first interfaces.Clear(); + // Use the same logic to also retrieve the set to use to validate unique IIDs in custom interfaces + if (!IidHashSetPool.TryTake(out HashSet? iids)) + { + iids = []; + } + + // Clear this set as well, since we might've retrieved one from the shared pool + iids.Clear(); + // We want to explicitly track whether the type implements any projected Windows Runtime // interfaces, as we are only interested in such types. We want to also gather all // implemented '[GeneratedComInterface]' interfaces, but if a type only implements @@ -157,17 +178,24 @@ public static void TryTrackExposedUserDefinedType( bool hasAnyProjectedWindowsRuntimeInterfaces = false; // Gather all implemented Windows Runtime interfaces for the current type - foreach (TypeSignature interfaceSignature in typeSignature.EnumerateAllInterfaces()) + foreach (TypeSignature interfaceSignature in typeSignature.EnumerateAllInterfaces(interopReferences)) { // Make sure we can resolve the interface type fully, which we should always be able to do. // This can really only fail for some constructed generics, for invalid type arguments. if (!interfaceSignature.IsFullyResolvable(out TypeDefinition? interfaceDefinition)) { - WellKnownInteropExceptions.InterfaceImplementationTypeNotResolvedWarning(interfaceSignature, typeDefinition).LogOrThrow(args.TreatWarningsAsErrors); + WellKnownInteropExceptions.InterfaceImplementationTypeNotResolvedWarning(interfaceSignature, typeSignature).LogOrThrow(args.TreatWarningsAsErrors); continue; } + // Check if the current interface is a Windows Runtime interface. We compute this here so that later + // we can still include the current interface even if it is an '[exclusiveto]' one, while filtering + // out all other '[exclusiveto]' interfaces that might show up as part of the covariant expansion. + // The reason for this is that we want overridable interfaces to be in the CCW interface entries, + // while we don't want them to appear if they're just a type argument for a generic interface. + bool isInterfaceWindowsRuntime = interfaceSignature.IsWindowsRuntimeType(interopReferences); + // Enumerate both the current interface, as well as all covariant combinations derived from it. // This is because we want entries in the vtable to match what users expect in the .NET world. // For instance, consider this scenario: @@ -187,7 +215,9 @@ public static void TryTrackExposedUserDefinedType( // '[exclusiveto]' interfaces too, which might still show up as part of the covariant // expansion. However, those would then either fail to resolve or just result in // unnecessary binary size increase, since nobody would ever use them from here. - if (covariantInterfaceSignature.IsNotExclusiveToWindowsRuntimeType(interopReferences)) + // We also have an additional check to include overridable interfaces (see notes above). + if (covariantInterfaceSignature.IsNotExclusiveToWindowsRuntimeType(interopReferences) || + (isInterfaceWindowsRuntime && SignatureComparer.IgnoreVersion.Equals(covariantInterfaceSignature, interfaceSignature))) { hasAnyProjectedWindowsRuntimeInterfaces = true; @@ -237,6 +267,12 @@ public static void TryTrackExposedUserDefinedType( continue; } + // Ensure that this is the first interface we see implemented on this type with this IID + if (!iids.Add(iid)) + { + WellKnownInteropExceptions.GeneratedComInterfaceDuplicateIidWarning(interfaceDefinition, typeDefinition, iid).LogOrThrow(args.TreatWarningsAsErrors); + } + // Validate that the current interface isn't trying to implement a reserved interface. // For instance, it's not allowed to try to explicitly implement 'IUnknown' or 'IInspectable'. if (WellKnownInterfaceIIDs.ReservedIIDsMap.TryGetValue(iid, out string? interfaceName)) @@ -257,6 +293,121 @@ public static void TryTrackExposedUserDefinedType( discoveryState.TrackUserDefinedType(typeSignature, interfaces.ToEquatableSet()); } + // Return the builder and set to the pool for reuse + TypeSignatureBuilderPool.Add(interfaces); + IidHashSetPool.Add(iids); + } + + /// + /// Tries to track an exposed SZ array type (which may or may not have a Windows Runtime element type). + /// + /// The for the type to analyze. + /// The for the SZ array type to analyze. + /// The arguments for this invocation. + /// The discovery state for this invocation. + /// The instance to use. + /// The module currently being analyzed. + /// + /// This method expects to either be non-generic, or + /// to have be a fully constructed signature for it. + /// + public static void TryTrackExposedSzArrayType( + TypeDefinition typeDefinition, + SzArrayTypeSignature typeSignature, + InteropGeneratorArgs args, + InteropGeneratorDiscoveryState discoveryState, + InteropReferences interopReferences, + ModuleDefinition module) + { + // Ignore types that should explicitly be excluded + if (TypeExclusions.IsExcluded(typeSignature, interopReferences)) + { + return; + } + + // We'll need to look up attributes and enumerate interfaces across the entire type + // hierarchy for this type, so make sure that we can resolve all types from it first. + if (!typeDefinition.IsTypeHierarchyFullyResolvable(out ITypeDefOrRef? failedResolutionBaseType)) + { + WellKnownInteropExceptions.ArrayTypeElementTypeNotFullyResolvedWarning(failedResolutionBaseType, typeDefinition).LogOrThrow(args.TreatWarningsAsErrors); + + return; + } + + // If the element type is a managed only type, ignore the array type. It is true that the array itself + // would still implement some Windows Runtime interfaces, however we assume that if a user has chosen + // to block marshalling for a given type, it means they also wouldn't want code to handle arrays of it. + if (typeDefinition.IsWindowsRuntimeManagedOnlyType(interopReferences)) + { + return; + } + + // Recursion check (see additional notes above) + if (!discoveryState.TryMarkSzArrayType(typeSignature)) + { + return; + } + + // Get or create a builder (see additional notes above) + if (!TypeSignatureBuilderPool.TryTake(out TypeSignatureEquatableSet.Builder? interfaces)) + { + interfaces = new TypeSignatureEquatableSet.Builder(); + } + + // Make sure to clear the builder first (see additional notes above) + interfaces.Clear(); + + // Gather all implemented Windows Runtime interfaces for the current type. Note that for + // SZ arrays we are guaranteed to have at least some, due to the enumerable interfaces + // that the runtime automatically implements on them. + foreach (TypeSignature interfaceSignature in typeSignature.EnumerateAllInterfaces(interopReferences)) + { + // Validate that we can resolve the interface. In this case we should be pretty confident + // that this won't possibly fail, since we expect to only see well-known interfaces here. + if (!interfaceSignature.IsFullyResolvable(out TypeDefinition? interfaceDefinition)) + { + WellKnownInteropExceptions.InterfaceImplementationTypeNotResolvedWarning(interfaceSignature, typeSignature).LogOrThrow(args.TreatWarningsAsErrors); + + continue; + } + + // Enumerate the current interface and the covariant combinations (see additional notes above) + foreach (TypeSignature covariantInterfaceSignature in WindowsRuntimeTypeAnalyzer.EnumerateCovarianceExpandedInterfaceTypes(interfaceSignature, interopReferences).Concat([interfaceSignature])) + { + // Track all interfaces except '[exclusiveto]' ones (see additional notes above). We don't need to care about + // overridable interfaces here, since those can only apply to classes, and SZ arrays will never have any. + if (covariantInterfaceSignature.IsNotExclusiveToWindowsRuntimeType(interopReferences)) + { + interfaces.Add(covariantInterfaceSignature); + + // Make sure that any discovered interfaces are also tracked (see additional notes above) + if (covariantInterfaceSignature is GenericInstanceTypeSignature constructedSignature) + { + TryTrackWindowsRuntimeGenericInterfaceTypeInstance( + typeSignature: constructedSignature, + args: args, + discoveryState, + interopReferences: interopReferences, + module: module); + } + } + } + } + + // If the array is a valid Windows Runtime type, track is specifically as such. + // This is because in this case we'll require additional, specialized marshalling. + if (typeSignature.IsWindowsRuntimeType(interopReferences)) + { + discoveryState.TrackSzArrayType(typeSignature, interfaces.ToEquatableSet()); + } + else + { + // Track the array as a user-defined type. Note that for SZ arrays that don't have an element type + // that is a Windows Runtime type, they're effectively just like any other normal user-defined type. + // That is, some 'Foo[]' type will behave conceptually the same as some 'List' instantiation. + discoveryState.TrackUserDefinedType(typeSignature, interfaces.ToEquatableSet()); + } + // Return the builder to the pool for reuse TypeSignatureBuilderPool.Add(interfaces); } diff --git a/src/WinRT.Interop.Generator/Errors/WellKnownInteropExceptions.cs b/src/WinRT.Interop.Generator/Errors/WellKnownInteropExceptions.cs index 4ee1d7453..0d151bca6 100644 --- a/src/WinRT.Interop.Generator/Errors/WellKnownInteropExceptions.cs +++ b/src/WinRT.Interop.Generator/Errors/WellKnownInteropExceptions.cs @@ -433,7 +433,7 @@ public static WellKnownInteropException CustomMappedTypeComWrappersMarshallerAtt /// /// Failed to resolve the type of an implemented interface. /// - public static WellKnownInteropWarning InterfaceImplementationTypeNotResolvedWarning(TypeSignature interfaceType, TypeDefinition type) + public static WellKnownInteropWarning InterfaceImplementationTypeNotResolvedWarning(TypeSignature interfaceType, TypeSignature type) { return Warning(49, $"Failed to resolve interface type '{interfaceType}' while processing type '{type}': the interface will not be included in the set of available COM interface entries."); } @@ -698,6 +698,22 @@ public static WellKnownInteropException CustomMappedTypeMethodsTypeResolveError( return Exception(81, $"Failed to resolve the associated 'Methods' type for the custom-mapped type '{type}'."); } + /// + /// Failed to resolve the element type for an array type. + /// + public static WellKnownInteropWarning ArrayTypeElementTypeNotFullyResolvedWarning(ITypeDefOrRef baseType, TypeDefinition type) + { + return Warning(82, $"Failed to resolve the base type '{baseType}' in the type hierarchy for element type '{type}': marshalling code for corresponding SZ arrays will not be generated."); + } + + /// + /// Multiple '[GeneratedComInterface]' types are using the same IID. + /// + public static WellKnownInteropWarning GeneratedComInterfaceDuplicateIidWarning(TypeDefinition interfaceType, TypeDefinition type, Guid iid) + { + return Warning(83, $"Failed to validate the '[GeneratedComInterface]' type '{interfaceType}' on type '{type}', because the type already implements another interface with IID '{iid.ToString().ToUpperInvariant()}': the interface will not be included in the set of available COM interface entries."); + } + /// /// Creates a new exception with the specified id and message. /// diff --git a/src/WinRT.Interop.Generator/Extensions/ITypeDescriptorExtensions.cs b/src/WinRT.Interop.Generator/Extensions/ITypeDescriptorExtensions.cs index 381571969..a2de91dc8 100644 --- a/src/WinRT.Interop.Generator/Extensions/ITypeDescriptorExtensions.cs +++ b/src/WinRT.Interop.Generator/Extensions/ITypeDescriptorExtensions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -39,6 +40,26 @@ public bool IsFullyResolvable([NotNullWhen(true)] out TypeDefinition? definition } } + extension(IEnumerable descriptors) + { + /// + /// Sorts the values of a sequence in ascending order, based on the fully qualified type names of the selected key descriptors. + /// + /// The selection function to retrieve values to use for sorting. + /// An whose elements are sorted. + /// Thrown if is . + /// + /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the + /// information that is required to perform the action. The query represented by this method is not executed until the + /// object is enumerated by calling its method. + /// + public IEnumerable OrderByFullyQualifiedTypeName(Func keySelector) + where TKey : class, ITypeDescriptor + { + return descriptors.OrderBy(keySelector, TypeDescriptorComparer.Create()); + } + } + extension(IEnumerable descriptors) where T : class, ITypeDescriptor { @@ -46,7 +67,7 @@ public bool IsFullyResolvable([NotNullWhen(true)] out TypeDefinition? definition /// Sorts the values of a sequence in ascending order, based on their fully qualified type names. /// /// An whose elements are sorted. - /// Thrown if is . + /// Thrown if is . /// /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the /// information that is required to perform the action. The query represented by this method is not executed until the diff --git a/src/WinRT.Interop.Generator/Extensions/TypeSignatureExtensions.cs b/src/WinRT.Interop.Generator/Extensions/TypeSignatureExtensions.cs index 83d20d61b..1f73aeda6 100644 --- a/src/WinRT.Interop.Generator/Extensions/TypeSignatureExtensions.cs +++ b/src/WinRT.Interop.Generator/Extensions/TypeSignatureExtensions.cs @@ -6,6 +6,7 @@ using AsmResolver.DotNet; using AsmResolver.DotNet.Signatures; using AsmResolver.PE.DotNet.Metadata.Tables; +using WindowsRuntime.InteropGenerator.References; namespace WindowsRuntime.InteropGenerator; @@ -68,12 +69,29 @@ public bool IsFullyResolvable([NotNullWhen(true)] out TypeDefinition? definition /// /// Enumerates all interface types implemented by the specified type, including those implemented by base types. /// + /// The instance to use. /// The sequence of interface types implemented by the input type. /// /// This method might return the same interface types multiple times, if implemented by multiple types in the hierarchy. /// - public IEnumerable EnumerateAllInterfaces() + public IEnumerable EnumerateAllInterfaces(InteropReferences interopReferences) { + // Each SZ array also gets a series of interfaces automatically implemented by the runtime. + // The set is fixed, so we can just hardcode those here to make sure they are also discovered. + // The normal logic wouldn't work here, because the base type for arrays is the element type. + if (signature is SzArrayTypeSignature arraySignature) + { + yield return interopReferences.IList.ToReferenceTypeSignature(); + yield return interopReferences.ICollection.ToReferenceTypeSignature(); + yield return interopReferences.IEnumerable.ToReferenceTypeSignature(); + yield return interopReferences.IList1.MakeGenericReferenceType(arraySignature.BaseType); + yield return interopReferences.ICollection1.MakeGenericReferenceType(arraySignature.BaseType); + yield return interopReferences.IEnumerable1.MakeGenericReferenceType(arraySignature.BaseType); + yield return interopReferences.IReadOnlyList1.MakeGenericReferenceType(arraySignature.BaseType); + + yield break; + } + TypeSignature? currentSignature = signature; while (currentSignature is not null) @@ -104,7 +122,7 @@ public IEnumerable EnumerateAllInterfaces() // Also recurse on the base interfaces (no need to instantiate the returned interface type // signatures for base interfaces here: they will be already instantiated when returned). - foreach (TypeSignature baseInterface in interfaceSignature.EnumerateAllInterfaces()) + foreach (TypeSignature baseInterface in interfaceSignature.EnumerateAllInterfaces(interopReferences)) { yield return baseInterface; } @@ -128,9 +146,19 @@ public IEnumerable EnumerateAllInterfaces() /// /// Enumerates all base types of a given type. /// + /// The instance to use. /// The sequence of base types of the input type. - public IEnumerable EnumerateBaseTypes() + public IEnumerable EnumerateBaseTypes(InteropReferences interopReferences) { + // If we see an SZ array, we can directly return 'System.Array' as its only base type. + // We can't rely on the base type from the signature, as it would be the element type. + if (signature is SzArrayTypeSignature) + { + yield return interopReferences.Array.ToReferenceTypeSignature(); + + yield break; + } + TypeSignature? currentSignature = signature; while (currentSignature is not null) diff --git a/src/WinRT.Interop.Generator/Factories/WellKnownTypeDefinitionFactory.cs b/src/WinRT.Interop.Generator/Factories/WellKnownTypeDefinitionFactory.cs index 57f0dd4fe..d94b86a8c 100644 --- a/src/WinRT.Interop.Generator/Factories/WellKnownTypeDefinitionFactory.cs +++ b/src/WinRT.Interop.Generator/Factories/WellKnownTypeDefinitionFactory.cs @@ -7,6 +7,7 @@ using AsmResolver.DotNet.Signatures; using AsmResolver.PE.DotNet.Metadata.Tables; using WindowsRuntime.InteropGenerator.References; +using WindowsRuntime.InteropGenerator.Resolvers; namespace WindowsRuntime.InteropGenerator.Factories; @@ -162,7 +163,7 @@ public static TypeDefinition DelegateVftbl( } /// - /// Creates a new type definition for the vtable of an 'IReference`1<T>' instantiation for some type. + /// Creates a new type definition for the vtable of an 'IReference<T>' instantiation for some type. /// /// The instance to use. /// The module that will contain the type being created. @@ -196,7 +197,7 @@ public static TypeDefinition DelegateReferenceVftbl(InteropReferences interopRef module.CorLibTypeFactory.Void.MakePointerType(), module.CorLibTypeFactory.Void.MakePointerType().MakePointerType()]); - // The vtable layout for 'IReference`1' looks like this: + // The vtable layout for 'IReference' looks like this: // // public delegate* unmanaged[MemberFunction] QueryInterface; // public delegate* unmanaged[MemberFunction] AddRef; @@ -1224,44 +1225,42 @@ public static TypeDefinition ReferenceArrayVftbl(InteropReferences interopRefere } /// - /// Creates a new type definition for COM interface entries for some SZ array type. + /// Creates a new type definition for COM interface entries for a user-defined type. /// + /// The number of COM interface entries to generate in the type. /// The instance to use. /// The module that will contain the type being created. /// The resulting instance. - public static TypeDefinition ReferenceArrayInterfaceEntriesType(InteropReferences interopReferences, ModuleDefinition module) + public static TypeDefinition UserDefinedInterfaceEntriesType(int numberOfEntries, InteropReferences interopReferences, ModuleDefinition module) { TypeDefinition interfaceEntriesType = new( ns: null, - name: ""u8, + name: $"", attributes: TypeAttributes.SequentialLayout | TypeAttributes.Sealed | TypeAttributes.BeforeFieldInit, baseType: interopReferences.ValueType.Import(module)); - // Get the signature for the 'ComInterfaceEntry' type (this is a bit involved, so cache it) + // Get the signature for the 'ComInterfaceEntry' type TypeSignature comInterfaceEntryType = interopReferences.ComInterfaceEntry.Import(module).ToValueTypeSignature(); - // The type layout looks like this: + // Calculate the number of dynamic entries, i.e. the ones from explicitly implemented interfaces + int numberOfDynamicEntries = numberOfEntries - InteropInterfaceEntriesResolver.NumberOfNativeComInterfaceEntries; + + ArgumentOutOfRangeException.ThrowIfLessThan(numberOfDynamicEntries, 1, nameof(numberOfEntries)); + + // Add a field for each interface entry + for (int i = 0; i < numberOfDynamicEntries; i++) + { + interfaceEntriesType.Fields.Add(new FieldDefinition($"InterfaceEntry(Index={i})", FieldAttributes.Public, comInterfaceEntryType)); + } + + // Add the default entries // - // public ComInterfaceEntry IReferenceArray'1; - // public ComInterfaceEntry IBindableVector; - // public ComInterfaceEntry IBindableIterable; - // public ComInterfaceEntry IVector'1; - // public ComInterfaceEntry IIterable'1; - // public ComInterfaceEntry IVectorView'1; - // public ComInterfaceEntry IPropertyValue; // public ComInterfaceEntry IStringable; // public ComInterfaceEntry IWeakReferenceSource; // public ComInterfaceEntry IMarshal; // public ComInterfaceEntry IAgileObject; // public ComInterfaceEntry IInspectable; // public ComInterfaceEntry IUnknown; - interfaceEntriesType.Fields.Add(new FieldDefinition("IReferenceArray'1"u8, FieldAttributes.Public, comInterfaceEntryType)); - interfaceEntriesType.Fields.Add(new FieldDefinition("IBindableVector"u8, FieldAttributes.Public, comInterfaceEntryType)); - interfaceEntriesType.Fields.Add(new FieldDefinition("IBindableIterable"u8, FieldAttributes.Public, comInterfaceEntryType)); - interfaceEntriesType.Fields.Add(new FieldDefinition("IVector'1"u8, FieldAttributes.Public, comInterfaceEntryType)); - interfaceEntriesType.Fields.Add(new FieldDefinition("IIterable'1"u8, FieldAttributes.Public, comInterfaceEntryType)); - interfaceEntriesType.Fields.Add(new FieldDefinition("IVectorView'1"u8, FieldAttributes.Public, comInterfaceEntryType)); - interfaceEntriesType.Fields.Add(new FieldDefinition("IPropertyValue"u8, FieldAttributes.Public, comInterfaceEntryType)); interfaceEntriesType.Fields.Add(new FieldDefinition("IStringable"u8, FieldAttributes.Public, comInterfaceEntryType)); interfaceEntriesType.Fields.Add(new FieldDefinition("IWeakReferenceSource"u8, FieldAttributes.Public, comInterfaceEntryType)); interfaceEntriesType.Fields.Add(new FieldDefinition("IMarshal"u8, FieldAttributes.Public, comInterfaceEntryType)); @@ -1273,29 +1272,57 @@ public static TypeDefinition ReferenceArrayInterfaceEntriesType(InteropReference } /// - /// Creates a new type definition for COM interface entries for a user-defined type. + /// Creates a new type definition for COM interface entries for an SZ array type. /// /// The number of COM interface entries to generate in the type. /// The instance to use. /// The module that will contain the type being created. /// The resulting instance. - public static TypeDefinition UserDefinedInterfaceEntriesType(int numberOfEntries, InteropReferences interopReferences, ModuleDefinition module) + public static TypeDefinition SzArrayInterfaceEntriesType(int numberOfEntries, InteropReferences interopReferences, ModuleDefinition module) { TypeDefinition interfaceEntriesType = new( ns: null, - name: $"", + name: $"", attributes: TypeAttributes.SequentialLayout | TypeAttributes.Sealed | TypeAttributes.BeforeFieldInit, baseType: interopReferences.ValueType.Import(module)); // Get the signature for the 'ComInterfaceEntry' type TypeSignature comInterfaceEntryType = interopReferences.ComInterfaceEntry.Import(module).ToValueTypeSignature(); - // Add a field for each interface entry - for (int i = 0; i < numberOfEntries; i++) + // Calculate the number of dynamic entries, i.e. the ones from explicitly implemented interfaces. + // This is similar to user-defined types (see above), except we also have two additional entries + // for the 'IReferenceArray' interface, and for the 'IPropertyValue' interface. + int numberOfDynamicEntries = numberOfEntries - InteropInterfaceEntriesResolver.NumberOfNativeComInterfaceEntries - 2; + + ArgumentOutOfRangeException.ThrowIfLessThan(numberOfDynamicEntries, 1, nameof(numberOfEntries)); + + // Add the special entry for the 'IReferenceArray' interface. This always has higher + // priority for CCWs of types that implement some kind of 'IReference*' interface. + interfaceEntriesType.Fields.Add(new FieldDefinition("ArrayReference"u8, FieldAttributes.Public, comInterfaceEntryType)); + + // Add a field for each interface entry (names start at index '1' since 'IReferenceArray' comes before them) + for (int i = 0; i < numberOfDynamicEntries; i++) { - interfaceEntriesType.Fields.Add(new FieldDefinition($"InterfaceEntry(Index={i})", FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition($"InterfaceEntry(Index={i + 1})", FieldAttributes.Public, comInterfaceEntryType)); } + // Add the default entries + // + // public ComInterfaceEntry IPropertyValue; + // public ComInterfaceEntry IStringable; + // public ComInterfaceEntry IWeakReferenceSource; + // public ComInterfaceEntry IMarshal; + // public ComInterfaceEntry IAgileObject; + // public ComInterfaceEntry IInspectable; + // public ComInterfaceEntry IUnknown; + interfaceEntriesType.Fields.Add(new FieldDefinition("IPropertyValue"u8, FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition("IStringable"u8, FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition("IWeakReferenceSource"u8, FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition("IMarshal"u8, FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition("IAgileObject"u8, FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition("IInspectable"u8, FieldAttributes.Public, comInterfaceEntryType)); + interfaceEntriesType.Fields.Add(new FieldDefinition("IUnknown"u8, FieldAttributes.Public, comInterfaceEntryType)); + return interfaceEntriesType; } diff --git a/src/WinRT.Interop.Generator/Generation/InteropGenerator.Discover.cs b/src/WinRT.Interop.Generator/Generation/InteropGenerator.Discover.cs index 53a65fab3..e289bff12 100644 --- a/src/WinRT.Interop.Generator/Generation/InteropGenerator.Discover.cs +++ b/src/WinRT.Interop.Generator/Generation/InteropGenerator.Discover.cs @@ -259,7 +259,7 @@ private static void DiscoverSzArrayTypes( { args.Token.ThrowIfCancellationRequested(); - // Track the SZ array type (if it's not applicable, it will be a no-op) + // Track the SZ array type (both for Windows Runtime types and user-defined types) InteropTypeDiscovery.TryTrackSzArrayType( typeSignature: typeSignature, args: args, diff --git a/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs b/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs index ff309c5b8..3f3c7b526 100644 --- a/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs +++ b/src/WinRT.Interop.Generator/Generation/InteropGenerator.Emit.cs @@ -2143,7 +2143,7 @@ private static void DefineSzArrayTypes( InteropReferences interopReferences, ModuleDefinition module) { - foreach (SzArrayTypeSignature typeSignature in discoveryState.SzArrayTypes.OrderByFullyQualifiedTypeName()) + foreach ((SzArrayTypeSignature typeSignature, TypeSignatureEquatableSet vtableTypes) in discoveryState.SzArrayAndVtableTypes.OrderByFullyQualifiedTypeName(static pair => pair.Key)) { args.Token.ThrowIfCancellationRequested(); @@ -2181,6 +2181,7 @@ private static void DefineSzArrayTypes( InteropTypeDefinitionBuilder.SzArray.InterfaceEntriesImpl( arrayType: typeSignature, + vtableTypes: vtableTypes, implType: arrayImplType, get_IidMethod: get_IidMethod, interopDefinitions: interopDefinitions, @@ -2188,10 +2189,12 @@ private static void DefineSzArrayTypes( emitState: emitState, module: module, useWindowsUIXamlProjections: args.UseWindowsUIXamlProjections, + interfaceEntriesType: out TypeDefinition interfaceEntriesType, interfaceEntriesImplType: out TypeDefinition arrayInterfaceEntriesImplType); InteropTypeDefinitionBuilder.SzArray.ComWrappersMarshallerAttribute( arrayType: typeSignature, + arrayInterfaceEntriesType: interfaceEntriesType, arrayInterfaceEntriesImplType: arrayInterfaceEntriesImplType, arrayComWrappersCallbackType: arrayComWrappersCallbackType, get_IidMethod: get_IidMethod, @@ -2417,16 +2420,17 @@ private static void DefineUserDefinedTypes( InteropTypeDefinitionBuilder.UserDefinedType.InterfaceEntriesImpl( userDefinedType: typeSignature, vtableTypes: vtableTypes, - args: args, + useWindowsUIXamlProjections: args.UseWindowsUIXamlProjections, interopDefinitions: interopDefinitions, interopReferences: interopReferences, emitState: emitState, module: module, + interfaceEntriesType: out TypeDefinition interfaceEntriesType, interfaceEntriesImplType: out TypeDefinition interfaceEntriesImplType); InteropTypeDefinitionBuilder.UserDefinedType.ComWrappersMarshallerAttribute( userDefinedType: typeSignature, - vtableTypes: vtableTypes, + interfaceEntriesType: interfaceEntriesType, interfaceEntriesImplType: interfaceEntriesImplType, interopDefinitions: interopDefinitions, interopReferences: interopReferences, @@ -2443,7 +2447,7 @@ private static void DefineUserDefinedTypes( } // Next, we can emit the actual proxy types for each user-defined type exposed as a CCW - foreach ((TypeSignature typeSignature, TypeSignatureEquatableSet vtableTypes) in discoveryState.UserDefinedAndVtableTypes.OrderBy(static pair => pair.Key, TypeDescriptorComparer.Create())) + foreach ((TypeSignature typeSignature, TypeSignatureEquatableSet vtableTypes) in discoveryState.UserDefinedAndVtableTypes.OrderByFullyQualifiedTypeName(static pair => pair.Key)) { args.Token.ThrowIfCancellationRequested(); @@ -2504,7 +2508,6 @@ private static void DefineDefaultImplementationDetailTypes(InteropDefinitions in module.TopLevelTypes.Add(interopDefinitions.IAsyncOperationWithProgressVftbl); module.TopLevelTypes.Add(interopDefinitions.IMapChangedEventArgsVftbl); module.TopLevelTypes.Add(interopDefinitions.IReferenceArrayVftbl); - module.TopLevelTypes.Add(interopDefinitions.IReferenceArrayInterfaceEntries); } catch (Exception e) { @@ -2521,11 +2524,17 @@ private static void DefineDynamicImplementationDetailTypes(InteropDefinitions in { try { - // Also emit all shared COM interface entries types that are programmatically generated + // Emit all shared COM interface entries types that are programmatically generated for user-defined types foreach (TypeDefinition typeDefinition in interopDefinitions.EnumerateUserDefinedInterfaceEntriesTypes().OrderByFullyQualifiedTypeName()) { module.TopLevelTypes.Add(typeDefinition); } + + // Also emit interface entries types for SZ arrays, same as for user-defined types above + foreach (TypeDefinition typeDefinition in interopDefinitions.EnumerateSzArrayInterfaceEntriesTypes().OrderByFullyQualifiedTypeName()) + { + module.TopLevelTypes.Add(typeDefinition); + } } catch (Exception e) { diff --git a/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs b/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs index 846feae00..5fbfdabe2 100644 --- a/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs +++ b/src/WinRT.Interop.Generator/Generation/InteropGeneratorDiscoveryState.cs @@ -63,18 +63,21 @@ internal sealed class InteropGeneratorDiscoveryState /// Backing field for . private readonly ConcurrentDictionary _keyValuePairTypes = new(SignatureComparer.IgnoreVersion); - /// Backing field for . - private readonly ConcurrentDictionary _szArrayTypes = new(SignatureComparer.IgnoreVersion); - /// Backing field to support . private readonly ConcurrentDictionary _markedUserDefinedTypes = new(SignatureComparer.IgnoreVersion); + /// Backing field to support . + private readonly ConcurrentDictionary _markedSzArrayTypes = new(SignatureComparer.IgnoreVersion); + /// Backing field to support . private readonly ConcurrentDictionary _markedWindowsRuntimeGenericInterfaceTypeInstances = new(SignatureComparer.IgnoreVersion); - /// Backing field for . + /// Backing field for . private readonly ConcurrentDictionary _userDefinedTypes = new(SignatureComparer.IgnoreVersion); + /// Backing field for . + private readonly ConcurrentDictionary _szArrayTypes = new(SignatureComparer.IgnoreVersion); + /// Backing field for . /// /// The value is also so we can de-duplicate equivalent sets across different maps (e.g. ). @@ -181,11 +184,6 @@ internal sealed class InteropGeneratorDiscoveryState /// public IReadOnlyCollection KeyValuePairTypes => (IReadOnlyCollection)_keyValuePairTypes.Keys; - /// - /// Gets all SZ array types. - /// - public IReadOnlyCollection SzArrayTypes => (IReadOnlyCollection)_szArrayTypes.Keys; - /// /// Gets all user-defined types. /// @@ -201,6 +199,11 @@ internal sealed class InteropGeneratorDiscoveryState /// public IReadOnlyCollection UserDefinedVtableTypes => (IReadOnlyCollection)_userDefinedVtableTypes.Keys; + /// + /// Gets all SZ array types and their vtable types. + /// + public IReadOnlyDictionary SzArrayAndVtableTypes => _szArrayTypes; + /// /// Gets whether any of the loaded modules reference the WinRT runtime .dll version 2. /// @@ -384,17 +387,6 @@ public void TrackKeyValuePairType(GenericInstanceTypeSignature keyValuePairType) _ = _keyValuePairTypes.TryAdd(keyValuePairType, 0); } - /// - /// Tracks a SZ array type. - /// - /// The SZ array type. - public void TrackSzArrayType(SzArrayTypeSignature szArrayType) - { - ThrowIfReadOnly(); - - _ = _szArrayTypes.TryAdd(szArrayType, 0); - } - /// /// Tries to mark a user-defined type as having been seen the first time, /// and indicating that it's in the process of being processed. @@ -406,13 +398,24 @@ public bool TryMarkUserDefinedType(TypeSignature userDefinedType) return _markedUserDefinedTypes.TryAdd(userDefinedType, 0); } + /// + /// Tries to mark an SZ array type as having been seen the first time, + /// and indicating that it's in the process of being processed. + /// + /// The SZ array type. + /// Whether this was the first time that was seen. + public bool TryMarkSzArrayType(SzArrayTypeSignature arrayType) + { + return _markedSzArrayTypes.TryAdd(arrayType, 0); + } + /// /// Tries to mark a constructed generic Windows Runtime interface type as having been seen the first time, /// and indicating that it's in the process of being processed. /// /// The constructed generic Windows Runtime interface type. /// Whether this was the first time that was seen. - public bool TryMarkWindowsRuntimeGenericInterfaceTypeInstance(TypeSignature interfaceType) + public bool TryMarkWindowsRuntimeGenericInterfaceTypeInstance(GenericInstanceTypeSignature interfaceType) { return _markedWindowsRuntimeGenericInterfaceTypeInstances.TryAdd(interfaceType, 0); } @@ -435,6 +438,18 @@ public void TrackUserDefinedType(TypeSignature userDefinedType, TypeSignatureEqu _ = _userDefinedTypes.TryAdd(userDefinedType, cachedVtableTypes); } + /// + /// Tracks an SZ array type. + /// + /// The SZ array type. + /// The vtable types for . + public void TrackSzArrayType(SzArrayTypeSignature arrayType, TypeSignatureEquatableSet vtableTypes) + { + ThrowIfReadOnly(); + + _ = _szArrayTypes.TryAdd(arrayType, vtableTypes); + } + /// /// Tracks a type that failed resolution. /// diff --git a/src/WinRT.Interop.Generator/Helpers/WindowsRuntimeTypeAnalyzer.cs b/src/WinRT.Interop.Generator/Helpers/WindowsRuntimeTypeAnalyzer.cs index 372d6d703..de6ef1430 100644 --- a/src/WinRT.Interop.Generator/Helpers/WindowsRuntimeTypeAnalyzer.cs +++ b/src/WinRT.Interop.Generator/Helpers/WindowsRuntimeTypeAnalyzer.cs @@ -29,7 +29,7 @@ public static bool TryGetMostDerivedWindowsRuntimeInterfaceType( interfaceType = null; // Go through all implemented interfaces for the user-defined type - foreach (TypeSignature interfaceSignature in type.EnumerateAllInterfaces()) + foreach (TypeSignature interfaceSignature in type.EnumerateAllInterfaces(interopReferences)) { // If the current interface is not a Windows Runtime type, just skip it. // We can only use Windows Runtime interfaces for the runtime class name. @@ -119,7 +119,7 @@ static IEnumerable EnumerateCovariantInterfaceTypesCore( yield return interfaceType; // Next, gather all combinations from interfaces implemented by the element type - foreach (TypeSignature elementInterfaceType in elementType.EnumerateAllInterfaces()) + foreach (TypeSignature elementInterfaceType in elementType.EnumerateAllInterfaces(interopReferences)) { // Construct the generic interface with the current element type yield return genericInterfaceType.MakeGenericReferenceType(elementInterfaceType); @@ -135,7 +135,7 @@ static IEnumerable EnumerateCovariantInterfaceTypesCore( } // Then, also gather all base types for the element type - foreach (TypeSignature baseType in elementType.EnumerateBaseTypes()) + foreach (TypeSignature baseType in elementType.EnumerateBaseTypes(interopReferences)) { yield return genericInterfaceType.MakeGenericReferenceType(baseType); } diff --git a/src/WinRT.Interop.Generator/Models/TypeSignatureEquatableSet.cs b/src/WinRT.Interop.Generator/Models/TypeSignatureEquatableSet.cs index 0aa152dd9..33cedc507 100644 --- a/src/WinRT.Interop.Generator/Models/TypeSignatureEquatableSet.cs +++ b/src/WinRT.Interop.Generator/Models/TypeSignatureEquatableSet.cs @@ -101,7 +101,7 @@ public override int GetHashCode() // And results in different hashcodes for equivalent signatures, which breaks everything. HashCode hashCode = default; - foreach (TypeSignature typeSignature in _set) + foreach (TypeSignature typeSignature in _set.OrderByFullyQualifiedTypeName()) { hashCode.Add(typeSignature, SignatureComparer.IgnoreVersion); } diff --git a/src/WinRT.Interop.Generator/References/InteropDefinitions.cs b/src/WinRT.Interop.Generator/References/InteropDefinitions.cs index c3e34b074..739eb7f96 100644 --- a/src/WinRT.Interop.Generator/References/InteropDefinitions.cs +++ b/src/WinRT.Interop.Generator/References/InteropDefinitions.cs @@ -28,6 +28,11 @@ internal sealed class InteropDefinitions /// private readonly ConcurrentDictionary _userDefinedInterfaceEntries; + /// + /// The map of generated COM interface entries types for SZ array types (with a Windows Runtime type for their element type). + /// + private readonly ConcurrentDictionary _szArrayInterfaceEntries; + /// /// Creates a new instance. /// @@ -38,6 +43,7 @@ public InteropDefinitions(InteropReferences interopReferences, ModuleDefinition _interopReferences = interopReferences; _interopModule = interopModule; _userDefinedInterfaceEntries = []; + _szArrayInterfaceEntries = []; } /// @@ -164,12 +170,7 @@ public InteropDefinitions(InteropReferences interopReferences, ModuleDefinition public TypeDefinition IReferenceArrayVftbl => field ??= WellKnownTypeDefinitionFactory.ReferenceArrayVftbl(_interopReferences, _interopModule); /// - /// Gets the for the IReferenceArrayInterfaceEntries type. - /// - public TypeDefinition IReferenceArrayInterfaceEntries => field ??= WellKnownTypeDefinitionFactory.ReferenceArrayInterfaceEntriesType(_interopReferences, _interopModule); - - /// - /// Enumerates all necessary COM interface entries types. + /// Enumerates all necessary COM interface entries types for user-defined types. /// /// The sequence of all necessary COM interface entries types. /// @@ -191,4 +192,28 @@ public TypeDefinition UserDefinedInterfaceEntries(int numberOfEntries) key: numberOfEntries, valueFactory: numberOfEntries => WellKnownTypeDefinitionFactory.UserDefinedInterfaceEntriesType(numberOfEntries, _interopReferences, _interopModule)); } + + /// + /// Enumerates all necessary COM interface entries types for SZ array types. + /// + /// The sequence of all necessary COM interface entries types. + /// + /// This method must be called after all necessary calls to . + /// + public IEnumerable EnumerateSzArrayInterfaceEntriesTypes() + { + return _szArrayInterfaceEntries.Values; + } + + /// + /// Gets the for the COM interface entries type for SZ array types with the specified number of entries. + /// + /// The number of COM interface entries to generate in the type. + /// The resulting instance. + public TypeDefinition SzArrayInterfaceEntries(int numberOfEntries) + { + return _szArrayInterfaceEntries.GetOrAdd( + key: numberOfEntries, + valueFactory: numberOfEntries => WellKnownTypeDefinitionFactory.SzArrayInterfaceEntriesType(numberOfEntries, _interopReferences, _interopModule)); + } } \ No newline at end of file diff --git a/src/WinRT.Interop.Generator/References/InteropReferences.cs b/src/WinRT.Interop.Generator/References/InteropReferences.cs index 36afe8df7..b069ecd56 100644 --- a/src/WinRT.Interop.Generator/References/InteropReferences.cs +++ b/src/WinRT.Interop.Generator/References/InteropReferences.cs @@ -373,6 +373,11 @@ public InteropReferences( /// public TypeReference IEnumerable1 => field ??= _corLibTypeFactory.CorLibScope.CreateTypeReference("System.Collections.Generic"u8, "IEnumerable`1"u8); + /// + /// Gets the for . + /// + public TypeReference ICollection => field ??= _corLibTypeFactory.CorLibScope.CreateTypeReference("System.Collections"u8, "ICollection"u8); + /// /// Gets the for . /// diff --git a/src/WinRT.Interop.Generator/Resolvers/InteropInterfaceEntriesResolver.cs b/src/WinRT.Interop.Generator/Resolvers/InteropInterfaceEntriesResolver.cs new file mode 100644 index 000000000..716dfe71d --- /dev/null +++ b/src/WinRT.Interop.Generator/Resolvers/InteropInterfaceEntriesResolver.cs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using AsmResolver.DotNet; +using AsmResolver.DotNet.Code.Cil; +using AsmResolver.DotNet.Signatures; +using WindowsRuntime.InteropGenerator.Generation; +using WindowsRuntime.InteropGenerator.Models; +using WindowsRuntime.InteropGenerator.References; +using static AsmResolver.PE.DotNet.Cil.CilOpCodes; + +namespace WindowsRuntime.InteropGenerator.Resolvers; + +/// +/// A resolver for CCW interface entries for some managed type to be exposed to native code. +/// +internal static class InteropInterfaceEntriesResolver +{ + /// + /// The number of default, always present COM interface entries (returned by ). + /// + public const int NumberOfNativeComInterfaceEntries = 6; + + /// + /// Creates an instance with a provided set of methods. + /// + /// The value to get the interface IID. + /// The value to get the interface vtable. + /// The resulting instance. + public static InteropInterfaceEntryInfo Create(IMethodDefOrRef get_IID, IMethodDefOrRef get_Vtable) + { + return new WindowsRuntimeInterfaceEntryInfo(get_IID, get_Vtable); + } + + /// + /// Enumerates all values from a given source set of vtable types. + /// + /// The vtable types to use as source. + /// The instance to use. + /// The instance to use. + /// The emit state for this invocation. + /// The module that will contain the type being created. + /// Whether to use Windows.UI.Xaml projections. + public static IEnumerable EnumerateMetadataInterfaceEntries( + TypeSignatureEquatableSet vtableTypes, + InteropDefinitions interopDefinitions, + InteropReferences interopReferences, + InteropGeneratorEmitState emitState, + ModuleDefinition module, + bool useWindowsUIXamlProjections) + { + // Append all entries for the type (which we share for all matching user-defined types) + foreach (TypeSignature typeSignature in vtableTypes) + { + // Handle generic types first, and then custom-mapped and manually projected types. + // These require special handling, because their ABI types are in different locations. + if (typeSignature is GenericInstanceTypeSignature genericTypeSignature) + { + (IMethodDefOrRef get_IIDMethod, IMethodDefOrRef get_VtableMethod) = InteropImplTypeResolver.GetGenericInstanceTypeImpl( + type: genericTypeSignature, + interopDefinitions: interopDefinitions, + interopReferences: interopReferences, + emitState: emitState); + + yield return new WindowsRuntimeInterfaceEntryInfo(get_IIDMethod, get_VtableMethod); + } + else if (typeSignature.IsCustomMappedWindowsRuntimeInterfaceType(interopReferences) || + typeSignature.IsManuallyProjectedWindowsRuntimeInterfaceType(interopReferences)) + { + (IMethodDefOrRef get_IIDMethod, IMethodDefOrRef get_VtableMethod) = InteropImplTypeResolver.GetCustomMappedOrManuallyProjectedTypeImpl( + type: typeSignature, + interopReferences: interopReferences, + useWindowsUIXamlProjections: useWindowsUIXamlProjections); + + yield return new WindowsRuntimeInterfaceEntryInfo(get_IIDMethod, get_VtableMethod); + } + else + { + // We always need to resolve the user-defined types in all cases below, so just do it once first + TypeDefinition interfaceType = typeSignature.Resolve()!; + + // For '[GeneratedComInterface]', we need to retrieve and use the generated vtable from the COM generators + if (interfaceType.IsGeneratedComInterfaceType) + { + // Ignore interfaces we can't retrieve information for (this should never happen, interfaces are filtered during discovery) + if (!interfaceType.TryGetInterfaceInformationType(interopReferences, out TypeSignature? interfaceInformationType)) + { + continue; + } + + // Get the IID of the interface (same as above, this is pre-validated) + if (!interfaceType.TryGetGuidAttribute(interopReferences, out Guid interfaceId)) + { + continue; + } + + // If we find the special 'IMarshal' interface, ignore it here. We want to use this + // later to replace our built-in 'IMarshal' implementation in its own vtable slot. + if (interfaceId == WellKnownInterfaceIIDs.IID_IMarshal) + { + continue; + } + + yield return new ComInterfaceEntryInfo(interfaceInformationType); + } + else + { + // This is the common case for all normally projected, non-generic Windows Runtime types + (IMethodDefOrRef get_IIDMethod, IMethodDefOrRef get_VtableMethod) = InteropImplTypeResolver.GetProjectedTypeImpl( + type: interfaceType, + interopReferences: interopReferences); + + yield return new WindowsRuntimeInterfaceEntryInfo(get_IIDMethod, get_VtableMethod); + } + } + } + } + + /// + /// Enumerates all values for native interfaces. + /// + /// The vtable types to use as source. + /// The instance to use. + public static IEnumerable EnumerateNativeInterfaceEntries( + TypeSignatureEquatableSet vtableTypes, + InteropReferences interopReferences) + { + // Get the entry info for 'IMarshal', either user-provided or the built-in one + if (!TryGetUserDefinedIMarshalInterfaceImplementation( + vtableTypes: vtableTypes, + interopReferences: interopReferences, + interfaceEntryInfo: out InteropInterfaceEntryInfo? marshalInterfaceEntryInfo)) + { + marshalInterfaceEntryInfo = new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IMarshal, interopReferences.IMarshalImplget_Vtable); + } + + // Prepare the set of all built-in native interface implementations. These always follow the vtable slots for + // user-defined interfaces implemented by exposed types. 'IUnknown' in particular must always be the last one. + yield return new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IStringable, interopReferences.IStringableImplget_Vtable); + yield return new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IWeakReferenceSource, interopReferences.IWeakReferenceSourceImplget_Vtable); + yield return marshalInterfaceEntryInfo; + yield return new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IAgileObject, interopReferences.IAgileObjectImplget_Vtable); + yield return new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IInspectable, interopReferences.IInspectableImplget_Vtable); + yield return new WindowsRuntimeInterfaceEntryInfo(interopReferences.WellKnownInterfaceIIDsget_IID_IUnknown, interopReferences.IUnknownImplget_Vtable); + } + + /// + /// Tries to get the value for a user-defined IMarshal interface. + /// + /// The vtable types to use as source. + /// The instance to use. + /// The resulting value for IMarshal, if found. + /// Whether was found. + private static bool TryGetUserDefinedIMarshalInterfaceImplementation( + TypeSignatureEquatableSet vtableTypes, + InteropReferences interopReferences, + [NotNullWhen(true)] out InteropInterfaceEntryInfo? interfaceEntryInfo) + { + foreach (TypeSignature typeSignature in vtableTypes) + { + // Ignore generic interfaces ('IMarshal' isn't generic) + if (typeSignature is GenericInstanceTypeSignature) + { + continue; + } + + // Ignore all custom-mapped and special interfaces as well + if (typeSignature.IsCustomMappedWindowsRuntimeInterfaceType(interopReferences) || + typeSignature.IsManuallyProjectedWindowsRuntimeInterfaceType(interopReferences)) + { + continue; + } + + // Resolve the user-defined interface type (same as above) + TypeDefinition interfaceType = typeSignature.Resolve()!; + + // We only care about '[GeneratedComInterface]' types + if (!interfaceType.IsGeneratedComInterfaceType) + { + continue; + } + + // Get the IID of the interface (same as above) + if (!interfaceType.TryGetGuidAttribute(interopReferences, out Guid interfaceId)) + { + continue; + } + + // Make sure that this is the 'IMarshal' implementation (we might not find one at all) + if (interfaceId != WellKnownInterfaceIIDs.IID_IMarshal) + { + continue; + } + + // Only get the information type now that we know we do need it + if (!interfaceType.TryGetInterfaceInformationType(interopReferences, out TypeSignature? interfaceInformationType)) + { + continue; + } + + interfaceEntryInfo = new ComInterfaceEntryInfo(interfaceInformationType); + + return true; + } + + interfaceEntryInfo = null; + + return false; + } + + /// + /// An type for Windows Runtime types. + /// + /// The value to get the interface IID. + /// The value to get the interface vtable. + private sealed class WindowsRuntimeInterfaceEntryInfo(IMethodDefOrRef get_IID, IMethodDefOrRef get_Vtable) : InteropInterfaceEntryInfo + { + /// + public override void LoadIID(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) + { + _ = instructions.Add(Call, get_IID.Import(module)); + _ = instructions.Add(Ldobj, interopReferences.Guid.Import(module)); + } + + /// + public override void LoadVtable(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) + { + _ = instructions.Add(Call, get_Vtable.Import(module)); + } + } + + /// + /// An type for COM types. + /// + /// The InterfaceInformation type for the current interface. + private sealed class ComInterfaceEntryInfo(TypeSignature interfaceInformationType) : InteropInterfaceEntryInfo + { + /// + public override void LoadIID(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) + { + _ = instructions.Add(Constrained, interfaceInformationType.Import(module).ToTypeDefOrRef()); + _ = instructions.Add(Call, interopReferences.IIUnknownInterfaceTypeget_Iid.Import(module)); + } + + /// + public override void LoadVtable(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module) + { + _ = instructions.Add(Constrained, interfaceInformationType.Import(module).ToTypeDefOrRef()); + _ = instructions.Add(Call, interopReferences.IIUnknownInterfaceTypeget_ManagedVirtualMethodTable.Import(module)); + } + } +} diff --git a/src/WinRT.Interop.Generator/Resolvers/InteropInterfaceEntryInfo.cs b/src/WinRT.Interop.Generator/Resolvers/InteropInterfaceEntryInfo.cs new file mode 100644 index 000000000..6d1fd393e --- /dev/null +++ b/src/WinRT.Interop.Generator/Resolvers/InteropInterfaceEntryInfo.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using AsmResolver.DotNet; +using AsmResolver.DotNet.Code.Cil; +using WindowsRuntime.InteropGenerator.References; + +namespace WindowsRuntime.InteropGenerator.Resolvers; + +/// +/// A base type to abstract inserting interface entries information into a static constructor. +/// +internal abstract class InteropInterfaceEntryInfo +{ + /// + /// Loads the IID for the interface onto the evaluation stack. + /// + /// The target . + /// The instance to use. + /// The in use. + public abstract void LoadIID(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module); + + /// + /// Loads the vtable for the interface onto the evaluation stack. + /// + /// The target . + /// The instance to use. + /// The in use. + public abstract void LoadVtable(CilInstructionCollection instructions, InteropReferences interopReferences, ModuleDefinition module); +}