Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/promptfoo/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import platform
import sys
import threading
import uuid
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -82,9 +83,10 @@ def _write_global_config(config: dict[str, Any]) -> None:
pass # Silently fail - telemetry should never break the CLI


def _get_user_id() -> str:
def _get_user_id(config: dict[str, Any] | None = None) -> str:
"""Get or create a unique user ID stored in the global config."""
config = _read_global_config()
if config is None:
config = _read_global_config()
user_id = config.get("id")

if not user_id:
Expand All @@ -95,9 +97,10 @@ def _get_user_id() -> str:
return user_id


def _get_user_email() -> str | None:
def _get_user_email(config: dict[str, Any] | None = None) -> str | None:
"""Get the user email from the global config if set."""
config = _read_global_config()
if config is None:
config = _read_global_config()
account = config.get("account", {})
return account.get("email") if isinstance(account, dict) else None

Expand Down Expand Up @@ -127,8 +130,9 @@ def _ensure_initialized(self) -> None:
return

try:
self._user_id = _get_user_id()
self._email = _get_user_email()
config = _read_global_config()
self._user_id = _get_user_id(config)
self._email = _get_user_email(config)
self._client = Posthog(
project_api_key=_POSTHOG_KEY,
host=_POSTHOG_HOST,
Expand Down Expand Up @@ -182,15 +186,20 @@ def shutdown(self) -> None:

# Global singleton instance
_telemetry: _Telemetry | None = None
_telemetry_lock = threading.Lock()


def _get_telemetry() -> _Telemetry:
"""Get the global telemetry instance."""
global _telemetry
if _telemetry is None:
_telemetry = _Telemetry()
atexit.register(_telemetry.shutdown)
return _telemetry
if _telemetry is not None:
return _telemetry

with _telemetry_lock:
if _telemetry is None:
_telemetry = _Telemetry()
atexit.register(_telemetry.shutdown)
return _telemetry


def record_wrapper_used(method: str) -> None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def test_read_probe_file_returns_none_when_missing(self, tmp_path: Path) -> None
"""Missing probe files return None."""
assert _read_probe_file(tmp_path / "missing") is None

def test_read_probe_file_returns_content_when_readable(self, tmp_path: Path) -> None:
"""Readable probe files return their text content."""
probe_file = tmp_path / "probe"
probe_file.write_text("value")

assert _read_probe_file(probe_file) == "value"

def test_read_probe_file_returns_none_when_unreadable(self, tmp_path: Path) -> None:
"""Unreadable probe files return None instead of raising."""
probe_file = tmp_path / "probe"
Expand Down Expand Up @@ -245,6 +252,7 @@ def test_detect_kubernetes_from_env(self, monkeypatch: pytest.MonkeyPatch) -> No
mock_path.return_value.exists.return_value = False

is_docker, is_k8s = _detect_container()
assert is_docker is False
assert is_k8s is True

def test_detect_container_returns_tuple(self) -> None:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import os
import threading
from pathlib import Path
from unittest import mock

Expand All @@ -17,11 +18,13 @@
from promptfoo.telemetry import (
_get_config_dir,
_get_env_bool,
_get_telemetry,
_get_user_email,
_get_user_id,
_is_ci,
_read_global_config,
_Telemetry,
_telemetry_lock,
_write_global_config,
record_wrapper_used,
)
Expand Down Expand Up @@ -268,6 +271,23 @@ def test_record_initializes_client(self, tmp_path: Path) -> None:
assert telemetry._client is mock_client
mock_client.capture.assert_called_once()

def test_initialization_reads_global_config_once(self) -> None:
"""Initialization shares one config read across user identity lookups."""
config = {"id": "test-user-id", "account": {"email": "test@example.com"}}

with (
mock.patch.dict(os.environ, {}, clear=True),
mock.patch("promptfoo.telemetry._read_global_config", return_value=config) as mock_read_config,
mock.patch("promptfoo.telemetry.Posthog") as mock_posthog,
):
telemetry = _Telemetry()
telemetry._ensure_initialized()

mock_read_config.assert_called_once_with()
assert telemetry._user_id == "test-user-id"
assert telemetry._email == "test@example.com"
mock_posthog.assert_called_once()

def test_record_enriches_properties(self, tmp_path: Path) -> None:
"""Test record adds enriched properties."""
config_file = tmp_path / "promptfoo.yaml"
Expand Down Expand Up @@ -471,3 +491,38 @@ def test_record_wrapper_used_disabled(self, monkeypatch: pytest.MonkeyPatch) ->
with mock.patch("promptfoo.telemetry._telemetry", None):
# Should not raise or make any calls
record_wrapper_used("global")

def test_get_telemetry_guards_singleton_initialization_with_lock(self) -> None:
"""Singleton construction waits on its lock and registers shutdown once."""
started = threading.Event()
finished = threading.Event()
instance = mock.Mock(spec=_Telemetry)
results: list[_Telemetry] = []

def initialize() -> None:
started.set()
results.append(_get_telemetry())
finished.set()

with (
mock.patch("promptfoo.telemetry._telemetry", None),
mock.patch("promptfoo.telemetry._Telemetry", return_value=instance) as mock_telemetry,
mock.patch("promptfoo.telemetry.atexit.register") as mock_register,
):
_telemetry_lock.acquire()
try:
worker = threading.Thread(target=initialize)
worker.start()
assert started.wait(timeout=1)
assert finished.wait(timeout=0.05) is False
mock_telemetry.assert_not_called()
finally:
_telemetry_lock.release()

worker.join(timeout=1)
assert worker.is_alive() is False

assert results == [instance]
assert _get_telemetry() is instance
mock_telemetry.assert_called_once_with()
mock_register.assert_called_once_with(instance.shutdown)
Loading