11"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
22"""
3- import logging
43import os .path
54import re
65import tempfile
76from contextlib import contextmanager
8- from copy import deepcopy
97from functools import wraps
108from pathlib import Path
119from types import SimpleNamespace
3129 from businesses .models import BillingPlan , Business
3230except ImportError :
3331 BillingPlan = Business = None
34- logger = logging .getLogger (__name__ )
3532
3633
3734@contextmanager
@@ -86,34 +83,13 @@ def email_mock():
8683
8784
8885@contextmanager
89- def gcs_client_mock ():
90- # be careful, this is a global contextmanager (sample_blob_names)
91- # and will affect all tests because it will be applied to all tests that use gcs_client
92- # it may lead to flaky tests if the sample blob names are not deterministic
93-
86+ def gcs_client_mock (sample_blob_names = None ):
9487 from collections import namedtuple
9588
9689 from google .cloud import storage as google_storage
9790
98- def get_sample_blob_names_for_bucket (bucket_name ):
99- # Bucket-specific logic to avoid test bleed
100- if bucket_name in ['pytest-recursive-scan-bucket' ]:
101- result = ['dataset/' , 'dataset/a.json' , 'dataset/sub/b.json' , 'other/c.json' ]
102- logger .info (f'get_sample_blob_names_for_bucket({ bucket_name } ) -> { result } (recursive scan bucket)' )
103- return result
104- elif bucket_name .startswith ('multitask_' ):
105- result = ['test.json' ]
106- logger .info (f'get_sample_blob_names_for_bucket({ bucket_name } ) -> { result } (multitask)' )
107- return result
108- elif bucket_name .startswith ('test-gs-bucket' ):
109- # Force deterministic samples for standard GCS test buckets - never use closure variable
110- result = ['abc' , 'def' , 'ghi' ]
111- logger .info (f'get_sample_blob_names_for_bucket({ bucket_name } ) -> { result } (test-gs-bucket prefix)' )
112- return result
113- else :
114- result = ['abc' , 'def' , 'ghi' ]
115- logger .info (f'get_sample_blob_names_for_bucket({ bucket_name } ) -> { result } (default)' )
116- return result
91+ File = namedtuple ('File' , ['name' ])
92+ sample_blob_names = sample_blob_names or ['abc' , 'def' , 'ghi' ]
11793
11894 class DummyGCSBlob :
11995 def __init__ (self , bucket_name , key , is_json , is_multitask ):
@@ -138,40 +114,28 @@ def __init__(self, bucket_name, key, is_json, is_multitask):
138114 def download_as_string (self ):
139115 data = f'test_blob_{ self .key } '
140116 if self .is_json :
141- payload = json .dumps (self .sample_json_contents )
142- logger .info (
143- f'DummyGCSBlob.download_as_string bucket={ self .bucket_name } key={ self .key } json=True bytes={ len (payload )} '
144- )
145- return payload
146- logger .info (f'DummyGCSBlob.download_as_string bucket={ self .bucket_name } key={ self .key } json=False' )
117+ return json .dumps (self .sample_json_contents )
147118 return data
148119
149120 def upload_from_string (self , string ):
150121 print (f'String { string } uploaded to bucket { self .bucket_name } ' )
151122
152123 def generate_signed_url (self , ** kwargs ):
153- url = f'https://storage.googleapis.com/{ self .bucket_name } /{ self .key } '
154- logger .info (f'DummyGCSBlob.generate_signed_url url={ url } ' )
155- return url
124+ return f'https://storage.googleapis.com/{ self .bucket_name } /{ self .key } '
156125
157126 def download_as_bytes (self ):
158- b = self .download_as_string ().encode ('utf-8' )
159- logger .info (f'DummyGCSBlob.download_as_bytes bucket={ self .bucket_name } key={ self .key } size={ len (b )} ' )
160- return b
127+ return self .download_as_string ().encode ('utf-8' )
161128
162129 class DummyGCSBucket :
163130 def __init__ (self , bucket_name , is_json , is_multitask ):
164131 self .name = bucket_name
165132 self .is_json = is_json
166133 self .is_multitask = is_multitask
167- # Use bucket-specific sample names
168- self .sample_blob_names = get_sample_blob_names_for_bucket ( bucket_name )
134+ # Share the outer sample names for bucket-scoped listing
135+ self .sample_blob_names = sample_blob_names
169136
170137 def list_blobs (self , prefix , ** kwargs ):
171- File = namedtuple ('File' , ['name' ])
172-
173138 if 'fake' in prefix :
174- logger .info (f'DummyGCSBucket.list_blobs bucket={ self .name } prefix={ prefix } -> [] (fake)' )
175139 return []
176140
177141 # Handle delimiter for non-recursive listing (only direct children)
@@ -189,31 +153,25 @@ def list_blobs(self, prefix, **kwargs):
189153 else :
190154 # Root-level: only keys without delimiter are direct children
191155 filtered_names = [name for name in self .sample_blob_names if delimiter not in name ]
192- logger .info (
193- f'DummyGCSBucket.list_blobs bucket={ self .name } prefix={ prefix } delimiter={ delimiter } -> { filtered_names } '
194- )
195156 return [File (name ) for name in filtered_names ]
196- result = [name for name in self .sample_blob_names if prefix is None or name .startswith (prefix )]
197- logger .info (f'DummyGCSBucket.list_blobs bucket={ self .name } prefix={ prefix } -> { result } ' )
198- return [File (name ) for name in result ]
157+ return [File (name ) for name in self .sample_blob_names if prefix is None or name .startswith (prefix )]
199158
200159 def blob (self , key ):
201- logger .info (f'DummyGCSBucket.blob bucket={ self .name } key={ key } ' )
202160 return DummyGCSBlob (self .name , key , self .is_json , self .is_multitask )
203161
204162 class DummyGCSClient :
163+ def __init__ (self , sample_json_contents = None ):
164+ self .sample_blob_names = sample_blob_names
165+
205166 def get_bucket (self , bucket_name ):
206167 is_json = bucket_name .endswith ('_JSON' )
207168 is_multitask = bucket_name .startswith ('multitask_' )
208- logger .info (
209- f'DummyGCSClient.get_bucket bucket={ bucket_name } is_json={ is_json } is_multitask={ is_multitask } '
210- )
211169 return DummyGCSBucket (bucket_name , is_json , is_multitask )
212170
213171 def list_blobs (self , bucket_name , prefix , delimiter = None ):
214172 is_json = bucket_name .endswith ('_JSON' )
215173 is_multitask = bucket_name .startswith ('multitask_' )
216- sample_blob_names = get_sample_blob_names_for_bucket ( bucket_name )
174+ sample_blob_names = [ 'test.json' ] if is_multitask else self . sample_blob_names
217175
218176 # Handle delimiter for non-recursive listing (only direct children)
219177 if delimiter :
@@ -229,30 +187,20 @@ def list_blobs(self, bucket_name, prefix, delimiter=None):
229187 else :
230188 # Root-level: only keys without delimiter are direct children
231189 filtered_names = [name for name in sample_blob_names if delimiter not in name ]
232- logger .info (
233- f'DummyGCSClient.list_blobs bucket={ bucket_name } prefix={ prefix } delimiter={ delimiter } -> { filtered_names } '
234- )
235190 return [DummyGCSBlob (bucket_name , name , is_json , is_multitask ) for name in filtered_names ]
236191
237- result = [name for name in sample_blob_names if prefix is None or name .startswith (prefix )]
238- logger .info (f'DummyGCSClient.list_blobs bucket={ bucket_name } prefix={ prefix } -> { result } ' )
239192 return [
240193 DummyGCSBlob (bucket_name , name , is_json , is_multitask )
241194 for name in sample_blob_names
242195 if prefix is None or name .startswith (prefix )
243196 ]
244197
245198 with mock .patch .object (google_storage , 'Client' , return_value = DummyGCSClient ()):
246- logger .info ('gcs_client_mock installed' )
247199 yield google_storage
248200
249201
250202@contextmanager
251203def azure_client_mock (sample_json_contents = None , sample_blob_names = None ):
252- # be careful, this is a global contextmanager (sample_json_contents, sample_blob_names)
253- # and will affect all tests because it will be applied to all tests that use azure_client
254- # and it may lead to flaky tests if the sample blob names are not deterministic
255-
256204 from collections import namedtuple
257205
258206 from io_storages .azure_blob import models
@@ -289,13 +237,12 @@ def content_as_bytes(self):
289237 class DummyAzureContainer :
290238 def __init__ (self , container_name , ** kwargs ):
291239 self .name = container_name
292- self .sample_blob_names = deepcopy (sample_blob_names )
293240
294241 def list_blobs (self , name_starts_with ):
295- return [File (name ) for name in self . sample_blob_names ]
242+ return [File (name ) for name in sample_blob_names ]
296243
297244 def walk_blobs (self , name_starts_with , delimiter ):
298- return [File (name ) for name in self . sample_blob_names ]
245+ return [File (name ) for name in sample_blob_names ]
299246
300247 def get_blob_client (self , key ):
301248 return DummyAzureBlob (self .name , key )
0 commit comments