diff --git a/common/api.py b/common/api.py index c97f56c4b5cb71..814979f46e5021 100644 --- a/common/api.py +++ b/common/api.py @@ -1,11 +1,15 @@ -import jwt import os +from datetime import UTC, datetime, timedelta + +import jwt import requests -from datetime import datetime, timedelta, UTC +from requests.adapters import HTTPAdapter, Retry + from openpilot.common.hardware.hw import Paths from openpilot.common.version import get_version API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com') +TOOLS_USER_AGENT = 'OpenpilotTools' # name: jwt signature algorithm KEYS = {"id_rsa": "RS256", @@ -13,9 +17,21 @@ class Api: - def __init__(self, dongle_id): + def __init__(self, dongle_id=None, token=None, session=None, user_agent=None, retry=False): self.dongle_id = dongle_id - self.jwt_algorithm, self.private_key, _ = get_key_pair() + self.token = token + self.user_agent = user_agent or default_user_agent() + self.session = session + self.jwt_algorithm, self.private_key, _ = get_key_pair() if dongle_id is not None else (None, None, None) + + if retry and session is None: + self.session = requests.Session() + retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) + self.session.mount('https://', HTTPAdapter(max_retries=retries)) + + @classmethod + def from_user_token(cls, token=None, session=None): + return cls(token=token, session=session, user_agent=TOOLS_USER_AGENT, retry=True) def get(self, *args, **kwargs): return self.request('GET', *args, **kwargs) @@ -23,10 +39,31 @@ def get(self, *args, **kwargs): def post(self, *args, **kwargs): return self.request('POST', *args, **kwargs) - def request(self, method, endpoint, timeout=None, access_token=None, **params): - return api_get(endpoint, method=method, timeout=timeout, access_token=access_token, **params) + def request(self, method, endpoint, timeout=None, access_token=None, session=None, **params): + return self._send(method, endpoint, timeout=timeout, access_token=access_token, session=session, params=params) + + def get_json(self, *args, **kwargs): + return self.request_json('GET', *args, **kwargs) + + def post_json(self, *args, **kwargs): + return self.request_json('POST', *args, **kwargs) + + def request_json(self, method, endpoint, **kwargs): + with self._send(method, endpoint, **kwargs) as resp: + return parse_api_response(resp) + + def _send(self, method, endpoint, timeout=None, access_token=None, session=None, params=None, **kwargs): + session = self.session if session is None else session + access_token = self.token if access_token is None else access_token + return request_api(method, endpoint, session=session, timeout=timeout, + access_token=access_token, user_agent=self.user_agent, params=params, **kwargs) def get_token(self, payload_extra=None, expiry_hours=1): + if self.dongle_id is None: + raise ValueError("dongle_id is required to generate an API token") + if self.private_key is None: + raise ValueError("private key is required to generate an API token") + now = datetime.now(UTC).replace(tzinfo=None) payload = { 'identity': self.dongle_id, @@ -42,16 +79,56 @@ def get_token(self, payload_extra=None, expiry_hours=1): return token -def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params): - headers = {} +class _APIException(Exception): + def __init__(self, message, status_code=None): + super().__init__(message) + self.status_code = status_code + + +class APIError(_APIException): + pass + + +class UnauthorizedError(_APIException): + pass + + +def _api_url(endpoint): + return API_HOST + "/" + endpoint.lstrip("/") + + +def default_user_agent(): + return "openpilot-" + get_version() + + +def api_headers(access_token=None, user_agent=None): + headers = {'User-Agent': user_agent or default_user_agent()} if access_token is not None: headers['Authorization'] = "JWT " + access_token + return headers - headers['User-Agent'] = "openpilot-" + get_version() - # TODO: add session to Api +def request_api(method, endpoint, session=None, timeout=None, access_token=None, user_agent=None, params=None, headers=None, **kwargs): + request_headers = api_headers(access_token, user_agent) + if headers is not None: + request_headers.update(headers) req = requests if session is None else session - return req.request(method, API_HOST + "/" + endpoint, timeout=timeout, headers=headers, params=params) + return req.request(method, _api_url(endpoint), timeout=timeout, headers=request_headers, params=params, **kwargs) + + +def parse_api_response(resp): + resp_json = resp.json() + if isinstance(resp_json, dict) and resp_json.get('error'): + if resp.status_code in [401, 403]: + raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py', resp.status_code) + + raise APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])), resp.status_code) + return resp_json + + +def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params): + return request_api(method, endpoint, session=session, timeout=timeout, + access_token=access_token, params=params) def get_key_pair() -> tuple[str, str, str] | tuple[None, None, None]: diff --git a/tools/lib/api.py b/tools/lib/api.py deleted file mode 100644 index f84fe758695fc3..00000000000000 --- a/tools/lib/api.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -import requests -from requests.adapters import HTTPAdapter, Retry -API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com') - -# TODO: this should be merged into common.api - -class CommaApi: - def __init__(self, token=None): - self.session = requests.Session() - self.session.headers['User-agent'] = 'OpenpilotTools' - if token: - self.session.headers['Authorization'] = 'JWT ' + token - - retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) - self.session.mount('https://', HTTPAdapter(max_retries=retries)) - - def request(self, method, endpoint, **kwargs): - with self.session.request(method, API_HOST + '/' + endpoint, **kwargs) as resp: - resp_json = resp.json() - if isinstance(resp_json, dict) and resp_json.get('error'): - if resp.status_code in [401, 403]: - raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py') - - e = APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error']))) - e.status_code = resp.status_code - raise e - return resp_json - - def get(self, endpoint, **kwargs): - return self.request('GET', endpoint, **kwargs) - - def post(self, endpoint, **kwargs): - return self.request('POST', endpoint, **kwargs) - -class APIError(Exception): - pass - -class UnauthorizedError(Exception): - pass diff --git a/tools/lib/auth.py b/tools/lib/auth.py index 5988397d0a94e2..24bb9eb8f2738c 100755 --- a/tools/lib/auth.py +++ b/tools/lib/auth.py @@ -29,7 +29,7 @@ from typing import Any from urllib.parse import parse_qs, urlencode -from openpilot.tools.lib.api import APIError, CommaApi, UnauthorizedError +from openpilot.common.api import APIError, Api, UnauthorizedError from openpilot.tools.lib.auth_config import set_token, get_token PORT = 3000 @@ -115,7 +115,7 @@ def login(method): break try: - auth_resp = CommaApi().post('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']}) + auth_resp = Api.from_user_token().post_json('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']}) set_token(auth_resp['access_token']) except APIError as e: print(f'Authentication Error: {e}', file=sys.stderr) @@ -137,7 +137,7 @@ def login(method): login(args.method) try: - me = CommaApi(token=get_token()).get('/v1/me') + me = Api.from_user_token(get_token()).get_json('/v1/me') print("Authenticated!") pprint.pprint(me) except UnauthorizedError: diff --git a/tools/lib/bootlog.py b/tools/lib/bootlog.py index 208ddc19c228fb..5987b62915456e 100644 --- a/tools/lib/bootlog.py +++ b/tools/lib/bootlog.py @@ -2,7 +2,7 @@ import re from openpilot.tools.lib.auth_config import get_token -from openpilot.tools.lib.api import CommaApi +from openpilot.common.api import Api from openpilot.tools.lib.helpers import RE @@ -52,6 +52,6 @@ def get_bootlog_from_id(bootlog_id: str) -> Bootlog | None: return None def get_bootlogs(dongle_id: str) -> list[Bootlog]: - api = CommaApi(get_token()) - r = api.get(f'v1/devices/{dongle_id}/bootlogs') + api = Api.from_user_token(get_token()) + r = api.get_json(f'v1/devices/{dongle_id}/bootlogs') return [Bootlog(b) for b in r] diff --git a/tools/lib/file_downloader.py b/tools/lib/file_downloader.py index 68061b201eea80..27774602b84133 100755 --- a/tools/lib/file_downloader.py +++ b/tools/lib/file_downloader.py @@ -18,7 +18,7 @@ import shutil from openpilot.common.hardware.hw import Paths -from openpilot.tools.lib.api import CommaApi, UnauthorizedError, APIError +from openpilot.common.api import APIError, Api, UnauthorizedError from openpilot.tools.lib.auth_config import get_token from openpilot.tools.lib.url_file import URLFile @@ -26,7 +26,7 @@ def api_call(func): """Run an API call, outputting JSON result or error to stdout.""" try: - result = func(CommaApi(get_token())) + result = func(Api.from_user_token(get_token())) json.dump(result, sys.stdout) except UnauthorizedError: json.dump({"error": "unauthorized"}, sys.stdout) @@ -45,7 +45,7 @@ def cache_file_path(url): def cmd_route_files(args): - api_call(lambda api: api.get(f"v1/route/{args.route}/files")) + api_call(lambda api: api.get_json(f"v1/route/{args.route}/files")) def cmd_download(args): @@ -110,19 +110,19 @@ def cmd_download(args): def cmd_devices(args): - api_call(lambda api: api.get("v1/me/devices/")) + api_call(lambda api: api.get_json("v1/me/devices/")) def cmd_device_routes(args): def fetch(api): if args.preserved: - return api.get(f"v1/devices/{args.dongle_id}/routes/preserved") + return api.get_json(f"v1/devices/{args.dongle_id}/routes/preserved") params = {} if args.start is not None: params['start'] = args.start if args.end is not None: params['end'] = args.end - return api.get(f"v1/devices/{args.dongle_id}/routes_segments", params=params) + return api.get_json(f"v1/devices/{args.dongle_id}/routes_segments", params=params) api_call(fetch) diff --git a/tools/lib/route.py b/tools/lib/route.py index 98334a06c86a6d..bd84971b0ceb6e 100644 --- a/tools/lib/route.py +++ b/tools/lib/route.py @@ -7,7 +7,7 @@ from itertools import chain from openpilot.tools.lib.auth_config import get_token -from openpilot.tools.lib.api import APIError, CommaApi +from openpilot.common.api import APIError, Api from openpilot.tools.lib.helpers import RE @@ -65,8 +65,8 @@ def qcamera_paths(self): # TODO: refactor this, it's super repetitive def _get_segments_remote(self): - api = CommaApi(get_token()) - route_files = api.get('v1/route/' + self.name.canonical_name + '/files') + api = Api.from_user_token(get_token()) + route_files = api.get_json('v1/route/' + self.name.canonical_name + '/files') self.files = list(chain.from_iterable(route_files.values())) segments = {} @@ -182,8 +182,8 @@ def name(self): @staticmethod @cache def _get_route_metadata(route_name: str): - api = CommaApi(get_token()) - return api.get(f'v1/route/{route_name}') + api = Api.from_user_token(get_token()) + return api.get_json(f'v1/route/{route_name}') @property def url(self): @@ -308,8 +308,8 @@ def from_azure_prefix(prefix): @cache def get_max_seg_number_cached(sr: 'SegmentRange') -> int: try: - api = CommaApi(get_token()) - max_seg_number = api.get("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"] + api = Api.from_user_token(get_token()) + max_seg_number = api.get_json("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"] assert isinstance(max_seg_number, int) return max_seg_number except Exception as e: diff --git a/tools/scripts/ssh.py b/tools/scripts/ssh.py index 33eac4081d1869..b9c919e9c9e4cc 100755 --- a/tools/scripts/ssh.py +++ b/tools/scripts/ssh.py @@ -6,7 +6,7 @@ from openpilot.common.basedir import BASEDIR from openpilot.tools.lib.auth_config import get_token -from openpilot.tools.lib.api import CommaApi +from openpilot.common.api import Api if __name__ == "__main__": parser = argparse.ArgumentParser(description="A helper for connecting to devices over the comma prime SSH proxy.\ @@ -18,7 +18,7 @@ parser.add_argument("--debug", help="enable debug output", action="store_true") args = parser.parse_args() - r = CommaApi(get_token()).get("v1/me/devices") + r = Api.from_user_token(get_token()).get_json("v1/me/devices") devices = {x['dongle_id']: x['alias'] for x in r} if not re.match("[0-9a-zA-Z]{16}", args.device):