Skip to content

Commit 3686850

Browse files
authored
Use Marshal.InitHandle API to avoid memory leak when OOM happens (#1613)
* Use `Marshal.InitHandle` API to avoid memory leak when OOM happens * Use pattern matching * Fix obtaining `IntPtr` from handles baked by non-`IntPtr` values * Use `InitHandle` for non-owning handles when possible * Simplify
1 parent e1462ff commit 3686850

6 files changed

Lines changed: 154 additions & 32 deletions

File tree

src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ internal static SyntaxToken Token(SyntaxKind kind)
7474

7575
internal static BlockSyntax Block(params StatementSyntax[] statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace);
7676

77+
internal static BlockSyntax Block(IEnumerable<StatementSyntax> statements) => SyntaxFactory.Block(OpenBrace, List(statements), CloseBrace);
78+
7779
internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression);
7880

7981
internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList<ExpressionSyntax> incrementors, StatementSyntax statement)
@@ -100,10 +102,12 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla
100102

101103
internal static DeclarationExpressionSyntax DeclarationExpression(TypeSyntax type, VariableDesignationSyntax designation) => SyntaxFactory.DeclarationExpression(type, designation);
102104

103-
internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier) => SyntaxFactory.VariableDeclarator(identifier);
105+
internal static VariableDeclaratorSyntax VariableDeclarator(SyntaxToken identifier, EqualsValueClauseSyntax? initializer = null) => SyntaxFactory.VariableDeclarator(identifier, argumentList: null, initializer: initializer);
104106

105107
internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)));
106108

109+
internal static VariableDeclarationSyntax VariableDeclaration(TypeSyntax type, params VariableDeclaratorSyntax[] variables) => SyntaxFactory.VariableDeclaration(type.WithTrailingTrivia(TriviaList(Space)), SeparatedList(variables));
110+
107111
internal static SizeOfExpressionSyntax SizeOfExpression(TypeSyntax type) => SyntaxFactory.SizeOfExpression(Token(SyntaxKind.SizeOfKeyword), Token(SyntaxKind.OpenParenToken), type, Token(SyntaxKind.CloseParenToken));
108112

109113
internal static MemberAccessExpressionSyntax MemberAccessExpression(SyntaxKind kind, ExpressionSyntax expression, SimpleNameSyntax name) => SyntaxFactory.MemberAccessExpression(kind, expression, Token(GetMemberAccessExpressionOperatorTokenKind(kind)), name);
@@ -190,7 +194,7 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla
190194

191195
internal static InitializerExpressionSyntax InitializerExpression(SyntaxKind kind, SeparatedSyntaxList<ExpressionSyntax> expressions) => SyntaxFactory.InitializerExpression(kind, OpenBrace, expressions, CloseBrace);
192196

193-
internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(), null);
197+
internal static ObjectCreationExpressionSyntax ObjectCreationExpression(TypeSyntax type, SeparatedSyntaxList<ArgumentSyntax> arguments = default) => SyntaxFactory.ObjectCreationExpression(Token(TriviaList(), SyntaxKind.NewKeyword, TriviaList(Space)), type, ArgumentList(arguments), null);
194198

195199
internal static ArrayCreationExpressionSyntax ArrayCreationExpression(ArrayTypeSyntax type, InitializerExpressionSyntax? initializer = null) => SyntaxFactory.ArrayCreationExpression(Token(SyntaxKind.NewKeyword), type, initializer);
196200

@@ -295,7 +299,7 @@ internal static SyntaxList<TNode> SingletonList<TNode>(TNode node)
295299

296300
internal static AttributeArgumentListSyntax AttributeArgumentList(SeparatedSyntaxList<AttributeArgumentSyntax> arguments = default) => SyntaxFactory.AttributeArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken));
297301

298-
internal static AttributeListSyntax AttributeList() => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, SeparatedList<AttributeSyntax>(), TokenWithLineFeed(SyntaxKind.CloseBracketToken));
302+
internal static AttributeListSyntax AttributeList(params SeparatedSyntaxList<AttributeSyntax> attributes) => SyntaxFactory.AttributeList(Token(SyntaxKind.OpenBracketToken), null, attributes, TokenWithLineFeed(SyntaxKind.CloseBracketToken));
299303

300304
internal static SyntaxList<TNode> List<TNode>()
301305
where TNode : SyntaxNode => SyntaxFactory.List<TNode>();
@@ -305,7 +309,7 @@ internal static SyntaxList<TNode> List<TNode>(IEnumerable<TNode> nodes)
305309

306310
internal static ParameterListSyntax ParameterList() => SyntaxFactory.ParameterList(Token(SyntaxKind.OpenParenToken), SeparatedList<ParameterSyntax>(), Token(SyntaxKind.CloseParenToken));
307311

308-
internal static ArgumentListSyntax ArgumentList(SeparatedSyntaxList<ArgumentSyntax> arguments = default) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken));
312+
internal static ArgumentListSyntax ArgumentList(params SeparatedSyntaxList<ArgumentSyntax> arguments) => SyntaxFactory.ArgumentList(Token(SyntaxKind.OpenParenToken), arguments, Token(SyntaxKind.CloseParenToken));
309313

310314
internal static AssignmentExpressionSyntax AssignmentExpression(SyntaxKind kind, ExpressionSyntax left, ExpressionSyntax right) => SyntaxFactory.AssignmentExpression(kind, left, Token(GetAssignmentExpressionOperatorTokenKind(kind)).WithLeadingTrivia(Space), right);
311315

src/Microsoft.Windows.CsWin32/Generator.Features.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public partial class Generator
2222
private readonly bool unscopedRefAttributePredefined;
2323
private readonly bool canUseComVariant;
2424
private readonly bool canUseMemberFunctionCallingConvention;
25+
private readonly bool canUseMarshalInitHandle;
2526
private readonly INamedTypeSymbol? runtimeFeatureClass;
2627
private readonly bool generateSupportedOSPlatformAttributes;
2728
private readonly bool generateSupportedOSPlatformAttributesOnInterfaces; // only supported on net6.0 (https://github.com/dotnet/runtime/pull/48838)

src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -378,19 +378,54 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
378378
.WithModifiers(TokenList(TokenWithSpace(SyntaxKind.OutKeyword)));
379379

380380
// HANDLE SomeLocal;
381-
leadingStatements.Add(LocalDeclarationStatement(VariableDeclaration(pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type).AddVariables(
382-
VariableDeclarator(typeDefHandleName.Identifier))));
381+
leadingStatements.Add(
382+
LocalDeclarationStatement(
383+
VariableDeclaration(
384+
pointedElementInfo.ToTypeSyntax(parameterTypeSyntaxSettings, GeneratingElement.FriendlyOverload, null).Type,
385+
VariableDeclarator(typeDefHandleName.Identifier))));
386+
387+
ArgumentSyntax ownsHandleArgument = Argument(
388+
NameColon(IdentifierName("ownsHandle")),
389+
refKindKeyword: default,
390+
LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression));
391+
392+
if (this.canUseMarshalInitHandle)
393+
{
394+
// Some = new SafeHandle(default, ownsHandle: true);
395+
leadingStatements.Add(
396+
ExpressionStatement(
397+
AssignmentExpression(
398+
SyntaxKind.SimpleAssignmentExpression,
399+
origName,
400+
ObjectCreationExpression(safeHandleType, [Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)), ownsHandleArgument]))));
401+
402+
// Marshal.InitHandle(Some, SomeLocal);
403+
trailingStatements.Add(
404+
ExpressionStatement(
405+
InvocationExpression(
406+
MemberAccessExpression(
407+
SyntaxKind.SimpleMemberAccessExpression,
408+
IdentifierName(nameof(Marshal)),
409+
IdentifierName("InitHandle")),
410+
ArgumentList(
411+
[
412+
Argument(origName),
413+
Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)),
414+
]))));
415+
}
416+
else
417+
{
418+
// Some = new SafeHandle(SomeLocal, ownsHandle: true);
419+
trailingStatements.Add(ExpressionStatement(AssignmentExpression(
420+
SyntaxKind.SimpleAssignmentExpression,
421+
origName,
422+
ObjectCreationExpression(safeHandleType).AddArgumentListArguments(
423+
Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)),
424+
ownsHandleArgument))));
425+
}
383426

384427
// Argument: &SomeLocal
385428
arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, typeDefHandleName));
386-
387-
// Some = new SafeHandle(SomeLocal, ownsHandle: true);
388-
trailingStatements.Add(ExpressionStatement(AssignmentExpression(
389-
SyntaxKind.SimpleAssignmentExpression,
390-
origName,
391-
ObjectCreationExpression(safeHandleType).AddArgumentListArguments(
392-
Argument(this.GetIntPtrFromTypeDef(typeDefHandleName, pointedElementInfo)),
393-
Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle")))))));
394429
}
395430
}
396431
else if (this.options.UseSafeHandles && isIn && !isOut && !isReleaseMethod && parameterTypeInfo is HandleTypeHandleInfo parameterHandleTypeInfo && this.TryGetHandleReleaseMethod(parameterHandleTypeInfo.Handle, paramAttributes, out string? releaseMethod) && !this.Reader.StringComparer.Equals(methodDefinition.Name, releaseMethod)
@@ -1108,7 +1143,46 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
11081143
&& returnTypeHandleInfo.Generator.TryGetHandleReleaseMethod(returnTypeHandleInfo.Handle, returnTypeAttributes, out string? returnReleaseMethod)
11091144
? this.RequestSafeHandle(returnReleaseMethod) : null;
11101145

1111-
if ((returnSafeHandleType is object || minorSignatureChange) && !signatureChanged)
1146+
IdentifierNameSyntax resultLocal = IdentifierName("__result");
1147+
1148+
if (this.canUseMarshalInitHandle && returnSafeHandleType is not null)
1149+
{
1150+
IdentifierNameSyntax resultSafeHandleLocal = IdentifierName("__resultSafeHandle");
1151+
1152+
// SafeHandle __resultSafeHandle = new SafeHandle(default, ownsHandle: true);
1153+
leadingStatements.Add(
1154+
LocalDeclarationStatement(
1155+
VariableDeclaration(
1156+
returnSafeHandleType,
1157+
VariableDeclarator(
1158+
resultSafeHandleLocal.Identifier,
1159+
EqualsValueClause(
1160+
ObjectCreationExpression(
1161+
returnSafeHandleType,
1162+
[
1163+
Argument(LiteralExpression(SyntaxKind.DefaultLiteralExpression)),
1164+
Argument(
1165+
NameColon(IdentifierName("ownsHandle")),
1166+
refKindKeyword: default,
1167+
LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression))
1168+
]))))));
1169+
1170+
// Marshal.InitHandle(__resultSafeHandle, __result);
1171+
trailingStatements.Add(
1172+
ExpressionStatement(
1173+
InvocationExpression(
1174+
MemberAccessExpression(
1175+
SyntaxKind.SimpleMemberAccessExpression,
1176+
IdentifierName(nameof(Marshal)),
1177+
IdentifierName("InitHandle")),
1178+
ArgumentList(
1179+
[
1180+
Argument(resultSafeHandleLocal),
1181+
Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)),
1182+
]))));
1183+
}
1184+
1185+
if ((returnSafeHandleType is not null || minorSignatureChange) && !signatureChanged)
11121186
{
11131187
// The parameter types are all the same, but we need a friendly overload with a different return type.
11141188
// Our only choice is to rename the friendly overload.
@@ -1145,20 +1219,33 @@ bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
11451219
})
11461220
.WithArgumentList(FixTrivia(ArgumentList().AddArguments(arguments.ToArray())));
11471221
bool hasVoidReturn = externMethodReturnType is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.VoidKeyword } };
1148-
BlockSyntax? body = Block().AddStatements(leadingStatements.ToArray());
1149-
IdentifierNameSyntax resultLocal = IdentifierName("__result");
1150-
if (returnSafeHandleType is object)
1222+
BlockSyntax? body = Block(leadingStatements);
1223+
if (returnSafeHandleType is not null)
11511224
{
1152-
//// HANDLE result = invocation();
1225+
// HANDLE result = invocation();
11531226
body = body.AddStatements(LocalDeclarationStatement(VariableDeclaration(externMethodReturnType)
11541227
.AddVariables(VariableDeclarator(resultLocal.Identifier).WithInitializer(EqualsValueClause(externInvocation)))));
11551228

11561229
body = body.AddStatements(trailingStatements.ToArray());
11571230

1158-
//// return new SafeHandle(result, ownsHandle: true);
1159-
body = body.AddStatements(ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments(
1160-
Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)),
1161-
Argument(LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression)).WithNameColon(NameColon(IdentifierName("ownsHandle"))))));
1231+
ReturnStatementSyntax returnStatement;
1232+
if (this.canUseMarshalInitHandle)
1233+
{
1234+
// return __resultSafeHandle;
1235+
returnStatement = ReturnStatement(IdentifierName("__resultSafeHandle"));
1236+
}
1237+
else
1238+
{
1239+
// return new SafeHandle(result, ownsHandle: true);
1240+
returnStatement = ReturnStatement(ObjectCreationExpression(returnSafeHandleType).AddArgumentListArguments(
1241+
Argument(this.GetIntPtrFromTypeDef(resultLocal, originalSignature.ReturnType)),
1242+
Argument(
1243+
NameColon(IdentifierName("ownsHandle")),
1244+
refKindKeyword: default,
1245+
LiteralExpression(doNotRelease ? SyntaxKind.FalseLiteralExpression : SyntaxKind.TrueLiteralExpression))));
1246+
}
1247+
1248+
body = body.AddStatements(returnStatement);
11621249
}
11631250
else if (hasVoidReturn)
11641251
{

src/Microsoft.Windows.CsWin32/Generator.Handle.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,32 @@ public partial class Generator
123123
VariableDeclarator(invalidValueFieldName.Identifier).WithInitializer(EqualsValueClause(invalidHandleIntPtr))))
124124
.AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword)));
125125

126+
SyntaxToken visibilityModifier = TokenWithSpace(this.Visibility);
127+
126128
// public SafeHandle() : base(INVALID_HANDLE_VALUE, true)
127129
members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier)
128-
.AddModifiers(TokenWithSpace(this.Visibility))
130+
.AddModifiers(visibilityModifier)
129131
.WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments(
130132
Argument(invalidValueFieldName),
131133
Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)))))
132134
.WithBody(Block()));
133135

134136
// public SafeHandle(IntPtr preexistingHandle, bool ownsHandle = true) : base(INVALID_HANDLE_VALUE, ownsHandle) { this.SetHandle(preexistingHandle); }
135-
const string preexistingHandleName = "preexistingHandle";
136-
const string ownsHandleName = "ownsHandle";
137+
IdentifierNameSyntax preexistingHandleName = IdentifierName("preexistingHandle");
138+
IdentifierNameSyntax ownsHandleName = IdentifierName("ownsHandle");
137139
members.Add(ConstructorDeclaration(safeHandleTypeIdentifier.Identifier)
138-
.AddModifiers(TokenWithSpace(this.Visibility))
140+
.AddModifiers(visibilityModifier)
139141
.AddParameterListParameters(
140-
Parameter(Identifier(preexistingHandleName)).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))),
141-
Parameter(Identifier(ownsHandleName)).WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)))
142+
Parameter(preexistingHandleName.Identifier).WithType(IntPtrTypeSyntax.WithTrailingTrivia(TriviaList(Space))),
143+
Parameter(ownsHandleName.Identifier)
144+
.WithType(PredefinedType(TokenWithSpace(SyntaxKind.BoolKeyword)))
142145
.WithDefault(EqualsValueClause(LiteralExpression(SyntaxKind.TrueLiteralExpression))))
143146
.WithInitializer(ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList().AddArguments(
144147
Argument(invalidValueFieldName),
145-
Argument(IdentifierName(ownsHandleName)))))
148+
Argument(ownsHandleName))))
146149
.WithBody(Block().AddStatements(
147150
ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ThisExpression(), IdentifierName("SetHandle")))
148-
.WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName(preexistingHandleName)))))))));
151+
.WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(preexistingHandleName))))))));
149152

150153
// public override bool IsInvalid => this.handle.ToInt64() == 0 || this.handle.ToInt64() == -1;
151154
ExpressionSyntax thisHandleToInt64 = InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, thisHandle, IdentifierName(nameof(IntPtr.ToInt64))), ArgumentList());
@@ -290,7 +293,7 @@ public partial class Generator
290293
IEnumerable<TypeSyntax> xmlDocParameterTypes = releaseMethodSignature.ParameterTypes.Select(p => p.ToTypeSyntax(this.externSignatureTypeSettings, GeneratingElement.HelperClassMember, default).Type);
291294

292295
ClassDeclarationSyntax safeHandleDeclaration = ClassDeclaration(Identifier(safeHandleClassName))
293-
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.PartialKeyword))
296+
.AddModifiers(visibilityModifier, TokenWithSpace(SyntaxKind.PartialKeyword))
294297
.WithBaseList(BaseList(SingletonSeparatedList<BaseTypeSyntax>(SimpleBaseType(SafeHandleTypeSyntax))))
295298
.AddMembers(members.ToArray())
296299
.AddAttributeLists(AttributeList().AddAttributes(GeneratedCodeAttribute))

0 commit comments

Comments
 (0)