diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 256a45ca8..1e1ae37f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,37 @@ jobs: with: args: check --config ci.ruff.toml + type-checking: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install poetry + uses: abatilo/actions-poetry@v2 + - name: Setup a local virtual environment + run: | + poetry config virtualenvs.create true --local + poetry config virtualenvs.in-project true --local + - uses: actions/cache@v3 + name: Define a cache for the virtual environment based on the dependencies lock file + with: + path: ./.venv + key: venv-type-check-${{ hashFiles('poetry.lock') }} + - uses: actions/cache@v3 + name: Cache mypy cache + with: + path: ./.mypy_cache + key: mypy-${{ hashFiles('**/*.py', 'pyproject.toml') }} + restore-keys: | + mypy- + - name: Install dependencies + run: poetry install --only=main,dev + - name: Run mypy type checking + run: poetry run mypy langfuse --no-error-summary + ci: runs-on: ubuntu-latest timeout-minutes: 30 @@ -160,7 +191,7 @@ jobs: all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix runs-on: ubuntu-latest - needs: [ci, linting] + needs: [ci, linting, type-checking] if: always() steps: - name: Successful deploy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89c6e0512..45f88c352 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,12 +2,34 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.2 hooks: - # Run the linter and fix + # Run the linter and fix - id: ruff - types_or: [ python, pyi, jupyter ] - args: [ --fix, --config=ci.ruff.toml ] + types_or: [python, pyi, jupyter] + args: [--fix, --config=ci.ruff.toml] # Run the formatter. - id: ruff-format - types_or: [ python, pyi, jupyter ] + types_or: [python, pyi, jupyter] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + additional_dependencies: + - types-requests + - types-setuptools + - httpx + - pydantic>=1.10.7 + - backoff>=1.10.0 + - openai>=0.27.8 + - wrapt + - packaging>=23.2 + - opentelemetry-api + - opentelemetry-sdk + - opentelemetry-exporter-otlp + - numpy + - langchain>=0.0.309 + - langchain-core + - langgraph + args: [--no-error-summary] + files: ^langfuse/ diff --git a/langfuse/_client/attributes.py b/langfuse/_client/attributes.py index d531d242b..1c22b7518 100644 --- a/langfuse/_client/attributes.py +++ b/langfuse/_client/attributes.py @@ -12,7 +12,7 @@ import json from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from langfuse._utils.serializer import EventSerializer from langfuse.model import PromptClient @@ -68,7 +68,7 @@ def create_trace_attributes( metadata: Optional[Any] = None, tags: Optional[List[str]] = None, public: Optional[bool] = None, -): +) -> dict: attributes = { LangfuseOtelSpanAttributes.TRACE_NAME: name, LangfuseOtelSpanAttributes.TRACE_USER_ID: user_id, @@ -93,7 +93,7 @@ def create_span_attributes( level: Optional[SpanLevel] = None, status_message: Optional[str] = None, version: Optional[str] = None, -): +) -> dict: attributes = { LangfuseOtelSpanAttributes.OBSERVATION_TYPE: "span", LangfuseOtelSpanAttributes.OBSERVATION_LEVEL: level, @@ -122,7 +122,7 @@ def create_generation_attributes( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, -): +) -> dict: attributes = { LangfuseOtelSpanAttributes.OBSERVATION_TYPE: "generation", LangfuseOtelSpanAttributes.OBSERVATION_LEVEL: level, @@ -151,20 +151,20 @@ def create_generation_attributes( return {k: v for k, v in attributes.items() if v is not None} -def _serialize(obj): +def _serialize(obj: Any) -> Optional[str]: return json.dumps(obj, cls=EventSerializer) if obj is not None else None def _flatten_and_serialize_metadata( metadata: Any, type: Literal["observation", "trace"] -): +) -> dict: prefix = ( LangfuseOtelSpanAttributes.OBSERVATION_METADATA if type == "observation" else LangfuseOtelSpanAttributes.TRACE_METADATA ) - metadata_attributes = {} + metadata_attributes: Dict[str, Union[str, int, None]] = {} if not isinstance(metadata, dict): metadata_attributes[prefix] = _serialize(metadata) diff --git a/langfuse/_client/client.py b/langfuse/_client/client.py index 8311d0872..8c165ca49 100644 --- a/langfuse/_client/client.py +++ b/langfuse/_client/client.py @@ -17,6 +17,7 @@ from opentelemetry import trace from opentelemetry import trace as otel_trace_api from opentelemetry.sdk.trace.id_generator import RandomIdGenerator +from opentelemetry.sdk.trace import TracerProvider from opentelemetry.util._decorator import ( _AgnosticContextManager, _agnosticcontextmanager, @@ -146,6 +147,7 @@ class Langfuse: _resources: Optional[LangfuseResourceManager] = None _mask: Optional[MaskFunction] = None + _otel_tracer: otel_trace_api.Tracer def __init__( self, @@ -166,11 +168,15 @@ def __init__( mask: Optional[MaskFunction] = None, blocked_instrumentation_scopes: Optional[List[str]] = None, additional_headers: Optional[Dict[str, str]] = None, - tracer_provider: Optional[otel_trace_api.TracerProvider] = None, + tracer_provider: Optional[TracerProvider] = None, ): - self._host = host or os.environ.get(LANGFUSE_HOST, "https://cloud.langfuse.com") - self._environment = environment or os.environ.get(LANGFUSE_TRACING_ENVIRONMENT) - self._project_id = None + self._host = host or cast( + str, os.environ.get(LANGFUSE_HOST, "https://cloud.langfuse.com") + ) + self._environment = environment or cast( + str, os.environ.get(LANGFUSE_TRACING_ENVIRONMENT) + ) + self._project_id: Optional[str] = None sample_rate = sample_rate or float(os.environ.get(LANGFUSE_SAMPLE_RATE, 1.0)) if not 0.0 <= sample_rate <= 1.0: raise ValueError( @@ -236,7 +242,7 @@ def __init__( self._otel_tracer = ( self._resources.tracer - if self._tracing_enabled + if self._tracing_enabled and self._resources.tracer is not None else otel_trace_api.NoOpTracer() ) self.api = self._resources.api @@ -657,9 +663,9 @@ def start_as_current_generation( def _create_span_with_parent_context( self, *, - name, - parent, - remote_parent_span, + name: str, + parent: Optional[otel_trace_api.Span] = None, + remote_parent_span: Optional[otel_trace_api.Span] = None, as_type: Literal["generation", "span"], end_on_exit: Optional[bool] = None, input: Optional[Any] = None, @@ -674,7 +680,7 @@ def _create_span_with_parent_context( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, - ): + ) -> Any: parent_span = parent or cast(otel_trace_api.Span, remote_parent_span) with otel_trace_api.use_span(parent_span): @@ -721,7 +727,7 @@ def _start_as_current_otel_span_with_processed_media( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, - ): + ) -> Any: with self._otel_tracer.start_as_current_span( name=name, end_on_exit=end_on_exit if end_on_exit is not None else True, @@ -936,7 +942,7 @@ def update_current_trace( metadata: Optional[Any] = None, tags: Optional[List[str]] = None, public: Optional[bool] = None, - ): + ) -> None: """Update the current trace with additional information. This method updates the Langfuse trace that the current span belongs to. It's useful for @@ -1054,35 +1060,41 @@ def create_event( ) otel_span.set_attribute(LangfuseOtelSpanAttributes.AS_ROOT, True) - return LangfuseEvent( - otel_span=otel_span, - langfuse_client=self, - environment=self._environment, - input=input, - output=output, - metadata=metadata, - version=version, - level=level, - status_message=status_message, - ).end(end_time=timestamp) + return cast( + LangfuseEvent, + LangfuseEvent( + otel_span=otel_span, + langfuse_client=self, + environment=self._environment, + input=input, + output=output, + metadata=metadata, + version=version, + level=level, + status_message=status_message, + ).end(end_time=timestamp), + ) otel_span = self._otel_tracer.start_span(name=name, start_time=timestamp) - return LangfuseEvent( - otel_span=otel_span, - langfuse_client=self, - environment=self._environment, - input=input, - output=output, - metadata=metadata, - version=version, - level=level, - status_message=status_message, - ).end(end_time=timestamp) + return cast( + LangfuseEvent, + LangfuseEvent( + otel_span=otel_span, + langfuse_client=self, + environment=self._environment, + input=input, + output=output, + metadata=metadata, + version=version, + level=level, + status_message=status_message, + ).end(end_time=timestamp), + ) def _create_remote_parent_span( self, *, trace_id: str, parent_span_id: Optional[str] - ): + ) -> Any: if not self._is_valid_trace_id(trace_id): langfuse_logger.warning( f"Passed trace ID '{trace_id}' is not a valid 32 lowercase hex char Langfuse trace id. Ignoring trace ID." @@ -1109,12 +1121,12 @@ def _create_remote_parent_span( return trace.NonRecordingSpan(span_context) - def _is_valid_trace_id(self, trace_id): + def _is_valid_trace_id(self, trace_id: str) -> bool: pattern = r"^[0-9a-f]{32}$" return bool(re.match(pattern, trace_id)) - def _is_valid_span_id(self, span_id): + def _is_valid_span_id(self, span_id: str) -> bool: pattern = r"^[0-9a-f]{16}$" return bool(re.match(pattern, span_id)) @@ -1216,12 +1228,12 @@ def create_trace_id(*, seed: Optional[str] = None) -> str: return sha256(seed.encode("utf-8")).digest()[:16].hex() - def _get_otel_trace_id(self, otel_span: otel_trace_api.Span): + def _get_otel_trace_id(self, otel_span: otel_trace_api.Span) -> str: span_context = otel_span.get_span_context() return self._format_otel_trace_id(span_context.trace_id) - def _get_otel_span_id(self, otel_span: otel_trace_api.Span): + def _get_otel_span_id(self, otel_span: otel_trace_api.Span) -> str: span_context = otel_span.get_span_context() return self._format_otel_span_id(span_context.span_id) @@ -1352,15 +1364,15 @@ def create_score( try: new_body = ScoreBody( id=score_id, - session_id=session_id, - dataset_run_id=dataset_run_id, - trace_id=trace_id, - observation_id=observation_id, + sessionId=session_id, + datasetRunId=dataset_run_id, + traceId=trace_id, + observationId=observation_id, name=name, value=value, - data_type=data_type, + dataType=data_type, # type: ignore comment=comment, - config_id=config_id, + configId=config_id, environment=self._environment, metadata=metadata, ) @@ -1555,7 +1567,7 @@ def score_current_trace( config_id=config_id, ) - def flush(self): + def flush(self) -> None: """Force flush all pending spans and events to the Langfuse API. This method manually flushes any pending spans, scores, and other events to the @@ -1578,7 +1590,7 @@ def flush(self): if self._resources is not None: self._resources.flush() - def shutdown(self): + def shutdown(self) -> None: """Shut down the Langfuse client and flush all pending data. This method cleanly shuts down the Langfuse client, ensuring all pending data @@ -1871,7 +1883,7 @@ def resolve_media_references( resolve_with: Literal["base64_data_uri"], max_depth: int = 10, content_fetch_timeout_seconds: int = 5, - ): + ) -> Any: """Replace media reference strings in an object with base64 data URIs. This method recursively traverses an object (up to max_depth) looking for media reference strings @@ -2054,16 +2066,20 @@ def get_prompt( try: # refresh prompt in background thread, refresh_prompt deduplicates tasks langfuse_logger.debug(f"Refreshing prompt '{cache_key}' in background.") - self._resources.prompt_cache.add_refresh_prompt_task( - cache_key, - lambda: self._fetch_prompt_and_update_cache( + + def refresh_task() -> None: + self._fetch_prompt_and_update_cache( name, version=version, label=label, ttl_seconds=cache_ttl_seconds, max_retries=bounded_max_retries, fetch_timeout_seconds=fetch_timeout_seconds, - ), + ) + + self._resources.prompt_cache.add_refresh_prompt_task( + cache_key, + refresh_task, ) langfuse_logger.debug( f"Returning stale prompt '{cache_key}' from cache." @@ -2088,7 +2104,7 @@ def _fetch_prompt_and_update_cache( label: Optional[str] = None, ttl_seconds: Optional[int] = None, max_retries: int, - fetch_timeout_seconds, + fetch_timeout_seconds: Optional[int], ) -> PromptClient: cache_key = PromptCache.generate_cache_key(name, version=version, label=label) langfuse_logger.debug(f"Fetching prompt '{cache_key}' from server...") @@ -2098,7 +2114,7 @@ def _fetch_prompt_and_update_cache( @backoff.on_exception( backoff.constant, Exception, max_tries=max_retries + 1, logger=None ) - def fetch_prompts(): + def fetch_prompts() -> Any: return self.api.prompts.get( self._url_encode(name), version=version, @@ -2112,6 +2128,7 @@ def fetch_prompts(): prompt_response = fetch_prompts() + prompt: PromptClient if prompt_response.type == "chat": prompt = ChatPromptClient(prompt_response) else: @@ -2208,14 +2225,16 @@ def create_prompt( raise ValueError( "For 'chat' type, 'prompt' must be a list of chat messages with role and content attributes." ) - request = CreatePromptRequest_Chat( - name=name, - prompt=cast(Any, prompt), - labels=labels, - tags=tags, - config=config or {}, - commitMessage=commit_message, - type="chat", + request: Union[CreatePromptRequest_Chat, CreatePromptRequest_Text] = ( + CreatePromptRequest_Chat( + name=name, + prompt=cast(Any, prompt), + labels=labels, + tags=tags, + config=config or {}, + commitMessage=commit_message, + type="chat", + ) ) server_prompt = self.api.prompts.create(request=request) @@ -2254,7 +2273,7 @@ def update_prompt( name: str, version: int, new_labels: List[str] = [], - ): + ) -> Any: """Update an existing prompt version in Langfuse. The Langfuse SDK prompt cache is invalidated for all prompts witht he specified name. Args: diff --git a/langfuse/_client/datasets.py b/langfuse/_client/datasets.py index 404a3020b..f06570e57 100644 --- a/langfuse/_client/datasets.py +++ b/langfuse/_client/datasets.py @@ -1,6 +1,7 @@ import datetime as dt import logging -from typing import TYPE_CHECKING, Any, List, Optional +from .span import LangfuseSpan +from typing import TYPE_CHECKING, Any, Generator, List, Optional from opentelemetry.util._decorator import _agnosticcontextmanager @@ -91,7 +92,7 @@ def run( run_name: str, run_metadata: Optional[Any] = None, run_description: Optional[str] = None, - ): + ) -> Generator[LangfuseSpan, None, None]: """Create a context manager for the dataset item run that links the execution to a Langfuse trace. This method is a context manager that creates a trace for the dataset run and yields a span diff --git a/langfuse/_client/get_client.py b/langfuse/_client/get_client.py index f53886e46..98a64fbfe 100644 --- a/langfuse/_client/get_client.py +++ b/langfuse/_client/get_client.py @@ -80,9 +80,11 @@ def get_client(*, public_key: Optional[str] = None) -> Langfuse: else: # Specific key provided, look up existing instance - instance = active_instances.get(public_key, None) + target_instance: Optional[LangfuseResourceManager] = active_instances.get( + public_key, None + ) - if instance is None: + if target_instance is None: # No instance found with this key - client not initialized properly langfuse_logger.warning( f"No Langfuse client with public key {public_key} has been initialized. Skipping tracing for decorated function." @@ -91,9 +93,10 @@ def get_client(*, public_key: Optional[str] = None) -> Langfuse: tracing_enabled=False, public_key="fake", secret_key="fake" ) + # target_instance is guaranteed to be not None at this point return Langfuse( public_key=public_key, - secret_key=instance.secret_key, - host=instance.host, - tracing_enabled=instance.tracing_enabled, + secret_key=target_instance.secret_key, + host=target_instance.host, + tracing_enabled=target_instance.tracing_enabled, ) diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index e37e91e99..84bb60ebd 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -20,6 +20,7 @@ ) from typing_extensions import ParamSpec +from opentelemetry.util._decorator import _AgnosticContextManager from langfuse._client.environment_variables import ( LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED, @@ -206,9 +207,11 @@ def _async_observe( transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> F: @wraps(func) - async def async_wrapper(*args, **kwargs): - trace_id = kwargs.pop("langfuse_trace_id", None) - parent_observation_id = kwargs.pop("langfuse_parent_observation_id", None) + async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: + trace_id = cast(str, kwargs.pop("langfuse_trace_id", None)) + parent_observation_id = cast( + str, kwargs.pop("langfuse_parent_observation_id", None) + ) trace_context: Optional[TraceContext] = ( { "trace_id": trace_id, @@ -227,9 +230,14 @@ async def async_wrapper(*args, **kwargs): if capture_input else None ) - public_key = kwargs.pop("langfuse_public_key", None) + public_key = cast(str, kwargs.pop("langfuse_public_key", None)) langfuse_client = get_client(public_key=public_key) - context_manager = ( + context_manager: Optional[ + Union[ + _AgnosticContextManager[LangfuseGeneration], + _AgnosticContextManager[LangfuseSpan], + ] + ] = ( ( langfuse_client.start_as_current_generation( name=final_name, @@ -294,7 +302,7 @@ def _sync_observe( transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> F: @wraps(func) - def sync_wrapper(*args, **kwargs): + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: trace_id = kwargs.pop("langfuse_trace_id", None) parent_observation_id = kwargs.pop("langfuse_parent_observation_id", None) trace_context: Optional[TraceContext] = ( @@ -317,7 +325,12 @@ def sync_wrapper(*args, **kwargs): ) public_key = kwargs.pop("langfuse_public_key", None) langfuse_client = get_client(public_key=public_key) - context_manager = ( + context_manager: Optional[ + Union[ + _AgnosticContextManager[LangfuseGeneration], + _AgnosticContextManager[LangfuseSpan], + ] + ] = ( ( langfuse_client.start_as_current_generation( name=final_name, @@ -398,7 +411,7 @@ def _wrap_sync_generator_result( langfuse_span_or_generation: Union[LangfuseSpan, LangfuseGeneration], generator: Generator, transform_to_string: Optional[Callable[[Iterable], str]] = None, - ): + ) -> Any: items = [] try: @@ -408,7 +421,7 @@ def _wrap_sync_generator_result( yield item finally: - output = items + output: Any = items if transform_to_string is not None: output = transform_to_string(items) @@ -434,7 +447,7 @@ async def _wrap_async_generator_result( yield item finally: - output = items + output: Any = items if transform_to_string is not None: output = transform_to_string(items) diff --git a/langfuse/_client/resource_manager.py b/langfuse/_client/resource_manager.py index 162c76e62..e0e3cbadc 100644 --- a/langfuse/_client/resource_manager.py +++ b/langfuse/_client/resource_manager.py @@ -25,6 +25,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.sampling import Decision, TraceIdRatioBased +from opentelemetry.trace import Tracer from langfuse._client.attributes import LangfuseOtelSpanAttributes from langfuse._client.constants import LANGFUSE_TRACER_NAME @@ -103,7 +104,14 @@ def __new__( with cls._lock: if public_key not in cls._instances: instance = super(LangfuseResourceManager, cls).__new__(cls) - instance._otel_tracer = None + + # Initialize tracer (will be noop until init instance) + instance._otel_tracer = otel_trace_api.get_tracer( + LANGFUSE_TRACER_NAME, + langfuse_version, + attributes={"public_key": public_key}, + ) + instance._initialize_instance( public_key=public_key, secret_key=secret_key, @@ -148,7 +156,7 @@ def _initialize_instance( blocked_instrumentation_scopes: Optional[List[str]] = None, additional_headers: Optional[Dict[str, str]] = None, tracer_provider: Optional[TracerProvider] = None, - ): + ) -> None: self.public_key = public_key self.secret_key = secret_key self.tracing_enabled = tracing_enabled @@ -278,14 +286,14 @@ def _initialize_instance( ) @classmethod - def reset(cls): + def reset(cls) -> None: with cls._lock: for key in cls._instances: cls._instances[key].shutdown() cls._instances.clear() - def add_score_task(self, event: dict, *, force_sample: bool = False): + def add_score_task(self, event: dict, *, force_sample: bool = False) -> None: try: # Sample scores with the same sampler that is used for tracing tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider()) @@ -331,14 +339,14 @@ def add_score_task(self, event: dict, *, force_sample: bool = False): return @property - def tracer(self): + def tracer(self) -> Optional[Tracer]: return self._otel_tracer @staticmethod - def get_current_span(): + def get_current_span() -> Any: return otel_trace_api.get_current_span() - def _stop_and_join_consumer_threads(self): + def _stop_and_join_consumer_threads(self) -> None: """End the consumer threads once the queue is empty. Blocks execution until finished @@ -377,7 +385,7 @@ def _stop_and_join_consumer_threads(self): f"Shutdown: Score ingestion thread #{score_ingestion_consumer._identifier} successfully terminated" ) - def flush(self): + def flush(self) -> None: tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider()) if isinstance(tracer_provider, otel_trace_api.ProxyTracerProvider): return @@ -391,7 +399,7 @@ def flush(self): self._media_upload_queue.join() langfuse_logger.debug("Successfully flushed media upload queue") - def shutdown(self): + def shutdown(self) -> None: # Unregister the atexit handler first atexit.unregister(self.shutdown) diff --git a/langfuse/_client/span.py b/langfuse/_client/span.py index f61de550c..34aa4f0d1 100644 --- a/langfuse/_client/span.py +++ b/langfuse/_client/span.py @@ -115,7 +115,7 @@ def __init__( ) # Handle media only if span is sampled - if self._otel_span.is_recording: + if self._otel_span.is_recording(): media_processed_input = self._process_media_and_apply_mask( data=input, field="input", span=self._otel_span ) @@ -160,7 +160,7 @@ def __init__( {k: v for k, v in attributes.items() if v is not None} ) - def end(self, *, end_time: Optional[int] = None): + def end(self, *, end_time: Optional[int] = None) -> "LangfuseSpanWrapper": """End the span, marking it as completed. This method ends the wrapped OpenTelemetry span, marking the end of the @@ -186,7 +186,7 @@ def update_trace( metadata: Optional[Any] = None, tags: Optional[List[str]] = None, public: Optional[bool] = None, - ): + ) -> "LangfuseSpanWrapper": """Update the trace that this span belongs to. This method updates trace-level attributes of the trace that this span @@ -205,7 +205,7 @@ def update_trace( public: Whether the trace should be publicly accessible """ if not self._otel_span.is_recording(): - return + return self media_processed_input = self._process_media_and_apply_mask( data=input, field="input", span=self._otel_span @@ -231,6 +231,8 @@ def update_trace( self._otel_span.set_attributes(attributes) + return self + @overload def score( self, @@ -385,7 +387,7 @@ def _set_processed_span_attributes( input: Optional[Any] = None, output: Optional[Any] = None, metadata: Optional[Any] = None, - ): + ) -> None: """Set span attributes after processing media and applying masks. Internal method that processes media in the input, output, and metadata @@ -436,7 +438,7 @@ def _process_media_and_apply_mask( data: Optional[Any] = None, span: otel_trace_api.Span, field: Union[Literal["input"], Literal["output"], Literal["metadata"]], - ): + ) -> Optional[Any]: """Process media in an attribute and apply masking. Internal method that processes any media content in the data and applies @@ -454,7 +456,7 @@ def _process_media_and_apply_mask( data=self._process_media_in_attribute(data=data, field=field) ) - def _mask_attribute(self, *, data): + def _mask_attribute(self, *, data: Any) -> Any: """Apply the configured mask function to data. Internal method that applies the client's configured masking function to @@ -483,7 +485,7 @@ def _process_media_in_attribute( *, data: Optional[Any] = None, field: Union[Literal["input"], Literal["output"], Literal["metadata"]], - ): + ) -> Optional[Any]: """Process any media content in the attribute data. Internal method that identifies and processes any media content in the @@ -568,7 +570,7 @@ def update( version: Optional[str] = None, level: Optional[SpanLevel] = None, status_message: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> "LangfuseSpan": """Update this span with new information. @@ -765,7 +767,7 @@ def start_generation( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, - ): + ) -> "LangfuseGeneration": """Create a new child generation span. This method creates a new child generation span with this span as the parent. @@ -972,17 +974,20 @@ def create_event( name=name, start_time=timestamp ) - return LangfuseEvent( - otel_span=new_otel_span, - langfuse_client=self._langfuse_client, - input=input, - output=output, - metadata=metadata, - environment=self._environment, - version=version, - level=level, - status_message=status_message, - ).end(end_time=timestamp) + return cast( + "LangfuseEvent", + LangfuseEvent( + otel_span=new_otel_span, + langfuse_client=self._langfuse_client, + input=input, + output=output, + metadata=metadata, + environment=self._environment, + version=version, + level=level, + status_message=status_message, + ).end(end_time=timestamp), + ) class LangfuseGeneration(LangfuseSpanWrapper): @@ -1066,7 +1071,7 @@ def update( usage_details: Optional[Dict[str, int]] = None, cost_details: Optional[Dict[str, float]] = None, prompt: Optional[PromptClient] = None, - **kwargs, + **kwargs: Dict[str, Any], ) -> "LangfuseGeneration": """Update this generation span with new information. diff --git a/langfuse/_client/utils.py b/langfuse/_client/utils.py index 670e40c4b..dac7a3f1b 100644 --- a/langfuse/_client/utils.py +++ b/langfuse/_client/utils.py @@ -11,7 +11,7 @@ from opentelemetry.sdk.trace import ReadableSpan -def span_formatter(span: ReadableSpan): +def span_formatter(span: ReadableSpan) -> str: parent_id = ( otel_trace_api.format_span_id(span.parent.span_id) if span.parent else None ) diff --git a/langfuse/_task_manager/media_manager.py b/langfuse/_task_manager/media_manager.py index 42596e450..13fa5f0c6 100644 --- a/langfuse/_task_manager/media_manager.py +++ b/langfuse/_task_manager/media_manager.py @@ -39,7 +39,7 @@ def __init__( LANGFUSE_MEDIA_UPLOAD_ENABLED, "True" ).lower() not in ("false", "0") - def process_next_media_upload(self): + def process_next_media_upload(self) -> None: try: upload_job = self._queue.get(block=True, timeout=1) self._log.debug( @@ -64,14 +64,14 @@ def _find_and_process_media( trace_id: str, observation_id: Optional[str], field: str, - ): + ) -> Any: if not self._enabled: return data seen = set() max_levels = 10 - def _process_data_recursively(data: Any, level: int): + def _process_data_recursively(data: Any, level: int) -> Any: if id(data) in seen or level > max_levels: return data @@ -170,7 +170,7 @@ def _process_media( trace_id: str, observation_id: Optional[str], field: str, - ): + ) -> None: if ( media._content_length is None or media._content_type is None @@ -217,7 +217,7 @@ def _process_upload_media_job( self, *, data: UploadMediaJob, - ): + ) -> None: upload_url_response = self._request_with_backoff( self._api_client.media.get_upload_url, request=GetMediaUploadUrlRequest( diff --git a/langfuse/_task_manager/media_upload_consumer.py b/langfuse/_task_manager/media_upload_consumer.py index ccfad2c20..182170864 100644 --- a/langfuse/_task_manager/media_upload_consumer.py +++ b/langfuse/_task_manager/media_upload_consumer.py @@ -28,7 +28,7 @@ def __init__( self._identifier = identifier self._media_manager = media_manager - def run(self): + def run(self) -> None: """Run the media upload consumer.""" self._log.debug( f"Thread: Media upload consumer thread #{self._identifier} started and actively processing queue items" @@ -36,7 +36,7 @@ def run(self): while self.running: self._media_manager.process_next_media_upload() - def pause(self): + def pause(self) -> None: """Pause the media upload consumer.""" self._log.debug( f"Thread: Pausing media upload consumer thread #{self._identifier}" diff --git a/langfuse/_task_manager/score_ingestion_consumer.py b/langfuse/_task_manager/score_ingestion_consumer.py index 9543c12d9..1a5b61f91 100644 --- a/langfuse/_task_manager/score_ingestion_consumer.py +++ b/langfuse/_task_manager/score_ingestion_consumer.py @@ -13,7 +13,7 @@ try: import pydantic.v1 as pydantic except ImportError: - import pydantic + import pydantic # type: ignore from langfuse._utils.parse_error import handle_exception from langfuse._utils.request import APIError, LangfuseClient @@ -61,9 +61,9 @@ def __init__( self._max_retries = max_retries or 3 self._public_key = public_key - def _next(self): + def _next(self) -> list: """Return the next batch of items to upload.""" - events = [] + events: list = [] start_time = time.monotonic() total_size = 0 @@ -119,7 +119,7 @@ def _get_item_size(self, item: Any) -> int: """Return the size of the item in bytes.""" return len(json.dumps(item, cls=EventSerializer).encode()) - def run(self): + def run(self) -> None: """Run the consumer.""" self._log.debug( f"Startup: Score ingestion consumer thread #{self._identifier} started with batch size {self._flush_at} and interval {self._flush_interval}s" @@ -127,7 +127,7 @@ def run(self): while self.running: self.upload() - def upload(self): + def upload(self) -> None: """Upload the next batch of items, return whether successful.""" batch = self._next() if len(batch) == 0: @@ -142,11 +142,11 @@ def upload(self): for _ in batch: self._ingestion_queue.task_done() - def pause(self): + def pause(self) -> None: """Pause the consumer.""" self.running = False - def _upload_batch(self, batch: List[Any]): + def _upload_batch(self, batch: List[Any]) -> None: self._log.debug( f"API: Uploading batch of {len(batch)} score events to Langfuse API" ) @@ -161,7 +161,7 @@ def _upload_batch(self, batch: List[Any]): @backoff.on_exception( backoff.expo, Exception, max_tries=self._max_retries, logger=None ) - def execute_task_with_backoff(batch: List[Any]): + def execute_task_with_backoff(batch: List[Any]) -> None: try: self._client.batch_post(batch=batch, metadata=metadata) except Exception as e: diff --git a/langfuse/_utils/__init__.py b/langfuse/_utils/__init__.py index 036a40be4..e8a02abc1 100644 --- a/langfuse/_utils/__init__.py +++ b/langfuse/_utils/__init__.py @@ -9,13 +9,13 @@ log = logging.getLogger("langfuse") -def _get_timestamp(): +def _get_timestamp() -> datetime: return datetime.now(timezone.utc) def _create_prompt_context( prompt: typing.Optional[PromptClient] = None, -): +) -> typing.Dict[str, typing.Optional[typing.Union[str, int]]]: if prompt is not None and not prompt.is_fallback: return {"prompt_version": prompt.version, "prompt_name": prompt.name} diff --git a/langfuse/_utils/environment.py b/langfuse/_utils/environment.py index bd7d6021d..a696b3a59 100644 --- a/langfuse/_utils/environment.py +++ b/langfuse/_utils/environment.py @@ -1,6 +1,7 @@ """@private""" import os +from typing import Optional common_release_envs = [ # Render @@ -26,7 +27,7 @@ ] -def get_common_release_envs(): +def get_common_release_envs() -> Optional[str]: for env in common_release_envs: if env in os.environ: return os.environ[env] diff --git a/langfuse/_utils/error_logging.py b/langfuse/_utils/error_logging.py index ef3507fe1..e5a7fe67c 100644 --- a/langfuse/_utils/error_logging.py +++ b/langfuse/_utils/error_logging.py @@ -1,15 +1,15 @@ import functools import logging -from typing import List, Optional +from typing import Any, Callable, List, Optional logger = logging.getLogger("langfuse") -def catch_and_log_errors(func): +def catch_and_log_errors(func: Callable[..., Any]) -> Callable[..., Any]: """Catch all exceptions and log them. Do NOT re-raise the exception.""" @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: try: return func(*args, **kwargs) except Exception as e: @@ -18,14 +18,17 @@ def wrapper(*args, **kwargs): return wrapper -def auto_decorate_methods_with(decorator, exclude: Optional[List[str]] = []): +def auto_decorate_methods_with( + decorator: Callable[[Callable[..., Any]], Callable[..., Any]], + exclude: Optional[List[str]] = None, +) -> Callable[[type], type]: """Class decorator to automatically apply a given decorator to all methods of a class. """ - def class_decorator(cls): + def class_decorator(cls: type) -> type: for attr_name, attr_value in cls.__dict__.items(): - if attr_name in exclude: + if exclude and attr_name in exclude: continue if callable(attr_value): # Wrap callable attributes (methods) with the decorator diff --git a/langfuse/_utils/parse_error.py b/langfuse/_utils/parse_error.py index 12d891606..cb5749f93 100644 --- a/langfuse/_utils/parse_error.py +++ b/langfuse/_utils/parse_error.py @@ -56,14 +56,14 @@ def generate_error_message_fern(error: Error) -> str: elif isinstance(error, ServiceUnavailableError): return errorResponseByCode.get(503, defaultErrorResponse) elif isinstance(error, ApiError): - status_code = ( - int(error.status_code) - if isinstance(error.status_code, str) - else error.status_code + status_code = error.status_code + return ( + errorResponseByCode.get(status_code, defaultErrorResponse) + if status_code is not None + else defaultErrorResponse ) - return errorResponseByCode.get(status_code, defaultErrorResponse) - else: - return defaultErrorResponse + + return defaultErrorResponse # type: ignore def handle_fern_exception(exception: Error) -> None: diff --git a/langfuse/_utils/prompt_cache.py b/langfuse/_utils/prompt_cache.py index 67611d50d..132dcb410 100644 --- a/langfuse/_utils/prompt_cache.py +++ b/langfuse/_utils/prompt_cache.py @@ -5,7 +5,7 @@ from datetime import datetime from queue import Empty, Queue from threading import Thread -from typing import Dict, List, Optional, Set +from typing import Callable, Dict, List, Optional, Set from langfuse.model import PromptClient @@ -39,7 +39,7 @@ def __init__(self, queue: Queue, identifier: int): self._queue = queue self._identifier = identifier - def run(self): + def run(self) -> None: while self.running: try: task = self._queue.get(timeout=1) @@ -58,7 +58,7 @@ def run(self): except Empty: pass - def pause(self): + def pause(self) -> None: """Pause the consumer.""" self.running = False @@ -83,7 +83,7 @@ def __init__(self, threads: int = 1): atexit.register(self.shutdown) - def add_task(self, key: str, task): + def add_task(self, key: str, task: Callable[[], None]) -> None: if key not in self._processing_keys: self._log.debug(f"Adding prompt cache refresh task for key: {key}") self._processing_keys.add(key) @@ -97,8 +97,8 @@ def add_task(self, key: str, task): def active_tasks(self) -> int: return len(self._processing_keys) - def _wrap_task(self, key: str, task): - def wrapped(): + def _wrap_task(self, key: str, task: Callable[[], None]) -> Callable[[], None]: + def wrapped() -> None: self._log.debug(f"Refreshing prompt cache for key: {key}") try: task() @@ -108,7 +108,7 @@ def wrapped(): return wrapped - def shutdown(self): + def shutdown(self) -> None: self._log.debug( f"Shutting down prompt refresh task manager, {len(self._consumers)} consumers,..." ) @@ -146,19 +146,19 @@ def __init__( def get(self, key: str) -> Optional[PromptCacheItem]: return self._cache.get(key, None) - def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]): + def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]) -> None: if ttl_seconds is None: ttl_seconds = DEFAULT_PROMPT_CACHE_TTL_SECONDS self._cache[key] = PromptCacheItem(value, ttl_seconds) - def invalidate(self, prompt_name: str): + def invalidate(self, prompt_name: str) -> None: """Invalidate all cached prompts with the given prompt name.""" for key in list(self._cache): if key.startswith(prompt_name): del self._cache[key] - def add_refresh_prompt_task(self, key: str, fetch_func): + def add_refresh_prompt_task(self, key: str, fetch_func: Callable[[], None]) -> None: self._log.debug(f"Submitting refresh task for key: {key}") self._task_manager.add_task(key, fetch_func) diff --git a/langfuse/_utils/request.py b/langfuse/_utils/request.py index d420a3a13..b106cee2f 100644 --- a/langfuse/_utils/request.py +++ b/langfuse/_utils/request.py @@ -34,7 +34,7 @@ def __init__( self._timeout = timeout self._session = session - def generate_headers(self): + def generate_headers(self) -> dict: return { "Authorization": "Basic " + b64encode( @@ -46,7 +46,7 @@ def generate_headers(self): "x_langfuse_public_key": self._public_key, } - def batch_post(self, **kwargs) -> httpx.Response: + def batch_post(self, **kwargs: Any) -> httpx.Response: """Post the `kwargs` to the batch API endpoint for events""" log = logging.getLogger("langfuse") log.debug("uploading data: %s", kwargs) @@ -56,7 +56,7 @@ def batch_post(self, **kwargs) -> httpx.Response: res, success_message="data uploaded successfully", return_json=False ) - def post(self, **kwargs) -> httpx.Response: + def post(self, **kwargs: Any) -> httpx.Response: """Post the `kwargs` to the API""" log = logging.getLogger("langfuse") url = self._remove_trailing_slash(self._base_url) + "/api/public/ingestion" @@ -125,7 +125,7 @@ def __init__(self, status: Union[int, str], message: str, details: Any = None): self.status = status self.details = details - def __str__(self): + def __str__(self) -> str: msg = "{0} ({1}): {2}" return msg.format(self.message, self.status, self.details) @@ -134,7 +134,7 @@ class APIErrors(Exception): def __init__(self, errors: List[APIError]): self.errors = errors - def __str__(self): + def __str__(self) -> str: errors = ", ".join(str(error) for error in self.errors) return f"[Langfuse] {errors}" diff --git a/langfuse/_utils/serializer.py b/langfuse/_utils/serializer.py index e4b625a6e..9232cc908 100644 --- a/langfuse/_utils/serializer.py +++ b/langfuse/_utils/serializer.py @@ -36,11 +36,11 @@ class Serializable: # type: ignore class EventSerializer(JSONEncoder): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.seen = set() # Track seen objects to detect circular references + self.seen: set[int] = set() # Track seen objects to detect circular references - def default(self, obj: Any): + def default(self, obj: Any) -> Any: try: if isinstance(obj, (datetime)): # Timezone-awareness check @@ -83,7 +83,7 @@ def default(self, obj: Any): return type(obj).__name__ if is_dataclass(obj): - return asdict(obj) + return asdict(obj) # type: ignore if isinstance(obj, UUID): return str(obj) @@ -114,7 +114,7 @@ def default(self, obj: Any): return str(obj) # if langchain is not available, the Serializable type is NoneType - if Serializable is not type(None) and isinstance(obj, Serializable): + if Serializable is not type(None) and isinstance(obj, Serializable): # type: ignore return obj.to_json() # 64-bit integers might overflow the JavaScript safe integer range. diff --git a/langfuse/langchain/CallbackHandler.py b/langfuse/langchain/CallbackHandler.py index 9c6a03c78..ed7cfb70a 100644 --- a/langfuse/langchain/CallbackHandler.py +++ b/langfuse/langchain/CallbackHandler.py @@ -14,7 +14,7 @@ f"Could not import langchain. The langchain integration will not work. {e}" ) -from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast +from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Type, Union, cast from uuid import UUID from langfuse._utils import _get_timestamp @@ -63,7 +63,7 @@ def __init__(self, *, public_key: Optional[str] = None) -> None: self.prompt_to_parent_run_map: Dict[UUID, Any] = {} self.updated_completion_start_time_memo: Set[UUID] = set() - self.last_trace_id = None + self.last_trace_id: Optional[str] = None def on_llm_new_token( self, @@ -106,18 +106,18 @@ def get_langchain_run_name( str: The determined name of the Langchain runnable. """ if "name" in kwargs and kwargs["name"] is not None: - return kwargs["name"] + return str(kwargs["name"]) if serialized is None: return "" try: - return serialized["name"] + return str(serialized["name"]) except (KeyError, TypeError): pass try: - return serialized["id"][-1] + return str(serialized["id"][-1]) except (KeyError, TypeError): pass @@ -177,7 +177,10 @@ def on_chain_start( name=span_name, metadata=span_metadata, input=inputs, - level=span_level, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), ) else: self.runs[run_id] = cast( @@ -186,7 +189,10 @@ def on_chain_start( name=span_name, metadata=span_metadata, input=inputs, - level=span_level, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), ) self.last_trace_id = self.runs[run_id].trace_id @@ -197,10 +203,10 @@ def on_chain_start( def _register_langfuse_prompt( self, *, - run_id, + run_id: UUID, parent_run_id: Optional[UUID], metadata: Optional[Dict[str, Any]], - ): + ) -> None: """We need to register any passed Langfuse prompt to the parent_run_id so that we can link following generations with that prompt. If parent_run_id is None, we are at the root of a trace and should not attempt to register the prompt, as there will be no LLM invocation following it. @@ -220,7 +226,7 @@ def _register_langfuse_prompt( registered_prompt = self.prompt_to_parent_run_map[parent_run_id] self.prompt_to_parent_run_map[run_id] = registered_prompt - def _deregister_langfuse_prompt(self, run_id: Optional[UUID]): + def _deregister_langfuse_prompt(self, run_id: Optional[UUID]) -> None: if run_id in self.prompt_to_parent_run_map: del self.prompt_to_parent_run_map[run_id] @@ -317,7 +323,9 @@ def on_chain_error( level = "ERROR" self.runs[run_id].update( - level=level, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], level + ), status_message=str(error) if level else None, input=kwargs.get("inputs"), ).end() @@ -451,7 +459,10 @@ def on_retriever_start( name=span_name, metadata=span_metadata, input=query, - level=span_level, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), ) else: self.runs[run_id] = cast( @@ -460,7 +471,10 @@ def on_retriever_start( name=span_name, input=query, metadata=span_metadata, - level=span_level, + level=cast( + Optional[Literal["DEBUG", "DEFAULT", "WARNING", "ERROR"]], + span_level, + ), ) except Exception as e: @@ -542,12 +556,12 @@ def __on_llm_action( self, serialized: Optional[Dict[str, Any]], run_id: UUID, - prompts: List[str], + prompts: List[Any], parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ): + ) -> None: try: tools = kwargs.get("invocation_params", {}).get("tools", None) if tools and isinstance(tools, list): @@ -556,7 +570,11 @@ def __on_llm_action( model_name = self._parse_model_and_log_errors( serialized=serialized, metadata=metadata, kwargs=kwargs ) - registered_prompt = self.prompt_to_parent_run_map.get(parent_run_id, None) + registered_prompt = ( + self.prompt_to_parent_run_map.get(parent_run_id) + if parent_run_id is not None + else None + ) if registered_prompt: self._deregister_langfuse_prompt(parent_run_id) @@ -573,9 +591,9 @@ def __on_llm_action( if parent_run_id is not None and parent_run_id in self.runs: self.runs[run_id] = cast( LangfuseSpan, self.runs[parent_run_id] - ).start_generation(**content) + ).start_generation(**content) # type: ignore else: - self.runs[run_id] = self.client.start_generation(**content) + self.runs[run_id] = self.client.start_generation(**content) # type: ignore self.last_trace_id = self.runs[run_id].trace_id @@ -583,7 +601,7 @@ def __on_llm_action( langfuse_logger.exception(e) @staticmethod - def _parse_model_parameters(kwargs): + def _parse_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]: """Parse the model parameters from the kwargs.""" if kwargs["invocation_params"].get("_type") == "IBM watsonx.ai" and kwargs[ "invocation_params" @@ -615,7 +633,13 @@ def _parse_model_parameters(kwargs): if value is not None } - def _parse_model_and_log_errors(self, *, serialized, metadata, kwargs): + def _parse_model_and_log_errors( + self, + *, + serialized: Optional[Dict[str, Any]], + metadata: Optional[Dict[str, Any]], + kwargs: Dict[str, Any], + ) -> Optional[str]: """Parse the model name and log errors if parsing fails.""" try: model_name = _parse_model_name_from_metadata( @@ -629,8 +653,9 @@ def _parse_model_and_log_errors(self, *, serialized, metadata, kwargs): langfuse_logger.exception(e) self._log_model_parse_warning() + return None - def _log_model_parse_warning(self): + def _log_model_parse_warning(self) -> None: if not hasattr(self, "_model_parse_warning_logged"): langfuse_logger.warning( "Langfuse was not able to parse the LLM model. The LLM call will be recorded without model name. Please create an issue: https://github.com/langfuse/langfuse/issues/new/choose" @@ -653,26 +678,26 @@ def on_llm_end( if run_id not in self.runs: raise Exception("Run not found, see docs what to do in this case.") else: - generation = response.generations[-1][-1] + response_generation = response.generations[-1][-1] extracted_response = ( - self._convert_message_to_dict(generation.message) - if isinstance(generation, ChatGeneration) - else _extract_raw_response(generation) + self._convert_message_to_dict(response_generation.message) + if isinstance(response_generation, ChatGeneration) + else _extract_raw_response(response_generation) ) llm_usage = _parse_usage(response) # e.g. azure returns the model name in the response model = _parse_model(response) - generation = cast(LangfuseGeneration, self.runs[run_id]) - generation.update( + langfuse_generation = cast(LangfuseGeneration, self.runs[run_id]) + langfuse_generation.update( output=extracted_response, usage=llm_usage, usage_details=llm_usage, input=kwargs.get("inputs"), model=model, ) - generation.end() + langfuse_generation.end() del self.runs[run_id] @@ -745,7 +770,7 @@ def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: message_dict["name"] = message.additional_kwargs["name"] if message.additional_kwargs: - message_dict["additional_kwargs"] = message.additional_kwargs + message_dict["additional_kwargs"] = message.additional_kwargs # type: ignore return message_dict @@ -759,14 +784,14 @@ def _log_debug_event( event_name: str, run_id: UUID, parent_run_id: Optional[UUID] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: langfuse_logger.debug( f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}" ) -def _extract_raw_response(last_response): +def _extract_raw_response(last_response: Any) -> Any: """Extract the response from the last response of the LLM call.""" # We return the text of the response if not empty if last_response.text is not None and last_response.text.strip() != "": @@ -779,11 +804,11 @@ def _extract_raw_response(last_response): return "" -def _flatten_comprehension(matrix): +def _flatten_comprehension(matrix: Any) -> Any: return [item for row in matrix for item in row] -def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): +def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]) -> Any: # maintains a list of key translations. For each key, the usage model is checked # and a new object will be created with the new key if the key exists in the usage model # All non matched keys will remain on the object. @@ -907,7 +932,7 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): return usage_model if usage_model else None -def _parse_usage(response: LLMResult): +def _parse_usage(response: LLMResult) -> Any: # langchain-anthropic uses the usage field llm_usage_keys = ["token_usage", "usage"] llm_usage = None @@ -954,7 +979,7 @@ def _parse_usage(response: LLMResult): return llm_usage -def _parse_model(response: LLMResult): +def _parse_model(response: LLMResult) -> Any: # langchain-anthropic uses the usage field llm_model_keys = ["model_name"] llm_model = None @@ -967,14 +992,14 @@ def _parse_model(response: LLMResult): return llm_model -def _parse_model_name_from_metadata(metadata: Optional[Dict[str, Any]]): +def _parse_model_name_from_metadata(metadata: Optional[Dict[str, Any]]) -> Any: if metadata is None or not isinstance(metadata, dict): return None return metadata.get("ls_model_name", None) -def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]): +def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]) -> Any: if metadata is None or not isinstance(metadata, dict): return metadata diff --git a/langfuse/langchain/utils.py b/langfuse/langchain/utils.py index 76be37100..544bcf957 100644 --- a/langfuse/langchain/utils.py +++ b/langfuse/langchain/utils.py @@ -130,7 +130,7 @@ def _extract_model_name( serialized, kwargs, path, cast(Literal["serialized", "kwargs"], select) ) if model: - return model + return str(model) return None @@ -159,7 +159,7 @@ def _extract_model_from_repr_by_pattern( return None -def _extract_model_with_regex(pattern: str, text: str): +def _extract_model_with_regex(pattern: str, text: str) -> Optional[str]: match = re.search(rf"{pattern}='(.*?)'", text) if match: return match.group(1) @@ -184,7 +184,8 @@ def _extract_model_by_path_for_id( and len(serialized_id) > 0 and serialized_id[-1] == id ): - return _extract_model_by_path(serialized, kwargs, keys, select_from) + result = _extract_model_by_path(serialized, kwargs, keys, select_from) + return str(result) if result is not None else None return None @@ -194,7 +195,7 @@ def _extract_model_by_path( kwargs: dict, keys: List[str], select_from: Literal["serialized", "kwargs"], -): +) -> Optional[str]: if serialized is None and select_from == "serialized": return None @@ -208,4 +209,4 @@ def _extract_model_by_path( if not current_obj: return None - return current_obj if current_obj else None + return str(current_obj) if current_obj else None diff --git a/langfuse/media.py b/langfuse/media.py index e0be5d7c5..6691785af 100644 --- a/langfuse/media.py +++ b/langfuse/media.py @@ -106,11 +106,11 @@ def _read_file(self, file_path: str) -> Optional[bytes]: return None - def _get_media_id(self): + def _get_media_id(self) -> Optional[str]: content_hash = self._content_sha256_hash if content_hash is None: - return + return None # Convert hash to base64Url url_safe_content_hash = content_hash.replace("+", "-").replace("/", "_") @@ -187,7 +187,7 @@ def parse_reference_string(reference_string: str) -> ParsedMediaReference: return ParsedMediaReference( media_id=parsed_data["id"], source=parsed_data["source"], - content_type=parsed_data["type"], + content_type=cast(MediaContentType, parsed_data["type"]), ) def _parse_base64_data_uri( @@ -314,8 +314,11 @@ def traverse(obj: Any, depth: int) -> Any: # Do not replace the reference string if there's an error continue - for ref_str, media_content in reference_string_to_media_content.items(): - result = result.replace(ref_str, media_content) + for ( + ref_str, + media_content_str, + ) in reference_string_to_media_content.items(): + result = result.replace(ref_str, media_content_str) return result @@ -336,4 +339,4 @@ def traverse(obj: Any, depth: int) -> Any: return obj - return traverse(obj, 0) + return cast(T, traverse(obj, 0)) diff --git a/langfuse/model.py b/langfuse/model.py index 521f9a82c..d1b5a80cf 100644 --- a/langfuse/model.py +++ b/langfuse/model.py @@ -3,7 +3,6 @@ import re from abc import ABC, abstractmethod from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, TypedDict, Union -from langfuse.logger import langfuse_logger from langfuse.api.resources.commons.types.dataset import ( Dataset, # noqa: F401 @@ -37,7 +36,8 @@ from langfuse.api.resources.datasets.types.create_dataset_request import ( # noqa: F401 CreateDatasetRequest, ) -from langfuse.api.resources.prompts import ChatMessage, Prompt, Prompt_Chat, Prompt_Text +from langfuse.api.resources.prompts import Prompt, Prompt_Chat, Prompt_Text +from langfuse.logger import langfuse_logger class ModelUsage(TypedDict): @@ -161,7 +161,12 @@ def __init__(self, prompt: Prompt, is_fallback: bool = False): self.is_fallback = is_fallback @abstractmethod - def compile(self, **kwargs) -> Union[str, List[ChatMessage]]: + def compile( + self, **kwargs: Union[str, Any] + ) -> Union[ + str, + Sequence[Union[ChatMessageDict, ChatMessageWithPlaceholdersDict_Placeholder]], + ]: pass @property @@ -170,15 +175,15 @@ def variables(self) -> List[str]: pass @abstractmethod - def __eq__(self, other): + def __eq__(self, other: object) -> bool: pass @abstractmethod - def get_langchain_prompt(self): + def get_langchain_prompt(self) -> Any: pass @staticmethod - def _get_langchain_prompt_string(content: str): + def _get_langchain_prompt_string(content: str) -> str: json_escaped_content = BasePromptClient._escape_json_for_langchain(content) return re.sub(r"{{\s*(\w+)\s*}}", r"{\g<1>}", json_escaped_content) @@ -255,7 +260,7 @@ def __init__(self, prompt: Prompt_Text, is_fallback: bool = False): super().__init__(prompt, is_fallback) self.prompt = prompt.prompt - def compile(self, **kwargs) -> str: + def compile(self, **kwargs: Union[str, Any]) -> str: return TemplateParser.compile_template(self.prompt, kwargs) @property @@ -263,8 +268,8 @@ def variables(self) -> List[str]: """Return all the variable names in the prompt template.""" return TemplateParser.find_variable_names(self.prompt) - def __eq__(self, other): - if isinstance(self, other.__class__): + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): return ( self.name == other.name and self.version == other.version @@ -274,7 +279,7 @@ def __eq__(self, other): return False - def get_langchain_prompt(self, **kwargs) -> str: + def get_langchain_prompt(self, **kwargs: Union[str, Any]) -> str: """Convert Langfuse prompt into string compatible with Langchain PromptTemplate. This method adapts the mustache-style double curly braces {{variable}} used in Langfuse @@ -299,7 +304,7 @@ def get_langchain_prompt(self, **kwargs) -> str: class ChatPromptClient(BasePromptClient): def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False): super().__init__(prompt, is_fallback) - self.prompt = [] + self.prompt: List[ChatMessageWithPlaceholdersDict] = [] for p in prompt.prompt: # Handle objects with attributes (normal case) @@ -314,13 +319,14 @@ def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False): self.prompt.append( ChatMessageWithPlaceholdersDict_Message( type="message", - role=p.role, - content=p.content, + role=p.role, # type: ignore + content=p.content, # type: ignore ), ) def compile( - self, **kwargs + self, + **kwargs: Union[str, Any], ) -> Sequence[Union[ChatMessageDict, ChatMessageWithPlaceholdersDict_Placeholder]]: """Compile the prompt with placeholders and variables. @@ -331,20 +337,23 @@ def compile( Returns: List of compiled chat messages as plain dictionaries, with unresolved placeholders kept as-is. """ - compiled_messages = [] - unresolved_placeholders = [] + compiled_messages: List[ + Union[ChatMessageDict, ChatMessageWithPlaceholdersDict_Placeholder] + ] = [] + unresolved_placeholders: List[ChatMessageWithPlaceholdersDict_Placeholder] = [] for chat_message in self.prompt: if chat_message["type"] == "message": # For regular messages, compile variables and add to output + message_obj = chat_message # type: ignore compiled_messages.append( - { - "role": chat_message["role"], - "content": TemplateParser.compile_template( - chat_message["content"], + ChatMessageDict( + role=message_obj["role"], # type: ignore + content=TemplateParser.compile_template( + message_obj["content"], # type: ignore kwargs, ), - }, + ), ) elif chat_message["type"] == "placeholder": placeholder_name = chat_message["name"] @@ -358,36 +367,42 @@ def compile( and "content" in msg ): compiled_messages.append( - { - "role": msg["role"], - "content": TemplateParser.compile_template( - msg["content"], + ChatMessageDict( + role=msg["role"], # type: ignore + content=TemplateParser.compile_template( + msg["content"], # type: ignore kwargs, ), - }, + ), ) else: compiled_messages.append( - str(placeholder_value), + ChatMessageDict( + role="NOT_GIVEN", + content=str(placeholder_value), + ) ) no_role_content_in_placeholder = f"Placeholder '{placeholder_name}' should contain a list of chat messages with 'role' and 'content' fields. Appended as string." langfuse_logger.warning(no_role_content_in_placeholder) else: compiled_messages.append( - str(placeholder_value), + ChatMessageDict( + role="NOT_GIVEN", + content=str(placeholder_value), + ), ) placeholder_not_a_list = f"Placeholder '{placeholder_name}' must contain a list of chat messages, got {type(placeholder_value)}" langfuse_logger.warning(placeholder_not_a_list) else: # Keep unresolved placeholder in the compiled messages compiled_messages.append(chat_message) - unresolved_placeholders.append(placeholder_name) + unresolved_placeholders.append(chat_message["name"]) # type: ignore if unresolved_placeholders: - unresolved_placeholders = f"Placeholders {unresolved_placeholders} have not been resolved. Pass them as keyword arguments to compile()." - langfuse_logger.warning(unresolved_placeholders) + unresolved_placeholders_message = f"Placeholders {unresolved_placeholders} have not been resolved. Pass them as keyword arguments to compile()." + langfuse_logger.warning(unresolved_placeholders_message) - return compiled_messages + return compiled_messages # type: ignore @property def variables(self) -> List[str]: @@ -401,8 +416,8 @@ def variables(self) -> List[str]: ) return variables - def __eq__(self, other): - if isinstance(self, other.__class__): + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): return ( self.name == other.name and self.version == other.version @@ -429,7 +444,9 @@ def __eq__(self, other): return False - def get_langchain_prompt(self, **kwargs): + def get_langchain_prompt( + self, **kwargs: Union[str, Any] + ) -> List[Union[Tuple[str, str], Any]]: """Convert Langfuse prompt into string compatible with Langchain ChatPromptTemplate. It specifically adapts the mustache-style double curly braces {{variable}} used in Langfuse @@ -447,12 +464,12 @@ def get_langchain_prompt(self, **kwargs): (role, content) tuples for regular messages or MessagesPlaceholder objects for unresolved placeholders. """ compiled_messages = self.compile(**kwargs) - langchain_messages = [] + langchain_messages: List[Union[Tuple[str, str], Any]] = [] for msg in compiled_messages: - if "type" in msg and msg["type"] == "placeholder": + if isinstance(msg, dict) and "type" in msg and msg["type"] == "placeholder": # type: ignore # unresolved placeholder -> add LC MessagesPlaceholder - placeholder_name = msg["name"] + placeholder_name = msg["name"] # type: ignore try: from langchain_core.prompts.chat import MessagesPlaceholder # noqa: PLC0415, I001 @@ -463,9 +480,13 @@ def get_langchain_prompt(self, **kwargs): import_error = "langchain_core is required to use get_langchain_prompt() with unresolved placeholders." raise ImportError(import_error) from e else: - langchain_messages.append( - (msg["role"], self._get_langchain_prompt_string(msg["content"])), - ) + if isinstance(msg, dict) and "role" in msg and "content" in msg: + langchain_messages.append( + ( + msg["role"], # type: ignore + self._get_langchain_prompt_string(msg["content"]), # type: ignore + ), + ) return langchain_messages diff --git a/langfuse/openai.py b/langfuse/openai.py index ea0cd186a..8b718f12a 100644 --- a/langfuse/openai.py +++ b/langfuse/openai.py @@ -21,8 +21,9 @@ import types from collections import defaultdict from dataclasses import dataclass +from datetime import datetime from inspect import isclass -from typing import Optional, cast +from typing import Optional, cast, Any from openai._types import NotGiven from packaging.version import Version @@ -44,10 +45,10 @@ try: from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI # noqa: F401 except ImportError: - AsyncAzureOpenAI = None - AsyncOpenAI = None - AzureOpenAI = None - OpenAI = None + AsyncAzureOpenAI = None # type: ignore + AsyncOpenAI = None # type: ignore + AzureOpenAI = None # type: ignore + OpenAI = None # type: ignore log = logging.getLogger("langfuse") @@ -166,14 +167,16 @@ class OpenAiDefinition: class OpenAiArgsExtractor: def __init__( self, - metadata=None, - name=None, - langfuse_prompt=None, # we cannot use prompt because it's an argument of the old OpenAI completions API - langfuse_public_key=None, - trace_id=None, - parent_observation_id=None, - **kwargs, - ): + metadata: Optional[Any] = None, + name: Optional[str] = None, + langfuse_prompt: Optional[ + Any + ] = None, # we cannot use prompt because it's an argument of the old OpenAI completions API + langfuse_public_key: Optional[str] = None, + trace_id: Optional[str] = None, + parent_observation_id: Optional[str] = None, + **kwargs: Any, + ) -> None: self.args = {} self.args["metadata"] = ( metadata @@ -194,10 +197,10 @@ def __init__( self.kwargs = kwargs - def get_langfuse_args(self): + def get_langfuse_args(self) -> Any: return {**self.args, **self.kwargs} - def get_openai_args(self): + def get_openai_args(self) -> Any: # If OpenAI model distillation is enabled, we need to add the metadata to the kwargs # https://platform.openai.com/docs/guides/distillation if self.kwargs.get("store", False): @@ -212,9 +215,9 @@ def get_openai_args(self): return self.kwargs -def _langfuse_wrapper(func): - def _with_langfuse(open_ai_definitions): - def wrapper(wrapped, instance, args, kwargs): +def _langfuse_wrapper(func: Any) -> Any: + def _with_langfuse(open_ai_definitions: Any) -> Any: + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: return func(open_ai_definitions, wrapped, args, kwargs) return wrapper @@ -222,7 +225,7 @@ def wrapper(wrapped, instance, args, kwargs): return _with_langfuse -def _extract_chat_prompt(kwargs: any): +def _extract_chat_prompt(kwargs: Any) -> Any: """Extracts the user input from prompts. Returns an array of messages or dict with messages and functions""" prompt = {} @@ -250,7 +253,7 @@ def _extract_chat_prompt(kwargs: any): return [_process_message(message) for message in kwargs.get("messages", [])] -def _process_message(message): +def _process_message(message: Any) -> Any: if not isinstance(message, dict): return message @@ -287,7 +290,7 @@ def _process_message(message): return processed_message -def _extract_chat_response(kwargs: any): +def _extract_chat_response(kwargs: Any) -> Any: """Extracts the llm output from the response.""" response = { "role": kwargs.get("role", None), @@ -320,7 +323,7 @@ def _extract_chat_response(kwargs: any): return response -def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs): +def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs: Any) -> Any: name = kwargs.get("name", "OpenAI-generation") if name is None: @@ -445,13 +448,13 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, kwargs): def _create_langfuse_update( - completion, + completion: Any, generation: LangfuseGeneration, - completion_start_time, - model=None, - usage=None, - metadata=None, -): + completion_start_time: Any, + model: Optional[str] = None, + usage: Optional[Any] = None, + metadata: Optional[Any] = None, +) -> Any: update = { "output": completion, "completion_start_time": completion_start_time, @@ -468,7 +471,7 @@ def _create_langfuse_update( generation.update(**update) -def _parse_usage(usage=None): +def _parse_usage(usage: Optional[Any] = None) -> Any: if usage is None: return @@ -493,7 +496,7 @@ def _parse_usage(usage=None): return usage_dict -def _extract_streamed_response_api_response(chunks): +def _extract_streamed_response_api_response(chunks: Any) -> Any: completion, model, usage = None, None, None metadata = {} @@ -520,8 +523,8 @@ def _extract_streamed_response_api_response(chunks): return (model, completion, usage, metadata) -def _extract_streamed_openai_response(resource, chunks): - completion = defaultdict(str) if resource.type == "chat" else "" +def _extract_streamed_openai_response(resource: Any, chunks: Any) -> Any: + completion: Any = defaultdict(lambda: None) if resource.type == "chat" else "" model, usage = None, None for chunk in chunks: @@ -602,7 +605,7 @@ def _extract_streamed_openai_response(resource, chunks): if resource.type == "completion": completion += choice.get("text", "") - def get_response_for_chat(): + def get_response_for_chat() -> Any: return ( completion["content"] or ( @@ -633,7 +636,9 @@ def get_response_for_chat(): ) -def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response): +def _get_langfuse_data_from_default_response( + resource: OpenAiDefinition, response: Any +) -> Any: if response is None: return None, "", None @@ -682,11 +687,11 @@ def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, respons return (model, completion, usage) -def _is_openai_v1(): +def _is_openai_v1() -> bool: return Version(openai.__version__) >= Version("1.0.0") -def _is_streaming_response(response): +def _is_streaming_response(response: Any) -> bool: return ( isinstance(response, types.GeneratorType) or isinstance(response, types.AsyncGeneratorType) @@ -696,7 +701,9 @@ def _is_streaming_response(response): @_langfuse_wrapper -def _wrap(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs): +def _wrap( + open_ai_resource: OpenAiDefinition, wrapped: Any, args: Any, kwargs: Any +) -> Any: arg_extractor = OpenAiArgsExtractor(*args, **kwargs) langfuse_args = arg_extractor.get_langfuse_args() @@ -757,7 +764,9 @@ def _wrap(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs): @_langfuse_wrapper -async def _wrap_async(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs): +async def _wrap_async( + open_ai_resource: OpenAiDefinition, wrapped: Any, args: Any, kwargs: Any +) -> Any: arg_extractor = OpenAiArgsExtractor(*args, **kwargs) langfuse_args = arg_extractor.get_langfuse_args() @@ -817,7 +826,7 @@ async def _wrap_async(open_ai_resource: OpenAiDefinition, wrapped, args, kwargs) raise ex -def register_tracing(): +def register_tracing() -> None: resources = OPENAI_METHODS_V1 if _is_openai_v1() else OPENAI_METHODS_V0 for resource in resources: @@ -845,18 +854,18 @@ class LangfuseResponseGeneratorSync: def __init__( self, *, - resource, - response, - generation, - ): - self.items = [] + resource: Any, + response: Any, + generation: Any, + ) -> None: + self.items: list[Any] = [] self.resource = resource self.response = response self.generation = generation - self.completion_start_time = None + self.completion_start_time: Optional[datetime] = None - def __iter__(self): + def __iter__(self) -> Any: try: for i in self.response: self.items.append(i) @@ -868,7 +877,7 @@ def __iter__(self): finally: self._finalize() - def __next__(self): + def __next__(self) -> Any: try: item = self.response.__next__() self.items.append(item) @@ -883,13 +892,13 @@ def __next__(self): raise - def __enter__(self): + def __enter__(self) -> Any: return self.__iter__() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: pass - def _finalize(self): + def _finalize(self) -> None: try: model, completion, usage, metadata = ( _extract_streamed_response_api_response(self.items) @@ -915,18 +924,18 @@ class LangfuseResponseGeneratorAsync: def __init__( self, *, - resource, - response, - generation, - ): - self.items = [] + resource: Any, + response: Any, + generation: Any, + ) -> None: + self.items: list[Any] = [] self.resource = resource self.response = response self.generation = generation - self.completion_start_time = None + self.completion_start_time: Optional[datetime] = None - async def __aiter__(self): + async def __aiter__(self) -> Any: try: async for i in self.response: self.items.append(i) @@ -938,7 +947,7 @@ async def __aiter__(self): finally: await self._finalize() - async def __anext__(self): + async def __anext__(self) -> Any: try: item = await self.response.__anext__() self.items.append(item) @@ -953,13 +962,13 @@ async def __anext__(self): raise - async def __aenter__(self): + async def __aenter__(self) -> Any: return self.__aiter__() - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: pass - async def _finalize(self): + async def _finalize(self) -> None: try: model, completion, usage, metadata = ( _extract_streamed_response_api_response(self.items) diff --git a/langfuse/py.typed b/langfuse/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/poetry.lock b/poetry.lock index 74c2c55d7..58bc625bd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2848,6 +2848,60 @@ files = [ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] +[[package]] +name = "mypy" +version = "1.16.1" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "mypy-1.16.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b4f0fed1022a63c6fec38f28b7fc77fca47fd490445c69d0a66266c59dd0b88a"}, + {file = "mypy-1.16.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:86042bbf9f5a05ea000d3203cf87aa9d0ccf9a01f73f71c58979eb9249f46d72"}, + {file = "mypy-1.16.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ea7469ee5902c95542bea7ee545f7006508c65c8c54b06dc2c92676ce526f3ea"}, + {file = "mypy-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:352025753ef6a83cb9e7f2427319bb7875d1fdda8439d1e23de12ab164179574"}, + {file = "mypy-1.16.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ff9fa5b16e4c1364eb89a4d16bcda9987f05d39604e1e6c35378a2987c1aac2d"}, + {file = "mypy-1.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:1256688e284632382f8f3b9e2123df7d279f603c561f099758e66dd6ed4e8bd6"}, + {file = "mypy-1.16.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:472e4e4c100062488ec643f6162dd0d5208e33e2f34544e1fc931372e806c0cc"}, + {file = "mypy-1.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea16e2a7d2714277e349e24d19a782a663a34ed60864006e8585db08f8ad1782"}, + {file = "mypy-1.16.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08e850ea22adc4d8a4014651575567b0318ede51e8e9fe7a68f25391af699507"}, + {file = "mypy-1.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22d76a63a42619bfb90122889b903519149879ddbf2ba4251834727944c8baca"}, + {file = "mypy-1.16.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2c7ce0662b6b9dc8f4ed86eb7a5d505ee3298c04b40ec13b30e572c0e5ae17c4"}, + {file = "mypy-1.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:211287e98e05352a2e1d4e8759c5490925a7c784ddc84207f4714822f8cf99b6"}, + {file = "mypy-1.16.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:af4792433f09575d9eeca5c63d7d90ca4aeceda9d8355e136f80f8967639183d"}, + {file = "mypy-1.16.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:66df38405fd8466ce3517eda1f6640611a0b8e70895e2a9462d1d4323c5eb4b9"}, + {file = "mypy-1.16.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:44e7acddb3c48bd2713994d098729494117803616e116032af192871aed80b79"}, + {file = "mypy-1.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ab5eca37b50188163fa7c1b73c685ac66c4e9bdee4a85c9adac0e91d8895e15"}, + {file = "mypy-1.16.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dedb6229b2c9086247e21a83c309754b9058b438704ad2f6807f0d8227f6ebdd"}, + {file = "mypy-1.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:1f0435cf920e287ff68af3d10a118a73f212deb2ce087619eb4e648116d1fe9b"}, + {file = "mypy-1.16.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ddc91eb318c8751c69ddb200a5937f1232ee8efb4e64e9f4bc475a33719de438"}, + {file = "mypy-1.16.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:87ff2c13d58bdc4bbe7dc0dedfe622c0f04e2cb2a492269f3b418df2de05c536"}, + {file = "mypy-1.16.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a7cfb0fe29fe5a9841b7c8ee6dffb52382c45acdf68f032145b75620acfbd6f"}, + {file = "mypy-1.16.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:051e1677689c9d9578b9c7f4d206d763f9bbd95723cd1416fad50db49d52f359"}, + {file = "mypy-1.16.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d5d2309511cc56c021b4b4e462907c2b12f669b2dbeb68300110ec27723971be"}, + {file = "mypy-1.16.1-cp313-cp313-win_amd64.whl", hash = "sha256:4f58ac32771341e38a853c5d0ec0dfe27e18e27da9cdb8bbc882d2249c71a3ee"}, + {file = "mypy-1.16.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7fc688329af6a287567f45cc1cefb9db662defeb14625213a5b7da6e692e2069"}, + {file = "mypy-1.16.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e198ab3f55924c03ead626ff424cad1732d0d391478dfbf7bb97b34602395da"}, + {file = "mypy-1.16.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09aa4f91ada245f0a45dbc47e548fd94e0dd5a8433e0114917dc3b526912a30c"}, + {file = "mypy-1.16.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13c7cd5b1cb2909aa318a90fd1b7e31f17c50b242953e7dd58345b2a814f6383"}, + {file = "mypy-1.16.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:58e07fb958bc5d752a280da0e890c538f1515b79a65757bbdc54252ba82e0b40"}, + {file = "mypy-1.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:f895078594d918f93337a505f8add9bd654d1a24962b4c6ed9390e12531eb31b"}, + {file = "mypy-1.16.1-py3-none-any.whl", hash = "sha256:5fc2ac4027d0ef28d6ba69a0343737a23c4d1b83672bf38d1fe237bdc0643b37"}, + {file = "mypy-1.16.1.tar.gz", hash = "sha256:6bd00a0a2094841c5e47e7374bb42b83d64c527a502e3334e1173a0c24437bab"}, +] + +[package.dependencies] +mypy_extensions = ">=1.0.0" +pathspec = ">=0.9.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing_extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -3432,6 +3486,17 @@ files = [ [package.extras] dev = ["jinja2"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pdoc" version = "14.6.0" @@ -5564,4 +5629,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "d0353d0579f11dc634c6da45be88e85310fd9f1172500b9670b2cbc428370eb9" +content-hash = "acdcf5642aba80585f46a67189b0ae2931af42d63551ce3061c09353d6dbf230" diff --git a/pyproject.toml b/pyproject.toml index ee21a4a78..2b7fc5215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ pytest-asyncio = ">=0.21.1,<0.24.0" pytest-httpserver = "^1.0.8" boto3 = "^1.28.59" ruff = ">=0.1.8,<0.6.0" +mypy = "^1.0.0" langchain-mistralai = ">=0.0.1,<0.3" google-cloud-aiplatform = "^1.38.1" cohere = ">=4.46,<6.0" @@ -72,5 +73,61 @@ log_cli = true [tool.poetry_bumpversion.file."langfuse/version.py"] +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = false +warn_no_return = true +warn_unreachable = true +strict_equality = true +show_error_codes = true + +# Performance optimizations for CI +cache_dir = ".mypy_cache" +sqlite_cache = true +incremental = true +show_column_numbers = true + +[[tool.mypy.overrides]] +module = [ + "langchain.*", + "openai.*", + "chromadb.*", + "tiktoken.*", + "google.*", + "anthropic.*", + "cohere.*", + "dashscope.*", + "pymongo.*", + "bson.*", + "boto3.*", + "llama_index.*", + "respx.*", + "bs4.*", + "lark.*", + "huggingface_hub.*", + "backoff.*", + "wrapt.*", + "packaging.*", + "requests.*", + "opentelemetry.*" +] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = [ + "langfuse.api.resources.*", + "langfuse.api.core.*", + "langfuse.api.client" +] +ignore_errors = true + [tool.poetry.scripts] release = "scripts.release:main"