Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 88 additions & 11 deletions common/api.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,69 @@
import jwt
import os
from datetime import UTC, datetime, timedelta

import jwt
import requests
from datetime import datetime, timedelta, UTC
from requests.adapters import HTTPAdapter, Retry

from openpilot.common.hardware.hw import Paths
from openpilot.common.version import get_version

API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com')
TOOLS_USER_AGENT = 'OpenpilotTools'

# name: jwt signature algorithm
KEYS = {"id_rsa": "RS256",
"id_ecdsa": "ES256"}


class Api:
def __init__(self, dongle_id):
def __init__(self, dongle_id=None, token=None, session=None, user_agent=None, retry=False):
self.dongle_id = dongle_id
self.jwt_algorithm, self.private_key, _ = get_key_pair()
self.token = token
self.user_agent = user_agent or default_user_agent()
self.session = session
self.jwt_algorithm, self.private_key, _ = get_key_pair() if dongle_id is not None else (None, None, None)

if retry and session is None:
self.session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
self.session.mount('https://', HTTPAdapter(max_retries=retries))

@classmethod
def from_user_token(cls, token=None, session=None):
return cls(token=token, session=session, user_agent=TOOLS_USER_AGENT, retry=True)

def get(self, *args, **kwargs):
return self.request('GET', *args, **kwargs)

def post(self, *args, **kwargs):
return self.request('POST', *args, **kwargs)

def request(self, method, endpoint, timeout=None, access_token=None, **params):
return api_get(endpoint, method=method, timeout=timeout, access_token=access_token, **params)
def request(self, method, endpoint, timeout=None, access_token=None, session=None, **params):
return self._send(method, endpoint, timeout=timeout, access_token=access_token, session=session, params=params)

def get_json(self, *args, **kwargs):
return self.request_json('GET', *args, **kwargs)

def post_json(self, *args, **kwargs):
return self.request_json('POST', *args, **kwargs)

def request_json(self, method, endpoint, **kwargs):
with self._send(method, endpoint, **kwargs) as resp:
return parse_api_response(resp)

def _send(self, method, endpoint, timeout=None, access_token=None, session=None, params=None, **kwargs):
session = self.session if session is None else session
access_token = self.token if access_token is None else access_token
return request_api(method, endpoint, session=session, timeout=timeout,
access_token=access_token, user_agent=self.user_agent, params=params, **kwargs)

def get_token(self, payload_extra=None, expiry_hours=1):
if self.dongle_id is None:
raise ValueError("dongle_id is required to generate an API token")
if self.private_key is None:
raise ValueError("private key is required to generate an API token")

now = datetime.now(UTC).replace(tzinfo=None)
payload = {
'identity': self.dongle_id,
Expand All @@ -42,16 +79,56 @@ def get_token(self, payload_extra=None, expiry_hours=1):
return token


def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params):
headers = {}
class _APIException(Exception):
def __init__(self, message, status_code=None):
super().__init__(message)
self.status_code = status_code


class APIError(_APIException):
pass


class UnauthorizedError(_APIException):
pass


def _api_url(endpoint):
return API_HOST + "/" + endpoint.lstrip("/")


def default_user_agent():
return "openpilot-" + get_version()


def api_headers(access_token=None, user_agent=None):
headers = {'User-Agent': user_agent or default_user_agent()}
if access_token is not None:
headers['Authorization'] = "JWT " + access_token
return headers

headers['User-Agent'] = "openpilot-" + get_version()

# TODO: add session to Api
def request_api(method, endpoint, session=None, timeout=None, access_token=None, user_agent=None, params=None, headers=None, **kwargs):
request_headers = api_headers(access_token, user_agent)
if headers is not None:
request_headers.update(headers)
req = requests if session is None else session
return req.request(method, API_HOST + "/" + endpoint, timeout=timeout, headers=headers, params=params)
return req.request(method, _api_url(endpoint), timeout=timeout, headers=request_headers, params=params, **kwargs)


def parse_api_response(resp):
resp_json = resp.json()
if isinstance(resp_json, dict) and resp_json.get('error'):
if resp.status_code in [401, 403]:
raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py', resp.status_code)

raise APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])), resp.status_code)
return resp_json


def api_get(endpoint, method='GET', timeout=None, access_token=None, session=None, **params):
return request_api(method, endpoint, session=session, timeout=timeout,
access_token=access_token, params=params)


def get_key_pair() -> tuple[str, str, str] | tuple[None, None, None]:
Expand Down
40 changes: 0 additions & 40 deletions tools/lib/api.py

This file was deleted.

6 changes: 3 additions & 3 deletions tools/lib/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import Any
from urllib.parse import parse_qs, urlencode

from openpilot.tools.lib.api import APIError, CommaApi, UnauthorizedError
from openpilot.common.api import APIError, Api, UnauthorizedError
from openpilot.tools.lib.auth_config import set_token, get_token

PORT = 3000
Expand Down Expand Up @@ -115,7 +115,7 @@ def login(method):
break

try:
auth_resp = CommaApi().post('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']})
auth_resp = Api.from_user_token().post_json('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']})
set_token(auth_resp['access_token'])
except APIError as e:
print(f'Authentication Error: {e}', file=sys.stderr)
Expand All @@ -137,7 +137,7 @@ def login(method):
login(args.method)

try:
me = CommaApi(token=get_token()).get('/v1/me')
me = Api.from_user_token(get_token()).get_json('/v1/me')
print("Authenticated!")
pprint.pprint(me)
except UnauthorizedError:
Expand Down
6 changes: 3 additions & 3 deletions tools/lib/bootlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re

from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
from openpilot.common.api import Api
from openpilot.tools.lib.helpers import RE


Expand Down Expand Up @@ -52,6 +52,6 @@ def get_bootlog_from_id(bootlog_id: str) -> Bootlog | None:
return None

def get_bootlogs(dongle_id: str) -> list[Bootlog]:
api = CommaApi(get_token())
r = api.get(f'v1/devices/{dongle_id}/bootlogs')
api = Api.from_user_token(get_token())
r = api.get_json(f'v1/devices/{dongle_id}/bootlogs')
return [Bootlog(b) for b in r]
12 changes: 6 additions & 6 deletions tools/lib/file_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import shutil

from openpilot.common.hardware.hw import Paths
from openpilot.tools.lib.api import CommaApi, UnauthorizedError, APIError
from openpilot.common.api import APIError, Api, UnauthorizedError
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.url_file import URLFile


def api_call(func):
"""Run an API call, outputting JSON result or error to stdout."""
try:
result = func(CommaApi(get_token()))
result = func(Api.from_user_token(get_token()))
json.dump(result, sys.stdout)
except UnauthorizedError:
json.dump({"error": "unauthorized"}, sys.stdout)
Expand All @@ -45,7 +45,7 @@ def cache_file_path(url):


def cmd_route_files(args):
api_call(lambda api: api.get(f"v1/route/{args.route}/files"))
api_call(lambda api: api.get_json(f"v1/route/{args.route}/files"))


def cmd_download(args):
Expand Down Expand Up @@ -110,19 +110,19 @@ def cmd_download(args):


def cmd_devices(args):
api_call(lambda api: api.get("v1/me/devices/"))
api_call(lambda api: api.get_json("v1/me/devices/"))


def cmd_device_routes(args):
def fetch(api):
if args.preserved:
return api.get(f"v1/devices/{args.dongle_id}/routes/preserved")
return api.get_json(f"v1/devices/{args.dongle_id}/routes/preserved")
params = {}
if args.start is not None:
params['start'] = args.start
if args.end is not None:
params['end'] = args.end
return api.get(f"v1/devices/{args.dongle_id}/routes_segments", params=params)
return api.get_json(f"v1/devices/{args.dongle_id}/routes_segments", params=params)
api_call(fetch)


Expand Down
14 changes: 7 additions & 7 deletions tools/lib/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itertools import chain

from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import APIError, CommaApi
from openpilot.common.api import APIError, Api
from openpilot.tools.lib.helpers import RE


Expand Down Expand Up @@ -65,8 +65,8 @@ def qcamera_paths(self):

# TODO: refactor this, it's super repetitive
def _get_segments_remote(self):
api = CommaApi(get_token())
route_files = api.get('v1/route/' + self.name.canonical_name + '/files')
api = Api.from_user_token(get_token())
route_files = api.get_json('v1/route/' + self.name.canonical_name + '/files')
self.files = list(chain.from_iterable(route_files.values()))

segments = {}
Expand Down Expand Up @@ -182,8 +182,8 @@ def name(self):
@staticmethod
@cache
def _get_route_metadata(route_name: str):
api = CommaApi(get_token())
return api.get(f'v1/route/{route_name}')
api = Api.from_user_token(get_token())
return api.get_json(f'v1/route/{route_name}')

@property
def url(self):
Expand Down Expand Up @@ -308,8 +308,8 @@ def from_azure_prefix(prefix):
@cache
def get_max_seg_number_cached(sr: 'SegmentRange') -> int:
try:
api = CommaApi(get_token())
max_seg_number = api.get("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"]
api = Api.from_user_token(get_token())
max_seg_number = api.get_json("/v1/route/" + sr.route_name.replace("/", "|"))["maxqlog"]
assert isinstance(max_seg_number, int)
return max_seg_number
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions tools/scripts/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from openpilot.common.basedir import BASEDIR
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
from openpilot.common.api import Api

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A helper for connecting to devices over the comma prime SSH proxy.\
Expand All @@ -18,7 +18,7 @@
parser.add_argument("--debug", help="enable debug output", action="store_true")
args = parser.parse_args()

r = CommaApi(get_token()).get("v1/me/devices")
r = Api.from_user_token(get_token()).get_json("v1/me/devices")
devices = {x['dongle_id']: x['alias'] for x in r}

if not re.match("[0-9a-zA-Z]{16}", args.device):
Expand Down
Loading