Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pyiceberg import __version__
from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
from pyiceberg.catalog.rest.credentials_provider import REFRESH_CREDENTIALS_ENABLED, VendedCredentialsProvider
from pyiceberg.catalog.rest.response import _handle_non_200_response
from pyiceberg.catalog.rest.scan_planning import (
FetchScanTasksRequest,
Expand Down Expand Up @@ -484,7 +485,10 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | Non
merged_properties = {**self.properties, **properties}
if self._auth_manager:
merged_properties[AUTH_MANAGER] = self._auth_manager
return load_file_io(merged_properties, location)
file_io = load_file_io(merged_properties, location)
if property_as_bool(merged_properties, REFRESH_CREDENTIALS_ENABLED, False):
file_io.set_credentials_provider(VendedCredentialsProvider(self._session, merged_properties))
return file_io

@override
def supports_server_side_planning(self) -> bool:
Expand Down
126 changes: 126 additions & 0 deletions pyiceberg/catalog/rest/credentials_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
#
from datetime import datetime

from pydantic import Field
from requests import HTTPError, Session

from pyiceberg.catalog import URI
from pyiceberg.catalog.rest.response import _handle_non_200_response
from pyiceberg.catalog.rest.scan_planning import StorageCredential
from pyiceberg.exceptions import ValidationError, ValidationException
from pyiceberg.io import (
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN,
S3_ACCESS_KEY_ID,
S3_SECRET_ACCESS_KEY,
S3_SESSION_TOKEN,
)
from pyiceberg.typedef import IcebergBaseModel, Properties
from pyiceberg.utils.properties import get_first_property_value

S3_SESSION_TOKEN_EXPIRES_AT_MS = "s3.session-token-expires-at-ms"
CREDENTIALS_ENDPOINT = "client.refresh-credentials-endpoint"
REFRESH_CREDENTIALS_ENABLED = "client.refresh-credentials-enabled"


class LoadCredentialsResponse(IcebergBaseModel):
credentials: list[StorageCredential] = Field(alias="storage-credentials")


class VendedCredentialsProvider:
_session: Session
_properties: Properties

def __init__(self, session: Session, properties: Properties):
self._session = session
self._properties = properties

def _extract_s3_credentials_from(self, props: Properties) -> tuple[str | None, str | None, str | None, str | None]:
"""Extract only S3 credentials from properties."""
access_key = get_first_property_value(props, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID)
secret_key = get_first_property_value(props, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY)
session_token = get_first_property_value(props, S3_SESSION_TOKEN, AWS_SESSION_TOKEN)
expiry = get_first_property_value(props, S3_SESSION_TOKEN_EXPIRES_AT_MS)

return access_key, secret_key, session_token, expiry

def _to_credentials_property_map(
self, access_key: str | None, secret_key: str | None, session_token: str | None, expiry: str | None
) -> Properties:
return {
S3_ACCESS_KEY_ID: access_key,
S3_SECRET_ACCESS_KEY: secret_key,
S3_SESSION_TOKEN: session_token,
S3_SESSION_TOKEN_EXPIRES_AT_MS: expiry,
}

def needs_refresh(self) -> bool:
"""Return True if the S3 session token expires within 300s."""
expiry = get_first_property_value(self._properties, S3_SESSION_TOKEN_EXPIRES_AT_MS)
if expiry is None:
return False
expires_at = datetime.fromtimestamp(int(expiry) / 1000)
seconds_remaining = (expires_at - datetime.now()).total_seconds()
return seconds_remaining < 300

def _build_refresh_endpoint(self) -> str:
"""Build credential refresh endpoint from properties."""
catalog_uri = get_first_property_value(self._properties, URI)
credentials_path = get_first_property_value(self._properties, CREDENTIALS_ENDPOINT)

if catalog_uri is None:
raise ValidationException("Invalid catalog endpoint: None")

if credentials_path is None:
raise ValidationException("Invalid credentials endpoint: None")

return str(catalog_uri).rstrip("/") + "/" + str(credentials_path).lstrip("/")

def _get_new_credentials(self) -> LoadCredentialsResponse | None:
try:
http_response = self._session.get(self._build_refresh_endpoint())
http_response.raise_for_status()
return LoadCredentialsResponse.model_validate_json(http_response.text)
except HTTPError as exc:
_handle_non_200_response(exc, {})
return None

def get_credentials(self) -> Properties:
"""Retrieve current S3 credentials, refreshing from the endpoint if near expiry."""
access_key, secret_key, session_token, expiry = self._extract_s3_credentials_from(self._properties)

if not self.needs_refresh():
return self._to_credentials_property_map(access_key, secret_key, session_token, expiry)

creds = self._get_new_credentials()

if creds is None:
raise ValidationError("Load credential response is None")
if not creds.credentials:
raise ValueError("Invalid S3 Credentials: empty")
if len(creds.credentials) > 1:
raise ValueError("Invalid S3 Credentials: only one S3 credential should exists")

updated_creds = self._extract_s3_credentials_from(creds.credentials[0].config)
updated_map = self._to_credentials_property_map(*updated_creds)

# Update internal properties with new credentials
self._properties = {**self._properties, **updated_map}

return updated_map
11 changes: 11 additions & 0 deletions pyiceberg/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@
from io import SEEK_SET
from types import TracebackType
from typing import (
TYPE_CHECKING,
Protocol,
runtime_checkable,
)

if TYPE_CHECKING:
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
from urllib.parse import urlparse

from pyiceberg.typedef import EMPTY_DICT, Properties
Expand Down Expand Up @@ -291,6 +295,13 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
FileNotFoundError: When the file at the provided location does not exist.
"""

def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None: # noqa: B027
"""Inject a credentials provider for refreshing vended storage credentials.

Args:
provider (VendedCredentialsProvider): A concrete type of VendedCredentialsProvider (e.g S3VendedCredentialsProvider)
"""


LOCATION = "location"
WAREHOUSE = "warehouse"
Expand Down
21 changes: 18 additions & 3 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from pyiceberg.catalog import TOKEN, URI
from pyiceberg.catalog.rest.auth import AUTH_MANAGER
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
from pyiceberg.exceptions import SignError
from pyiceberg.io import (
ADLS_ACCOUNT_HOST,
Expand Down Expand Up @@ -166,9 +167,12 @@ def _file(_: Properties) -> LocalFileSystem:
return LocalFileSystem(auto_mkdir=True)


def _s3(properties: Properties) -> AbstractFileSystem:
def _s3(properties: Properties, cred_provider: VendedCredentialsProvider | None) -> AbstractFileSystem:
from s3fs import S3FileSystem

if cred_provider is not None and cred_provider.needs_refresh():
properties = {**properties, **cred_provider.get_credentials()}

client_kwargs = {
"endpoint_url": properties.get(S3_ENDPOINT),
"aws_access_key_id": get_first_property_value(properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
Expand Down Expand Up @@ -319,6 +323,7 @@ def _hf(properties: Properties) -> AbstractFileSystem:
}

_ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"})
_S3_SCHEMES = frozenset({"s3", "s3a", "s3n"})


class FsspecInputFile(InputFile):
Expand Down Expand Up @@ -430,8 +435,12 @@ class FsspecFileIO(FileIO):
def __init__(self, properties: Properties):
self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS)
self._thread_locals = threading.local()
self._credentials_provider: VendedCredentialsProvider | None = None
super().__init__(properties=properties)

def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None:
self._credentials_provider = provider

@override
def new_input(self, location: str) -> FsspecInputFile:
"""Get an FsspecInputFile instance to read bytes from the file at the given location.
Expand Down Expand Up @@ -486,9 +495,12 @@ def _get_fs_from_uri(self, uri: "ParseResult") -> AbstractFileSystem:

def get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme, cached per thread."""
if not hasattr(self._thread_locals, "get_fs_cached"):
self._thread_locals.get_fs_cached = lru_cache(self._get_fs)
# If we have available a CredentialProvider and we detect that the tokens need to be refreshed
# then invalidate the cached fileio in order to get a new fileio with the fresh credentials
needs_refresh = self._credentials_provider and self._credentials_provider.needs_refresh()

if not hasattr(self._thread_locals, "get_fs_cached") or needs_refresh:
self._thread_locals.get_fs_cached = lru_cache(self._get_fs)
return self._thread_locals.get_fs_cached(scheme, hostname)

def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem:
Expand All @@ -499,6 +511,9 @@ def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSyste
if scheme in _ADLS_SCHEMES:
return _adls(self.properties, hostname)

if scheme in _S3_SCHEMES:
return _s3(self.properties, self._credentials_provider)

return self._scheme_to_fs[scheme](self.properties)

def __getstate__(self) -> dict[str, Any]:
Expand Down
17 changes: 15 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string

if TYPE_CHECKING:
from pyiceberg.catalog.rest.credentials_provider import VendedCredentialsProvider
from pyiceberg.table import FileScanTask, WriteTask

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -394,8 +395,20 @@ class PyArrowFileIO(FileIO):

def __init__(self, properties: Properties = EMPTY_DICT):
self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs)
self._credentials_provider: VendedCredentialsProvider | None = None
super().__init__(properties=properties)

def set_credentials_provider(self, provider: VendedCredentialsProvider) -> None:
self._credentials_provider = provider

def _get_fs(self, scheme: str, netloc: str | None) -> FileSystem:
# If we have available a CredentialProvider and we detect that the tokens need to be refreshed
# then invalidate the cached fileio in order to get a new fileio with the fresh credentials
if self._credentials_provider and self._credentials_provider.needs_refresh():
self.properties = {**self.properties, **self._credentials_provider.get_credentials()}
self.fs_by_scheme = lru_cache(self._initialize_fs)
return self.fs_by_scheme(scheme, netloc)

@staticmethod
def parse_location(location: str, properties: Properties = EMPTY_DICT) -> tuple[str, str, str]:
"""Return (scheme, netloc, path) for the given location.
Expand Down Expand Up @@ -628,7 +641,7 @@ def new_input(self, location: str) -> PyArrowFile:
"""
scheme, netloc, path = self.parse_location(location, self.properties)
return PyArrowFile(
fs=self.fs_by_scheme(scheme, netloc),
fs=self._get_fs(scheme, netloc),
location=location,
path=path,
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
Expand All @@ -646,7 +659,7 @@ def new_output(self, location: str) -> PyArrowFile:
"""
scheme, netloc, path = self.parse_location(location, self.properties)
return PyArrowFile(
fs=self.fs_by_scheme(scheme, netloc),
fs=self._get_fs(scheme, netloc),
location=location,
path=path,
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
Expand Down
Loading