Skip to content

Commit 287d840

Browse files
authored
Handle CoCreateable classes in ComSourceGenerators mode (#1502)
* Handle CoCreateable classes
1 parent 0ca3106 commit 287d840

10 files changed

Lines changed: 217 additions & 9 deletions

File tree

src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ internal static SyntaxToken Token(SyntaxKind kind)
7676

7777
internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression);
7878

79+
internal static CollectionExpressionSyntax CollectionExpression(SeparatedSyntaxList<CollectionElementSyntax> elements = default) => SyntaxFactory.CollectionExpression(elements);
80+
81+
internal static ExpressionElementSyntax ExpressionElement(ExpressionSyntax expression) => SyntaxFactory.ExpressionElement(expression);
82+
7983
internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList<ExpressionSyntax> incrementors, StatementSyntax statement)
8084
{
8185
SyntaxToken semicolonToken = SyntaxFactory.Token(TriviaList(), SyntaxKind.SemicolonToken, TriviaList(Space));
@@ -321,6 +325,8 @@ internal static SyntaxList<TNode> List<TNode>(IEnumerable<TNode> nodes)
321325

322326
internal static TypeConstraintSyntax TypeConstraint(TypeSyntax type) => SyntaxFactory.TypeConstraint(type);
323327

328+
internal static ClassOrStructConstraintSyntax ClassOrStructConstraint(SyntaxKind kind) => SyntaxFactory.ClassOrStructConstraint(kind);
329+
324330
internal static TypeParameterConstraintClauseSyntax TypeParameterConstraintClause(IdentifierNameSyntax name, SeparatedSyntaxList<TypeParameterConstraintSyntax> constraints) => SyntaxFactory.TypeParameterConstraintClause(TokenWithSpace(SyntaxKind.WhereKeyword), name, TokenWithSpaces(SyntaxKind.ColonToken), constraints);
325331

326332
internal static FieldDeclarationSyntax FieldDeclaration(VariableDeclarationSyntax declaration) => SyntaxFactory.FieldDeclaration(default, default, declaration, Semicolon);

src/Microsoft.Windows.CsWin32/Generator.Com.cs

Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ private static bool GenerateCcwFor(MetadataReader reader, StringHandle typeName,
6868
return true;
6969
}
7070

71+
private static StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression(
72+
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName),
73+
ArgumentList()));
74+
7175
/// <summary>
7276
/// Generates a type to represent a COM interface.
7377
/// </summary>
@@ -327,10 +331,6 @@ FunctionPointerParameterSyntax ToFunctionPointerParameter(ParameterSyntax p)
327331
if (methodDefinition.Generator.TryGetPropertyAccessorInfo(methodDefinition, originalIfaceName, context, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) &&
328332
declaredProperties.Contains(propertyName.Identifier.ValueText))
329333
{
330-
StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression(
331-
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName),
332-
ArgumentList()));
333-
334334
BlockSyntax? body;
335335
switch (accessorKind)
336336
{
@@ -1307,15 +1307,124 @@ private bool TryDeclareCOMGuidInterfaceIfNecessary()
13071307
/// Creates an empty class that when instantiated, creates a cocreatable Windows object
13081308
/// that may implement a number of interfaces at runtime, discoverable only by documentation.
13091309
/// </summary>
1310-
private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef)
1310+
private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef, Context context)
13111311
{
1312+
bool canUseComImport = context.AllowMarshaling && !this.useSourceGenerators;
1313+
13121314
IdentifierNameSyntax name = IdentifierName(this.Reader.GetString(typeDef.Name));
13131315
Guid guid = this.FindGuidFromAttribute(typeDef) ?? throw new ArgumentException("Type does not have a GuidAttribute.");
13141316
SyntaxTokenList classModifiers = TokenList(TokenWithSpace(this.Visibility));
13151317
classModifiers = classModifiers.Add(TokenWithSpace(SyntaxKind.PartialKeyword));
13161318
ClassDeclarationSyntax result = ClassDeclaration(name.Identifier)
13171319
.WithModifiers(classModifiers)
1318-
.AddAttributeLists(AttributeList().AddAttributes(GUID(guid), ComImportAttributeSyntax));
1320+
.AddAttributeLists(AttributeList().AddAttributes(GUID(guid)).AddAttributes(canUseComImport ? [ComImportAttributeSyntax] : []));
1321+
1322+
if (!canUseComImport && !this.Options.ComInterop.UseIntPtrForComOutPointers)
1323+
{
1324+
string obsoleteMessage = context.AllowMarshaling
1325+
? $"COM source generators do not support direct instantiation of co-creatable classes. Use {name.Identifier}.CreateInstance<T> instead."
1326+
: $"Marshaling is disabled, so direct instantiation of co-creatable classes is not supported. Use {name.Identifier}.CreateInstance<T> instead.";
1327+
1328+
// Generate a private readonly field for the Guid
1329+
// private static readonly Guid CLSID_Foo = new Guid(...);
1330+
SyntaxToken clsidFieldName = Identifier($"CLSID_{name.Identifier}");
1331+
FieldDeclarationSyntax clsidField = FieldDeclaration(
1332+
VariableDeclaration(IdentifierName(nameof(Guid)))
1333+
.AddVariables(VariableDeclarator(clsidFieldName).WithInitializer(EqualsValueClause(GuidValue(guid)))))
1334+
.AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword));
1335+
result = result.AddMembers(clsidField);
1336+
1337+
// If using source generators or marshalling is disabled, generate a constructor with obsolete attribute like this:
1338+
// [Obsolete("COM source generators do not support direct instantiation of co-creatable classes. Use CreateInstance<T> method instead.")]
1339+
// public Foo() { throw new NotSupportedException("COM source generators do not support direct instantiation of co-creatable classes. Use CreateInstance<T> method instead."); }
1340+
AttributeSyntax obsoleteAttribute =
1341+
Attribute(IdentifierName(nameof(ObsoleteAttribute)))
1342+
.AddArgumentListArguments(
1343+
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(obsoleteMessage))));
1344+
ConstructorDeclarationSyntax constructor = ConstructorDeclaration(name.Identifier)
1345+
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword))
1346+
.AddAttributeLists(AttributeList().AddAttributes(obsoleteAttribute))
1347+
.WithBody(
1348+
Block(
1349+
ThrowStatement(
1350+
ObjectCreationExpression(IdentifierName(nameof(NotSupportedException)))
1351+
.WithArgumentList(
1352+
ArgumentList().AddArguments(
1353+
Argument(
1354+
LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(obsoleteMessage))))))));
1355+
result = result.AddMembers(constructor);
1356+
1357+
this.MainGenerator.TryGenerateExternMethod("CoCreateInstance", out IReadOnlyCollection<string> preciseApi);
1358+
this.MainGenerator.TryGenerateConstant("CLSCTX", out preciseApi);
1359+
1360+
if (context.AllowMarshaling)
1361+
{
1362+
// Then add the CreateInstance<T> method:
1363+
// public static T CreateInstance<T>() where T : class
1364+
// {
1365+
// PInvoke.CoCreateInstance<T>(CLSID_Foo, null, CLSCTX.CLSCTX_SERVER, out T ret).ThrowOnFailure();
1366+
// return ret;
1367+
// }
1368+
TypeParameterSyntax typeParameter = TypeParameter(Identifier("T"));
1369+
GenericNameSyntax genericName = GenericName("CreateInstance").AddTypeArgumentListArguments(IdentifierName("T"));
1370+
MethodDeclarationSyntax createInstanceMethod = MethodDeclaration(IdentifierName("T"), genericName.Identifier)
1371+
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword))
1372+
.AddTypeParameterListParameters(typeParameter)
1373+
.AddConstraintClauses(
1374+
TypeParameterConstraintClause(IdentifierName("T"), SingletonSeparatedList<TypeParameterConstraintSyntax>(ClassOrStructConstraint(SyntaxKind.ClassConstraint))))
1375+
.WithBody(
1376+
Block(
1377+
ThrowOnHRFailure(
1378+
InvocationExpression(QualifiedName(ParseName($"{this.Win32NamespacePrefix}.{this.options.ClassName}"), GenericName("CoCreateInstance").AddTypeArgumentListArguments(IdentifierName("T"))))
1379+
.WithArgumentList(
1380+
ArgumentList().AddArguments(
1381+
Argument(IdentifierName(clsidFieldName)),
1382+
Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)),
1383+
Argument(
1384+
MemberAccessExpression(
1385+
SyntaxKind.SimpleMemberAccessExpression,
1386+
QualifiedName(ParseName($"{this.Win32NamespacePrefix}.System.Com"), IdentifierName("CLSCTX")),
1387+
IdentifierName("CLSCTX_SERVER"))),
1388+
Argument(DeclarationExpression(IdentifierName("T").WithTrailingTrivia(Space), SingleVariableDesignation(Identifier("ret")))).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))),
1389+
ReturnStatement(IdentifierName("ret"))));
1390+
result = result.AddMembers(createInstanceMethod);
1391+
}
1392+
else
1393+
{
1394+
// Then add a CreateInstance<T> method that looks like this:
1395+
// public static HRESULT CreateInstance<T>(out T* instance) where T : unmanaged
1396+
// {
1397+
// return PInvoke.CoCreateInstance<T>(CLSID_Foo, null, CLSCTX.CLSCTX_SERVER, out instance);
1398+
// }
1399+
TypeParameterSyntax typeParameter = TypeParameter(Identifier("T"));
1400+
GenericNameSyntax genericName = GenericName("CreateInstance").AddTypeArgumentListArguments(IdentifierName("T"));
1401+
MethodDeclarationSyntax createInstanceMethod = MethodDeclaration(IdentifierName($"{this.Win32NamespacePrefix}.Foundation.HRESULT"), genericName.Identifier)
1402+
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.UnsafeKeyword))
1403+
.AddTypeParameterListParameters(typeParameter)
1404+
.AddConstraintClauses(
1405+
TypeParameterConstraintClause(IdentifierName("T"), SingletonSeparatedList<TypeParameterConstraintSyntax>(TypeConstraint(IdentifierName("unmanaged")))))
1406+
.WithParameterList(
1407+
ParameterList().AddParameters(
1408+
Parameter(Identifier("instance"))
1409+
.WithType(PointerType(IdentifierName("T")))
1410+
.WithModifiers(TokenList(Token(SyntaxKind.OutKeyword)))))
1411+
.WithBody(
1412+
Block(
1413+
ReturnStatement(
1414+
InvocationExpression(QualifiedName(ParseName($"{this.Win32NamespacePrefix}.{this.options.ClassName}"), GenericName("CoCreateInstance").AddTypeArgumentListArguments(IdentifierName("T"))))
1415+
.WithArgumentList(
1416+
ArgumentList().AddArguments(
1417+
Argument(IdentifierName(clsidFieldName)),
1418+
Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)),
1419+
Argument(
1420+
MemberAccessExpression(
1421+
SyntaxKind.SimpleMemberAccessExpression,
1422+
QualifiedName(ParseName($"{this.Win32NamespacePrefix}.System.Com"), IdentifierName("CLSCTX")),
1423+
IdentifierName("CLSCTX_SERVER"))),
1424+
Argument(IdentifierName("instance")).WithRefKindKeyword(Token(SyntaxKind.OutKeyword)))))));
1425+
result = result.AddMembers(createInstanceMethod);
1426+
}
1427+
}
13191428

13201429
result = this.AddApiDocumentation(name.Identifier.ValueText, result);
13211430
return result;

src/Microsoft.Windows.CsWin32/Generator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ private IReadOnlyList<ISymbol> FindTypeSymbolsIfAlreadyAvailable(string fullyQua
14411441
}
14421442
else if (this.IsEmptyStructWithGuid(typeDef))
14431443
{
1444-
typeDeclaration = this.DeclareCocreatableClass(typeDef);
1444+
typeDeclaration = this.DeclareCocreatableClass(typeDef, context);
14451445
}
14461446
else
14471447
{

src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,34 @@ internal static ObjectCreationExpressionSyntax GuidValue(CustomAttribute guidAtt
425425
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(k), k))));
426426
}
427427

428+
internal static ObjectCreationExpressionSyntax GuidValue(Guid guid)
429+
{
430+
byte[] bytes = guid.ToByteArray();
431+
uint a = BitConverter.ToUInt32(bytes, 0);
432+
ushort b = BitConverter.ToUInt16(bytes, 4);
433+
ushort c = BitConverter.ToUInt16(bytes, 6);
434+
byte d = bytes[8];
435+
byte e = bytes[9];
436+
byte f = bytes[10];
437+
byte g = bytes[11];
438+
byte h = bytes[12];
439+
byte i = bytes[13];
440+
byte j = bytes[14];
441+
byte k = bytes[15];
442+
return ObjectCreationExpression(GuidTypeSyntax).AddArgumentListArguments(
443+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(a), a))),
444+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(b), b))),
445+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(c), c))),
446+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(d), d))),
447+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(e), e))),
448+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(f), f))),
449+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(g), g))),
450+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(h), h))),
451+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(i), i))),
452+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(j), j))),
453+
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(k), k))));
454+
}
455+
428456
internal static ExpressionSyntax IntPtrExpr(IntPtr value) => ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(
429457
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value.ToInt64()))));
430458

test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ public async Task TestPlatformCaseSensitivity(string platform)
100100
await this.InvokeGeneratorAndCompile($"{nameof(this.TestPlatformCaseSensitivity)}_{platform}");
101101
}
102102

103+
[Fact]
104+
public async Task TestGenerateCoCreateableClass()
105+
{
106+
this.nativeMethods.Add("ShellLink");
107+
await this.InvokeGeneratorAndCompileFromFact();
108+
109+
var shellLinkType = Assert.Single(this.FindGeneratedType("ShellLink"));
110+
111+
// Check that it does not have the ComImport attribute.
112+
Assert.DoesNotContain(shellLinkType.AttributeLists, al => al.Attributes.Any(attr => attr.Name.ToString().Contains("ComImport")));
113+
114+
// Check that it contains a CreateInstance method
115+
Assert.Contains(shellLinkType.DescendantNodes().OfType<MethodDeclarationSyntax>(), method => method.Identifier.Text == "CreateInstance");
116+
}
117+
103118
[Theory]
104119
[InlineData("IMFMediaKeySession", "get_KeySystem", "winmdroot.Foundation.BSTR* keySystem")]
105120
[InlineData("AddPrinterW", "AddPrinter", "winmdroot.Foundation.PWSTR pName, uint Level, Span<byte> pPrinter")]

test/GenerationSandbox.BuildTask.Tests/COMTests.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma warning disable IDE0005
55
#pragma warning disable SA1201, SA1512, SA1005, SA1507, SA1515, SA1403, SA1402, SA1411, SA1300, SA1313, SA1134, SA1307, SA1308
66

7+
using System.ComponentModel;
78
using System.Net.NetworkInformation;
89
using System.Runtime.InteropServices;
910
using System.Runtime.InteropServices.Marshalling;
@@ -17,6 +18,7 @@
1718
using Windows.Win32.Graphics.Direct3D11;
1819
using Windows.Win32.System.Com;
1920
using Windows.Win32.System.WinRT.Composition;
21+
using Windows.Win32.UI.Shell;
2022

2123
[Trait("WindowsOnly", "true")]
2224
public partial class COMTests
@@ -70,4 +72,12 @@ public async Task CanInteropWithICompositorInterop()
7072
Assert.Skip("Skipping due to UnauthorizedAccessException.");
7173
}
7274
}
75+
76+
[Fact]
77+
public void CocreatableClassesWithImplicitInterfaces()
78+
{
79+
var shellLinkW = ShellLink.CreateInstance<IShellLinkW>();
80+
var persistFile = (IPersistFile)shellLinkW;
81+
Assert.NotNull(persistFile);
82+
}
7383
}

test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,7 @@ WINTRUST_DATA
3434
WINTRUST_FILE_INFO
3535
WinVerifyTrust
3636
WM_HOTKEY
37-
WNDCLASSW
37+
WNDCLASSW
38+
ShellLink
39+
CoCreateInstance
40+
IShellLinkW

test/GenerationSandbox.Unmarshalled.Tests/COMTests.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

4-
#pragma warning disable IDE0005
4+
#pragma warning disable IDE0005,SA1202
55

6+
using System.Runtime.InteropServices;
67
using Windows.Win32;
78
using Windows.Win32.System.Com;
9+
using Windows.Win32.UI.Shell;
810

911
public class COMTests
1012
{
@@ -19,5 +21,19 @@ public void COMStaticGuid()
1921
private static Guid GetGuid<T>()
2022
where T : IComIID
2123
=> T.Guid;
24+
25+
[Trait("WindowsOnly", "true")]
26+
[Fact]
27+
public unsafe void CocreatableClassesWithImplicitInterfaces()
28+
{
29+
Assert.SkipUnless(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "Test calls Windows-specific APIs");
30+
31+
ShellLink.CreateInstance(out IShellLinkW* shellLinkWPtr).ThrowOnFailure();
32+
shellLinkWPtr->QueryInterface(typeof(IPersistFile).GUID, out void* ppv).ThrowOnFailure();
33+
IPersistFile* persistFilePtr = (IPersistFile*)ppv;
34+
Assert.NotNull(persistFilePtr);
35+
persistFilePtr->Release();
36+
shellLinkWPtr->Release();
37+
}
2238
#endif
2339
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
IEventSubscription
22
IPersistFile
33
IStream
4+
ShellLink
5+
IShellLinkW

test/Microsoft.Windows.CsWin32.Tests/COMTests.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,4 +487,23 @@ public void IUnknown_Derived_QueryInterfaceGenericHelper()
487487
this.FindGeneratedMethod("QueryInterface"),
488488
m => m.Parent is StructDeclarationSyntax { Identifier.Text: "ITypeLib" } && m.TypeParameterList?.Parameters.Count == 1);
489489
}
490+
491+
[Theory, PairwiseData]
492+
public void TestGenerateCoCreateableClass(bool useIntPtrForComOutPtr)
493+
{
494+
this.generator = this.CreateGenerator(new GeneratorOptions { AllowMarshaling = false, ComInterop = new GeneratorOptions.ComInteropOptions { UseIntPtrForComOutPointers = useIntPtrForComOutPtr } });
495+
496+
this.GenerateApi("ShellLink");
497+
498+
var shellLinkType = Assert.Single(this.FindGeneratedType("ShellLink"));
499+
500+
// Check that it does not have the ComImport attribute.
501+
Assert.DoesNotContain(shellLinkType.AttributeLists, al => al.Attributes.Any(attr => attr.Name.ToString().Contains("ComImport")));
502+
503+
if (!useIntPtrForComOutPtr)
504+
{
505+
// Check that it contains a CreateInstance method
506+
Assert.Contains(shellLinkType.DescendantNodes().OfType<MethodDeclarationSyntax>(), method => method.Identifier.Text == "CreateInstance");
507+
}
508+
}
490509
}

0 commit comments

Comments
 (0)