1616
1717using MSTest . Analyzers . Helpers ;
1818
19+ using Polyfills ;
20+
1921namespace MSTest . Analyzers ;
2022
2123/// <summary>
@@ -31,8 +33,8 @@ public sealed class FlowTestContextCancellationTokenFixer : CodeFixProvider
3133
3234 /// <inheritdoc />
3335 public override FixAllProvider GetFixAllProvider ( )
34- // See https://github.com/dotnet/roslyn/blob/main/docs/analyzers/ FixAllProvider.md for more information on Fix All Providers
35- => WellKnownFixAllProviders . BatchFixer ;
36+ // Use custom FixAllProvider to handle adding TestContext property when needed
37+ => FlowTestContextCancellationTokenFixAllProvider . Instance ;
3638
3739 /// <inheritdoc />
3840 public sealed override async Task RegisterCodeFixesAsync ( CodeFixContext context )
@@ -49,37 +51,213 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
4951 }
5052
5153 diagnostic . Properties . TryGetValue ( FlowTestContextCancellationTokenAnalyzer . TestContextMemberNamePropertyKey , out string ? testContextMemberName ) ;
54+ diagnostic . Properties . TryGetValue ( nameof ( FlowTestContextCancellationTokenAnalyzer . TestContextState ) , out string ? testContextState ) ;
5255
5356 // Register a code action that will invoke the fix
5457 context . RegisterCodeFix (
5558 CodeAction . Create (
5659 title : CodeFixResources . PassCancellationTokenFix ,
57- createChangedDocument : c => AddCancellationTokenParameterAsync ( context . Document , invocationExpression , testContextMemberName , c ) ,
58- equivalenceKey : "AddTestContextCancellationToken" ) ,
60+ createChangedDocument : async c =>
61+ {
62+ DocumentEditor editor = await DocumentEditor . CreateAsync ( context . Document , context . CancellationToken ) . ConfigureAwait ( false ) ;
63+ return ApplyFix ( editor , invocationExpression , testContextMemberName , testContextState , adjustedSymbols : null , c ) ;
64+ } ,
65+ equivalenceKey : nameof ( FlowTestContextCancellationTokenFixer ) ) ,
5966 diagnostic ) ;
6067 }
6168
62- private static async Task < Document > AddCancellationTokenParameterAsync (
63- Document document ,
69+ internal static Document ApplyFix (
70+ DocumentEditor editor ,
6471 InvocationExpressionSyntax invocationExpression ,
6572 string ? testContextMemberName ,
73+ string ? testContextState ,
74+ HashSet < ISymbol > ? adjustedSymbols ,
6675 CancellationToken cancellationToken )
6776 {
68- DocumentEditor editor = await DocumentEditor . CreateAsync ( document , cancellationToken ) . ConfigureAwait ( false ) ;
77+ if ( testContextState == nameof ( FlowTestContextCancellationTokenAnalyzer . TestContextState . CouldBeInScopeAsProperty ) )
78+ {
79+ Debug . Assert ( testContextMemberName is null , "TestContext member name should be null when state is CouldBeInScopeAsProperty" ) ;
80+ AddCancellationTokenArgument ( editor , invocationExpression , "TestContext" ) ;
81+ TypeDeclarationSyntax ? containingTypeDeclaration = invocationExpression . FirstAncestorOrSelf < TypeDeclarationSyntax > ( ) ;
82+ if ( containingTypeDeclaration is not null )
83+ {
84+ // adjustedSymbols is null meaning we are only applying a single fix (in that case we add the property).
85+ // If we are in fix all, we then verify if a previous fix has already added the property.
86+ // We only add the property if it wasn't added by a previous fix.
87+ // NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property.
88+ if ( adjustedSymbols is null ||
89+ editor . SemanticModel . GetDeclaredSymbol ( containingTypeDeclaration , cancellationToken ) is not { } symbol ||
90+ adjustedSymbols . Add ( symbol ) )
91+ {
92+ editor . ReplaceNode ( containingTypeDeclaration , ( containingTypeDeclaration , _ ) => AddTestContextProperty ( ( TypeDeclarationSyntax ) containingTypeDeclaration ) ) ;
93+ }
94+ }
95+ }
96+ else if ( testContextState == nameof ( FlowTestContextCancellationTokenAnalyzer . TestContextState . CouldBeInScopeAsParameter ) )
97+ {
98+ Debug . Assert ( testContextMemberName is null , "TestContext member name should be null when state is CouldBeInScopeAsParameter" ) ;
99+ AddCancellationTokenArgument ( editor , invocationExpression , "testContext" ) ;
100+ MethodDeclarationSyntax ? containingMethodDeclaration = invocationExpression . FirstAncestorOrSelf < MethodDeclarationSyntax > ( ) ;
101+
102+ if ( containingMethodDeclaration is not null )
103+ {
104+ // adjustedSymbols is null meaning we are only applying a single fix (in that case we add the parameter).
105+ // If we are in fix all, we then verify if a previous fix has already added the parameter.
106+ // We only add the parameter if it wasn't added by a previous fix.
107+ // NOTE: We don't expect GetDeclaredSymbol to return null, but if it did (e.g, error scenario), we add the property.
108+ if ( adjustedSymbols is null ||
109+ editor . SemanticModel . GetDeclaredSymbol ( containingMethodDeclaration , cancellationToken ) is not { } symbol ||
110+ adjustedSymbols . Add ( symbol ) )
111+ {
112+ editor . ReplaceNode ( containingMethodDeclaration , ( containingMethodDeclaration , _ ) => AddTestContextParameterToMethod ( ( MethodDeclarationSyntax ) containingMethodDeclaration ) ) ;
113+ }
114+ }
115+ }
116+ else
117+ {
118+ Guard . NotNull ( testContextMemberName ) ;
119+ AddCancellationTokenArgument ( editor , invocationExpression , testContextMemberName ) ;
120+ }
121+
122+ return editor . GetChangedDocument ( ) ;
123+ }
124+
125+ internal static void AddCancellationTokenArgument (
126+ DocumentEditor editor ,
127+ InvocationExpressionSyntax invocationExpression ,
128+ string testContextMemberName )
129+ {
130+ // Find the containing method to determine the context
131+ MethodDeclarationSyntax ? containingMethod = invocationExpression . FirstAncestorOrSelf < MethodDeclarationSyntax > ( ) ;
69132
70133 // Create the TestContext.CancellationTokenSource.Token expression
71134 MemberAccessExpressionSyntax testContextExpression = SyntaxFactory . MemberAccessExpression (
72135 SyntaxKind . SimpleMemberAccessExpression ,
73136 SyntaxFactory . MemberAccessExpression (
74137 SyntaxKind . SimpleMemberAccessExpression ,
75- SyntaxFactory . IdentifierName ( testContextMemberName ?? "testContext" ) ,
138+ SyntaxFactory . IdentifierName ( testContextMemberName ) ,
76139 SyntaxFactory . IdentifierName ( "CancellationTokenSource" ) ) ,
77140 SyntaxFactory . IdentifierName ( "Token" ) ) ;
78141
79- ArgumentListSyntax currentArguments = invocationExpression . ArgumentList ;
80- SeparatedSyntaxList < ArgumentSyntax > newArguments = currentArguments . Arguments . Add ( SyntaxFactory . Argument ( testContextExpression ) ) ;
81- InvocationExpressionSyntax newInvocation = invocationExpression . WithArgumentList ( currentArguments . WithArguments ( newArguments ) ) ;
82- editor . ReplaceNode ( invocationExpression , newInvocation ) ;
83- return editor . GetChangedDocument ( ) ;
142+ editor . ReplaceNode ( invocationExpression , ( node , _ ) =>
143+ {
144+ var invocationExpression = ( InvocationExpressionSyntax ) node ;
145+ ArgumentListSyntax currentArguments = invocationExpression . ArgumentList ;
146+ SeparatedSyntaxList < ArgumentSyntax > newArguments = currentArguments . Arguments . Add ( SyntaxFactory . Argument ( testContextExpression ) ) ;
147+ return invocationExpression . WithArgumentList ( currentArguments . WithArguments ( newArguments ) ) ;
148+ } ) ;
149+ }
150+
151+ internal static MethodDeclarationSyntax AddTestContextParameterToMethod ( MethodDeclarationSyntax method )
152+ {
153+ // Create TestContext parameter
154+ ParameterSyntax testContextParameter = SyntaxFactory . Parameter ( SyntaxFactory . Identifier ( "testContext" ) )
155+ . WithType ( SyntaxFactory . IdentifierName ( "TestContext" ) ) ;
156+
157+ // Add the parameter to the method
158+ SeparatedSyntaxList < ParameterSyntax > updatedParameterList = method . ParameterList . Parameters . Count == 0
159+ ? SyntaxFactory . SingletonSeparatedList ( testContextParameter )
160+ : method . ParameterList . Parameters . Add ( testContextParameter ) ;
161+
162+ return method . WithParameterList ( method . ParameterList . WithParameters ( updatedParameterList ) ) ;
163+ }
164+
165+ internal static TypeDeclarationSyntax AddTestContextProperty ( TypeDeclarationSyntax typeDeclaration )
166+ {
167+ PropertyDeclarationSyntax testContextProperty = SyntaxFactory . PropertyDeclaration (
168+ SyntaxFactory . IdentifierName ( "TestContext" ) ,
169+ "TestContext" )
170+ . WithModifiers ( SyntaxFactory . TokenList ( SyntaxFactory . Token ( SyntaxKind . PublicKeyword ) ) )
171+ . WithAccessorList ( SyntaxFactory . AccessorList (
172+ SyntaxFactory . List ( new [ ]
173+ {
174+ SyntaxFactory . AccessorDeclaration ( SyntaxKind . GetAccessorDeclaration )
175+ . WithSemicolonToken ( SyntaxFactory . Token ( SyntaxKind . SemicolonToken ) ) ,
176+ SyntaxFactory . AccessorDeclaration ( SyntaxKind . SetAccessorDeclaration )
177+ . WithSemicolonToken ( SyntaxFactory . Token ( SyntaxKind . SemicolonToken ) ) ,
178+ } ) ) ) ;
179+
180+ return typeDeclaration . AddMembers ( testContextProperty ) ;
181+ }
182+ }
183+
184+ /// <summary>
185+ /// Custom FixAllProvider for <see cref="FlowTestContextCancellationTokenFixer"/> that can add TestContext property when needed.
186+ /// This ensures that when multiple fixes are applied to the same class, the TestContext property is added only once.
187+ /// </summary>
188+ internal sealed class FlowTestContextCancellationTokenFixAllProvider : FixAllProvider
189+ {
190+ public static readonly FlowTestContextCancellationTokenFixAllProvider Instance = new ( ) ;
191+
192+ private FlowTestContextCancellationTokenFixAllProvider ( )
193+ {
194+ }
195+
196+ public override Task < CodeAction ? > GetFixAsync ( FixAllContext fixAllContext )
197+ => Task . FromResult < CodeAction ? > ( new FixAllCodeAction ( fixAllContext ) ) ;
198+
199+ private sealed class FixAllCodeAction : CodeAction
200+ {
201+ private readonly FixAllContext _fixAllContext ;
202+
203+ public FixAllCodeAction ( FixAllContext fixAllContext )
204+ => _fixAllContext = fixAllContext ;
205+
206+ public override string Title => CodeFixResources . PassCancellationTokenFix ;
207+
208+ public override string ? EquivalenceKey => nameof ( FlowTestContextCancellationTokenFixer ) ;
209+
210+ protected override async Task < Solution ? > GetChangedSolutionAsync ( CancellationToken cancellationToken )
211+ {
212+ FixAllContext fixAllContext = _fixAllContext ;
213+ var editor = new SolutionEditor ( fixAllContext . Solution ) ;
214+ var fixedSymbols = new HashSet < ISymbol > ( SymbolEqualityComparer . Default ) ;
215+
216+ if ( fixAllContext . Scope == FixAllScope . Document )
217+ {
218+ DocumentEditor documentEditor = await editor . GetDocumentEditorAsync ( fixAllContext . Document ! . Id , cancellationToken ) . ConfigureAwait ( false ) ;
219+ foreach ( Diagnostic diagnostic in await fixAllContext . GetDocumentDiagnosticsAsync ( fixAllContext . Document ! ) . ConfigureAwait ( false ) )
220+ {
221+ FixOneDiagnostic ( documentEditor , diagnostic , fixedSymbols , cancellationToken ) ;
222+ }
223+ }
224+ else if ( fixAllContext . Scope == FixAllScope . Project )
225+ {
226+ await FixAllInProjectAsync ( fixAllContext , fixAllContext . Project , editor , fixedSymbols , cancellationToken ) . ConfigureAwait ( false ) ;
227+ }
228+ else if ( fixAllContext . Scope == FixAllScope . Solution )
229+ {
230+ foreach ( Project project in fixAllContext . Solution . Projects )
231+ {
232+ await FixAllInProjectAsync ( fixAllContext , project , editor , fixedSymbols , cancellationToken ) . ConfigureAwait ( false ) ;
233+ }
234+ }
235+
236+ return editor . GetChangedSolution ( ) ;
237+ }
238+
239+ private static async Task FixAllInProjectAsync ( FixAllContext fixAllContext , Project project , SolutionEditor editor , HashSet < ISymbol > fixedSymbols , CancellationToken cancellationToken )
240+ {
241+ foreach ( Diagnostic diagnostic in await fixAllContext . GetAllDiagnosticsAsync ( project ) . ConfigureAwait ( false ) )
242+ {
243+ DocumentId documentId = editor . OriginalSolution . GetDocumentId ( diagnostic . Location . SourceTree ) ! ;
244+ DocumentEditor documentEditor = await editor . GetDocumentEditorAsync ( documentId , cancellationToken ) . ConfigureAwait ( false ) ;
245+ FixOneDiagnostic ( documentEditor , diagnostic , fixedSymbols , cancellationToken ) ;
246+ }
247+ }
248+
249+ private static void FixOneDiagnostic ( DocumentEditor documentEditor , Diagnostic diagnostic , HashSet < ISymbol > fixedSymbols , CancellationToken cancellationToken )
250+ {
251+ SyntaxNode node = documentEditor . OriginalRoot . FindNode ( diagnostic . Location . SourceSpan , getInnermostNodeForTie : true ) ;
252+ if ( node is not InvocationExpressionSyntax invocationExpression )
253+ {
254+ return ;
255+ }
256+
257+ diagnostic . Properties . TryGetValue ( FlowTestContextCancellationTokenAnalyzer . TestContextMemberNamePropertyKey , out string ? testContextMemberName ) ;
258+ diagnostic . Properties . TryGetValue ( nameof ( FlowTestContextCancellationTokenAnalyzer . TestContextState ) , out string ? testContextState ) ;
259+
260+ FlowTestContextCancellationTokenFixer . ApplyFix ( documentEditor , invocationExpression , testContextMemberName , testContextState , fixedSymbols , cancellationToken ) ;
261+ }
84262 }
85263}
0 commit comments