Skip to content

Commit 6c15e61

Browse files
niklubnikmakseq
authored andcommitted
fix: BROS-446: Add recursive scan to storages (#8506) (#8532)
Co-authored-by: nik <nik@heartex.net> Co-authored-by: niklub <niklub@users.noreply.github.com> Co-authored-by: makseq <makseq@gmail.com>
1 parent 0a920e3 commit 6c15e61

2 files changed

Lines changed: 16 additions & 69 deletions

File tree

label_studio/io_storages/tests/test_multitask_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_import_multiple_tasks_s3(project, common_task_data):
117117

118118
def test_import_multiple_tasks_gcs(project, common_task_data):
119119
# initialize mock with sample data
120-
with gcs_client_mock():
120+
with gcs_client_mock(sample_blob_names=['test.json']):
121121
_test_storage_import(
122122
project,
123123
GCSImportStorageFactory,

label_studio/tests/utils.py

Lines changed: 15 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
22
"""
3-
import logging
43
import os.path
54
import re
65
import tempfile
76
from contextlib import contextmanager
8-
from copy import deepcopy
97
from functools import wraps
108
from pathlib import Path
119
from types import SimpleNamespace
@@ -31,7 +29,6 @@
3129
from businesses.models import BillingPlan, Business
3230
except ImportError:
3331
BillingPlan = Business = None
34-
logger = logging.getLogger(__name__)
3532

3633

3734
@contextmanager
@@ -86,34 +83,13 @@ def email_mock():
8683

8784

8885
@contextmanager
89-
def gcs_client_mock():
90-
# be careful, this is a global contextmanager (sample_blob_names)
91-
# and will affect all tests because it will be applied to all tests that use gcs_client
92-
# it may lead to flaky tests if the sample blob names are not deterministic
93-
86+
def gcs_client_mock(sample_blob_names=None):
9487
from collections import namedtuple
9588

9689
from google.cloud import storage as google_storage
9790

98-
def get_sample_blob_names_for_bucket(bucket_name):
99-
# Bucket-specific logic to avoid test bleed
100-
if bucket_name in ['pytest-recursive-scan-bucket']:
101-
result = ['dataset/', 'dataset/a.json', 'dataset/sub/b.json', 'other/c.json']
102-
logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (recursive scan bucket)')
103-
return result
104-
elif bucket_name.startswith('multitask_'):
105-
result = ['test.json']
106-
logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (multitask)')
107-
return result
108-
elif bucket_name.startswith('test-gs-bucket'):
109-
# Force deterministic samples for standard GCS test buckets - never use closure variable
110-
result = ['abc', 'def', 'ghi']
111-
logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (test-gs-bucket prefix)')
112-
return result
113-
else:
114-
result = ['abc', 'def', 'ghi']
115-
logger.info(f'get_sample_blob_names_for_bucket({bucket_name}) -> {result} (default)')
116-
return result
91+
File = namedtuple('File', ['name'])
92+
sample_blob_names = sample_blob_names or ['abc', 'def', 'ghi']
11793

11894
class DummyGCSBlob:
11995
def __init__(self, bucket_name, key, is_json, is_multitask):
@@ -138,40 +114,28 @@ def __init__(self, bucket_name, key, is_json, is_multitask):
138114
def download_as_string(self):
139115
data = f'test_blob_{self.key}'
140116
if self.is_json:
141-
payload = json.dumps(self.sample_json_contents)
142-
logger.info(
143-
f'DummyGCSBlob.download_as_string bucket={self.bucket_name} key={self.key} json=True bytes={len(payload)}'
144-
)
145-
return payload
146-
logger.info(f'DummyGCSBlob.download_as_string bucket={self.bucket_name} key={self.key} json=False')
117+
return json.dumps(self.sample_json_contents)
147118
return data
148119

149120
def upload_from_string(self, string):
150121
print(f'String {string} uploaded to bucket {self.bucket_name}')
151122

152123
def generate_signed_url(self, **kwargs):
153-
url = f'https://storage.googleapis.com/{self.bucket_name}/{self.key}'
154-
logger.info(f'DummyGCSBlob.generate_signed_url url={url}')
155-
return url
124+
return f'https://storage.googleapis.com/{self.bucket_name}/{self.key}'
156125

157126
def download_as_bytes(self):
158-
b = self.download_as_string().encode('utf-8')
159-
logger.info(f'DummyGCSBlob.download_as_bytes bucket={self.bucket_name} key={self.key} size={len(b)}')
160-
return b
127+
return self.download_as_string().encode('utf-8')
161128

162129
class DummyGCSBucket:
163130
def __init__(self, bucket_name, is_json, is_multitask):
164131
self.name = bucket_name
165132
self.is_json = is_json
166133
self.is_multitask = is_multitask
167-
# Use bucket-specific sample names
168-
self.sample_blob_names = get_sample_blob_names_for_bucket(bucket_name)
134+
# Share the outer sample names for bucket-scoped listing
135+
self.sample_blob_names = sample_blob_names
169136

170137
def list_blobs(self, prefix, **kwargs):
171-
File = namedtuple('File', ['name'])
172-
173138
if 'fake' in prefix:
174-
logger.info(f'DummyGCSBucket.list_blobs bucket={self.name} prefix={prefix} -> [] (fake)')
175139
return []
176140

177141
# Handle delimiter for non-recursive listing (only direct children)
@@ -189,31 +153,25 @@ def list_blobs(self, prefix, **kwargs):
189153
else:
190154
# Root-level: only keys without delimiter are direct children
191155
filtered_names = [name for name in self.sample_blob_names if delimiter not in name]
192-
logger.info(
193-
f'DummyGCSBucket.list_blobs bucket={self.name} prefix={prefix} delimiter={delimiter} -> {filtered_names}'
194-
)
195156
return [File(name) for name in filtered_names]
196-
result = [name for name in self.sample_blob_names if prefix is None or name.startswith(prefix)]
197-
logger.info(f'DummyGCSBucket.list_blobs bucket={self.name} prefix={prefix} -> {result}')
198-
return [File(name) for name in result]
157+
return [File(name) for name in self.sample_blob_names if prefix is None or name.startswith(prefix)]
199158

200159
def blob(self, key):
201-
logger.info(f'DummyGCSBucket.blob bucket={self.name} key={key}')
202160
return DummyGCSBlob(self.name, key, self.is_json, self.is_multitask)
203161

204162
class DummyGCSClient:
163+
def __init__(self, sample_json_contents=None):
164+
self.sample_blob_names = sample_blob_names
165+
205166
def get_bucket(self, bucket_name):
206167
is_json = bucket_name.endswith('_JSON')
207168
is_multitask = bucket_name.startswith('multitask_')
208-
logger.info(
209-
f'DummyGCSClient.get_bucket bucket={bucket_name} is_json={is_json} is_multitask={is_multitask}'
210-
)
211169
return DummyGCSBucket(bucket_name, is_json, is_multitask)
212170

213171
def list_blobs(self, bucket_name, prefix, delimiter=None):
214172
is_json = bucket_name.endswith('_JSON')
215173
is_multitask = bucket_name.startswith('multitask_')
216-
sample_blob_names = get_sample_blob_names_for_bucket(bucket_name)
174+
sample_blob_names = ['test.json'] if is_multitask else self.sample_blob_names
217175

218176
# Handle delimiter for non-recursive listing (only direct children)
219177
if delimiter:
@@ -229,30 +187,20 @@ def list_blobs(self, bucket_name, prefix, delimiter=None):
229187
else:
230188
# Root-level: only keys without delimiter are direct children
231189
filtered_names = [name for name in sample_blob_names if delimiter not in name]
232-
logger.info(
233-
f'DummyGCSClient.list_blobs bucket={bucket_name} prefix={prefix} delimiter={delimiter} -> {filtered_names}'
234-
)
235190
return [DummyGCSBlob(bucket_name, name, is_json, is_multitask) for name in filtered_names]
236191

237-
result = [name for name in sample_blob_names if prefix is None or name.startswith(prefix)]
238-
logger.info(f'DummyGCSClient.list_blobs bucket={bucket_name} prefix={prefix} -> {result}')
239192
return [
240193
DummyGCSBlob(bucket_name, name, is_json, is_multitask)
241194
for name in sample_blob_names
242195
if prefix is None or name.startswith(prefix)
243196
]
244197

245198
with mock.patch.object(google_storage, 'Client', return_value=DummyGCSClient()):
246-
logger.info('gcs_client_mock installed')
247199
yield google_storage
248200

249201

250202
@contextmanager
251203
def azure_client_mock(sample_json_contents=None, sample_blob_names=None):
252-
# be careful, this is a global contextmanager (sample_json_contents, sample_blob_names)
253-
# and will affect all tests because it will be applied to all tests that use azure_client
254-
# and it may lead to flaky tests if the sample blob names are not deterministic
255-
256204
from collections import namedtuple
257205

258206
from io_storages.azure_blob import models
@@ -289,13 +237,12 @@ def content_as_bytes(self):
289237
class DummyAzureContainer:
290238
def __init__(self, container_name, **kwargs):
291239
self.name = container_name
292-
self.sample_blob_names = deepcopy(sample_blob_names)
293240

294241
def list_blobs(self, name_starts_with):
295-
return [File(name) for name in self.sample_blob_names]
242+
return [File(name) for name in sample_blob_names]
296243

297244
def walk_blobs(self, name_starts_with, delimiter):
298-
return [File(name) for name in self.sample_blob_names]
245+
return [File(name) for name in sample_blob_names]
299246

300247
def get_blob_client(self, key):
301248
return DummyAzureBlob(self.name, key)

0 commit comments

Comments
 (0)