Skip to content
5 changes: 5 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def load_routing_configuration(file: Path = _config_file) -> TomlTable:
return typing.cast("TomlTable", _load_configuration(file)["routing"])


@functools.cache
def load_run_configuration(file: Path = _config_file) -> TomlTable:
return typing.cast("TomlTable", _load_configuration(file).get("run", {}))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When resolving merge conflicts, please make this a DevelopmentConfiguration option for now.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.



@functools.cache
def load_database_configuration(file: Path = _config_file) -> TomlTable:
configuration = _load_configuration(file)
Expand Down
3 changes: 3 additions & 0 deletions src/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ database="openml"
[routing]
minio_url="http://minio:9000/"
server_url="http://php-api:80/"

[run]
evaluation_engine_ids = [1]
2 changes: 1 addition & 1 deletion src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
SELECT *, uploadDate as upload_date
SELECT *, uploadDate as upload_date, fullName AS full_name
FROM implementation
WHERE id = :flow_id
""",
Expand Down
117 changes: 116 additions & 1 deletion src/database/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Row, text
from sqlalchemy import Row, bindparam, text
from sqlalchemy.ext.asyncio import AsyncConnection


Expand All @@ -22,6 +22,121 @@ async def exist(id_: int, expdb: AsyncConnection) -> bool:
return bool(row.one_or_none())


async def get(run_id: int, expdb: AsyncConnection) -> Row | None:
"""Fetch the core run row from the `run` table.

Returns the row if found, or None if no run with `run_id` exists.
The `error_message` column is NULL when the run completed without errors.
"""
row = await expdb.execute(
text(
"""
SELECT `rid`, `uploader`, `setup`, `task_id`, `error_message`
FROM `run`
WHERE `rid` = :run_id
""",
),
parameters={"run_id": run_id},
)
return row.one_or_none()


async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]:
"""Fetch all tags associated with a run from the `run_tag` table.

The `id` column in `run_tag` refers to the run ID
"""
rows = await expdb.execute(
text(
"""
SELECT `tag`
FROM `run_tag`
WHERE `id` = :run_id
""",
),
parameters={"run_id": run_id},
)
return [row.tag for row in rows.all()]


async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
"""Fetch the dataset(s) used as input for a run, with name and url.

Joins `input_data` with `dataset` to include the dataset name and ARFF URL.
"""
rows = await expdb.execute(
text(
"""
SELECT `id`.`data` AS `did`, `d`.`name`, `d`.`url`
FROM `input_data` `id`
JOIN `dataset` `d` ON `id`.`data` = `d`.`did`
WHERE `id`.`run` = :run_id
""",
),
parameters={"run_id": run_id},
)
return cast("list[Row]", rows.all())


async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
"""Fetch output files attached to a run from the `runfile` table.

Typical entries include the description XML and predictions ARFF.
The `field` column holds the file label (e.g. "description", "predictions").
"""
rows = await expdb.execute(
text(
"""
SELECT `file_id`, `field`
FROM `runfile`
WHERE `source` = :run_id
""",
),
parameters={"run_id": run_id},
)
return cast("list[Row]", rows.all())


async def get_evaluations(
run_id: int,
expdb: AsyncConnection,
*,
evaluation_engine_ids: list[int],
) -> list[Row]:
"""Fetch evaluation metric results for a run.

Joins `evaluation` with `math_function` to resolve the metric name
(the `evaluation` table stores only a `function_id`, not the name directly).

Filters by `evaluation_engine_id IN (...)`. The list is configurable
via `config.toml [run] evaluation_engine_ids`.
Dynamic named parameters are used for aiomysql compatibility.
"""
if not evaluation_engine_ids:
return []

query = text(
"""
SELECT `m`.`name`, `e`.`value`, `e`.`array_data`, NULL as `repeat`, NULL as `fold`
FROM `evaluation` `e`
JOIN `math_function` `m` ON `e`.`function_id` = `m`.`id`
WHERE `e`.`source` = :run_id
AND `e`.`evaluation_engine_id` IN :engine_ids
UNION ALL
SELECT `m`.`name`, `ef`.`value`, `ef`.`array_data`, `ef`.`repeat`, `ef`.`fold`
FROM `evaluation_fold` `ef`
JOIN `math_function` `m` ON `ef`.`function_id` = `m`.`id`
WHERE `ef`.`source` = :run_id
AND `ef`.`evaluation_engine_id` IN :engine_ids
""",
).bindparams(bindparam("engine_ids", expanding=True))
rows = await expdb.execute(
query,
parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids},
)
return cast("list[Row]", rows.all())


async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
Expand Down
43 changes: 43 additions & 0 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,49 @@ async def get_task_type(task_type_id: int, expdb: AsyncConnection) -> Row | None
return row.one_or_none()


async def get_task_type_name(task_id: int, expdb: AsyncConnection) -> str | None:
"""Fetch the human-readable task type name for the task associated with a run.

Joins `task` and `task_type` on `ttid` to resolve the name
(e.g. "Supervised Classification").
"""
row = await expdb.execute(
text(
"""
SELECT `tt`.`name`
FROM `task` `t`
JOIN `task_type` `tt` ON `t`.`ttid` = `tt`.`ttid`
WHERE `t`.`task_id` = :task_id
""",
),
parameters={"task_id": task_id},
)
result = row.one_or_none()
return result.name if result else None


async def get_task_evaluation_measure(task_id: int, expdb: AsyncConnection) -> str | None:
"""Fetch the evaluation measure configured for a task, if any.

Queries `task_inputs` for the row where `input = 'evaluation_measures'`.
Returns None (not an empty string) when no such row exists, so callers
can treat a falsy result uniformly.
"""
row = await expdb.execute(
text(
"""
SELECT `value`
FROM `task_inputs`
WHERE `task_id` = :task_id
AND `input` = 'evaluation_measures'
""",
),
parameters={"task_id": task_id},
)
result = row.one_or_none()
return result.value if result else None


async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
text(
Expand Down
67 changes: 53 additions & 14 deletions src/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,45 @@ class UserGroup(IntEnum):
READ_ONLY = (3,)


async def get_user_id_for(*, api_key: APIKey, connection: AsyncConnection) -> int | None:
row = await connection.execute(
text(
"""
SELECT *
FROM users
WHERE session_hash = :api_key
""",
),
parameters={"api_key": api_key},
async def get_user(
*,
connection: AsyncConnection,
api_key: APIKey | None = None,
user_id: int | None = None,
) -> "User | None":
"""Fetch the full user by either api_key or user_id."""
if (api_key is None) == (user_id is None):
msg = "Exactly one of api_key or user_id must be provided."
raise ValueError(msg)

if api_key is not None:
query = """
SELECT id, first_name, last_name
FROM users
WHERE session_hash = :api_key
LIMIT 1
"""
else:
query = """
SELECT id, first_name, last_name
FROM users
WHERE id = :user_id
LIMIT 1
"""

result = await connection.execute(
text(query),
parameters={"api_key": api_key, "user_id": user_id},
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
user = row.one_or_none()
return user.id if user else None
row = result.one_or_none()
if row:
return User(
user_id=row.id,
first_name=row.first_name,
last_name=row.last_name,
_database=connection,
)
return None


async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> list[int]:
Expand All @@ -60,12 +86,25 @@ async def get_user_groups_for(*, user_id: int, connection: AsyncConnection) -> l
class User:
user_id: int
_database: AsyncConnection
first_name: str = ""
last_name: str = ""
_groups: list[UserGroup] | None = None

@property
def full_name(self) -> str:
"""Return the combined first and last name."""
return " ".join(part for part in [self.first_name, self.last_name] if part)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return " ".join(part for part in [self.first_name, self.last_name] if part)
return f"{self.first_name} {self.last_name}".strip()

If this setup was to work around None values (as opposed to emtpy strings), then I think it may be better to add a __post_init__ method which sets the names to an empty string if a None value was provided. Otherwise, the current type annotation of str is incorrect (should be str | None - I think the better fix is to make sure the value is always str).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noted.


@classmethod
async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None:
if (user_id := await get_user_id_for(api_key=api_key, connection=user_db)) is not None:
return cls(user_id, _database=user_db)
user = await get_user(api_key=api_key, connection=user_db)
if user is not None:
return cls(
user_id=user.user_id,
first_name=user.first_name,
last_name=user.last_name,
_database=user_db,
)
return None

async def get_groups(self) -> list[UserGroup]:
Expand Down
Loading
Loading