From 07eb0fb3af32ae4c2a435ce0c593c8493a610295 Mon Sep 17 00:00:00 2001 From: Adeeb Shihadeh Date: Sun, 21 Jun 2026 13:55:14 -0700 Subject: [PATCH 1/3] merge apis --- common/api.py | 68 +++++++++++++++++++++++++++---- common/tests/test_api.py | 78 ++++++++++++++++++++++++++++++++++++ tools/lib/api.py | 40 ------------------ tools/lib/auth.py | 2 +- tools/lib/bootlog.py | 2 +- tools/lib/file_downloader.py | 2 +- tools/lib/route.py | 2 +- tools/scripts/ssh.py | 2 +- 8 files changed, 143 insertions(+), 53 deletions(-) create mode 100644 common/tests/test_api.py delete mode 100644 tools/lib/api.py diff --git a/common/api.py b/common/api.py index c97f56c4b5cb71..8fc7ed29904dbc 100644 --- a/common/api.py +++ b/common/api.py @@ -1,11 +1,15 @@ -import jwt import os +from datetime import UTC, datetime, timedelta + +import jwt import requests -from datetime import datetime, timedelta, UTC +from requests.adapters import HTTPAdapter, Retry + from openpilot.common.hardware.hw import Paths from openpilot.common.version import get_version API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com') +TOOLS_USER_AGENT = 'OpenpilotTools' # name: jwt signature algorithm KEYS = {"id_rsa": "RS256", @@ -13,8 +17,9 @@ class Api: - def __init__(self, dongle_id): + def __init__(self, dongle_id, session=None): self.dongle_id = dongle_id + self.session = session self.jwt_algorithm, self.private_key, _ = get_key_pair() def get(self, *args, **kwargs): @@ -23,8 +28,10 @@ def get(self, *args, **kwargs): def post(self, *args, **kwargs): return self.request('POST', *args, **kwargs) - def request(self, method, endpoint, timeout=None, access_token=None, **params): - return api_get(endpoint, method=method, timeout=timeout, access_token=access_token, **params) + def request(self, method, endpoint, timeout=None, access_token=None, session=None, **params): + session = self.session if session is None else session + return api_get(endpoint, method=method, timeout=timeout, access_token=access_token, + session=session, **params) def get_token(self, payload_extra=None, expiry_hours=1): now = datetime.now(UTC).replace(tzinfo=None) @@ -42,6 +49,53 @@ def get_token(self, payload_extra=None, expiry_hours=1): return token +class CommaApi: + def __init__(self, token=None, session=None): + self.session = session or requests.Session() + self.session.headers['User-Agent'] = TOOLS_USER_AGENT + if token: + self.session.headers['Authorization'] = 'JWT ' + token + + if session is None: + retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) + self.session.mount('https://', HTTPAdapter(max_retries=retries)) + + def request(self, method, endpoint, **kwargs): + with _request(method, endpoint, session=self.session, **kwargs) as resp: + resp_json = resp.json() + if isinstance(resp_json, dict) and resp_json.get('error'): + if resp.status_code in [401, 403]: + raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py', resp.status_code) + + raise APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])), resp.status_code) + return resp_json + + def get(self, endpoint, **kwargs): + return self.request('GET', endpoint, **kwargs) + + def post(self, endpoint, **kwargs): + return self.request('POST', endpoint, **kwargs) + + +class APIError(Exception): + def __init__(self, message, status_code=None): + super().__init__(message) + self.status_code = status_code + + +class UnauthorizedError(APIError): + pass + + +def _api_url(endpoint): + return API_HOST + "/" + endpoint.lstrip("/") + + +def _request(method, endpoint, session=None, **kwargs): + req = requests if session is None else session + return req.request(method, _api_url(endpoint), **kwargs) + + def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params): headers = {} if access_token is not None: @@ -49,9 +103,7 @@ def api_get(endpoint, method='GET', timeout=None, access_token=None, session=Non headers['User-Agent'] = "openpilot-" + get_version() - # TODO: add session to Api - req = requests if session is None else session - return req.request(method, API_HOST + "/" + endpoint, timeout=timeout, headers=headers, params=params) + return _request(method, endpoint, session=session, timeout=timeout, headers=headers, params=params) def get_key_pair() -> tuple[str, str, str] | tuple[None, None, None]: diff --git a/common/tests/test_api.py b/common/tests/test_api.py new file mode 100644 index 00000000000000..752d1820cf16ec --- /dev/null +++ b/common/tests/test_api.py @@ -0,0 +1,78 @@ +import pytest + +from openpilot.common import api as api_module +from openpilot.common.api import APIError, CommaApi, UnauthorizedError, api_get + + +class FakeResponse: + def __init__(self, payload, status_code=200): + self.payload = payload + self.status_code = status_code + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def json(self): + return self.payload + + +class FakeSession: + def __init__(self, response): + self.headers = {} + self.response = response + self.requests = [] + self.mounts = [] + + def mount(self, prefix, adapter): + self.mounts.append((prefix, adapter)) + + def request(self, method, url, **kwargs): + self.requests.append((method, url, kwargs)) + return self.response + + +def test_api_get_returns_raw_response(monkeypatch): + monkeypatch.setattr(api_module, "API_HOST", "https://api.example") + monkeypatch.setattr(api_module, "get_version", lambda: "test") + response = object() + session = FakeSession(response) + + assert api_get("/v1/me", access_token="token", session=session, timeout=5, foo="bar") is response + assert session.requests == [( + "GET", + "https://api.example/v1/me", + {"timeout": 5, "headers": {"Authorization": "JWT token", "User-Agent": "openpilot-test"}, "params": {"foo": "bar"}}, + )] + + +def test_comma_api_returns_json(monkeypatch): + monkeypatch.setattr(api_module, "API_HOST", "https://api.example") + session = FakeSession(FakeResponse({"ok": True})) + + assert CommaApi("token", session=session).get("/v1/me", timeout=5) == {"ok": True} + assert session.headers == {"User-Agent": "OpenpilotTools", "Authorization": "JWT token"} + assert session.requests == [("GET", "https://api.example/v1/me", {"timeout": 5})] + + +def test_comma_api_raises_api_error(): + session = FakeSession(FakeResponse({"error": "missing", "description": "not found"}, status_code=404)) + + with pytest.raises(APIError) as exc: + CommaApi(session=session).get("v1/route/missing") + + assert str(exc.value) == "404:not found" + assert exc.value.status_code == 404 + + +@pytest.mark.parametrize("status_code", [401, 403]) +def test_comma_api_raises_unauthorized(status_code): + session = FakeSession(FakeResponse({"error": "auth"}, status_code=status_code)) + + with pytest.raises(UnauthorizedError) as exc: + CommaApi(session=session).get("v1/me") + + assert str(exc.value) == "Unauthorized. Authenticate with tools/lib/auth.py" + assert exc.value.status_code == status_code diff --git a/tools/lib/api.py b/tools/lib/api.py deleted file mode 100644 index f84fe758695fc3..00000000000000 --- a/tools/lib/api.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -import requests -from requests.adapters import HTTPAdapter, Retry -API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com') - -# TODO: this should be merged into common.api - -class CommaApi: - def __init__(self, token=None): - self.session = requests.Session() - self.session.headers['User-agent'] = 'OpenpilotTools' - if token: - self.session.headers['Authorization'] = 'JWT ' + token - - retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) - self.session.mount('https://', HTTPAdapter(max_retries=retries)) - - def request(self, method, endpoint, **kwargs): - with self.session.request(method, API_HOST + '/' + endpoint, **kwargs) as resp: - resp_json = resp.json() - if isinstance(resp_json, dict) and resp_json.get('error'): - if resp.status_code in [401, 403]: - raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py') - - e = APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error']))) - e.status_code = resp.status_code - raise e - return resp_json - - def get(self, endpoint, **kwargs): - return self.request('GET', endpoint, **kwargs) - - def post(self, endpoint, **kwargs): - return self.request('POST', endpoint, **kwargs) - -class APIError(Exception): - pass - -class UnauthorizedError(Exception): - pass diff --git a/tools/lib/auth.py b/tools/lib/auth.py index 5988397d0a94e2..8b36e64837a89a 100755 --- a/tools/lib/auth.py +++ b/tools/lib/auth.py @@ -29,7 +29,7 @@ from typing import Any from urllib.parse import parse_qs, urlencode -from openpilot.tools.lib.api import APIError, CommaApi, UnauthorizedError +from openpilot.common.api import APIError, CommaApi, UnauthorizedError from openpilot.tools.lib.auth_config import set_token, get_token PORT = 3000 diff --git a/tools/lib/bootlog.py b/tools/lib/bootlog.py index 208ddc19c228fb..32dcb96784d866 100644 --- a/tools/lib/bootlog.py +++ b/tools/lib/bootlog.py @@ -2,7 +2,7 @@ import re from openpilot.tools.lib.auth_config import get_token -from openpilot.tools.lib.api import CommaApi +from openpilot.common.api import CommaApi from openpilot.tools.lib.helpers import RE diff --git a/tools/lib/file_downloader.py b/tools/lib/file_downloader.py index 68061b201eea80..86fa590bd6fd95 100755 --- a/tools/lib/file_downloader.py +++ b/tools/lib/file_downloader.py @@ -18,7 +18,7 @@ import shutil from openpilot.common.hardware.hw import Paths -from openpilot.tools.lib.api import CommaApi, UnauthorizedError, APIError +from openpilot.common.api import APIError, CommaApi, UnauthorizedError from openpilot.tools.lib.auth_config import get_token from openpilot.tools.lib.url_file import URLFile diff --git a/tools/lib/route.py b/tools/lib/route.py index 98334a06c86a6d..4710987633a020 100644 --- a/tools/lib/route.py +++ b/tools/lib/route.py @@ -7,7 +7,7 @@ from itertools import chain from openpilot.tools.lib.auth_config import get_token -from openpilot.tools.lib.api import APIError, CommaApi +from openpilot.common.api import APIError, CommaApi from openpilot.tools.lib.helpers import RE diff --git a/tools/scripts/ssh.py b/tools/scripts/ssh.py index 33eac4081d1869..b59b3dcdd573d9 100755 --- a/tools/scripts/ssh.py +++ b/tools/scripts/ssh.py @@ -6,7 +6,7 @@ from openpilot.common.basedir import BASEDIR from openpilot.tools.lib.auth_config import get_token -from openpilot.tools.lib.api import CommaApi +from openpilot.common.api import CommaApi if __name__ == "__main__": parser = argparse.ArgumentParser(description="A helper for connecting to devices over the comma prime SSH proxy.\ From 9201244327a1a72c97c433acf344a07888aec26b Mon Sep 17 00:00:00 2001 From: Adeeb Shihadeh Date: Sun, 21 Jun 2026 14:00:21 -0700 Subject: [PATCH 2/3] cleanup --- common/api.py | 106 ++++++++++++++++++++++------------- common/tests/test_api.py | 78 -------------------------- tools/lib/auth.py | 6 +- tools/lib/bootlog.py | 6 +- tools/lib/file_downloader.py | 12 ++-- tools/lib/route.py | 14 ++--- tools/scripts/ssh.py | 4 +- 7 files changed, 89 insertions(+), 137 deletions(-) delete mode 100644 common/tests/test_api.py diff --git a/common/api.py b/common/api.py index 8fc7ed29904dbc..a4c980fbc6a4d5 100644 --- a/common/api.py +++ b/common/api.py @@ -17,10 +17,21 @@ class Api: - def __init__(self, dongle_id, session=None): + def __init__(self, dongle_id=None, token=None, session=None, user_agent=None, retry=False): self.dongle_id = dongle_id + self.token = token + self.user_agent = user_agent or default_user_agent() self.session = session - self.jwt_algorithm, self.private_key, _ = get_key_pair() + self.jwt_algorithm, self.private_key, _ = get_key_pair() if dongle_id is not None else (None, None, None) + + if retry and session is None: + self.session = requests.Session() + retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) + self.session.mount('https://', HTTPAdapter(max_retries=retries)) + + @classmethod + def from_user_token(cls, token=None, session=None): + return cls(token=token, session=session, user_agent=TOOLS_USER_AGENT, retry=True) def get(self, *args, **kwargs): return self.request('GET', *args, **kwargs) @@ -30,10 +41,31 @@ def post(self, *args, **kwargs): def request(self, method, endpoint, timeout=None, access_token=None, session=None, **params): session = self.session if session is None else session - return api_get(endpoint, method=method, timeout=timeout, access_token=access_token, - session=session, **params) + access_token = self.token if access_token is None else access_token + return request_api(method, endpoint, session=session, timeout=timeout, + access_token=access_token, user_agent=self.user_agent, params=params) + + def get_json(self, *args, **kwargs): + return self.request_json('GET', *args, **kwargs) + + def post_json(self, *args, **kwargs): + return self.request_json('POST', *args, **kwargs) + + def request_json(self, method, endpoint, timeout=None, access_token=None, session=None, params=None, **kwargs): + session = self.session if session is None else session + access_token = self.token if access_token is None else access_token + with request_api(method, endpoint, session=session, timeout=timeout, access_token=access_token, + user_agent=self.user_agent, params=params, **kwargs) as resp: + return parse_api_response(resp) def get_token(self, payload_extra=None, expiry_hours=1): + if self.dongle_id is None: + raise ValueError("dongle_id is required to generate an API token") + if self.jwt_algorithm is None or self.private_key is None: + self.jwt_algorithm, self.private_key, _ = get_key_pair() + if self.jwt_algorithm is None or self.private_key is None: + raise ValueError("private key is required to generate an API token") + now = datetime.now(UTC).replace(tzinfo=None) payload = { 'identity': self.dongle_id, @@ -49,42 +81,16 @@ def get_token(self, payload_extra=None, expiry_hours=1): return token -class CommaApi: - def __init__(self, token=None, session=None): - self.session = session or requests.Session() - self.session.headers['User-Agent'] = TOOLS_USER_AGENT - if token: - self.session.headers['Authorization'] = 'JWT ' + token - - if session is None: - retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) - self.session.mount('https://', HTTPAdapter(max_retries=retries)) - - def request(self, method, endpoint, **kwargs): - with _request(method, endpoint, session=self.session, **kwargs) as resp: - resp_json = resp.json() - if isinstance(resp_json, dict) and resp_json.get('error'): - if resp.status_code in [401, 403]: - raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py', resp.status_code) - - raise APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])), resp.status_code) - return resp_json - - def get(self, endpoint, **kwargs): - return self.request('GET', endpoint, **kwargs) - - def post(self, endpoint, **kwargs): - return self.request('POST', endpoint, **kwargs) - - class APIError(Exception): def __init__(self, message, status_code=None): super().__init__(message) self.status_code = status_code -class UnauthorizedError(APIError): - pass +class UnauthorizedError(Exception): + def __init__(self, message, status_code=None): + super().__init__(message) + self.status_code = status_code def _api_url(endpoint): @@ -96,14 +102,38 @@ def _request(method, endpoint, session=None, **kwargs): return req.request(method, _api_url(endpoint), **kwargs) -def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params): - headers = {} +def default_user_agent(): + return "openpilot-" + get_version() + + +def api_headers(access_token=None, user_agent=None): + headers = {'User-Agent': user_agent or default_user_agent()} if access_token is not None: headers['Authorization'] = "JWT " + access_token + return headers + + +def request_api(method, endpoint, session=None, timeout=None, access_token=None, user_agent=None, params=None, headers=None, **kwargs): + request_headers = api_headers(access_token, user_agent) + if headers is not None: + request_headers.update(headers) + return _request(method, endpoint, session=session, timeout=timeout, + headers=request_headers, params=params, **kwargs) + - headers['User-Agent'] = "openpilot-" + get_version() +def parse_api_response(resp): + resp_json = resp.json() + if isinstance(resp_json, dict) and resp_json.get('error'): + if resp.status_code in [401, 403]: + raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py', resp.status_code) - return _request(method, endpoint, session=session, timeout=timeout, headers=headers, params=params) + raise APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])), resp.status_code) + return resp_json + + +def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params): + return request_api(method, endpoint, session=session, timeout=timeout, + access_token=access_token, params=params) def get_key_pair() -> tuple[str, str, str] | tuple[None, None, None]: diff --git a/common/tests/test_api.py b/common/tests/test_api.py deleted file mode 100644 index 752d1820cf16ec..00000000000000 --- a/common/tests/test_api.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -from openpilot.common import api as api_module -from openpilot.common.api import APIError, CommaApi, UnauthorizedError, api_get - - -class FakeResponse: - def __init__(self, payload, status_code=200): - self.payload = payload - self.status_code = status_code - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def json(self): - return self.payload - - -class FakeSession: - def __init__(self, response): - self.headers = {} - self.response = response - self.requests = [] - self.mounts = [] - - def mount(self, prefix, adapter): - self.mounts.append((prefix, adapter)) - - def request(self, method, url, **kwargs): - self.requests.append((method, url, kwargs)) - return self.response - - -def test_api_get_returns_raw_response(monkeypatch): - monkeypatch.setattr(api_module, "API_HOST", "https://api.example") - monkeypatch.setattr(api_module, "get_version", lambda: "test") - response = object() - session = FakeSession(response) - - assert api_get("/v1/me", access_token="token", session=session, timeout=5, foo="bar") is response - assert session.requests == [( - "GET", - "https://api.example/v1/me", - {"timeout": 5, "headers": {"Authorization": "JWT token", "User-Agent": "openpilot-test"}, "params": {"foo": "bar"}}, - )] - - -def test_comma_api_returns_json(monkeypatch): - monkeypatch.setattr(api_module, "API_HOST", "https://api.example") - session = FakeSession(FakeResponse({"ok": True})) - - assert CommaApi("token", session=session).get("/v1/me", timeout=5) == {"ok": True} - assert session.headers == {"User-Agent": "OpenpilotTools", "Authorization": "JWT token"} - assert session.requests == [("GET", "https://api.example/v1/me", {"timeout": 5})] - - -def test_comma_api_raises_api_error(): - session = FakeSession(FakeResponse({"error": "missing", "description": "not found"}, status_code=404)) - - with pytest.raises(APIError) as exc: - CommaApi(session=session).get("v1/route/missing") - - assert str(exc.value) == "404:not found" - assert exc.value.status_code == 404 - - -@pytest.mark.parametrize("status_code", [401, 403]) -def test_comma_api_raises_unauthorized(status_code): - session = FakeSession(FakeResponse({"error": "auth"}, status_code=status_code)) - - with pytest.raises(UnauthorizedError) as exc: - CommaApi(session=session).get("v1/me") - - assert str(exc.value) == "Unauthorized. Authenticate with tools/lib/auth.py" - assert exc.value.status_code == status_code diff --git a/tools/lib/auth.py b/tools/lib/auth.py index 8b36e64837a89a..24bb9eb8f2738c 100755 --- a/tools/lib/auth.py +++ b/tools/lib/auth.py @@ -29,7 +29,7 @@ from typing import Any from urllib.parse import parse_qs, urlencode -from openpilot.common.api import APIError, CommaApi, UnauthorizedError +from openpilot.common.api import APIError, Api, UnauthorizedError from openpilot.tools.lib.auth_config import set_token, get_token PORT = 3000 @@ -115,7 +115,7 @@ def login(method): break try: - auth_resp = CommaApi().post('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']}) + auth_resp = Api.from_user_token().post_json('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']}) set_token(auth_resp['access_token']) except APIError as e: print(f'Authentication Error: {e}', file=sys.stderr) @@ -137,7 +137,7 @@ def login(method): login(args.method) try: - me = CommaApi(token=get_token()).get('/v1/me') + me = Api.from_user_token(get_token()).get_json('/v1/me') print("Authenticated!") pprint.pprint(me) except UnauthorizedError: diff --git a/tools/lib/bootlog.py b/tools/lib/bootlog.py index 32dcb96784d866..5987b62915456e 100644 --- a/tools/lib/bootlog.py +++ b/tools/lib/bootlog.py @@ -2,7 +2,7 @@ import re from openpilot.tools.lib.auth_config import get_token -from openpilot.common.api import CommaApi +from openpilot.common.api import Api from openpilot.tools.lib.helpers import RE @@ -52,6 +52,6 @@ def get_bootlog_from_id(bootlog_id: str) -> Bootlog | None: return None def get_bootlogs(dongle_id: str) -> list[Bootlog]: - api = CommaApi(get_token()) - r = api.get(f'v1/devices/{dongle_id}/bootlogs') + api = Api.from_user_token(get_token()) + r = api.get_json(f'v1/devices/{dongle_id}/bootlogs') return [Bootlog(b) for b in r] diff --git a/tools/lib/file_downloader.py b/tools/lib/file_downloader.py index 86fa590bd6fd95..27774602b84133 100755 --- a/tools/lib/file_downloader.py +++ b/tools/lib/file_downloader.py @@ -18,7 +18,7 @@ import shutil from openpilot.common.hardware.hw import Paths -from openpilot.common.api import APIError, CommaApi, UnauthorizedError +from openpilot.common.api import APIError, Api, UnauthorizedError from openpilot.tools.lib.auth_config import get_token from openpilot.tools.lib.url_file import URLFile @@ -26,7 +26,7 @@ def api_call(func): """Run an API call, outputting JSON result or error to stdout.""" try: - result = func(CommaApi(get_token())) + result = func(Api.from_user_token(get_token())) json.dump(result, sys.stdout) except UnauthorizedError: json.dump({"error": "unauthorized"}, sys.stdout) @@ -45,7 +45,7 @@ def cache_file_path(url): def cmd_route_files(args): - api_call(lambda api: api.get(f"v1/route/{args.route}/files")) + api_call(lambda api: api.get_json(f"v1/route/{args.route}/files")) def cmd_download(args): @@ -110,19 +110,19 @@ def cmd_download(args): def cmd_devices(args): - api_call(lambda api: api.get("v1/me/devices/")) + api_call(lambda api: api.get_json("v1/me/devices/")) def cmd_device_routes(args): def fetch(api): if args.preserved: - return api.get(f"v1/devices/{args.dongle_id}/routes/preserved") + return api.get_json(f"v1/devices/{args.dongle_id}/routes/preserved") params = {} if args.start is not None: params['start'] = args.start if args.end is not None: params['end'] = args.end - return api.get(f"v1/devices/{args.dongle_id}/routes_segments", params=params) + return api.get_json(f"v1/devices/{args.dongle_id}/routes_segments", params=params) api_call(fetch) diff --git a/tools/lib/route.py b/tools/lib/route.py index 4710987633a020..bd84971b0ceb6e 100644 --- a/tools/lib/route.py +++ b/tools/lib/route.py @@ -7,7 +7,7 @@ from itertools import chain from openpilot.tools.lib.auth_config import get_token -from openpilot.common.api import APIError, CommaApi +from openpilot.common.api import APIError, Api from openpilot.tools.lib.helpers import RE @@ -65,8 +65,8 @@ def qcamera_paths(self): # TODO: refactor this, it's super repetitive def _get_segments_remote(self): - api = CommaApi(get_token()) - route_files = api.get('v1/route/' + self.name.canonical_name + '/files') + api = Api.from_user_token(get_token()) + route_files = api.get_json('v1/route/' + self.name.canonical_name + '/files') self.files = list(chain.from_iterable(route_files.values())) segments = {} @@ -182,8 +182,8 @@ def name(self): @staticmethod @cache def _get_route_metadata(route_name: str): - api = CommaApi(get_token()) - return api.get(f'v1/route/{route_name}') + api = Api.from_user_token(get_token()) + return api.get_json(f'v1/route/{route_name}') @property def url(self): @@ -308,8 +308,8 @@ def from_azure_prefix(prefix): @cache def get_max_seg_number_cached(sr: 'SegmentRange') -> int: try: - api = CommaApi(get_token()) - max_seg_number = api.get("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"] + api = Api.from_user_token(get_token()) + max_seg_number = api.get_json("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"] assert isinstance(max_seg_number, int) return max_seg_number except Exception as e: diff --git a/tools/scripts/ssh.py b/tools/scripts/ssh.py index b59b3dcdd573d9..b9c919e9c9e4cc 100755 --- a/tools/scripts/ssh.py +++ b/tools/scripts/ssh.py @@ -6,7 +6,7 @@ from openpilot.common.basedir import BASEDIR from openpilot.tools.lib.auth_config import get_token -from openpilot.common.api import CommaApi +from openpilot.common.api import Api if __name__ == "__main__": parser = argparse.ArgumentParser(description="A helper for connecting to devices over the comma prime SSH proxy.\ @@ -18,7 +18,7 @@ parser.add_argument("--debug", help="enable debug output", action="store_true") args = parser.parse_args() - r = CommaApi(get_token()).get("v1/me/devices") + r = Api.from_user_token(get_token()).get_json("v1/me/devices") devices = {x['dongle_id']: x['alias'] for x in r} if not re.match("[0-9a-zA-Z]{16}", args.device): From d8bd40d694f006446208b51a8987473ba5205095 Mon Sep 17 00:00:00 2001 From: Adeeb Shihadeh Date: Sun, 21 Jun 2026 14:52:54 -0700 Subject: [PATCH 3/3] cleanup --- common/api.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/common/api.py b/common/api.py index a4c980fbc6a4d5..814979f46e5021 100644 --- a/common/api.py +++ b/common/api.py @@ -40,10 +40,7 @@ def post(self, *args, **kwargs): return self.request('POST', *args, **kwargs) def request(self, method, endpoint, timeout=None, access_token=None, session=None, **params): - session = self.session if session is None else session - access_token = self.token if access_token is None else access_token - return request_api(method, endpoint, session=session, timeout=timeout, - access_token=access_token, user_agent=self.user_agent, params=params) + return self._send(method, endpoint, timeout=timeout, access_token=access_token, session=session, params=params) def get_json(self, *args, **kwargs): return self.request_json('GET', *args, **kwargs) @@ -51,19 +48,20 @@ def get_json(self, *args, **kwargs): def post_json(self, *args, **kwargs): return self.request_json('POST', *args, **kwargs) - def request_json(self, method, endpoint, timeout=None, access_token=None, session=None, params=None, **kwargs): + def request_json(self, method, endpoint, **kwargs): + with self._send(method, endpoint, **kwargs) as resp: + return parse_api_response(resp) + + def _send(self, method, endpoint, timeout=None, access_token=None, session=None, params=None, **kwargs): session = self.session if session is None else session access_token = self.token if access_token is None else access_token - with request_api(method, endpoint, session=session, timeout=timeout, access_token=access_token, - user_agent=self.user_agent, params=params, **kwargs) as resp: - return parse_api_response(resp) + return request_api(method, endpoint, session=session, timeout=timeout, + access_token=access_token, user_agent=self.user_agent, params=params, **kwargs) def get_token(self, payload_extra=None, expiry_hours=1): if self.dongle_id is None: raise ValueError("dongle_id is required to generate an API token") - if self.jwt_algorithm is None or self.private_key is None: - self.jwt_algorithm, self.private_key, _ = get_key_pair() - if self.jwt_algorithm is None or self.private_key is None: + if self.private_key is None: raise ValueError("private key is required to generate an API token") now = datetime.now(UTC).replace(tzinfo=None) @@ -81,25 +79,22 @@ def get_token(self, payload_extra=None, expiry_hours=1): return token -class APIError(Exception): +class _APIException(Exception): def __init__(self, message, status_code=None): super().__init__(message) self.status_code = status_code -class UnauthorizedError(Exception): - def __init__(self, message, status_code=None): - super().__init__(message) - self.status_code = status_code +class APIError(_APIException): + pass -def _api_url(endpoint): - return API_HOST + "/" + endpoint.lstrip("/") +class UnauthorizedError(_APIException): + pass -def _request(method, endpoint, session=None, **kwargs): - req = requests if session is None else session - return req.request(method, _api_url(endpoint), **kwargs) +def _api_url(endpoint): + return API_HOST + "/" + endpoint.lstrip("/") def default_user_agent(): @@ -117,8 +112,8 @@ def request_api(method, endpoint, session=None, timeout=None, access_token=None, request_headers = api_headers(access_token, user_agent) if headers is not None: request_headers.update(headers) - return _request(method, endpoint, session=session, timeout=timeout, - headers=request_headers, params=params, **kwargs) + req = requests if session is None else session + return req.request(method, _api_url(endpoint), timeout=timeout, headers=request_headers, params=params, **kwargs) def parse_api_response(resp):