Skip to content

Commit 080e007

Browse files
youssef-backport-botCopilotEvangelinkYoussef1313
authored
Fix codefix of analyzer for flowing cancellation token by @Copilot in #6239 (backport to rel/3.10) (#6259)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Evangelink <11340282+Evangelink@users.noreply.github.com> Co-authored-by: Youssef1313 <youssefvictor00@gmail.com>
1 parent 9777164 commit 080e007

3 files changed

Lines changed: 424 additions & 33 deletions

File tree

src/Analyzers/MSTest.Analyzers.CodeFixes/FlowTestContextCancellationTokenFixer.cs

Lines changed: 191 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
using MSTest.Analyzers.Helpers;
1818

19+
using Polyfills;
20+
1921
namespace MSTest.Analyzers;
2022

2123
/// <summary>
@@ -31,8 +33,8 @@ public sealed class FlowTestContextCancellationTokenFixer : CodeFixProvider
3133

3234
/// <inheritdoc />
3335
public override FixAllProvider GetFixAllProvider()
34-
// See https://github.com/dotnet/roslyn/blob/main/docs/analyzers/FixAllProvider.md for more information on Fix All Providers
35-
=> WellKnownFixAllProviders.BatchFixer;
36+
// Use custom FixAllProvider to handle adding TestContext property when needed
37+
=> FlowTestContextCancellationTokenFixAllProvider.Instance;
3638

3739
/// <inheritdoc />
3840
public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
@@ -49,37 +51,213 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
4951
}
5052

5153
diagnostic.Properties.TryGetValue(FlowTestContextCancellationTokenAnalyzer.TestContextMemberNamePropertyKey, out string? testContextMemberName);
54+
diagnostic.Properties.TryGetValue(nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState), out string? testContextState);
5255

5356
// Register a code action that will invoke the fix
5457
context.RegisterCodeFix(
5558
CodeAction.Create(
5659
title: CodeFixResources.PassCancellationTokenFix,
57-
createChangedDocument: c => AddCancellationTokenParameterAsync(context.Document, invocationExpression, testContextMemberName, c),
58-
equivalenceKey: "AddTestContextCancellationToken"),
60+
createChangedDocument: async c =>
61+
{
62+
DocumentEditor editor = await DocumentEditor.CreateAsync(context.Document, context.CancellationToken).ConfigureAwait(false);
63+
return ApplyFix(editor, invocationExpression, testContextMemberName, testContextState, adjustedSymbols: null, c);
64+
},
65+
equivalenceKey: nameof(FlowTestContextCancellationTokenFixer)),
5966
diagnostic);
6067
}
6168

62-
private static async Task<Document> AddCancellationTokenParameterAsync(
63-
Document document,
69+
internal static Document ApplyFix(
70+
DocumentEditor editor,
6471
InvocationExpressionSyntax invocationExpression,
6572
string? testContextMemberName,
73+
string? testContextState,
74+
HashSet<ISymbol>? adjustedSymbols,
6675
CancellationToken cancellationToken)
6776
{
68-
DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
77+
if (testContextState == nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState.CouldBeInScopeAsProperty))
78+
{
79+
Debug.Assert(testContextMemberName is null, "TestContext member name should be null when state is CouldBeInScopeAsProperty");
80+
AddCancellationTokenArgument(editor, invocationExpression, "TestContext");
81+
TypeDeclarationSyntax? containingTypeDeclaration = invocationExpression.FirstAncestorOrSelf<TypeDeclarationSyntax>();
82+
if (containingTypeDeclaration is not null)
83+
{
84+
// adjustedSymbols is null meaning we are only applying a single fix (in that case we add the property).
85+
// If we are in fix all, we then verify if a previous fix has already added the property.
86+
// We only add the property if it wasn't added by a previous fix.
87+
// NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property.
88+
if (adjustedSymbols is null ||
89+
editor.SemanticModel.GetDeclaredSymbol(containingTypeDeclaration, cancellationToken) is not { } symbol ||
90+
adjustedSymbols.Add(symbol))
91+
{
92+
editor.ReplaceNode(containingTypeDeclaration, (containingTypeDeclaration, _) => AddTestContextProperty((TypeDeclarationSyntax)containingTypeDeclaration));
93+
}
94+
}
95+
}
96+
else if (testContextState == nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState.CouldBeInScopeAsParameter))
97+
{
98+
Debug.Assert(testContextMemberName is null, "TestContext member name should be null when state is CouldBeInScopeAsParameter");
99+
AddCancellationTokenArgument(editor, invocationExpression, "testContext");
100+
MethodDeclarationSyntax? containingMethodDeclaration = invocationExpression.FirstAncestorOrSelf<MethodDeclarationSyntax>();
101+
102+
if (containingMethodDeclaration is not null)
103+
{
104+
// adjustedSymbols is null meaning we are only applying a single fix (in that case we add the parameter).
105+
// If we are in fix all, we then verify if a previous fix has already added the parameter.
106+
// We only add the parameter if it wasn't added by a previous fix.
107+
// NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property.
108+
if (adjustedSymbols is null ||
109+
editor.SemanticModel.GetDeclaredSymbol(containingMethodDeclaration, cancellationToken) is not { } symbol ||
110+
adjustedSymbols.Add(symbol))
111+
{
112+
editor.ReplaceNode(containingMethodDeclaration, (containingMethodDeclaration, _) => AddTestContextParameterToMethod((MethodDeclarationSyntax)containingMethodDeclaration));
113+
}
114+
}
115+
}
116+
else
117+
{
118+
Guard.NotNull(testContextMemberName);
119+
AddCancellationTokenArgument(editor, invocationExpression, testContextMemberName);
120+
}
121+
122+
return editor.GetChangedDocument();
123+
}
124+
125+
internal static void AddCancellationTokenArgument(
126+
DocumentEditor editor,
127+
InvocationExpressionSyntax invocationExpression,
128+
string testContextMemberName)
129+
{
130+
// Find the containing method to determine the context
131+
MethodDeclarationSyntax? containingMethod = invocationExpression.FirstAncestorOrSelf<MethodDeclarationSyntax>();
69132

70133
// Create the TestContext.CancellationTokenSource.Token expression
71134
MemberAccessExpressionSyntax testContextExpression = SyntaxFactory.MemberAccessExpression(
72135
SyntaxKind.SimpleMemberAccessExpression,
73136
SyntaxFactory.MemberAccessExpression(
74137
SyntaxKind.SimpleMemberAccessExpression,
75-
SyntaxFactory.IdentifierName(testContextMemberName ?? "testContext"),
138+
SyntaxFactory.IdentifierName(testContextMemberName),
76139
SyntaxFactory.IdentifierName("CancellationTokenSource")),
77140
SyntaxFactory.IdentifierName("Token"));
78141

79-
ArgumentListSyntax currentArguments = invocationExpression.ArgumentList;
80-
SeparatedSyntaxList<ArgumentSyntax> newArguments = currentArguments.Arguments.Add(SyntaxFactory.Argument(testContextExpression));
81-
InvocationExpressionSyntax newInvocation = invocationExpression.WithArgumentList(currentArguments.WithArguments(newArguments));
82-
editor.ReplaceNode(invocationExpression, newInvocation);
83-
return editor.GetChangedDocument();
142+
editor.ReplaceNode(invocationExpression, (node, _) =>
143+
{
144+
var invocationExpression = (InvocationExpressionSyntax)node;
145+
ArgumentListSyntax currentArguments = invocationExpression.ArgumentList;
146+
SeparatedSyntaxList<ArgumentSyntax> newArguments = currentArguments.Arguments.Add(SyntaxFactory.Argument(testContextExpression));
147+
return invocationExpression.WithArgumentList(currentArguments.WithArguments(newArguments));
148+
});
149+
}
150+
151+
internal static MethodDeclarationSyntax AddTestContextParameterToMethod(MethodDeclarationSyntax method)
152+
{
153+
// Create TestContext parameter
154+
ParameterSyntax testContextParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier("testContext"))
155+
.WithType(SyntaxFactory.IdentifierName("TestContext"));
156+
157+
// Add the parameter to the method
158+
SeparatedSyntaxList<ParameterSyntax> updatedParameterList = method.ParameterList.Parameters.Count == 0
159+
? SyntaxFactory.SingletonSeparatedList(testContextParameter)
160+
: method.ParameterList.Parameters.Add(testContextParameter);
161+
162+
return method.WithParameterList(method.ParameterList.WithParameters(updatedParameterList));
163+
}
164+
165+
internal static TypeDeclarationSyntax AddTestContextProperty(TypeDeclarationSyntax typeDeclaration)
166+
{
167+
PropertyDeclarationSyntax testContextProperty = SyntaxFactory.PropertyDeclaration(
168+
SyntaxFactory.IdentifierName("TestContext"),
169+
"TestContext")
170+
.WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword)))
171+
.WithAccessorList(SyntaxFactory.AccessorList(
172+
SyntaxFactory.List(new[]
173+
{
174+
SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
175+
.WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)),
176+
SyntaxFactory.AccessorDeclaration(SyntaxKind.SetAccessorDeclaration)
177+
.WithSemicolonToken(SyntaxFactory.Token(SyntaxKind.SemicolonToken)),
178+
})));
179+
180+
return typeDeclaration.AddMembers(testContextProperty);
181+
}
182+
}
183+
184+
/// <summary>
185+
/// Custom FixAllProvider for <see cref="FlowTestContextCancellationTokenFixer"/> that can add TestContext property when needed.
186+
/// This ensures that when multiple fixes are applied to the same class, the TestContext property is added only once.
187+
/// </summary>
188+
internal sealed class FlowTestContextCancellationTokenFixAllProvider : FixAllProvider
189+
{
190+
public static readonly FlowTestContextCancellationTokenFixAllProvider Instance = new();
191+
192+
private FlowTestContextCancellationTokenFixAllProvider()
193+
{
194+
}
195+
196+
public override Task<CodeAction?> GetFixAsync(FixAllContext fixAllContext)
197+
=> Task.FromResult<CodeAction?>(new FixAllCodeAction(fixAllContext));
198+
199+
private sealed class FixAllCodeAction : CodeAction
200+
{
201+
private readonly FixAllContext _fixAllContext;
202+
203+
public FixAllCodeAction(FixAllContext fixAllContext)
204+
=> _fixAllContext = fixAllContext;
205+
206+
public override string Title => CodeFixResources.PassCancellationTokenFix;
207+
208+
public override string? EquivalenceKey => nameof(FlowTestContextCancellationTokenFixer);
209+
210+
protected override async Task<Solution?> GetChangedSolutionAsync(CancellationToken cancellationToken)
211+
{
212+
FixAllContext fixAllContext = _fixAllContext;
213+
var editor = new SolutionEditor(fixAllContext.Solution);
214+
var fixedSymbols = new HashSet<ISymbol>(SymbolEqualityComparer.Default);
215+
216+
if (fixAllContext.Scope == FixAllScope.Document)
217+
{
218+
DocumentEditor documentEditor = await editor.GetDocumentEditorAsync(fixAllContext.Document!.Id, cancellationToken).ConfigureAwait(false);
219+
foreach (Diagnostic diagnostic in await fixAllContext.GetDocumentDiagnosticsAsync(fixAllContext.Document!).ConfigureAwait(false))
220+
{
221+
FixOneDiagnostic(documentEditor, diagnostic, fixedSymbols, cancellationToken);
222+
}
223+
}
224+
else if (fixAllContext.Scope == FixAllScope.Project)
225+
{
226+
await FixAllInProjectAsync(fixAllContext, fixAllContext.Project, editor, fixedSymbols, cancellationToken).ConfigureAwait(false);
227+
}
228+
else if (fixAllContext.Scope == FixAllScope.Solution)
229+
{
230+
foreach (Project project in fixAllContext.Solution.Projects)
231+
{
232+
await FixAllInProjectAsync(fixAllContext, project, editor, fixedSymbols, cancellationToken).ConfigureAwait(false);
233+
}
234+
}
235+
236+
return editor.GetChangedSolution();
237+
}
238+
239+
private static async Task FixAllInProjectAsync(FixAllContext fixAllContext, Project project, SolutionEditor editor, HashSet<ISymbol> fixedSymbols, CancellationToken cancellationToken)
240+
{
241+
foreach (Diagnostic diagnostic in await fixAllContext.GetAllDiagnosticsAsync(project).ConfigureAwait(false))
242+
{
243+
DocumentId documentId = editor.OriginalSolution.GetDocumentId(diagnostic.Location.SourceTree)!;
244+
DocumentEditor documentEditor = await editor.GetDocumentEditorAsync(documentId, cancellationToken).ConfigureAwait(false);
245+
FixOneDiagnostic(documentEditor, diagnostic, fixedSymbols, cancellationToken);
246+
}
247+
}
248+
249+
private static void FixOneDiagnostic(DocumentEditor documentEditor, Diagnostic diagnostic, HashSet<ISymbol> fixedSymbols, CancellationToken cancellationToken)
250+
{
251+
SyntaxNode node = documentEditor.OriginalRoot.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true);
252+
if (node is not InvocationExpressionSyntax invocationExpression)
253+
{
254+
return;
255+
}
256+
257+
diagnostic.Properties.TryGetValue(FlowTestContextCancellationTokenAnalyzer.TestContextMemberNamePropertyKey, out string? testContextMemberName);
258+
diagnostic.Properties.TryGetValue(nameof(FlowTestContextCancellationTokenAnalyzer.TestContextState), out string? testContextState);
259+
260+
FlowTestContextCancellationTokenFixer.ApplyFix(documentEditor, invocationExpression, testContextMemberName, testContextState, fixedSymbols, cancellationToken);
261+
}
84262
}
85263
}

src/Analyzers/MSTest.Analyzers/FlowTestContextCancellationTokenAnalyzer.cs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ private static void AnalyzeInvocation(
7474
IMethodSymbol method = invocationOperation.TargetMethod;
7575

7676
// Check if we're in a context where a TestContext is already available or could be made available.
77-
if (!HasOrCouldHaveTestContextInScope(context.ContainingSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, out string? testContextMemberNameInScope))
77+
if (!HasOrCouldHaveTestContextInScope(context.ContainingSymbol, testContextSymbol, classCleanupAttributeSymbol, assemblyCleanupAttributeSymbol, testMethodAttributeSymbol, out string? testContextMemberNameInScope, out TestContextState? testContextState))
7878
{
7979
return;
8080
}
@@ -93,13 +93,7 @@ private static void AnalyzeInvocation(
9393
invocationOperation.Arguments.FirstOrDefault(arg => SymbolEqualityComparer.Default.Equals(arg.Parameter, cancellationTokenParameter))?.ArgumentKind != ArgumentKind.Explicit)
9494
{
9595
// The called method has an optional CancellationToken parameter, but it was not explicitly provided.
96-
ImmutableDictionary<string, string?> properties = ImmutableDictionary<string, string?>.Empty;
97-
if (testContextMemberNameInScope is not null)
98-
{
99-
properties = properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope);
100-
}
101-
102-
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope)));
96+
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState)));
10397
return;
10498
}
10599

@@ -108,16 +102,15 @@ private static void AnalyzeInvocation(
108102
if (cancellationTokenParameter is null &&
109103
HasOverloadWithCancellationToken(method, cancellationTokenSymbol))
110104
{
111-
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope)));
105+
context.ReportDiagnostic(invocationOperation.Syntax.CreateDiagnostic(FlowTestContextCancellationTokenRule, properties: GetPropertiesBag(testContextMemberNameInScope, testContextState)));
112106
}
113107

114-
static ImmutableDictionary<string, string?> GetPropertiesBag(string? testContextMemberNameInScope)
108+
static ImmutableDictionary<string, string?> GetPropertiesBag(string? testContextMemberNameInScope, TestContextState? testContextState)
115109
{
116110
ImmutableDictionary<string, string?> properties = ImmutableDictionary<string, string?>.Empty;
117-
if (testContextMemberNameInScope is not null)
118-
{
119-
properties = properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope);
120-
}
111+
properties = testContextMemberNameInScope is not null
112+
? properties.Add(TestContextMemberNamePropertyKey, testContextMemberNameInScope)
113+
: properties.Add(nameof(TestContextState), testContextState.ToString());
121114

122115
return properties;
123116
}
@@ -151,9 +144,11 @@ private static bool HasOrCouldHaveTestContextInScope(
151144
INamedTypeSymbol classCleanupAttributeSymbol,
152145
INamedTypeSymbol assemblyCleanupAttributeSymbol,
153146
INamedTypeSymbol testMethodAttributeSymbol,
154-
out string? testContextMemberNameInScope)
147+
out string? testContextMemberNameInScope,
148+
[NotNullWhen(true)] out TestContextState? testContextState)
155149
{
156150
testContextMemberNameInScope = null;
151+
testContextState = null;
157152

158153
if (containingSymbol is not IMethodSymbol method)
159154
{
@@ -164,6 +159,7 @@ private static bool HasOrCouldHaveTestContextInScope(
164159
if (method.Parameters.FirstOrDefault(p => testContextSymbol.Equals(p.Type, SymbolEqualityComparer.Default)) is { } testContextParameter)
165160
{
166161
testContextMemberNameInScope = testContextParameter.Name;
162+
testContextState = TestContextState.InScope;
167163
return true;
168164
}
169165

@@ -178,6 +174,7 @@ private static bool HasOrCouldHaveTestContextInScope(
178174
testContextMemberNameInScope = testContextMember.Name.StartsWith('<') && testContextMember.Name.EndsWith(">P", StringComparison.Ordinal)
179175
? testContextMember.Name.Substring(1, testContextMember.Name.Length - 3)
180176
: testContextMember.Name;
177+
testContextState = TestContextState.InScope;
181178
return true;
182179
}
183180

@@ -191,11 +188,13 @@ private static bool HasOrCouldHaveTestContextInScope(
191188
(classCleanupAttributeSymbol.Equals(attribute.AttributeClass, SymbolEqualityComparer.Default) ||
192189
assemblyCleanupAttributeSymbol.Equals(attribute.AttributeClass, SymbolEqualityComparer.Default)))
193190
{
191+
testContextState = TestContextState.CouldBeInScopeAsParameter;
194192
return true;
195193
}
196194

197195
if (attribute.AttributeClass?.Inherits(testMethodAttributeSymbol) == true)
198196
{
197+
testContextState = TestContextState.CouldBeInScopeAsProperty;
199198
return true;
200199
}
201200
}
@@ -228,4 +227,11 @@ private static bool IsCompatibleOverloadWithCancellationToken(IMethodSymbol orig
228227
IParameterSymbol lastParam = candidateParams[candidateParams.Length - 1];
229228
return SymbolEqualityComparer.Default.Equals(lastParam.Type, cancellationTokenSymbol);
230229
}
230+
231+
internal enum TestContextState
232+
{
233+
InScope,
234+
CouldBeInScopeAsParameter,
235+
CouldBeInScopeAsProperty,
236+
}
231237
}

0 commit comments

Comments
 (0)