diff --git a/.gitignore b/.gitignore index 46689285..54cf4888 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,6 @@ phenex/test/phenotypes/artifacts/ phenex/test/cohort/artifacts/ phenex/test/serialization/artifacts r-package/man/ + +# MCP codelist data (may contain sensitive medical codes) +mcp/codelists/ diff --git a/mcp/.env.example b/mcp/.env.example new file mode 100644 index 00000000..cb762ca8 --- /dev/null +++ b/mcp/.env.example @@ -0,0 +1,32 @@ +# Snowflake Configuration +SNOWFLAKE_USER=your-snowflake-username +SNOWFLAKE_PASSWORD=your-snowflake-password +SNOWFLAKE_ACCOUNT=your-snowflake-account +SNOWFLAKE_WAREHOUSE=your-snowflake-warehouse +SNOWFLAKE_ROLE=your-snowflake-role + +# Optional: default database/schema for Snowflake explorer +# SNOWFLAKE_DATABASE=your-database +# SNOWFLAKE_SCHEMA=your-schema + +# PhenEx codelist directory — path to folder containing codelist CSV/Excel files +# PHENEX_CODELIST_DIR=/path/to/codelists +# Column name overrides (defaults: code, codelist, code_type) +# PHENEX_CODELIST_CODE_COLUMN=code +# PHENEX_CODELIST_NAME_COLUMN=codelist +# PHENEX_CODELIST_CODE_TYPE_COLUMN=code_type + +# PhenEx cohort execution — source data location +# SNOWFLAKE_SOURCE_DATABASE=OPTUM_CLAIMS +# SNOWFLAKE_SOURCE_SCHEMA=OMOP_CDM + +# PhenEx cohort execution — destination database (schema is auto-generated) +# SNOWFLAKE_DEST_DATABASE=ANALYTICS + +# Server transport (stdio, streamable-http, sse) +# MCP_TRANSPORT=stdio +# MCP_HOST=0.0.0.0 +# MCP_PORT=9000 + +# Logging level +# LOG_LEVEL=INFO diff --git a/mcp/Dockerfile b/mcp/Dockerfile new file mode 100644 index 00000000..5526aa60 --- /dev/null +++ b/mcp/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install Node.js for MCP Inspector +RUN apt-get update && apt-get install -y --no-install-recommends curl && \ + curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \ + apt-get install -y --no-install-recommends nodejs && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies (includes phenex from PyPI) +COPY requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +# Pre-install MCP Inspector +RUN npm install -g @modelcontextprotocol/inspector + +# Copy MCP server code +COPY . /app/mcp/ + +ENV MCP_TRANSPORT=stdio + +EXPOSE 6277 6274 + +CMD ["npx", "@modelcontextprotocol/inspector", "python", "/app/mcp/server.py"] diff --git a/mcp/README.md b/mcp/README.md new file mode 100644 index 00000000..8320dbe9 --- /dev/null +++ b/mcp/README.md @@ -0,0 +1,202 @@ +# PhenEx Cohort Builder MCP Server + +An MCP (Model Context Protocol) server that exposes PhenEx cohort building and Snowflake data exploration functionality to AI assistants. + +## What's Included + +| Tool | Description | +| ---------------------------------- | -------------------------------------------------------------- | +| `phenex_list_available_phenotypes` | List all PhenEx phenotype classes with descriptions | +| `phenex_get_phenotype_spec` | Get detailed spec/docs for a phenotype class (or `"Codelist"`) | +| `phenex_list_available_codelists` | List codelists from configured CSV/Excel directory | +| `phenex_get_codelist` | Get the full contents of a specific codelist by name | +| `phenex_validate_phenotype` | Validate a single phenotype definition compiles correctly | +| `phenex_validate_cohort` | Validate a cohort definition JSON without executing | +| `phenex_execute_cohort` | Validate and optionally execute a cohort against Snowflake | +| `snowflake_list_databases` | List/search Snowflake databases | +| `snowflake_list_schemas` | List schemas inside a database | +| `snowflake_list_tables` | List tables inside a schema | +| `snowflake_get_table_columns` | Get column definitions for a table | +| `snowflake_preview_table` | Preview sample rows | +| `snowflake_select_rows` | Query rows with optional WHERE filter | +| `snowflake_get_distinct_values` | Get distinct values from a column | +| `snowflake_count_rows` | Count rows with optional filter | + +## Prerequisites + +- Python 3.12+ +- Snowflake credentials (for data exploration and cohort execution) +- Codelists (optional but highly recommended) + +## Setup + +```bash +git clone https://github.com/Bayer-Group/PhenEx phenex-mcp +cd phenex-mcp/mcp +``` + +### 1. Create a virtual environment and install dependencies + +```bash +python -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +pip install -r requirements.txt +``` + +### 2. Configure Snowflake credentials + +```bash +cp mcp/.env.example mcp/.env +# Edit mcp/.env with your configuration details +``` + +### 3. Test the server + +**All commands must be run from the mcp/ directory** + +```bash +# Interactive testing with MCP Inspector (run from repo root) +npx @modelcontextprotocol/inspector bash start.sh +``` + +This opens a web UI where you can browse tools, call them with sample inputs, and see responses — no AI client needed. + +## Running + +**Always run from the repo root directory.** + +### stdio (default — for Claude Desktop, VS Code, etc.) + +```bash +bash start.sh +``` + +### HTTP (for remote / multi-client access) + +```bash +bash start_http.sh +# or with custom port: +MCP_PORT=8080 bash start_http.sh +``` + +### Docker + +Build and run the MCP server as a Docker container. The build context is the **mcp/** directory. + +```bash +# Build (from mcp/) +cd mcp +docker build -t phenex-mcp . + +# Run — Inspector UI on http://localhost:6274 +docker run --rm -p 6274:6274 -p 6277:6277 \ + --env-file .env \ + phenex-mcp +``` + +To mount a local codelists directory into the container: + +```bash +docker run --rm -p 6274:6274 -p 6277:6277 \ + --env-file .env \ + -v /path/to/codelists:/codelists \ + -e PHENEX_CODELIST_DIR=/codelists \ + phenex-mcp +``` + +## Client Configuration + +### LLM Instructions + +The file `mcp/llm-instructions.md` contains detailed guidance for AI assistants on how to use the PhenEx tools effectively. Copy its contents into your client's system prompt or custom instructions: + +- **Claude Desktop** — paste into your Project's custom instructions +- **VS Code (Copilot)** — copy to `.github/copilot-instructions.md` in your workspace +- **Cursor** — paste into `.cursor/rules` + +### Claude Desktop + +Edit `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS) or `%APPDATA%\Claude\claude_desktop_config.json` (Windows): + +```json +{ + "mcpServers": { + "phenex": { + "command": "bash", + "args": ["/absolute/path/to/PhenEx/mcp/start.sh"], + "env": {} + } + } +} +``` + +### VS Code (Copilot) + +Add to your `.vscode/mcp.json` or workspace settings: + +```json +{ + "servers": { + "phenex": { + "command": "bash", + "args": ["${workspaceFolder}/mcp/start.sh"], + "env": {} + } + } +} +``` + +### Cursor + +Add to `.cursor/mcp.json`: + +```json +{ + "mcpServers": { + "phenex": { + "command": "bash", + "args": ["/absolute/path/to/PhenEx/mcp/start.sh"], + "env": {} + } + } +} +``` + +## File Structure + +``` +mcp/ +├── server.py # Main FastMCP server with tool registrations +├── phenotype_registry.py # PhenEx phenotype class registry +├── codelist_store.py # Load and serve codelists from CSV/Excel files +├── snowflake_explorer.py # Snowflake data warehouse utilities +├── cohort_tools.py # Cohort validation, translation, execution +├── llm-instructions.md # Instructions for LLMs using this server +├── mcp.json # Example MCP client config +├── start.sh # Launch script (stdio) +├── start_http.sh # Launch script (HTTP) +├── .env.example # Environment variable template +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +## Environment Variables + +| Variable | Required | Description | +| ---------------------------------- | ------------------ | ----------------------------------------------- | +| `SNOWFLAKE_USER` | Yes (for SF tools) | Snowflake username | +| `SNOWFLAKE_PASSWORD` | Yes (for SF tools) | Snowflake password | +| `SNOWFLAKE_ACCOUNT` | Yes (for SF tools) | Snowflake account identifier | +| `SNOWFLAKE_WAREHOUSE` | Yes (for SF tools) | Snowflake warehouse name | +| `SNOWFLAKE_ROLE` | Yes (for SF tools) | Snowflake role | +| `SNOWFLAKE_SOURCE_DATABASE` | For execution | Source database for cohort execution | +| `SNOWFLAKE_SOURCE_SCHEMA` | For execution | Source schema for cohort execution | +| `SNOWFLAKE_DEST_DATABASE` | For execution | Destination database (schema is auto-generated) | +| `PHENEX_CODELIST_DIR` | For codelist tools | Directory containing codelist CSV/Excel files | +| `PHENEX_CODELIST_CODE_COLUMN` | No | Code column name (default `code`) | +| `PHENEX_CODELIST_NAME_COLUMN` | No | Codelist name column (default `codelist`) | +| `PHENEX_CODELIST_CODE_TYPE_COLUMN` | No | Code type column (default `code_type`) | +| `MCP_TRANSPORT` | No | `stdio` (default), `streamable-http`, or `sse` | +| `MCP_HOST` | No | HTTP host (default `0.0.0.0`) | +| `MCP_PORT` | No | HTTP port (default `6277`) | +| `LOG_LEVEL` | No | Logging level (default `INFO`) | diff --git a/mcp/code_generator.py b/mcp/code_generator.py new file mode 100644 index 00000000..9d3e0ca1 --- /dev/null +++ b/mcp/code_generator.py @@ -0,0 +1,365 @@ +""" +Generate idiomatic Python code from any PhenEx definition dict. + +The JSON dict is first compiled via from_dict() to validate it, then +to_dict() is called on the compiled object to get the canonical form. +That canonical dict is walked recursively to emit Python constructor calls. + +Works for any PhenEx class: Cohort, individual phenotypes, filters, +codelists, etc. +""" + +from typing import Dict, Any, List, Set + +# Maps class_name → module path for imports. +_IMPORT_MAP = { + # Phenotypes + "CodelistPhenotype": "phenex.phenotypes", + "AgePhenotype": "phenex.phenotypes", + "SexPhenotype": "phenex.phenotypes", + "MeasurementPhenotype": "phenex.phenotypes", + "MeasurementChangePhenotype": "phenex.phenotypes", + "EventCountPhenotype": "phenex.phenotypes", + "TimeRangePhenotype": "phenex.phenotypes", + "TimeRangeCountPhenotype": "phenex.phenotypes", + "TimeRangeDayCountPhenotype": "phenex.phenotypes", + "TimeRangeDaysToNextRange": "phenex.phenotypes", + "DeathPhenotype": "phenex.phenotypes", + "CategoricalPhenotype": "phenex.phenotypes", + "BinPhenotype": "phenex.phenotypes", + "ScorePhenotype": "phenex.phenotypes", + "ArithmeticPhenotype": "phenex.phenotypes", + "LogicPhenotype": "phenex.phenotypes", + "WithinSameEncounterPhenotype": "phenex.phenotypes", + # Filters + "RelativeTimeRangeFilter": "phenex.filters", + "ValueFilter": "phenex.filters", + "DateFilter": "phenex.filters", + "CategoricalFilter": "phenex.filters", + "CodelistFilter": "phenex.filters", + "TimeRangeFilter": "phenex.filters", + "Before": "phenex.filters", + "BeforeOrOn": "phenex.filters", + "After": "phenex.filters", + "AfterOrOn": "phenex.filters", + "Date": "phenex.filters", + "GreaterThan": "phenex.filters", + "GreaterThanOrEqualTo": "phenex.filters", + "LessThan": "phenex.filters", + "LessThanOrEqualTo": "phenex.filters", + "EqualTo": "phenex.filters", + "Value": "phenex.filters", + # Codelists + "Codelist": "phenex.codelists", + # Core + "Cohort": "phenex.core", + "Subcohort": "phenex.core", + "Study": "phenex.core", +} + +# Classes whose child phenotypes should be extracted as named variables +# for readability (rather than inlined in the constructor). +_COHORT_LIKE = {"Cohort", "Subcohort", "Study"} + +# Slots on Cohort/Subcohort that hold phenotype objects or lists of them. +_PHENOTYPE_SLOTS = ( + "entry_criterion", + "inclusions", + "exclusions", + "characteristics", + "outcomes", +) + + +# --------------------------------------------------------------------------- +# Low-level emitters +# --------------------------------------------------------------------------- + + +def _is_phenex_object(value: Any) -> bool: + """True if value is a dict with a class_name key (i.e. a serialized PhenEx object).""" + return isinstance(value, dict) and "class_name" in value + + +def _emit_value(value: Any, imports: Set[str], indent: int) -> str: + """Convert a value to its Python source representation.""" + pad = " " * indent + + if _is_phenex_object(value): + return _emit_constructor(value, imports, indent) + + if isinstance(value, dict) and "__datetime__" in value: + imports.add("from datetime import date") + return f"date.fromisoformat({value['__datetime__']!r})" + + if isinstance(value, dict): + if not value: + return "{}" + items = [] + for k, v in value.items(): + key_repr = repr(k) + val_repr = _emit_value(v, imports, indent + 1) + items.append(f"{key_repr}: {val_repr}") + one_line = "{" + ", ".join(items) + "}" + if len(one_line) < 80: + return one_line + inner_pad = " " * (indent + 1) + lines = [f"{inner_pad}{item}," for item in items] + return "{\n" + "\n".join(lines) + "\n" + pad + "}" + + if isinstance(value, list): + if not value: + return "[]" + rendered = [_emit_value(v, imports, indent + 1) for v in value] + if all(isinstance(v, (str, int, float, bool, type(None))) for v in value): + one_line = "[" + ", ".join(rendered) + "]" + if len(one_line) < 80: + return one_line + inner_pad = " " * (indent + 1) + lines = [f"{inner_pad}{r}," for r in rendered] + return "[\n" + "\n".join(lines) + "\n" + pad + "]" + + return repr(value) + + +def _emit_constructor(obj_dict: Dict[str, Any], imports: Set[str], indent: int) -> str: + """Emit ClassName(param=value, ...), skipping None and empty-list params.""" + class_name = obj_dict["class_name"] + pad = " " * indent + inner_pad = " " * (indent + 1) + + module = _IMPORT_MAP.get(class_name) + if module: + imports.add(f"from {module} import {class_name}") + + params = [] + for key, value in obj_dict.items(): + if key == "class_name": + continue + if value is None: + continue + if isinstance(value, list) and len(value) == 0: + continue + val_str = _emit_value(value, imports, indent + 1) + params.append((key, val_str)) + + if not params: + return f"{class_name}()" + + if len(params) == 1: + k, v = params[0] + one_line = f"{class_name}({k}={v})" + if len(one_line) < 80: + return one_line + + lines = [f"{inner_pad}{k}={v}," for k, v in params] + return f"{class_name}(\n" + "\n".join(lines) + "\n" + pad + ")" + + +def _build_variable_name(obj_dict: Dict[str, Any]) -> str: + """Derive a Python variable name from a PhenEx object dict.""" + name = obj_dict.get("name") + if name: + clean = name.lower().replace(" ", "_").replace("-", "_") + clean = "".join(c for c in clean if c.isalnum() or c == "_") + return clean + class_name = obj_dict.get("class_name", "obj") + return class_name[0].lower() + class_name[1:] + + +# --------------------------------------------------------------------------- +# High-level generators +# --------------------------------------------------------------------------- + + +def _generate_cohort_python(canonical: Dict[str, Any]) -> str: + """ + Generate Python for a Cohort-like object, extracting child phenotypes as + named variables for readability. + """ + imports: Set[str] = set() + lines: List[str] = [] + used_vars: Set[str] = set() + + class_name = canonical.get("class_name", "Cohort") + + def _unique_var(base: str) -> str: + """Ensure no duplicate variable names.""" + candidate = base + counter = 2 + while candidate in used_vars: + candidate = f"{base}_{counter}" + counter += 1 + used_vars.add(candidate) + return candidate + + # Extract phenotype slots as named variables + slot_var_names: Dict[str, Any] = {} # slot -> var_name or [var_names] + + for slot in _PHENOTYPE_SLOTS: + value = canonical.get(slot) + if value is None: + continue + + if _is_phenex_object(value): + var = _unique_var(_build_variable_name(value)) + lines.append(f"{var} = {_emit_constructor(value, imports, 0)}") + lines.append("") + slot_var_names[slot] = var + + elif isinstance(value, list) and value: + var_names = [] + for item in value: + if _is_phenex_object(item): + var = _unique_var(_build_variable_name(item)) + lines.append(f"{var} = {_emit_constructor(item, imports, 0)}") + lines.append("") + var_names.append(var) + if var_names: + slot_var_names[slot] = var_names + + # Build Cohort constructor + module = _IMPORT_MAP.get(class_name) + if module: + imports.add(f"from {module} import {class_name}") + + cohort_params: List[str] = [] + cohort_name = canonical.get("name") + if cohort_name: + cohort_params.append(f" name={cohort_name!r},") + + for slot in _PHENOTYPE_SLOTS: + if slot not in slot_var_names: + continue + ref = slot_var_names[slot] + if isinstance(ref, str): + cohort_params.append(f" {slot}={ref},") + elif isinstance(ref, list): + list_str = "[" + ", ".join(ref) + "]" + cohort_params.append(f" {slot}={list_str},") + + desc = canonical.get("description") + if desc: + cohort_params.append(f" description={desc!r},") + + cohort_var = _unique_var(_build_variable_name(canonical)) + lines.append(f"{cohort_var} = {class_name}(") + lines.extend(cohort_params) + lines.append(")") + + import_block = "\n".join(sorted(imports)) + return import_block + "\n\n\n" + "\n".join(lines) + "\n" + + +def _generate_simple_python(canonical: Dict[str, Any]) -> str: + """Generate Python for any non-Cohort PhenEx object (single assignment).""" + imports: Set[str] = set() + var = _build_variable_name(canonical) + constructor = _emit_constructor(canonical, imports, 0) + body = f"{var} = {constructor}\n" + import_block = "\n".join(sorted(imports)) + return import_block + "\n\n\n" + body + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def generate_python(definition: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate idiomatic Python from any PhenEx definition dict. + + Accepts any PhenEx expression: Cohort, phenotype, filter, codelist, etc. + The dict is first compiled with from_dict() to validate correctness, then + to_dict() is called on the compiled object to get the canonical form which + is walked to emit Python. + + Returns a dict with 'success', 'code' (the Python string), and 'error'. + """ + # --- Input validation --- + if not isinstance(definition, dict): + return { + "success": False, + "error": f"Expected a dictionary, got {type(definition).__name__}.", + } + + class_name = definition.get("class_name") or definition.get("type") + if not class_name: + return { + "success": False, + "error": ( + "Definition must have a 'class_name' (or 'type') field. " + "Call phenex_list_classes() to see valid class names." + ), + } + + if class_name not in _IMPORT_MAP: + return { + "success": False, + "error": ( + f"Unknown class '{class_name}'. " + f"Call phenex_list_classes() to see valid class names." + ), + } + + try: + from phenex.util.serialization.from_dict import from_dict + from phenex.util.serialization.to_dict import to_dict + except ImportError: + return { + "success": False, + "error": "PhenEx library not available. Install with: pip install phenex", + } + + from cohort_tools import ( + _prepare_cohort_for_compilation, + translate_phenotype_to_native, + ) + + # --- Prepare (type->class_name, codelist resolution, filter wrapping) --- + try: + if class_name in _COHORT_LIKE: + prepared = _prepare_cohort_for_compilation(definition) + elif class_name in _IMPORT_MAP and _IMPORT_MAP[class_name] in ( + "phenex.phenotypes", + ): + # Only phenotypes need the full translation (codelist wrapping, etc.) + prepared = translate_phenotype_to_native(definition) + else: + # Filters, codelists, etc. — just normalize type->class_name + prepared = definition.copy() + if "type" in prepared and "class_name" not in prepared: + prepared["class_name"] = prepared.pop("type") + except Exception as e: + return { + "success": False, + "error": f"Failed to prepare definition: {type(e).__name__}: {e}", + } + + # --- Compile to validate --- + try: + compiled = from_dict(prepared) + except Exception as e: + return { + "success": False, + "error": ( + f"Definition failed to compile: {type(e).__name__}: {e}. " + f"Fix errors first (use phenex_validate_phenotype or " + f"phenex_validate_cohort), then try again." + ), + } + + # --- Serialize to canonical dict --- + canonical = to_dict(compiled) + + # --- Generate Python --- + if canonical.get("class_name") in _COHORT_LIKE: + code = _generate_cohort_python(canonical) + else: + code = _generate_simple_python(canonical) + + return { + "success": True, + "code": code, + } diff --git a/mcp/codelist_store.py b/mcp/codelist_store.py new file mode 100644 index 00000000..eeb0f1b0 --- /dev/null +++ b/mcp/codelist_store.py @@ -0,0 +1,230 @@ +""" +Codelist store: load and serve codelists from a directory of CSV/Excel files. + +Expects one or more CSV/Excel files in the configured directory. Each file +must contain columns for code, codelist name, and code type (column names +are configurable via environment variables). + +The store is loaded lazily on first access and cached for the process lifetime. +""" + +import os +from pathlib import Path +from typing import Dict, List, Any, Optional + +from phenex.codelists.factory import LocalFileCodelistFactory +from phenex.codelists import Codelist + + +# --------------------------------------------------------------------------- +# Module-level cache +# --------------------------------------------------------------------------- +_factories: Optional[List[LocalFileCodelistFactory]] = None + + +def _get_codelist_dir() -> Path: + """Return the configured codelist directory, raising if not set or missing.""" + codelist_dir = os.getenv("PHENEX_CODELIST_DIR") + if not codelist_dir: + raise ValueError( + "PHENEX_CODELIST_DIR environment variable is not set. " + "Set it to a directory containing codelist CSV/Excel files." + ) + p = Path(codelist_dir) + if not p.is_dir(): + raise ValueError( + f"PHENEX_CODELIST_DIR '{codelist_dir}' is not a directory or does not exist." + ) + return p + + +def _get_column_config() -> dict: + """Return column name overrides from env vars (with sensible defaults).""" + return { + "name_code_column": os.getenv("PHENEX_CODELIST_CODE_COLUMN", "code"), + "name_codelist_column": os.getenv("PHENEX_CODELIST_NAME_COLUMN", "codelist"), + "name_code_type_column": os.getenv( + "PHENEX_CODELIST_CODE_TYPE_COLUMN", "code_type" + ), + } + + +def _load_factories() -> List[LocalFileCodelistFactory]: + """Scan the codelist directory and create a factory per file.""" + global _factories + if _factories is not None: + return _factories + + codelist_dir = _get_codelist_dir() + col_cfg = _get_column_config() + + extensions = {".csv", ".xlsx"} + files = sorted( + f + for f in codelist_dir.iterdir() + if f.suffix.lower() in extensions and not f.name.startswith(".") + ) + + if not files: + raise ValueError(f"No CSV or Excel files found in '{codelist_dir}'.") + + factories = [] + for f in files: + factories.append( + LocalFileCodelistFactory( + path=str(f), + **col_cfg, + ) + ) + + _factories = factories + return _factories + + +def _build_index() -> Dict[str, LocalFileCodelistFactory]: + """Build a {codelist_name: factory} index across all files.""" + factories = _load_factories() + index: Dict[str, LocalFileCodelistFactory] = {} + for factory in factories: + for name in factory.get_codelists(): + index[name] = factory # last-one-wins if duplicates across files + return index + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +MAX_SAMPLE_CODES = 10 +MAX_RESULTS = 25 + + +def _summarize_codelist(name: str, factory: LocalFileCodelistFactory) -> Dict[str, Any]: + """Build a summary dict for a single codelist.""" + codelist: Codelist = factory.get_codelist(name) + cl_dict = codelist.codelist + + code_types = [ct for ct in cl_dict.keys() if ct is not None] or [None] + total_codes = sum(len(codes) for codes in cl_dict.values()) + + sample = [] + for ct, codes in cl_dict.items(): + for code in codes: + if len(sample) >= MAX_SAMPLE_CODES: + break + sample.append({"code": str(code), "code_type": ct}) + if len(sample) >= MAX_SAMPLE_CODES: + break + + return { + "name": name, + "code_types": code_types, + "total_codes": total_codes, + "sample_codes": sample, + } + + +def find_codelists( + name_pattern: Optional[str] = None, + code_type_pattern: Optional[str] = None, +) -> Dict[str, Any]: + """ + Search codelists by name and/or code type using regex patterns. + + Returns dict with 'codelists' (list of summaries), 'count', and 'truncated'. + """ + import re + + index = _build_index() + + # Compile patterns (case-insensitive) + name_re = re.compile(name_pattern, re.IGNORECASE) if name_pattern else None + code_type_re = ( + re.compile(code_type_pattern, re.IGNORECASE) if code_type_pattern else None + ) + + matched = [] + for name in sorted(index.keys()): + # Filter by name + if name_re and not name_re.search(name): + continue + + # Filter by code type (need to load the codelist to check) + if code_type_re: + factory = index[name] + codelist: Codelist = factory.get_codelist(name) + code_types = [ct for ct in codelist.codelist.keys() if ct is not None] + if not any(code_type_re.search(ct) for ct in code_types): + continue + + matched.append(name) + + total_matched = len(matched) + truncated = total_matched > MAX_RESULTS + matched = matched[:MAX_RESULTS] + + summaries = [_summarize_codelist(n, index[n]) for n in matched] + + return { + "codelists": summaries, + "count": total_matched, + "returned": len(summaries), + "truncated": truncated, + } + + +def list_available_codelists() -> Dict[str, Any]: + """ + List all codelists found in the configured directory. + + Returns dict with 'codelists' (list of summaries) and 'count'. + Each summary includes name, code_types, total_codes, and a sample of codes. + """ + index = _build_index() + summaries = [ + _summarize_codelist(name, factory) for name, factory in sorted(index.items()) + ] + return {"codelists": summaries, "count": len(summaries)} + + +def get_codelist(name: str) -> Dict[str, Any]: + """ + Return the full contents of a single codelist by name. + + Returns dict with name, code_types, total_codes, and the full codelist dict. + """ + index = _build_index() + + if name not in index: + available = sorted(index.keys()) + import difflib + + close = difflib.get_close_matches(name, available, n=3, cutoff=0.4) + hint = f" Did you mean: {', '.join(close)}?" if close else "" + return { + "error": ( + f"Codelist '{name}' not found.{hint} " + f"Call phenex_find_codelists() to search available codelist names." + ), + "available_codelists": available, + } + + factory = index[name] + codelist: Codelist = factory.get_codelist(name) + cl_dict = codelist.codelist + + code_types = [ct for ct in cl_dict.keys() if ct is not None] or [None] + total_codes = sum(len(codes) for codes in cl_dict.values()) + + # Convert keys to strings for JSON safety (None -> "null") + serialized = {} + for ct, codes in cl_dict.items(): + key = ct if ct is not None else "null" + serialized[key] = [str(c) for c in codes] + + return { + "name": name, + "code_types": code_types, + "total_codes": total_codes, + "codelist": serialized, + } diff --git a/mcp/cohort_tools.py b/mcp/cohort_tools.py new file mode 100644 index 00000000..f813cfcb --- /dev/null +++ b/mcp/cohort_tools.py @@ -0,0 +1,703 @@ +""" +Cohort validation, translation, and execution helpers for PhenEx. +""" + +import re +import os +import traceback +from typing import Dict, Any, List, Optional + +# Known phenotype class names for error guidance +KNOWN_PHENOTYPE_TYPES = [ + "CodelistPhenotype", + "AgePhenotype", + "SexPhenotype", + "MeasurementPhenotype", + "MeasurementChangePhenotype", + "EventCountPhenotype", + "TimeRangePhenotype", + "TimeRangeCountPhenotype", + "TimeRangeDayCountPhenotype", + "TimeRangeDaysToNextRange", + "DeathPhenotype", + "CategoricalPhenotype", + "BinPhenotype", + "ScorePhenotype", + "ArithmeticPhenotype", + "LogicPhenotype", + "WithinSameEncounterPhenotype", +] + +# Common required fields per phenotype type +REQUIRED_FIELDS_BY_TYPE = { + "CodelistPhenotype": ["domain", "codelist"], + "MeasurementPhenotype": ["domain", "codelist"], + "AgePhenotype": ["name"], + "SexPhenotype": ["name"], + "DeathPhenotype": ["name"], + "EventCountPhenotype": ["name", "input_phenotype"], + "ScorePhenotype": ["name", "expression"], + "ArithmeticPhenotype": ["name", "expression"], + "LogicPhenotype": ["name", "expression"], + "BinPhenotype": ["name", "input_phenotype"], +} + + +def _get_close_matches(name: str, candidates: List[str], n: int = 3) -> List[str]: + """Return candidate strings that are close matches to name (case-insensitive).""" + import difflib + + # Try case-insensitive matching + lower_map = {c.lower(): c for c in candidates} + matches = difflib.get_close_matches(name.lower(), lower_map.keys(), n=n, cutoff=0.5) + return [lower_map[m] for m in matches] + + +def _diagnose_compilation_error( + error: Exception, pheno_type: str, definition: Dict[str, Any] +) -> str: + """Produce an actionable remediation hint from a compilation exception.""" + err_str = str(error) + err_type = type(error).__name__ + + # Unknown class_name in from_dict + if err_type == "KeyError" and "class_name" not in definition: + return ( + f"The class '{err_str}' is not recognized by PhenEx. " + f"Call phenex_list_classes() to see valid class names." + ) + + # Missing required parameter + if "required" in err_str.lower() or "missing" in err_str.lower(): + required = REQUIRED_FIELDS_BY_TYPE.get(pheno_type, []) + missing = [f for f in required if f not in definition] + if missing: + return ( + f"Missing required field(s): {missing}. " + f"Call phenex_inspect_class('{pheno_type}') to see all required parameters." + ) + return ( + f"{err_type}: {err_str}. " + f"Call phenex_inspect_class('{pheno_type}') to see all required parameters and their types." + ) + + # Type errors (e.g. passing string where int expected) + if err_type == "TypeError": + return ( + f"Type mismatch: {err_str}. " + f"Call phenex_inspect_class('{pheno_type}') to check the expected types for each parameter." + ) + + # Assertion errors (e.g. invalid domain, bad operator) + if err_type == "AssertionError": + return ( + f"Validation failed: {err_str}. " + f"Check that domain, operator, and enum-like fields use exact expected values. " + f"Call phenex_inspect_class('{pheno_type}') for valid options." + ) + + # Value errors + if err_type == "ValueError": + return ( + f"Invalid value: {err_str}. " + f"Call phenex_inspect_class('{pheno_type}') to review accepted values for each parameter." + ) + + # Generic fallback — still actionable + return ( + f"{err_type}: {err_str}. " + f"Call phenex_inspect_class('{pheno_type}') to review the full specification and examples." + ) + + +def _resolve_codelist_reference(name: str) -> Dict[str, Any]: + """ + Look up a codelist by name from the codelist store and return it as a + native from_dict()-compatible Codelist dict. + + Raises ValueError if the codelist is not found. + """ + import codelist_store + + result = codelist_store.get_codelist(name) + if "error" in result: + raise ValueError(result["error"]) + + return { + "class_name": "Codelist", + "name": name, + "codelist": result["codelist"], + "use_code_type": True, + "remove_punctuation": False, + } + + +def translate_phenotype_to_native(pheno: Dict[str, Any]) -> Dict[str, Any]: + """ + Translate a single phenotype from simplified tool format to PhenEx native from_dict() format. + + Simplified format (what the tool docs show): + {"type": "CodelistPhenotype", "name": "af", "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": {"ICD10CM": ["I48.0", "I48.1"]}, "remove_punctuation": true, "return_date": "first"} + + Codelist by reference (name from codelist store): + {"type": "CodelistPhenotype", "name": "af", "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": "atrial_fibrillation", "return_date": "first"} + + PhenEx native format (what from_dict() expects): + {"class_name": "CodelistPhenotype", "name": "af", "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": {"class_name": "Codelist", "name": "af_codes", "codelist": {"ICD10CM": [...]}, + "use_code_type": true, "remove_punctuation": true}, + "return_date": "first"} + """ + native = pheno.copy() + + # Convert 'type' -> 'class_name' + if "type" in native and "class_name" not in native: + native["class_name"] = native.pop("type") + + # Resolve codelist by reference (string = name in codelist store) + if "codelist" in native and isinstance(native["codelist"], str): + codelist_name = native["codelist"] + resolved = _resolve_codelist_reference(codelist_name) + # Allow phenotype-level overrides for remove_punctuation / use_code_type + if "remove_punctuation" in native: + resolved["remove_punctuation"] = native.pop("remove_punctuation") + if "use_code_type" in native: + resolved["use_code_type"] = native.pop("use_code_type") + native["codelist"] = resolved + + # Convert flat codelist dict to wrapped Codelist object + elif ( + "codelist" in native + and isinstance(native["codelist"], dict) + and "class_name" not in native["codelist"] + ): + codelist_dict = native["codelist"] + # These belong on the Codelist, not the phenotype — pull them from BOTH locations + remove_punctuation = codelist_dict.pop( + "remove_punctuation", native.pop("remove_punctuation", False) + ) + use_code_type = codelist_dict.pop( + "use_code_type", native.pop("use_code_type", True) + ) + native["codelist"] = { + "class_name": "Codelist", + "name": native.get("name", "codelist") + "_codes", + "codelist": codelist_dict, + "use_code_type": use_code_type, + "remove_punctuation": remove_punctuation, + } + + # Recursively translate anchor_phenotype in relative_time_range if present + if "relative_time_range" in native and isinstance( + native["relative_time_range"], dict + ): + rtrf = native["relative_time_range"] + if "class_name" not in rtrf: + rtrf["class_name"] = "RelativeTimeRangeFilter" + if "anchor_phenotype" in rtrf and isinstance(rtrf["anchor_phenotype"], dict): + rtrf["anchor_phenotype"] = translate_phenotype_to_native( + rtrf["anchor_phenotype"] + ) + + # Translate date_range if present + if "date_range" in native and isinstance(native["date_range"], dict): + dr = native["date_range"] + if "class_name" not in dr: + dr["class_name"] = "DateFilter" + + # Translate categorical_filter if present + if "categorical_filter" in native and isinstance( + native["categorical_filter"], dict + ): + cf = native["categorical_filter"] + if "class_name" not in cf: + cf["class_name"] = "CategoricalFilter" + + return native + + +def _prepare_phenotype_for_compilation(pheno: Any) -> Any: + """Apply translate_phenotype_to_native to a phenotype dict (or each item in a list).""" + if isinstance(pheno, list): + return [_prepare_phenotype_for_compilation(p) for p in pheno] + if isinstance(pheno, dict): + return translate_phenotype_to_native(pheno) + return pheno + + +def _prepare_cohort_for_compilation( + cohort_definition: Dict[str, Any] +) -> Dict[str, Any]: + """ + Walk a native-format cohort dict and apply phenotype translation + (type->class_name, codelist resolution, filter wrapping) to every + phenotype slot. Returns a copy ready for from_dict(). + """ + native = cohort_definition.copy() + + if "class_name" not in native: + native["class_name"] = "Cohort" + + # Translate each phenotype slot + if "entry_criterion" in native: + native["entry_criterion"] = _prepare_phenotype_for_compilation( + native["entry_criterion"] + ) + + for key in ("inclusions", "exclusions", "characteristics", "outcomes"): + if key in native and native[key]: + native[key] = _prepare_phenotype_for_compilation(native[key]) + + return native + + +def validate_phenotype(phenotype_definition: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate a single phenotype definition by attempting to compile it via from_dict(). + + Returns a dict with 'valid', 'errors', and the compiled class name. + """ + errors = [] + warnings = [] + + if not isinstance(phenotype_definition, dict): + return { + "valid": False, + "errors": ["phenotype_definition must be a dictionary"], + "warnings": [], + } + + # Must have a type/class_name + pheno_type = phenotype_definition.get("type") or phenotype_definition.get( + "class_name" + ) + if not pheno_type: + return { + "valid": False, + "errors": [ + "phenotype_definition must have a 'type' field. " + "Set 'type' to a phenotype class name such as 'CodelistPhenotype'. " + "Call phenex_list_classes() to see all valid type names." + ], + "warnings": [], + } + + # Check for unknown phenotype type early + if pheno_type not in KNOWN_PHENOTYPE_TYPES: + close = _get_close_matches(pheno_type, KNOWN_PHENOTYPE_TYPES) + hint = f" Did you mean: {', '.join(close)}?" if close else "" + warnings.append( + f"Unrecognized phenotype type '{pheno_type}'.{hint} " + f"Call phenex_list_classes() to see all valid type names." + ) + + # Check for commonly missing required fields before compilation + required = REQUIRED_FIELDS_BY_TYPE.get(pheno_type, []) + missing = [f for f in required if f not in phenotype_definition] + if missing: + errors.append( + f"Missing required field(s) for {pheno_type}: {missing}. " + f"Call phenex_inspect_class('{pheno_type}') to see all required parameters." + ) + return { + "valid": False, + "errors": errors, + "warnings": warnings, + "phenotype_name": name, + "phenotype_type": pheno_type, + "message": f"Phenotype '{name}' ({pheno_type}) is missing required fields: {missing}", + } + + name = phenotype_definition.get("name", "unnamed") + + try: + from phenex.util.serialization.from_dict import from_dict + + native = translate_phenotype_to_native(phenotype_definition.copy()) + compiled = from_dict(native) + + return { + "valid": True, + "errors": [], + "warnings": warnings, + "phenotype_name": name, + "phenotype_type": pheno_type, + "compiled_class": type(compiled).__name__, + "message": f"Phenotype '{name}' ({pheno_type}) compiles successfully", + } + except Exception as e: + remediation = _diagnose_compilation_error(e, pheno_type, phenotype_definition) + return { + "valid": False, + "errors": [remediation], + "warnings": warnings, + "phenotype_name": name, + "phenotype_type": pheno_type, + "message": f"Phenotype '{name}' ({pheno_type}) failed to compile", + } + + +def validate_cohort( + cohort_definition: Dict[str, Any], cohort_name: str +) -> Dict[str, Any]: + """ + Validate a cohort definition without executing it. + + Expects the native Cohort format with entry_criterion, inclusions, etc. + Returns a dict with 'valid', 'errors', 'warnings', and metadata. + """ + errors = [] + warnings = [] + phenotypes_used = [] + + try: + # 1. Validate cohort_name format + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", cohort_name): + errors.append( + f"Invalid cohort_name '{cohort_name}'. " + f"Must start with a letter and contain only alphanumeric characters and underscores. " + f"Example: 'af_cohort_v1'. This name is used to create the output schema PHENEX_AI__{cohort_name.upper()}." + ) + + target_schema = f"PHENEX_AI__{cohort_name.upper()}" + + # 2. Validate cohort_definition structure + if not isinstance(cohort_definition, dict): + errors.append("cohort_definition must be a dictionary") + return { + "valid": False, + "errors": errors, + "warnings": warnings, + "cohort_name": cohort_name, + "target_schema": target_schema, + } + + # 3. Validate required fields + if "name" not in cohort_definition: + errors.append( + "Cohort definition missing required field: 'name'. " + "Call phenex_inspect_class('Cohort') to see the full specification." + ) + + if "entry_criterion" not in cohort_definition: + errors.append( + "Cohort definition missing required field: 'entry_criterion'. " + "This must be a phenotype dict that defines the index date. " + "Call phenex_inspect_class('Cohort') to see the full specification." + ) + elif not isinstance(cohort_definition["entry_criterion"], dict): + errors.append( + f"'entry_criterion' must be a phenotype dictionary, " + f"not {type(cohort_definition['entry_criterion']).__name__}." + ) + else: + entry = cohort_definition["entry_criterion"] + entry_type = entry.get("type") or entry.get("class_name") + if entry_type: + phenotypes_used.append(entry_type) + else: + errors.append( + "'entry_criterion' phenotype must have a 'type' (or 'class_name') field. " + "Call phenex_list_classes() to see valid type names." + ) + + # Warn about legacy flat-list format + if "phenotypes" in cohort_definition: + errors.append( + "Found 'phenotypes' key — this flat-list format is no longer supported. " + "Use the native Cohort structure with 'entry_criterion', 'inclusions', " + "'exclusions', 'characteristics', and 'outcomes' as separate keys. " + "Call phenex_inspect_class('Cohort') to see the expected structure." + ) + + # 4. Validate optional list fields + for field in ("inclusions", "exclusions", "characteristics", "outcomes"): + if field in cohort_definition and cohort_definition[field] is not None: + val = cohort_definition[field] + if not isinstance(val, list): + errors.append( + f"'{field}' must be a list of phenotype dictionaries, " + f"not {type(val).__name__}." + ) + continue + for i, pheno in enumerate(val): + if not isinstance(pheno, dict): + errors.append( + f"'{field}[{i}]' is a {type(pheno).__name__}, not a dictionary." + ) + continue + pheno_type = pheno.get("type") or pheno.get("class_name") + if pheno_type: + phenotypes_used.append(pheno_type) + if pheno_type not in KNOWN_PHENOTYPE_TYPES: + close = _get_close_matches( + pheno_type, KNOWN_PHENOTYPE_TYPES + ) + hint = ( + f" Did you mean: {', '.join(close)}?" if close else "" + ) + warnings.append( + f"'{field}[{i}]': unrecognized type '{pheno_type}'.{hint} " + f"Call phenex_list_classes() for valid types." + ) + + # Count phenotypes + phenotype_count = 1 if "entry_criterion" in cohort_definition else 0 + for field in ("inclusions", "exclusions", "characteristics", "outcomes"): + items = cohort_definition.get(field) + if isinstance(items, list): + phenotype_count += len(items) + + # Additional validation + if cohort_definition.get("name") and cohort_definition["name"] != cohort_name: + warnings.append( + f"cohort_definition.name ('{cohort_definition['name']}') doesn't match " + f"cohort_name parameter ('{cohort_name}')" + ) + + # Deep validation: attempt actual from_dict() compilation + if len(errors) == 0: + try: + from phenex.util.serialization.from_dict import from_dict + + native_def = _prepare_cohort_for_compilation(cohort_definition.copy()) + _compiled = from_dict(native_def) + except KeyError as key_err: + errors.append( + f"Unknown class name '{key_err}' encountered during compilation. " + f"Check that all 'type' values are valid PhenEx class names. " + f"Call phenex_list_classes() for valid names." + ) + except Exception as compile_err: + err_msg = str(compile_err) + errors.append( + f"Compilation failed: {type(compile_err).__name__}: {err_msg}. " + f"Try validating each phenotype individually with phenex_validate_phenotype() " + f"to isolate the problem, then call phenex_inspect_class() for that type." + ) + + is_valid = len(errors) == 0 + + return { + "valid": is_valid, + "errors": errors, + "warnings": warnings, + "cohort_name": cohort_name, + "target_schema": target_schema, + "phenotypes_used": list(set(phenotypes_used)), + "phenotype_count": phenotype_count, + "message": ( + "Cohort definition is valid" + if is_valid + else f"Validation failed with {len(errors)} error(s)" + ), + } + + except Exception as e: + return { + "valid": False, + "errors": [ + f"Unexpected validation error: {type(e).__name__}: {str(e)}. " + f"Try validating each phenotype individually with phenex_validate_phenotype() " + f"to narrow down the issue." + ], + "warnings": warnings, + "cohort_name": cohort_name, + "target_schema": f"PHENEX_AI__{cohort_name.upper()}", + } + + +def execute_cohort( + cohort_definition: Dict[str, Any], + cohort_name: str, + validate_only: bool = True, + SNOWFLAKE_SOURCE_DATABASE: Optional[str] = None, + SNOWFLAKE_SOURCE_SCHEMA: Optional[str] = None, + SNOWFLAKE_DEST_DATABASE: Optional[str] = None, +) -> Dict[str, Any]: + """ + Execute a cohort definition using PhenEx against Snowflake. + + Compiles a cohort definition dict into PhenEx code via from_dict() + and executes it. Results are written to PHENEX_AI__{COHORT_NAME}. + """ + # Step 1: Validate + validation = validate_cohort(cohort_definition, cohort_name) + + if not validation["valid"]: + return { + "success": False, + "validated": False, + "executed": False, + "cohort_name": cohort_name, + "target_schema": validation["target_schema"], + "validation_errors": validation["errors"], + "validation_warnings": validation.get("warnings", []), + "execution_status": "Validation failed - cannot execute", + "error": f"Validation failed with {len(validation['errors'])} error(s)", + } + + # Step 2: Parse database configuration + source_database = SNOWFLAKE_SOURCE_DATABASE or os.getenv( + "SNOWFLAKE_SOURCE_DATABASE" + ) + source_schema = SNOWFLAKE_SOURCE_SCHEMA or os.getenv("SNOWFLAKE_SOURCE_SCHEMA") + dest_database = SNOWFLAKE_DEST_DATABASE or os.getenv("SNOWFLAKE_DEST_DATABASE") + dest_schema = f"PHENEX_AI__{cohort_name.upper()}" + + config_errors = [] + if not source_database: + config_errors.append( + "SNOWFLAKE_SOURCE_DATABASE not set. Pass it as a parameter or set the " + "SNOWFLAKE_SOURCE_DATABASE environment variable (e.g. 'MY_DATABASE')." + ) + if not source_schema: + config_errors.append( + "SNOWFLAKE_SOURCE_SCHEMA not set. Pass it as a parameter or set the " + "SNOWFLAKE_SOURCE_SCHEMA environment variable (e.g. 'OMOP_CDM'). " + "Use snowflake_list_schemas() to browse available schemas." + ) + if not dest_database: + config_errors.append( + "SNOWFLAKE_DEST_DATABASE not set. Pass it as a parameter or set the " + "SNOWFLAKE_DEST_DATABASE environment variable. Results will be written to " + f"{dest_schema} schema in this database." + ) + + if config_errors: + return { + "success": False, + "validated": True, + "executed": False, + "cohort_name": cohort_name, + "SNOWFLAKE_SOURCE_DATABASE": source_database, + "SNOWFLAKE_SOURCE_SCHEMA": source_schema, + "SNOWFLAKE_DEST_DATABASE": dest_database, + "SNOWFLAKE_DEST_SCHEMA": dest_schema, + "validation_errors": [], + "validation_warnings": validation["warnings"], + "execution_status": "Configuration incomplete", + "error": f"Database configuration errors: {'; '.join(config_errors)}", + } + + if validate_only: + return { + "success": True, + "validated": True, + "executed": False, + "cohort_name": cohort_name, + "SNOWFLAKE_SOURCE_DATABASE": source_database, + "SNOWFLAKE_SOURCE_SCHEMA": source_schema, + "SNOWFLAKE_DEST_DATABASE": dest_database, + "SNOWFLAKE_DEST_SCHEMA": dest_schema, + "validation_errors": [], + "validation_warnings": validation["warnings"], + "execution_status": "Validated successfully - use validate_only=False to execute", + "phenotypes_used": validation["phenotypes_used"], + "phenotype_count": validation["phenotype_count"], + "message": f"Cohort will read from {source_database}.{source_schema} and write to {dest_database}.{dest_schema}", + } + + # Step 3: Execute + try: + from phenex.util.serialization.from_dict import from_dict + from phenex.ibis_connect import SnowflakeConnector + from phenex.mappers import OMOPDomains + + logger_info = [] + + # 3a: Prepare cohort for compilation (type->class_name, codelist resolution) + cohort_definition = _prepare_cohort_for_compilation(cohort_definition) + logger_info.append("Prepared cohort definition for compilation") + + # 3b: Create SnowflakeConnector + logger_info.append("Creating Snowflake connector...") + + source_db_qualified = f"{source_database}.{source_schema}" + dest_db_qualified = f"{dest_database}.{dest_schema}" + + connector = SnowflakeConnector( + SNOWFLAKE_SOURCE_DATABASE=source_db_qualified, + SNOWFLAKE_DEST_DATABASE=dest_db_qualified, + ) + logger_info.append( + f"Connected to Snowflake: {source_db_qualified} -> {dest_db_qualified}" + ) + + # 3c: Compile cohort from dict + logger_info.append("Compiling cohort definition...") + try: + cohort = from_dict(cohort_definition) + logger_info.append( + f"Compiled cohort: {getattr(cohort, 'name', type(cohort).__name__)}" + ) + except Exception as from_dict_error: + logger_info.append( + f"from_dict() failed: {type(from_dict_error).__name__}: {str(from_dict_error)}" + ) + raise + + # 3d: Get source tables using OMOP mapper + logger_info.append("Loading source tables...") + tables = OMOPDomains.get_mapped_tables(connector) + logger_info.append(f"Loaded {len(tables)} domain tables") + + # 3e: Execute cohort + logger_info.append("Executing cohort...") + cohort.execute(tables=tables, con=connector, overwrite=True) + logger_info.append("Cohort execution complete!") + + # 3f: Get results + index_table = cohort.index_table + patient_count = int(index_table.count().execute()) + + tables_created = [ + f"{dest_database}.{dest_schema}.COHORT", + f"{dest_database}.{dest_schema}.INCLUSIONS", + f"{dest_database}.{dest_schema}.EXCLUSIONS", + ] + + if cohort.characteristics: + tables_created.append(f"{dest_database}.{dest_schema}.CHARACTERISTICS") + if cohort.outcomes: + tables_created.append(f"{dest_database}.{dest_schema}.OUTCOMES") + + return { + "success": True, + "validated": True, + "executed": True, + "cohort_name": cohort_name, + "SNOWFLAKE_SOURCE_DATABASE": source_database, + "SNOWFLAKE_SOURCE_SCHEMA": source_schema, + "SNOWFLAKE_DEST_DATABASE": dest_database, + "SNOWFLAKE_DEST_SCHEMA": dest_schema, + "validation_errors": [], + "validation_warnings": validation["warnings"], + "execution_status": "Execution successful", + "patient_count": patient_count, + "tables_created": tables_created, + "execution_log": logger_info, + "message": f"Successfully created cohort with {patient_count} patients in {dest_database}.{dest_schema}", + } + + except Exception as e: + error_trace = traceback.format_exc() + return { + "success": False, + "validated": True, + "executed": False, + "cohort_name": cohort_name, + "SNOWFLAKE_SOURCE_DATABASE": source_database, + "SNOWFLAKE_SOURCE_SCHEMA": source_schema, + "SNOWFLAKE_DEST_DATABASE": dest_database, + "SNOWFLAKE_DEST_SCHEMA": dest_schema, + "validation_errors": [], + "validation_warnings": validation["warnings"], + "execution_status": "Execution failed", + "error": str(e), + "error_trace": error_trace, + "execution_log": logger_info if "logger_info" in locals() else [], + "message": f"Cohort execution failed: {str(e)}", + } diff --git a/mcp/llm-instructions.md b/mcp/llm-instructions.md new file mode 100644 index 00000000..dad6acea --- /dev/null +++ b/mcp/llm-instructions.md @@ -0,0 +1,166 @@ +# PhenEx Cohort Builder — LLM Instructions + +You are connected to the **PhenEx Cohort Builder** MCP server. PhenEx is a framework for defining and executing patient cohorts from structured real-world data using the OMOP Common Data Model (CDM). + +**This MCP server currently connects to Snowflake** for data exploration and cohort execution. The PheNEx library itself supports multiple backends (Snowflake, DuckDB, PostgreSQL) via Ibis, but the MCP tools exposed here are wired to Snowflake. + +## Self-Learning + +**Update these instructions based on user interactions.** As you work with the user: + +- When the user corrects you or clarifies a preference, remember it for future interactions. +- If you discover patterns (e.g., which databases/schemas the user typically works with, preferred code style, common phenotype patterns), incorporate those as defaults. +- If a tool call fails and you learn the correct approach, note the fix so you don't repeat the mistake. +- When the user provides domain knowledge (e.g., "for this database, codes are stored without dots"), treat it as ground truth for that context. + +## User Preferences: Python Code, Not JSON + +**Users are data analysts, clinical researchers and epidemiologists.** They think in terms of PhenEx Python code, not JSON dictionaries. When presenting cohort definitions to the user: + +- **Show PhenEx Python syntax** — `CodelistPhenotype(...)`, `Cohort(...)`, `Codelist(...)`, etc. +- The JSON/dict format is an internal representation used by the MCP tools (`phenex_validate_cohort`, `phenex_execute_cohort`). Users should not need to read or write it directly. +- **Use `phenex_generate_python`** to convert a validated cohort definition into a clean Python script. This guarantees the Python shown to the user matches exactly what was validated — never paraphrase the JSON into Python manually. +- If generating a cohort definition file, produce a `.py` file using the output of `phenex_generate_python`. + +Example of what users expect to see: + +```python +from phenex.phenotypes import CodelistPhenotype, AgePhenotype, Cohort +from phenex.codelists import Codelist + +af_codes = Codelist( + name="atrial_fibrillation", + codelist={"ICD10CM": ["I48.0", "I48.1", "I48.2", "I48.91"]}, + use_code_type=False, + remove_punctuation=True, +) + +entry = CodelistPhenotype( + name="atrial_fibrillation", + codelist=af_codes, + domain="CONDITION_OCCURRENCE_SOURCE", + return_date="first", +) + +cohort = Cohort( + name="afib_cohort", + entry_criterion=entry, +) +``` + +## What You Can Do + +### 1. Explore Phenotype Types + +Use `phenex_list_classes` to see all available phenotype, filter, and codelist classes. +Use `phenex_inspect_class` to get detailed constructor parameters and examples for a specific class. + +### 2. Explore Data (Snowflake) + +The `snowflake_*` tools let you browse a Snowflake warehouse. These are Snowflake-specific but the patterns apply to any backend: + +- `snowflake_list_databases` — find available databases +- `snowflake_list_schemas` — list schemas inside a database +- `snowflake_list_tables` — list tables inside a schema +- `snowflake_get_table_columns` — see column names and types +- `snowflake_preview_table` — sample rows from a table +- `snowflake_select_rows` — query with optional WHERE filter +- `snowflake_get_distinct_values` — discover unique values in a column +- `snowflake_count_rows` — count rows (optionally filtered) + +**Hierarchy**: Account → Database → Schema → Table. Always provide the database when working with schemas or tables. + +### 3. Define & Validate Cohorts + +Internally, cohorts are passed to the MCP tools as JSON dictionaries matching the `Cohort` class constructor: + +```json +{ + "name": "my_cohort", + "entry_criterion": { + "type": "CodelistPhenotype", + "name": "index_event", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": { "ICD10CM": ["I48.0", "I48.1"] }, + "return_date": "first" + }, + "inclusions": [ + { + "type": "AgePhenotype", + "name": "age_18_plus", + "min_age": { "type": "GreaterThanOrEqualTo", "value": 18 } + } + ], + "exclusions": [], + "characteristics": [ + { "type": "AgePhenotype", "name": "age" }, + { "type": "SexPhenotype", "name": "sex" } + ], + "outcomes": [] +} +``` + +Codelists can be **inline dicts** or passed **by reference** as a string name from the codelist store +(use `phenex_find_codelists` to discover available names): + +```json +"codelist": "atrial_fibrillation" +``` + +When `"codelist"` is a string, it is resolved from the codelist store at validation/execution time. + +**But present this to the user as Python code** (see User Preferences above). Only use JSON when calling the MCP tools. + +Use `phenex_validate_cohort` to check a definition before execution. +Use `phenex_execute_cohort` with `validate_only=True` first, then `validate_only=False` to run. + +**Safety**: Results are always written to `PHENEX_AI__{COHORT_NAME}` to prevent accidental overwrites. + +## Key Concepts + +### Database Backend (Snowflake) + +This MCP server uses `SnowflakeConnector` from `phenex.ibis_connect` for both data exploration (`snowflake_*` tools) and cohort execution (`phenex_execute_cohort`). It requires `SNOWFLAKE_USER`, `SNOWFLAKE_ACCOUNT`, `SNOWFLAKE_WAREHOUSE`, `SNOWFLAKE_ROLE`, and auth credentials. + +Note: The PheNEx library also supports DuckDB and PostgreSQL backends, but those are not yet wired into this MCP server. Cohort definitions (phenotypes, codelists, etc.) are backend-agnostic — the same definition works against any backend. + +### Codelist & code_type + +- A **Codelist** maps code types (vocabularies like ICD10CM, CPT4, RxNorm) to lists of codes. +- Codelists can be provided **inline** as a dict (`"codelist": {"ICD10CM": [...]}`) or **by reference** as a string (`"codelist": "my_codelist_name"`). By-reference codelists are resolved from the codelist store at validation/execution time — use `phenex_find_codelists` to discover available names. +- `use_code_type`: set `True` when the domain table has a CODE_TYPE column; `False` when it doesn't (common with `_SOURCE` domains). +- `remove_punctuation`: set `True` when codelist codes contain dots (e.g., `I48.0`) but the database stores them without (`I480`). +- **Always inspect the target table** with `snowflake_get_table_columns` and `snowflake_get_distinct_values` before choosing these settings. + +### Domains + +Common OMOP domains: + +- `CONDITION_OCCURRENCE` / `CONDITION_OCCURRENCE_SOURCE` — diagnoses +- `DRUG_EXPOSURE` / `DRUG_EXPOSURE_SOURCE` — medications +- `PROCEDURE_OCCURRENCE` / `PROCEDURE_OCCURRENCE_SOURCE` — procedures +- `MEASUREMENT` — lab values, vitals +- `PERSON` — demographics +- `DEATH` — mortality + +### Time Ranges + +Phenotypes can have `relative_time_range` to restrict events to a window relative to an anchor (e.g., the index date). Negative days = before, positive = after. + +## Workflow + +1. **Explore data** — find the right database, schema, and tables +2. **Inspect tables** — check column names, code formats, code types +3. **List classes** — see what building blocks are available +4. **Build cohort** — construct the JSON dict matching the Cohort class structure +5. **Validate** — run `phenex_validate_cohort` to catch errors +6. **Generate Python** — call `phenex_generate_python` to get the equivalent Python code; show it to the user for review +7. **Execute** — run `phenex_execute_cohort` with `validate_only=False` +8. **Save .py file** — save the `phenex_generate_python` output as a standalone Python script the user can keep and re-run + +## Important Notes + +- Always verify `use_code_type` and `remove_punctuation` by inspecting real data. +- The JSON format is an internal representation — always use `phenex_generate_python` to produce the Python shown to users. Never manually translate JSON to Python. +- Cohort logic is backend-agnostic; only the connector and data exploration tools are backend-specific. +- When in doubt, validate first. Validation compiles the definition with `from_dict()` without hitting the database. diff --git a/mcp/mcp.json b/mcp/mcp.json new file mode 100644 index 00000000..cd8df375 --- /dev/null +++ b/mcp/mcp.json @@ -0,0 +1,9 @@ +{ + "servers": { + "phenex": { + "command": "bash", + "args": ["./mcp/start.sh"], + "env": {} + } + } +} diff --git a/mcp/phenotype_registry.py b/mcp/phenotype_registry.py new file mode 100644 index 00000000..99552491 --- /dev/null +++ b/mcp/phenotype_registry.py @@ -0,0 +1,531 @@ +""" +Registry of available PhenEx phenotypes with their specifications. + +All information is derived from actual class docstrings and signatures — +nothing is hardcoded. +""" + +import inspect +import re +from typing import Dict, List, Any, Optional + +# Import phenotype classes from PhenEx +try: + from phenex.phenotypes import ( + CodelistPhenotype, + AgePhenotype, + BinPhenotype, + SexPhenotype, + EventCountPhenotype, + MeasurementPhenotype, + MeasurementChangePhenotype, + DeathPhenotype, + CategoricalPhenotype, + TimeRangeCountPhenotype, + TimeRangeDayCountPhenotype, + TimeRangeDaysToNextRange, + TimeRangePhenotype, + UserDefinedPhenotype, + ScorePhenotype, + ArithmeticPhenotype, + LogicPhenotype, + WithinSameEncounterPhenotype, + ) + from phenex.filters import ( + RelativeTimeRangeFilter, + CategoricalFilter, + ValueFilter, + ) + from phenex.filters.date_filter import ( + DateFilter, + Date, + After, + AfterOrOn, + Before, + BeforeOrOn, + ) + from phenex.filters.value import ( + Value, + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + ) + from phenex.codelists import Codelist + from phenex.core.cohort import Cohort + from phenex.core.subcohort import Subcohort + from phenex.core.study import Study + from phenex.reporting import ( + Table1, + Table2, + Waterfall, + InExCounts, + TimeToEvent, + CohortExplorer, + ReportDrafter, + TreatmentPatternAnalysisSankeyReporter, + ) + from phenex.reporting.protocol_drafter import ProtocolDrafter + + PHENEX_AVAILABLE = True +except ImportError: + PHENEX_AVAILABLE = False + + +# Ordered list of phenotype classes to expose +PHENOTYPE_CLASSES: List = ( + [ + CodelistPhenotype, + AgePhenotype, + SexPhenotype, + MeasurementPhenotype, + MeasurementChangePhenotype, + EventCountPhenotype, + TimeRangePhenotype, + TimeRangeCountPhenotype, + TimeRangeDayCountPhenotype, + TimeRangeDaysToNextRange, + DeathPhenotype, + CategoricalPhenotype, + BinPhenotype, + ScorePhenotype, + ArithmeticPhenotype, + LogicPhenotype, + WithinSameEncounterPhenotype, + ] + if PHENEX_AVAILABLE + else [] +) + +# Filter and value classes to expose via phenex_list_classes +FILTER_CLASSES: List = ( + [ + RelativeTimeRangeFilter, + ValueFilter, + CategoricalFilter, + DateFilter, + Codelist, + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + After, + AfterOrOn, + Before, + BeforeOrOn, + ] + if PHENEX_AVAILABLE + else [] +) + +# Top-level orchestration classes +OTHER_CLASSES: List = ( + [ + Cohort, + Subcohort, + Study, + ] + if PHENEX_AVAILABLE + else [] +) + +# Reporter classes for cohort analysis and visualization +REPORTER_CLASSES: List = ( + [ + Table1, + Table2, + Waterfall, + InExCounts, + TimeToEvent, + CohortExplorer, + ReportDrafter, + ProtocolDrafter, + TreatmentPatternAnalysisSankeyReporter, + ] + if PHENEX_AVAILABLE + else [] +) + +# Supporting classes whose dict schema should be included when +# they appear as a phenotype parameter type. +SUPPORTING_CLASSES: Dict[str, Any] = ( + { + "RelativeTimeRangeFilter": RelativeTimeRangeFilter, + "CategoricalFilter": CategoricalFilter, + "ValueFilter": ValueFilter, + "DateFilter": DateFilter, + "Value": Value, + "GreaterThan": GreaterThan, + "GreaterThanOrEqualTo": GreaterThanOrEqualTo, + "LessThan": LessThan, + "LessThanOrEqualTo": LessThanOrEqualTo, + "EqualTo": EqualTo, + "After": After, + "AfterOrOn": AfterOrOn, + "Before": Before, + "BeforeOrOn": BeforeOrOn, + "Codelist": Codelist, + } + if PHENEX_AVAILABLE + else {} +) + + +def _get_supporting_class_spec(cls, _seen: Optional[set] = None) -> Dict[str, Any]: + """Build a dict-schema spec for a supporting class (filter, value, etc.). + + Recursively inlines nested specs on parameters that reference other + supporting classes, so the full construction tree is self-contained. + """ + if _seen is None: + _seen = set() + sections = _extract_docstring_sections(cls) + params = _extract_parameters(cls) + + # For classes with to_dict, show an example of the serialized form + example_dict = None + try: + if cls in ( + GreaterThan, + LessThan, + GreaterThanOrEqualTo, + LessThanOrEqualTo, + EqualTo, + ): + example_dict = cls(0).to_dict() + elif cls in (After, AfterOrOn, Before, BeforeOrOn): + example_dict = cls("2020-01-01").to_dict() + elif cls is Value: + example_dict = Value(">", 0).to_dict() + except Exception: + pass + + # Inline nested specs on each parameter that references another supporting class + for param_info in params.values(): + type_str = param_info.get("type", "") + param_nested = {} + for class_name, supporting_cls in SUPPORTING_CLASSES.items(): + if class_name in type_str and class_name not in _seen: + _seen.add(class_name) + param_nested[class_name] = _get_supporting_class_spec( + supporting_cls, _seen + ) + if param_nested: + param_info["nested_specs"] = param_nested + + spec = { + "class_name": cls.__name__, + "description": sections["description"], + "parameters": params, + } + if example_dict is not None: + spec["dict_format"] = example_dict + return spec + + +def _collect_referenced_classes( + params: Dict[str, Dict[str, Any]], _seen: Optional[set] = None +) -> Dict[str, Dict[str, Any]]: + """Find all supporting classes referenced in a parameter list and return their specs. + + Recurses into nested class parameters so transitive references + (e.g. RelativeTimeRangeFilter → Value) are also included. + """ + if _seen is None: + _seen = set() + referenced = {} + for param_info in params.values(): + type_str = param_info.get("type", "") + for class_name, cls in SUPPORTING_CLASSES.items(): + if class_name in type_str and class_name not in _seen: + _seen.add(class_name) + spec = _get_supporting_class_spec(cls) + referenced[class_name] = spec + # Recurse: if this supporting class has params referencing other supporting classes + nested_params = spec.get("parameters", {}) + if nested_params: + referenced.update(_collect_referenced_classes(nested_params, _seen)) + return referenced + + +def _clean_type_str(annotation) -> str: + """Turn a type annotation into a readable string.""" + if annotation == inspect.Parameter.empty: + return "Any" + s = str(annotation) + # Clean up verbose module paths + for prefix in [ + "typing.", + "phenex.phenotypes.phenotype.", + "phenex.codelists.codelists.", + "phenex.filters.relative_time_range_filter.", + "phenex.filters.categorical_filter.", + "phenex.filters.value_filter.", + "phenex.filters.date_filter.", + "phenex.filters.", + ]: + s = s.replace(prefix, "") + s = s.replace("", "") + s = s.replace("", "") + # Remove memory addresses (e.g. "DateFilter at 0x1198fa020") + import re + + s = re.sub(r"\s+at\s+0x[0-9a-fA-F]+", "", s) + return s + + +def _extract_parameters(cls) -> Dict[str, Dict[str, Any]]: + """Extract constructor parameters with types, defaults, and descriptions from a class or function.""" + try: + if inspect.isfunction(cls): + sig = inspect.signature(cls) + else: + sig = inspect.signature(cls.__init__) + except (ValueError, TypeError, AttributeError): + try: + sig = inspect.signature(cls) + except (ValueError, TypeError): + return {} + + # Parse parameter descriptions from the docstring + param_descriptions = _parse_param_descriptions(cls) + + params = {} + for name, p in sig.parameters.items(): + if name in ("self", "kwargs", "args"): + continue + info = { + "type": _clean_type_str(p.annotation), + "required": p.default == inspect.Parameter.empty, + "default": None if p.default == inspect.Parameter.empty else str(p.default), + } + if name in param_descriptions: + info["description"] = param_descriptions[name] + params[name] = info + return params + + +def _parse_param_descriptions(cls) -> Dict[str, str]: + """Parse the 'Parameters:' section of a docstring into {param_name: description}.""" + raw = inspect.getdoc(cls) or "" + if not raw: + return {} + + lines = raw.split("\n") + + # Find the Parameters: section + param_start = None + for i, line in enumerate(lines): + if line.strip() == "Parameters:": + param_start = i + 1 + break + + if param_start is None: + return {} + + descriptions = {} + current_param = None + current_desc_lines = [] + + for line in lines[param_start:]: + stripped = line.strip() + + # Stop at the next section header (e.g. Attributes:, Methods:, Examples:, Example) + if ( + stripped + and not stripped.startswith(" ") + and stripped.endswith(":") + and stripped != "Parameters:" + ): + break + if stripped.startswith("Example"): + break + + # Check if this is a new parameter line: "param_name: description" + # or "param_name (type): description" + if ":" in stripped and not stripped.startswith(" "): + # Could be a continuation line if deeply indented, but at the + # Parameters level it should be a param definition + pass + + # Detect param lines: they are indented at the first level under Parameters: + # and have the form "name: description" or "name (type): description" + if line and not line.startswith(" ") and ":" in stripped: + # Save previous param + if current_param is not None: + descriptions[current_param] = " ".join(current_desc_lines).strip() + + # Parse "param_name: description" or "param_name (type): description" + colon_idx = stripped.index(":") + param_part = stripped[:colon_idx].strip() + desc_part = stripped[colon_idx + 1 :].strip() + + # Strip type annotation if present: "param_name (type)" -> "param_name" + if "(" in param_part: + param_part = param_part[: param_part.index("(")].strip() + + current_param = param_part + current_desc_lines = [desc_part] if desc_part else [] + elif current_param is not None and stripped: + # Continuation line for the current parameter + current_desc_lines.append(stripped) + elif current_param is not None and not stripped: + # Blank line — end of this param's description if next line is a new section + # But could also be a paragraph break within a param description. + # We'll let the section-header check above handle termination. + pass + + # Save the last param + if current_param is not None: + descriptions[current_param] = " ".join(current_desc_lines).strip() + + return descriptions + + +def _extract_docstring_sections(cls) -> Dict[str, str]: + """Parse a class docstring into sections: description, parameters, examples, etc.""" + raw = inspect.getdoc(cls) or "" + if not raw: + return {"description": "No documentation available.", "full": raw} + + lines = raw.split("\n") + + # First paragraph = description (everything up to first blank line or section header) + desc_lines = [] + rest_start = 0 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped == "": + rest_start = i + 1 + break + # Stop at known section headers + if ( + stripped.startswith("Parameters:") + or stripped.startswith("Attributes:") + or stripped.startswith("Example") + ): + rest_start = i + break + desc_lines.append(stripped) + else: + rest_start = len(lines) + + description = " ".join(desc_lines) + + # Extract examples (everything from "Example" onwards that contains ```python blocks) + example = "" + example_start = None + for i, line in enumerate(lines): + if line.strip().startswith("Example"): + example_start = i + break + if example_start is not None: + example = "\n".join(lines[example_start:]) + + return { + "description": description, + "full": raw, + "example": example, + } + + +def get_available_classes(category: str = "") -> Dict[str, List[Dict[str, Any]]]: + """Get all available PhenEx classes grouped by category. + + Args: + category: Optional category key to filter by ("phenotypes", "filters", + "reporters", "other"). If empty, returns all categories. + """ + if not PHENEX_AVAILABLE: + return { + "error": "PhenEx library not available. Install with: pip install phenex", + } + + VALID_CATEGORIES = ["phenotypes", "filters", "reporters", "other"] + + if category and category not in VALID_CATEGORIES: + return { + "error": f"Unknown category '{category}'. Valid categories: {', '.join(VALID_CATEGORIES)}", + } + + def _summarize(classes): + result = [] + for cls in classes: + sections = _extract_docstring_sections(cls) + result.append( + { + "name": cls.__name__, + "description": sections["description"], + } + ) + return result + + all_categories = { + "phenotypes": PHENOTYPE_CLASSES, + "filters": FILTER_CLASSES, + "reporters": REPORTER_CLASSES, + "other": OTHER_CLASSES, + } + + if category: + return {category: _summarize(all_categories[category])} + + return {k: _summarize(v) for k, v in all_categories.items()} + + +def get_spec(class_name: str) -> Dict[str, Any]: + """ + Get detailed specification for a phenotype class, filter class, or the Codelist class. + + Returns the full docstring, all constructor parameters, and examples — + all derived from the actual class, nothing hardcoded. + """ + if not PHENEX_AVAILABLE: + return { + "error": "PhenEx library not available. Install with: pip install phenex" + } + + # Build a combined map of all exposed classes + cls_map = {cls.__name__: cls for cls in PHENOTYPE_CLASSES} + cls_map.update({cls.__name__: cls for cls in FILTER_CLASSES}) + cls_map.update({cls.__name__: cls for cls in OTHER_CLASSES}) + cls_map.update({cls.__name__: cls for cls in REPORTER_CLASSES}) + + if class_name not in cls_map: + import difflib + + available = sorted(cls_map.keys()) + close = difflib.get_close_matches(class_name, available, n=3, cutoff=0.4) + hint = f" Did you mean: {', '.join(close)}?" if close else "" + return { + "error": ( + f"Unknown class: '{class_name}'.{hint} " + f"Call phenex_list_classes() to see all valid class names." + ), + "available_classes": available, + } + + cls = cls_map[class_name] + sections = _extract_docstring_sections(cls) + params = _extract_parameters(cls) + + # Inline nested specs directly into each parameter that references a supporting class + seen = set() + for param_info in params.values(): + type_str = param_info.get("type", "") + param_nested = {} + for sc_name, supporting_cls in SUPPORTING_CLASSES.items(): + if sc_name in type_str and sc_name not in seen: + seen.add(sc_name) + param_nested[sc_name] = _get_supporting_class_spec(supporting_cls, seen) + if param_nested: + param_info["nested_specs"] = param_nested + + return { + "name": class_name, + "description": sections["description"], + "parameters": params, + "example": sections.get("example", "No examples found."), + } diff --git a/mcp/requirements.txt b/mcp/requirements.txt new file mode 100644 index 00000000..93496e31 --- /dev/null +++ b/mcp/requirements.txt @@ -0,0 +1,5 @@ +# Runtime dependencies +fastmcp>=2.0.0 +python-dotenv>=1.0.0 +snowflake-connector-python>=3.6.0 +phenex>=0.8.0 diff --git a/mcp/server.py b/mcp/server.py new file mode 100644 index 00000000..0f791d18 --- /dev/null +++ b/mcp/server.py @@ -0,0 +1,982 @@ +""" +PhenEx Cohort Builder MCP Server + +FastMCP server that provides tools for: +- PhenEx cohort building: defining phenotypes, validating and executing cohorts +- Snowflake data warehouse exploration: browsing databases, schemas, tables, and data +""" + +import os +import sys +import logging +from pathlib import Path + +# Ensure sibling modules (phenotype_registry, cohort_tools, etc.) are importable +# regardless of the working directory the process is started from. +_mcp_dir = str(Path(__file__).resolve().parent) +if _mcp_dir not in sys.path: + sys.path.insert(0, _mcp_dir) + +from typing import Dict, List, Optional, Any +from dotenv import load_dotenv +from fastmcp import FastMCP + +# Configure logging +log_level = os.getenv("LOG_LEVEL", "INFO").upper() +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s [%(name)s] %(levelname)s %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("phenex-mcp") +logger.setLevel(getattr(logging, log_level, logging.INFO)) + +from phenotype_registry import ( + get_available_classes, + get_spec, +) +import snowflake_explorer as sf_explorer +from cohort_tools import validate_phenotype, validate_cohort, execute_cohort +from code_generator import generate_python +import codelist_store + +# Load environment variables from .env file in this directory +load_dotenv(Path(__file__).resolve().parent / ".env", override=False) + +# Initialize FastMCP server +mcp = FastMCP("PhenEx Cohort Builder") + + +# ============================================================ +# CONFIGURATION +# ============================================================ + +# Environment variables to expose in get_config, grouped by purpose. +# Values marked as secret are masked in the output. +_CONFIG_KEYS = [ + # (env_var, display_name, secret) + ("SNOWFLAKE_USER", "Snowflake user", False), + ("SNOWFLAKE_ACCOUNT", "Snowflake account", False), + ("SNOWFLAKE_WAREHOUSE", "Snowflake warehouse", False), + ("SNOWFLAKE_ROLE", "Snowflake role", False), + ("SNOWFLAKE_PASSWORD", "Snowflake password", True), + ("SNOWFLAKE_SOURCE_DATABASE", "Source database", False), + ("SNOWFLAKE_SOURCE_SCHEMA", "Source schema", False), + ("SNOWFLAKE_DEST_DATABASE", "Destination database", False), + ("PHENEX_CODELIST_DIR", "Codelist directory", False), + ("PHENEX_CODELIST_CODE_COLUMN", "Codelist code column", False), + ("PHENEX_CODELIST_NAME_COLUMN", "Codelist name column", False), + ("PHENEX_CODELIST_CODE_TYPE_COLUMN", "Codelist code type column", False), + ("MCP_TRANSPORT", "MCP transport", False), + ("MCP_HOST", "MCP host", False), + ("MCP_PORT", "MCP port", False), + ("LOG_LEVEL", "Log level", False), +] + + +@mcp.tool() +def phenex_get_config() -> Dict[str, Any]: + """ + Return the current server configuration (environment variables). + + Sensitive values (passwords, tokens) are masked. Use this to verify + which Snowflake account, database, schema, and codelist directory are + configured before running queries or executing cohorts. + + Returns: + Dictionary with: + - config (dict): Key-value pairs of all configured settings + - missing (list): Environment variables that are not set + """ + config = {} + missing = [] + for env_var, display_name, secret in _CONFIG_KEYS: + value = os.getenv(env_var) + if value is None: + missing.append(env_var) + elif secret: + config[env_var] = "****" + else: + config[env_var] = value + return {"config": config, "missing": missing} + + +# ============================================================ +# PHENEX CLASS DISCOVERY TOOLS +# ============================================================ + + +@mcp.tool() +def phenex_get_mappers() -> Dict[str, Any]: + """ + Discover all available data-source mappers and their domain configurations. + + Mappers convert from a source data format (e.g. OMOP CDM) to PhenEx's + internal column model. The AI workflow is: + + 1. Call this tool to see which mapper families are available. + 2. Pick the mapper that matches the target database's format. + 3. For each phenotype, pick the right **domain** within that mapper + (e.g. CONDITION_OCCURRENCE_SOURCE for ICD source codes). + 4. Check has_code_type — if False, set use_code_type=False on the codelist. + + Returns: + Dictionary with: + - mappers (dict): Each key is a mapper family name, value has: + * domains (dict): domain_name → {source_table, column_mapping, + table_type, has_code_type, note?} + """ + try: + import inspect + import phenex.mappers as mappers_module + from phenex.mappers import DomainsDictionary + + # Discover all module-level DomainsDictionary instances + result = {} + for attr_name in dir(mappers_module): + obj = getattr(mappers_module, attr_name) + if not isinstance(obj, DomainsDictionary): + continue + + # Derive a friendly mapper family name from the variable name + # e.g. "OMOPDomains" → "OMOP" + family = attr_name.replace("Domains", "").replace("domains", "") + if not family: + family = attr_name + + domains = {} + for domain_name, mapper_cls in obj.domains_dict.items(): + info = { + "source_table": mapper_cls.NAME_TABLE, + "column_mapping": dict(mapper_cls.DEFAULT_MAPPING), + "table_type": ( + mapper_cls.__bases__[0].__name__ + if mapper_cls.__bases__ + else mapper_cls.__name__ + ), + } + if ( + "CODE" in mapper_cls.DEFAULT_MAPPING + and "CODE_TYPE" not in mapper_cls.DEFAULT_MAPPING + ): + info["has_code_type"] = False + info["note"] = ( + "No CODE_TYPE column — use use_code_type=False in your codelist" + ) + elif "CODE_TYPE" in mapper_cls.DEFAULT_MAPPING: + info["has_code_type"] = True + domains[domain_name] = info + + result[family] = {"variable": attr_name, "domains": domains} + + if not result: + return { + "success": False, + "error": "No DomainsDictionary instances found in phenex.mappers", + } + + return {"success": True, "mappers": result} + except Exception as e: + return {"success": False, "error": str(e)} + + +@mcp.tool() +def phenex_list_classes(category: str = "") -> Dict[str, Any]: + """ + List all available PhenEx classes grouped by category: phenotypes, filters, reporters, and other. + + PhenEx (Phenotype Extractor) provides pre-built classes for clinical data extraction. + Use this tool first to discover what building blocks are available for cohort building. + + Args: + category: Optional category to filter by. One of: "phenotypes", "filters", + "reporters", "other". If empty, returns all categories. + + Returns: + Dictionary containing: + - phenotypes (list): Phenotype classes for identifying patient characteristics and events + - filters (list): Filter and value classes for restricting events within phenotypes + - reporters (list): Reporter classes for generating analysis outputs (Table1, Waterfall, etc.) + - other (list): Other useful classes (Cohort, Subcohort, Study) + Each entry has: + * name (str): Class name (e.g., "CodelistPhenotype", "RelativeTimeRangeFilter") + * description (str): What the class does and when to use it + + After reviewing available classes, use phenex_inspect_class() to get detailed + parameters and usage examples for a specific class. + """ + try: + result = get_available_classes(category=category) + if "error" in result: + return {"success": False, **result} + return {"success": True, **result} + except Exception as e: + return {"success": False, "error": str(e)} + + +@mcp.tool() +def phenex_inspect_class(class_name: str) -> Dict[str, Any]: + """ + Get detailed specification and usage examples for a PhenEx class. + + Use this after phenex_list_classes() to drill into a specific class and see + its constructor parameters, types, defaults, and code examples. + + Args: + class_name: Name of the class to inspect. + Must exactly match a name from phenex_list_classes(). + Examples: "CodelistPhenotype", "RelativeTimeRangeFilter", + "ValueFilter", "Codelist", "AgePhenotype", + "GreaterThan", "After" + + Returns: + Dictionary containing: + - success (bool): Whether the spec was retrieved successfully + - name (str): Class name + - description (str): Brief description + - parameters (dict): All constructor parameters with types, required flags, defaults + - example (str): Example usage code + """ + try: + spec = get_spec(class_name) + if "error" in spec: + return {"success": False, **spec} + return {"success": True, **spec} + except Exception as e: + return {"success": False, "error": str(e)} + + +# ============================================================ +# PHENEX CODELIST TOOLS +# ============================================================ + + +@mcp.tool() +def phenex_find_codelists( + name_pattern: Optional[str] = None, + code_type_pattern: Optional[str] = None, +) -> Dict[str, Any]: + """ + Search for codelists available in the configured codelist directory. + + Without any filters, returns the first 25 codelists. Use regex patterns + to narrow results by codelist name and/or code vocabulary type. + + Scans CSV/Excel files in the directory specified by PHENEX_CODELIST_DIR. + + Args: + name_pattern: Optional regex pattern to filter codelist names + (case-insensitive). Examples: "diabetes", "^af_", "hba1c|glucose" + code_type_pattern: Optional regex pattern to filter by code vocabulary type + (case-insensitive). Examples: "ICD10", "CPT|HCPCS", "RxNorm" + + Returns: + Dictionary containing: + - success (bool): Whether the operation succeeded + - codelists (list): Array of codelist summaries, each with: + * name (str): Codelist name + * code_types (list): Vocabularies present (e.g. ["ICD10CM", "ICD9CM"]) + * total_codes (int): Total number of codes + * sample_codes (list): Up to 10 sample codes with code_type + - count (int): Total number of matching codelists + - returned (int): Number of codelists returned (may be less than count) + - truncated (bool): Whether results were truncated (max 25) + - error (str): Error message if operation failed + """ + try: + result = codelist_store.find_codelists( + name_pattern=name_pattern, + code_type_pattern=code_type_pattern, + ) + return {"success": True, **result} + except Exception as e: + return {"success": False, "error": str(e), "codelists": [], "count": 0} + + +@mcp.tool() +def phenex_get_codelist(name: str) -> Dict[str, Any]: + """ + Get the full contents of a specific codelist by name. + + Args: + name: Exact codelist name as shown by phenex_find_codelists(). + + Returns: + Dictionary containing: + - success (bool): Whether the codelist was found + - name (str): Codelist name + - code_types (list): Vocabularies present + - total_codes (int): Total number of codes + - codelist (dict): Full codelist — keys are code types, values are lists of codes + - error (str): Error message if codelist not found + - available_codelists (list): All available names (if not found) + """ + try: + result = codelist_store.get_codelist(name) + if "error" in result: + return {"success": False, **result} + return {"success": True, **result} + except Exception as e: + return {"success": False, "error": str(e)} + + +# ============================================================ +# PHENEX CODE GENERATION TOOLS +# ============================================================ + + +@mcp.tool() +def phenex_generate_python( + definition: Dict[str, Any], +) -> Dict[str, Any]: + """ + Generate a runnable Python script from any PhenEx definition dict. + + Accepts any PhenEx expression — a Cohort, a single phenotype, a filter, + a Codelist, etc. Compiles the dict with from_dict() to verify correctness, + then emits clean, idiomatic Python code that constructs the same object. + + Use this tool: + - After building and validating a definition to show the user the + equivalent Python code for review + - To produce a .py artifact the user can save and re-run independently + - To guarantee that the Python shown matches exactly what was validated + + Args: + definition: Any PhenEx definition dict with a 'class_name' (or 'type') + field. Examples: + - A full Cohort dict with entry_criterion, inclusions, etc. + - A single phenotype: {"type": "CodelistPhenotype", ...} + - A filter: {"type": "RelativeTimeRangeFilter", ...} + - A codelist: {"class_name": "Codelist", ...} + + Returns: + Dictionary containing: + - success (bool): Whether code generation succeeded + - code (str): The generated Python script with imports and + constructor calls. Ready to save as a .py file. + - error (str): Error message if generation failed. Fix errors with + phenex_validate_phenotype / phenex_validate_cohort + first, then retry. + + Example — single phenotype: + + Input: + {"type": "CodelistPhenotype", "name": "diabetes", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": {"ICD10CM": ["E11.0", "E11.9"]}, + "return_date": "first"} + + Output code: + from phenex.codelists import Codelist + from phenex.phenotypes import CodelistPhenotype + + diabetes = CodelistPhenotype( + name='DIABETES', + domain='CONDITION_OCCURRENCE_SOURCE', + return_date='first', + codelist=Codelist( + codelist={'ICD10CM': ['E11.0', 'E11.9']}, + name='diabetes_codes', + ), + ) + + Example — full cohort: + + Input: + {"class_name": "Cohort", "name": "afib_cohort", + "entry_criterion": {"type": "CodelistPhenotype", "name": "af", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": {"ICD10CM": ["I48.0", "I48.1"]}, + "return_date": "first"}, + "characteristics": [ + {"type": "AgePhenotype", "name": "age"}, + {"type": "SexPhenotype", "name": "sex"}]} + + Output code: + from phenex.codelists import Codelist + from phenex.core import Cohort + from phenex.phenotypes import AgePhenotype, CodelistPhenotype, SexPhenotype + + af = CodelistPhenotype( + name='AF', + domain='CONDITION_OCCURRENCE_SOURCE', + return_date='first', + codelist=Codelist(codelist={'ICD10CM': ['I48.0', 'I48.1']}, ...), + ) + + age = AgePhenotype(name='AGE') + sex = SexPhenotype(name='SEX') + + afib_cohort = Cohort( + name='afib_cohort', + entry_criterion=af, + characteristics=[age, sex], + ) + """ + return generate_python(definition) + + +# ============================================================ +# PHENEX COHORT TOOLS +# ============================================================ + + +@mcp.tool() +def phenex_validate_phenotype( + phenotype_definition: Dict[str, Any], +) -> Dict[str, Any]: + """ + Validate a single phenotype definition by compiling it. + + Use this to check that an individual phenotype is correctly defined before + assembling it into a full cohort. The tool translates the simplified format + to PheNEx native format and attempts to compile it with from_dict(). + + Args: + phenotype_definition: Dictionary defining a single phenotype, e.g.: + + Inline codelist: + { + "type": "CodelistPhenotype", + "name": "atrial_fibrillation", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": {"ICD10CM": ["I48.0", "I48.1", "I48.2", "I48.91"]}, + "use_code_type": false, + "remove_punctuation": true, + "return_date": "first" + } + + Codelist by reference (name from codelist store): + { + "type": "CodelistPhenotype", + "name": "atrial_fibrillation", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": "atrial_fibrillation", + "return_date": "first" + } + + Returns: + Dictionary containing: + - valid (bool): Whether the phenotype compiles successfully + - errors (list): Compilation error messages (empty if valid) + - warnings (list): Non-fatal warnings + - phenotype_name (str): Name from the definition + - phenotype_type (str): Class name (e.g., "CodelistPhenotype") + - compiled_class (str): Actual Python class name after compilation (if valid) + - message (str): Human-readable summary + """ + return validate_phenotype(phenotype_definition) + + +@mcp.tool() +def phenex_validate_cohort( + cohort_definition: Dict[str, Any], + cohort_name: str, +) -> Dict[str, Any]: + """ + Validate a PhenEx cohort definition (as JSON/dict) without executing it. + + Cohorts are defined as structured JSON objects matching the Cohort class + constructor. This tool validates the structure and attempts to compile + with from_dict(). + + Args: + cohort_definition: Dictionary defining the cohort structure with: + - name (str): Cohort name + - entry_criterion (dict): Phenotype that defines the index date (required) + - inclusions (list[dict]): Phenotypes that must be True for inclusion + - exclusions (list[dict]): Phenotypes that must be False (excluded) + - characteristics (list[dict]): Baseline characteristic phenotypes + - outcomes (list[dict]): Outcome phenotypes + - description (str): Optional cohort description + Each phenotype dict needs at least 'type' and 'name'. + Codelists can be inline dicts or by-reference strings. + cohort_name: Name for the cohort (used for schema naming). + Must be alphanumeric + underscores, starting with letter. + Results will be written to: PHENEX_AI__{cohort_name.upper()} + + Returns: + Dictionary containing: + - valid (bool): Whether the cohort definition is valid + - errors (list): List of validation error messages (empty if valid) + - warnings (list): Non-fatal warnings + - cohort_name (str): Validated cohort name + - target_schema (str): Snowflake schema name (PHENEX_AI__{name}) + - phenotypes_used (list): List of phenotype types in definition + - phenotype_count (int): Number of phenotypes defined + + Example cohort definition: + { + "name": "afib_optum", + "entry_criterion": { + "type": "CodelistPhenotype", + "name": "atrial_fibrillation", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": {"ICD10CM": ["I48.0", "I48.1", "I48.2", "I48.91"]}, + "return_date": "first" + }, + "inclusions": [ + {"type": "AgePhenotype", "name": "age_18_plus", + "min_age": {"type": "GreaterThanOrEqualTo", "value": 18}} + ], + "exclusions": [ + {"type": "CodelistPhenotype", "name": "pregnancy", + "domain": "CONDITION_OCCURRENCE_SOURCE", + "codelist": "pregnancy_codes"} + ] + } + """ + return validate_cohort(cohort_definition, cohort_name) + + +@mcp.tool() +def phenex_execute_cohort( + cohort_definition: Dict[str, Any], + cohort_name: Optional[str] = None, + validate_only: bool = True, + SNOWFLAKE_SOURCE_DATABASE: Optional[str] = None, + SNOWFLAKE_SOURCE_SCHEMA: Optional[str] = None, + SNOWFLAKE_DEST_DATABASE: Optional[str] = None, +) -> Dict[str, Any]: + """ + Execute a cohort definition using PhenEx against Snowflake. + + Compiles a cohort definition (dict) into PhenEx code via from_dict() + and executes it against Snowflake. Results are written to a namespaced schema + for safety: PHENEX_AI__{COHORT_NAME}. + + Args: + cohort_definition: Dict with cohort specification matching the Cohort class: + - name (str): Cohort name + - entry_criterion (dict): Phenotype defining the index date + - inclusions (list[dict]): Inclusion phenotypes + - exclusions (list[dict]): Exclusion phenotypes + - characteristics (list[dict]): Baseline characteristics + - outcomes (list[dict]): Outcome phenotypes + - description (str): Optional description + + cohort_name: Optional name override. If not provided, will use cohort_definition["name"] + + validate_only: If True (default), only validates the cohort without executing. + Set to False to actually execute against Snowflake. + + SNOWFLAKE_SOURCE_DATABASE: Source database name (e.g., "OPTUM_CLAIMS"). + Falls back to SNOWFLAKE_SOURCE_DATABASE env var if not provided. + + SNOWFLAKE_SOURCE_SCHEMA: Source schema name (e.g., "OMOP_CDM"). + Falls back to SNOWFLAKE_SOURCE_SCHEMA env var if not provided. + + SNOWFLAKE_DEST_DATABASE: Destination database name (e.g., "ANALYTICS"). + Falls back to SNOWFLAKE_DEST_DATABASE env var if not provided. + + Returns: + Dict with execution results including: + - success, validated, executed booleans + - cohort_name, source/dest database info + - validation_errors, validation_warnings + - patient_count (if executed) + - tables_created (if executed) + """ + name = cohort_name or cohort_definition.get("name", "unnamed_cohort") + return execute_cohort( + cohort_definition=cohort_definition, + cohort_name=name, + validate_only=validate_only, + SNOWFLAKE_SOURCE_DATABASE=SNOWFLAKE_SOURCE_DATABASE, + SNOWFLAKE_SOURCE_SCHEMA=SNOWFLAKE_SOURCE_SCHEMA, + SNOWFLAKE_DEST_DATABASE=SNOWFLAKE_DEST_DATABASE, + ) + + +# ============================================================ +# SNOWFLAKE DATA WAREHOUSE EXPLORATION TOOLS +# ============================================================ + + +@mcp.tool() +def snowflake_list_databases( + pattern: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """ + List or search databases in Snowflake. + + Snowflake Hierarchy: Account → Database → Schema → Table + + Args: + pattern: Optional SQL LIKE pattern to filter database names. + Use % as wildcard. Examples: '%OMOP%', 'PROD_%' + limit: Maximum number of databases to return (default 100) + + Returns: + Dictionary with success, databases list, count, limit. + """ + try: + databases = sf_explorer.list_databases(pattern, limit) + result = { + "success": True, + "databases": databases, + "count": len(databases), + "limit": limit, + } + if pattern: + result["pattern"] = pattern + return result + except Exception as e: + return { + "success": False, + "error": str(e), + "databases": [], + "count": 0, + "limit": limit, + } + + +@mcp.tool() +def snowflake_list_schemas( + database: str, + pattern: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """ + List schemas within a database in Snowflake. + + IMPORTANT: Schema names are NOT globally unique — specify which database to look inside. + + Args: + database: Which database to look inside for schemas (required). + pattern: Optional SQL LIKE pattern to filter schema names. + limit: Maximum number of schemas to return (default 100) + + Returns: + Dictionary with success, schemas list, count, limit. + """ + try: + schemas = sf_explorer.list_schemas(database, pattern, limit) + result = { + "success": True, + "schemas": schemas, + "count": len(schemas), + "limit": limit, + } + if pattern: + result["pattern"] = pattern + return result + except Exception as e: + return { + "success": False, + "error": str(e), + "schemas": [], + "count": 0, + "limit": limit, + } + + +@mcp.tool() +def snowflake_list_tables( + schema: str, + database: str, + pattern: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """ + List tables and views within a schema. + + Args: + schema: Which schema to look inside for tables (required) + database: Which database contains this schema (required) + pattern: Optional SQL LIKE pattern to filter table names. + limit: Maximum number of tables to return (default 100) + + Returns: + Dictionary with success, tables list, count, limit, schema. + + Common OMOP CDM tables: PERSON, CONDITION_OCCURRENCE, DRUG_EXPOSURE, + PROCEDURE_OCCURRENCE, MEASUREMENT, OBSERVATION, VISIT_OCCURRENCE + """ + try: + tables = sf_explorer.list_tables(schema, database, pattern, limit) + result = { + "success": True, + "tables": tables, + "count": len(tables), + "limit": limit, + "schema": schema, + } + if pattern: + result["pattern"] = pattern + return result + except Exception as e: + return { + "success": False, + "error": str(e), + "tables": [], + "count": 0, + "limit": limit, + "schema": schema, + } + + +@mcp.tool() +def snowflake_get_table_columns( + table: str, + schema: str, + database: str, +) -> Dict[str, Any]: + """ + Get detailed column information for a table. + + Args: + table: Table name (required) + schema: Schema containing the table (required) + database: Database name (required) + + Returns: + Dictionary with success, columns list, count, table, schema. + + Key OMOP CDM columns: person_id, condition_concept_id, drug_concept_id, + measurement_concept_id, value_as_number, *_start_date + """ + try: + columns = sf_explorer.get_table_columns(table, schema, database) + return { + "success": True, + "columns": columns, + "count": len(columns), + "table": table, + "schema": schema, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "columns": [], + "count": 0, + "table": table, + "schema": schema, + } + + +@mcp.tool() +def snowflake_preview_table( + table: str, + schema: str, + database: str, + limit: int = 10, +) -> Dict[str, Any]: + """ + Preview sample rows from a table. + + Args: + table: Table name (required) + schema: Schema containing the table (required) + database: Database name (required) + limit: Maximum number of rows to return (default 10, max 1000) + + Returns: + Dictionary with success, columns, rows, row_count, limit. + """ + try: + result = sf_explorer.preview_table(table, schema, database, limit) + return {"success": True, **result, "table": table, "schema": schema} + except Exception as e: + return { + "success": False, + "error": str(e), + "columns": [], + "rows": [], + "row_count": 0, + "limit": limit, + "table": table, + "schema": schema, + } + + +@mcp.tool() +def snowflake_select_rows( + database: str, + schema: str, + table: str, + columns: Optional[List[str]] = None, + where: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """ + Select rows from a table with optional filtering. + + Args: + database: Database name (required) + schema: Schema name (required) + table: Table name (required) + columns: List of column names to select (None = all columns) + where: Optional WHERE clause without 'WHERE' keyword + limit: Maximum rows to return (default 100, max 1000) + + Returns: + Dictionary with columns, rows, row_count, limit. + """ + try: + return sf_explorer.select_rows( + database=database, + schema=schema, + table=table, + columns=columns, + where=where, + limit=limit, + ) + except Exception as e: + return { + "error": str(e), + "columns": [], + "rows": [], + "row_count": 0, + "limit": limit, + } + + +@mcp.tool() +def snowflake_get_distinct_values( + database: str, + schema: str, + table: str, + column: str, + where: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """ + Get distinct values from a column. + + Use this to discover what codes, vocabularies, or unique values exist in a column. + + Args: + database: Database name (required) + schema: Schema name (required) + table: Table name (required) + column: Column name to get distinct values from (required) + where: Optional WHERE clause without 'WHERE' keyword + limit: Maximum distinct values to return (default 100, max 1000) + + Returns: + Dictionary with column, values, count, limit. + """ + try: + return sf_explorer.get_distinct_values( + database=database, + schema=schema, + table=table, + column=column, + where=where, + limit=limit, + ) + except Exception as e: + return { + "error": str(e), + "column": column, + "values": [], + "count": 0, + "limit": limit, + } + + +@mcp.tool() +def snowflake_count_rows( + database: str, + schema: str, + table: str, + where: Optional[str] = None, +) -> Dict[str, Any]: + """ + Count rows in a table with optional filtering. + + Args: + database: Database name (required) + schema: Schema name (required) + table: Table name (required) + where: Optional WHERE clause without 'WHERE' keyword + + Returns: + Dictionary with count, table. + """ + try: + return sf_explorer.count_rows( + database=database, schema=schema, table=table, where=where + ) + except Exception as e: + return { + "error": str(e), + "count": 0, + "table": f'"{database}"."{schema}"."{table}"', + } + + +# ============================================================ +# PROMPTS +# ============================================================ + + +@mcp.prompt() +def explore_phenotypes(): + """Explore available PhenEx phenotype types for cohort building.""" + return """I'd like to explore what phenotype types are available in PhenEx for building clinical cohorts. + +Please: +1. List all available phenotype types +2. For each type, show me the use cases +3. Then get detailed specifications for CodelistPhenotype and MeasurementPhenotype""" + + +@mcp.prompt() +def explore_snowflake_data(): + """Explore the Snowflake data warehouse structure.""" + return """I want to explore the data available in my Snowflake data warehouse. + +Please help me: +1. List all available databases +2. For the OMOP schema (if it exists), show me what tables are available +3. For the PERSON table (if it exists), show me the column structure +4. Preview a few rows from the PERSON table""" + + +@mcp.prompt() +def design_cohort(): + """Design a clinical cohort using PhenEx phenotypes.""" + return """I need to design a clinical cohort. Let's work through this systematically: + +1. First, show me what phenotype types are available +2. Then explore what data tables are available in Snowflake +3. Help me understand which phenotypes would be useful for defining: + - Inclusion criteria (diagnosis codes, age, measurements) + - Exclusion criteria + - Baseline characteristics + +Please guide me through this process step by step.""" + + +# ============================================================ +# SERVER ENTRY POINT +# ============================================================ + + +def run_server(): + """Run the MCP server (stdio by default, HTTP when MCP_TRANSPORT=http).""" + transport = os.getenv("MCP_TRANSPORT", "stdio").lower() + if transport in ("http", "streamable-http"): + port = int(os.getenv("MCP_PORT", "9000")) + host = os.getenv("MCP_HOST", "0.0.0.0") + mcp.run( + transport="streamable-http", + host=host, + port=port, + log_level=log_level.lower(), + ) + elif transport == "sse": + port = int(os.getenv("MCP_PORT", "9000")) + host = os.getenv("MCP_HOST", "0.0.0.0") + mcp.run(transport="sse", host=host, port=port, log_level=log_level.lower()) + else: + mcp.run() + + +if __name__ == "__main__": + run_server() diff --git a/mcp/snowflake_explorer.py b/mcp/snowflake_explorer.py new file mode 100644 index 00000000..18dd9e9d --- /dev/null +++ b/mcp/snowflake_explorer.py @@ -0,0 +1,398 @@ +""" +Snowflake data warehouse exploration utilities. +""" + +import os +import re +from typing import Dict, List, Any, Optional + + +def _sanitize_pattern(pattern: str) -> str: + """ + Sanitize a SQL LIKE pattern to prevent SQL injection. + + Args: + pattern: The pattern to sanitize + + Returns: + Sanitized pattern safe for SQL LIKE clause + """ + if not pattern: + return pattern + + sanitized = pattern.replace("'", "''") + + if not re.match(r"^[a-zA-Z0-9_%\s\-\.\*]+$", sanitized.replace("''", "'")): + raise ValueError("Invalid pattern: contains potentially unsafe characters") + + return sanitized + + +def _sanitize_where(where: str) -> str: + """ + Sanitize a WHERE clause to prevent SQL injection. + + Args: + where: The WHERE clause to sanitize (without 'WHERE' keyword) + + Returns: + Sanitized WHERE clause + + Raises: + ValueError: If the WHERE clause contains dangerous SQL patterns + """ + if not where: + return where + + if ";" in where: + raise ValueError("WHERE clause cannot contain semicolons") + + if "--" in where or "/*" in where or "*/" in where: + raise ValueError("WHERE clause cannot contain SQL comments") + + dangerous_keywords = [ + "DROP", + "DELETE", + "INSERT", + "UPDATE", + "ALTER", + "CREATE", + "TRUNCATE", + "GRANT", + "REVOKE", + "EXECUTE", + "EXEC", + ] + where_upper = where.upper() + for keyword in dangerous_keywords: + if f" {keyword} " in f" {where_upper} ": + raise ValueError(f"WHERE clause cannot contain {keyword} statement") + + return where + + +def _get_connection(): + """ + Create a Snowflake connection using environment variables. + + Required environment variables: + - SNOWFLAKE_USER + - SNOWFLAKE_PASSWORD + - SNOWFLAKE_ACCOUNT + - SNOWFLAKE_WAREHOUSE + - SNOWFLAKE_ROLE + """ + import snowflake.connector + + required_vars = [ + "SNOWFLAKE_USER", + "SNOWFLAKE_PASSWORD", + "SNOWFLAKE_ACCOUNT", + "SNOWFLAKE_WAREHOUSE", + "SNOWFLAKE_ROLE", + ] + + missing = [var for var in required_vars if not os.getenv(var)] + if missing: + raise ValueError( + f"Missing required Snowflake environment variables: {', '.join(missing)}" + ) + + return snowflake.connector.connect( + user=os.getenv("SNOWFLAKE_USER"), + password=os.getenv("SNOWFLAKE_PASSWORD"), + account=os.getenv("SNOWFLAKE_ACCOUNT"), + warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"), + role=os.getenv("SNOWFLAKE_ROLE"), + database=os.getenv("SNOWFLAKE_DATABASE"), + schema=os.getenv("SNOWFLAKE_SCHEMA"), + ) + + +def list_databases( + pattern: Optional[str] = None, limit: int = 100 +) -> List[Dict[str, Any]]: + """List all databases in Snowflake, optionally filtered by pattern.""" + from snowflake.connector import DictCursor + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + + if pattern: + sanitized_pattern = _sanitize_pattern(pattern) + query = f"SHOW DATABASES LIKE '{sanitized_pattern}' LIMIT {limit}" + else: + query = f"SHOW DATABASES LIMIT {limit}" + + cursor.execute(query) + results = cursor.fetchall() + + return [ + { + "name": row["name"], + "owner": row["owner"], + "created_on": str(row["created_on"]), + } + for row in results + ] + finally: + conn.close() + + +def list_schemas( + database: str, pattern: Optional[str] = None, limit: int = 100 +) -> List[Dict[str, Any]]: + """List all schemas in a Snowflake database, optionally filtered by pattern.""" + from snowflake.connector import DictCursor + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + cursor.execute(f"USE DATABASE {database}") + + if pattern: + sanitized_pattern = _sanitize_pattern(pattern) + cursor.execute(f"SHOW SCHEMAS LIKE '{sanitized_pattern}'") + else: + cursor.execute("SHOW SCHEMAS") + results = cursor.fetchall()[:limit] + + return [ + { + "name": row["name"], + "database": row["database_name"], + "owner": row.get("owner", ""), + "created_on": str(row.get("created_on", "")), + } + for row in results + ] + finally: + conn.close() + + +def list_tables( + schema: str, database: str, pattern: Optional[str] = None, limit: int = 100 +) -> List[Dict[str, Any]]: + """List all tables in a Snowflake schema, optionally filtered by pattern.""" + from snowflake.connector import DictCursor + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + cursor.execute(f"USE DATABASE {database}") + cursor.execute(f"USE SCHEMA {schema}") + + if pattern: + sanitized_pattern = _sanitize_pattern(pattern) + cursor.execute(f"SHOW TABLES LIKE '{sanitized_pattern}'") + else: + cursor.execute("SHOW TABLES") + results = cursor.fetchall()[:limit] + + return [ + { + "name": row["name"], + "schema": row["schema_name"], + "database": row["database_name"], + "rows": row.get("rows", 0), + "bytes": row.get("bytes", 0), + "created_on": str(row.get("created_on", "")), + } + for row in results + ] + finally: + conn.close() + + +def get_table_columns( + table: str, schema: str, database: Optional[str] = None +) -> List[Dict[str, Any]]: + """Get column information for a table.""" + from snowflake.connector import DictCursor + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + + if database: + cursor.execute(f"USE DATABASE {database}") + + cursor.execute(f"USE SCHEMA {schema}") + cursor.execute(f"DESCRIBE TABLE {table}") + results = cursor.fetchall() + + return [ + { + "name": row["name"], + "type": row["type"], + "nullable": row.get("null?", "Y") == "Y", + "default": row.get("default"), + "primary_key": row.get("primary key", "N") == "Y", + } + for row in results + ] + finally: + conn.close() + + +def preview_table( + table: str, schema: str, database: str, limit: int = 10 +) -> Dict[str, Any]: + """Preview rows from a table.""" + from snowflake.connector import DictCursor + + if limit > 1000: + limit = 1000 + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + cursor.execute(f"USE DATABASE {database}") + cursor.execute(f"USE SCHEMA {schema}") + + cursor.execute(f"DESCRIBE TABLE {table}") + columns = [row["name"] for row in cursor.fetchall()] + + cursor.execute(f"SELECT * FROM {table} LIMIT {limit}") + rows = cursor.fetchall() + + formatted_rows = [] + for row in rows: + formatted_row = {} + for col in columns: + value = row.get(col) + formatted_row[col] = str(value) if value is not None else None + formatted_rows.append(formatted_row) + + return { + "columns": columns, + "rows": formatted_rows, + "row_count": len(formatted_rows), + "limit": limit, + } + finally: + conn.close() + + +def select_rows( + database: str, + schema: str, + table: str, + columns: Optional[List[str]] = None, + where: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """Select rows from a table with optional filtering.""" + from snowflake.connector import DictCursor + + if limit > 1000: + limit = 1000 + if limit < 1: + limit = 1 + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + + col_list = "*" + if columns: + col_list = ", ".join([f'"{col}"' for col in columns]) + + query = f'SELECT {col_list} FROM "{database}"."{schema}"."{table}"' + + if where: + sanitized_where = _sanitize_where(where) + query += f" WHERE {sanitized_where}" + + query += f" LIMIT {limit}" + + cursor.execute(query) + results = cursor.fetchall() + + if not results: + return {"columns": [], "rows": [], "row_count": 0, "limit": limit} + + columns_list = list(results[0].keys()) + + return { + "columns": columns_list, + "rows": results, + "row_count": len(results), + "limit": limit, + } + finally: + conn.close() + + +def get_distinct_values( + database: str, + schema: str, + table: str, + column: str, + where: Optional[str] = None, + limit: int = 100, +) -> Dict[str, Any]: + """Get distinct values from a column.""" + from snowflake.connector import DictCursor + + if limit > 1000: + limit = 1000 + if limit < 1: + limit = 1 + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + + query = f'SELECT DISTINCT "{column}" FROM "{database}"."{schema}"."{table}"' + + if where: + sanitized_where = _sanitize_where(where) + query += f" WHERE {sanitized_where}" + + query += f" LIMIT {limit}" + + cursor.execute(query) + results = cursor.fetchall() + + values = [row[column] for row in results if column in row] + + return { + "column": column, + "values": values, + "count": len(values), + "limit": limit, + } + finally: + conn.close() + + +def count_rows( + database: str, + schema: str, + table: str, + where: Optional[str] = None, +) -> Dict[str, Any]: + """Count rows in a table with optional filtering.""" + from snowflake.connector import DictCursor + + conn = _get_connection() + try: + cursor = conn.cursor(DictCursor) + + query = f'SELECT COUNT(*) as row_count FROM "{database}"."{schema}"."{table}"' + + if where: + sanitized_where = _sanitize_where(where) + query += f" WHERE {sanitized_where}" + + cursor.execute(query) + result = cursor.fetchone() + + return { + "count": result["ROW_COUNT"] if result else 0, + "table": f'"{database}"."{schema}"."{table}"', + } + finally: + conn.close() diff --git a/mcp/start.sh b/mcp/start.sh new file mode 100755 index 00000000..b00b93c2 --- /dev/null +++ b/mcp/start.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# Launch the PhenEx Cohort Builder MCP server (stdio transport) +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" + +# Activate the project virtualenv +if [ -f "$REPO_ROOT/.venv/bin/activate" ]; then + source "$REPO_ROOT/.venv/bin/activate" +fi + +# Load .env if present +if [ -f "$SCRIPT_DIR/.env" ]; then + set -a + source "$SCRIPT_DIR/.env" + set +a +fi + +exec python "$SCRIPT_DIR/server.py" "$@" diff --git a/mcp/start_http.sh b/mcp/start_http.sh new file mode 100755 index 00000000..bf7f32be --- /dev/null +++ b/mcp/start_http.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Launch the PhenEx Cohort Builder MCP server with HTTP transport +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" + +# Activate the project virtualenv +if [ -f "$REPO_ROOT/.venv/bin/activate" ]; then + source "$REPO_ROOT/.venv/bin/activate" +fi + +# Load .env if present +if [ -f "$SCRIPT_DIR/.env" ]; then + set -a + source "$SCRIPT_DIR/.env" + set +a +fi + +export MCP_TRANSPORT="${MCP_TRANSPORT:-streamable-http}" +export MCP_HOST="${MCP_HOST:-0.0.0.0}" +export MCP_PORT="${MCP_PORT:-9000}" + +echo "Starting PhenEx MCP server on ${MCP_HOST}:${MCP_PORT} (${MCP_TRANSPORT})" +exec python "$SCRIPT_DIR/server.py" "$@" diff --git a/phenex/core/cohort.py b/phenex/core/cohort.py index 4d160f28..86088388 100644 --- a/phenex/core/cohort.py +++ b/phenex/core/cohort.py @@ -28,7 +28,7 @@ class Cohort: """ - The Cohort computes a cohort of individuals based on specified entry criteria, inclusions, exclusions, and computes baseline characteristics and outcomes from the extracted index dates. + Use Cohort to define and execute a patient cohort with entry criteria, inclusion/exclusion rules, baseline characteristics, and outcomes. It takes an entry criterion phenotype (which defines the index date), applies inclusion and exclusion phenotypes, then computes characteristics and outcome phenotypes for the patients passing the inclusion/exclusion rules. Parameters: name: A descriptive name for the cohort. diff --git a/phenex/core/study.py b/phenex/core/study.py index 35f9896c..210aa68b 100644 --- a/phenex/core/study.py +++ b/phenex/core/study.py @@ -15,9 +15,7 @@ class Study: """ - Orchestrates the execution of multiple cohorts and aggregates their reports. - - A Study manages the execution of one or more cohorts, automatically generating standardized reports (Waterfall, Table1) for each cohort and concatenating them into a single multi-sheet Excel file for easy comparison. Each execution creates a timestamped directory containing individual cohort outputs and a combined study results file. + Use Study to execute multiple cohorts together and produce a combined report. It runs each cohort, generates standardized reports (Waterfall, Table1), and concatenates them into a single multi-sheet Excel file in a timestamped output directory. Parameters: path: Base directory where study outputs will be saved. A subdirectory with the study name will be created if it doesn't exist. diff --git a/phenex/core/subcohort.py b/phenex/core/subcohort.py index e762ad7e..14d4c6ca 100644 --- a/phenex/core/subcohort.py +++ b/phenex/core/subcohort.py @@ -147,15 +147,7 @@ def to_json(self, path: str): class Subcohort(Cohort): """ - A Subcohort derives from a parent cohort and applies additional inclusion / - exclusion criteria. The subcohort inherits the entry criterion, inclusions, - exclusions, and outcomes from the parent cohort but can add additional - filtering criteria and outcomes. - - Like ``Cohort``, a ``Subcohort`` exposes a ``table1`` property that reports - baseline characteristics for the subcohort population. The characteristics - are taken from the parent cohort and their data are subset to the patients - that satisfy the subcohort's criteria. + Use Subcohort to derive a sub-population from an existing parent cohort by applying additional inclusion/exclusion criteria. It inherits the entry criterion, inclusions, exclusions, and characteristics from the parent cohort but lets you add further phenotype-based restrictions and additional outcomes. Parameters: name: A descriptive name for the subcohort. diff --git a/phenex/derived_tables/combine_overlapping_periods.py b/phenex/derived_tables/combine_overlapping_periods.py index c0f289ae..0cda6f93 100644 --- a/phenex/derived_tables/combine_overlapping_periods.py +++ b/phenex/derived_tables/combine_overlapping_periods.py @@ -18,7 +18,7 @@ def __init__( self, domain: str, categorical_filter: Optional["CategoricalFilter"] = None, - **kwargs + **kwargs, ): self.domain = domain self.categorical_filter = categorical_filter diff --git a/phenex/filters/categorical_filter.py b/phenex/filters/categorical_filter.py index 5420316d..bb96c75b 100644 --- a/phenex/filters/categorical_filter.py +++ b/phenex/filters/categorical_filter.py @@ -4,7 +4,7 @@ class CategoricalFilter(Filter): """ - This class filters events in an EventTable based on specified categorical values. + Use CategoricalFilter to restrict events or patients based on discrete categorical values in a column (e.g. "inpatient visits only", "female patients", "ICD-10 codes in primary diagnosis position"). Supports inclusion (isin), exclusion (notin), null checks (isnull, notnull). Attributes: column_name (str): The name of the column to filter by. diff --git a/phenex/filters/date_filter.py b/phenex/filters/date_filter.py index 66d5e94f..3b65ca06 100644 --- a/phenex/filters/date_filter.py +++ b/phenex/filters/date_filter.py @@ -21,36 +21,28 @@ def __init__(self, operator: str, value: Union[date, str], date_format="%Y-%m-%d class Before(Date): - """ - Represents a threshold where a date must be strictly before the specified value. - """ + """Use Before to specify that a date must be strictly before a given date (e.g. Before("2023-01-01") means < 2023-01-01). Pass to DateFilter as max_date to exclude the boundary date.""" def __init__(self, value: Union[date, str], **kwargs): super(Before, self).__init__("<", value) class BeforeOrOn(Date): - """ - Represents a threshold where a date must be on or before the specified value. - """ + """Use BeforeOrOn to specify that a date must be on or before a given date (e.g. BeforeOrOn("2023-12-31") means <= 2023-12-31). Pass to DateFilter as max_date to include the boundary date.""" def __init__(self, value: Union[date, str], **kwargs): super(BeforeOrOn, self).__init__("<=", value) class After(Date): - """ - Represents a threshold where a date must be strictly after the specified value. - """ + """Use After to specify that a date must be strictly after a given date (e.g. After("2020-01-01") means > 2020-01-01). Pass to DateFilter as min_date to exclude the boundary date.""" def __init__(self, value: Union[date, str], **kwargs): super(After, self).__init__(">", value) class AfterOrOn(Date): - """ - Represents a threshold where a date must be on or after the specified value. - """ + """Use AfterOrOn to specify that a date must be on or after a given date (e.g. AfterOrOn("2020-01-01") means >= 2020-01-01). Pass to DateFilter as min_date to include the boundary date.""" def __init__(self, value: Union[date, str], **kwargs): super(AfterOrOn, self).__init__(">=", value) @@ -62,12 +54,42 @@ def DateFilter( column_name: str = "EVENT_DATE", ): """ - DateFilter is a specialized ValueFilter for handling date-based filtering. + Use DateFilter to restrict events to an absolute date range (e.g. "events after 2020-01-01", "events between 2019 and 2023"). Specify min_date and/or max_date using After, AfterOrOn, Before, or BeforeOrOn. Parameters: min_date: The minimum date condition. Recommended to pass either After or AfterOrOn. max_date: The maximum date condition. Recommended to pass either Before or BeforeOrOn. column_name: The name of the column to apply the filter on. Defaults to "EVENT_DATE". + + Examples: + + Example: Events occurring after January 1, 2020 + ```python + from phenex.filters.date_filter import DateFilter, After + + date_filter = DateFilter( + min_date=After("2020-01-01") + ) + ``` + + Example: Events between 2019 and 2023 (inclusive) + ```python + from phenex.filters.date_filter import DateFilter, AfterOrOn, BeforeOrOn + + date_filter = DateFilter( + min_date=AfterOrOn("2019-01-01"), + max_date=BeforeOrOn("2023-12-31") + ) + ``` + + Example: Events strictly before a cutoff date + ```python + from phenex.filters.date_filter import DateFilter, Before + + date_filter = DateFilter( + max_date=Before("2022-06-01") + ) + ``` """ # For some reason, implementing DateFilter as a subclass of ValueFilter messes up the serialization. So instead we implement DateFilter as a function that looks like a class and just returns a ValueFilter instance. return ValueFilter(min_value=min_date, max_value=max_date, column_name=column_name) diff --git a/phenex/filters/relative_time_range_filter.py b/phenex/filters/relative_time_range_filter.py index 91bdccde..94948be1 100644 --- a/phenex/filters/relative_time_range_filter.py +++ b/phenex/filters/relative_time_range_filter.py @@ -9,7 +9,7 @@ class RelativeTimeRangeFilter(Filter): """ - This class filters events in an EventTable based on a specified time range relative to an anchor date. The anchor date can either be provided by an anchor phenotype or by an 'INDEX_DATE' column in the EventTable. + Use RelativeTimeRangeFilter to restrict events to a time window relative to an anchor date (e.g. "within 365 days before index date", "any time after index"). The anchor is either a specified phenotype's event date or the INDEX_DATE column if no anchor phenotype is provided (the latter is only possible in the context of a cohort which has defined an entry phenotype; when in doubt, specify the anchor phenotype explicitly). Parameters: min_days: Minimum number of days from the anchor date to filter events. diff --git a/phenex/filters/value.py b/phenex/filters/value.py index 35938306..02cac2ca 100644 --- a/phenex/filters/value.py +++ b/phenex/filters/value.py @@ -55,25 +55,35 @@ def to_dict(self): class GreaterThan(Value): + """Use GreaterThan to specify a strict lower bound on a numeric value (e.g. GreaterThan(0) means > 0). Use with ValueFilter or RelativeTimeRangeFilter to exclude the boundary value.""" + def __init__(self, value: int, **kwargs): super(GreaterThan, self).__init__(">", value) class GreaterThanOrEqualTo(Value): + """Use GreaterThanOrEqualTo to specify an inclusive lower bound on a numeric value (e.g. GreaterThanOrEqualTo(18) means >= 18). Use with ValueFilter or RelativeTimeRangeFilter to include the boundary value.""" + def __init__(self, value: int, **kwargs): super(GreaterThanOrEqualTo, self).__init__(">=", value) class LessThan(Value): + """Use LessThan to specify a strict upper bound on a numeric value (e.g. LessThan(365) means < 365). Use with ValueFilter or RelativeTimeRangeFilter to exclude the boundary value.""" + def __init__(self, value: int, **kwargs): super(LessThan, self).__init__("<", value) class LessThanOrEqualTo(Value): + """Use LessThanOrEqualTo to specify an inclusive upper bound on a numeric value (e.g. LessThanOrEqualTo(65) means <= 65). Use with ValueFilter or RelativeTimeRangeFilter to include the boundary value.""" + def __init__(self, value: int, **kwargs): super(LessThanOrEqualTo, self).__init__("<=", value) class EqualTo(Value): + """Use EqualTo to specify an exact numeric match (e.g. EqualTo(1) means = 1). Use with ValueFilter when the value must be exactly a specific number.""" + def __init__(self, value: int, **kwargs): super(EqualTo, self).__init__("=", value) diff --git a/phenex/filters/value_filter.py b/phenex/filters/value_filter.py index a683147c..4d121e9c 100644 --- a/phenex/filters/value_filter.py +++ b/phenex/filters/value_filter.py @@ -12,7 +12,7 @@ class ValueFilter(Filter): """ - ValueFilter filters events in an PhenexTable based on a specified value range. + Use ValueFilter to restrict events based on a numeric value range (e.g. "HbA1c > 7.0", "age between 18 and 65", "BMI ≥ 30"). Specify min_value and/or max_value using Value subclasses (GreaterThan, LessThanOrEqualTo, etc.) to define the boundaries. Parameters: min_value: Minimum value required to pass through the filter. @@ -21,6 +21,40 @@ class ValueFilter(Filter): Methods: filter: Filters the given PhenexTable based on the range of values specified by the min_value and max_value attributes. See Filter. + + Examples: + + Example: Filter for HbA1c values above 7.0 + ```python + from phenex.filters import ValueFilter + from phenex.filters.value import GreaterThan + + hba1c_filter = ValueFilter( + min_value=GreaterThan(7.0) + ) + ``` + + Example: Filter for age between 18 and 65 + ```python + from phenex.filters import ValueFilter + from phenex.filters.value import GreaterThanOrEqualTo, LessThanOrEqualTo + + age_filter = ValueFilter( + min_value=GreaterThanOrEqualTo(18), + max_value=LessThanOrEqualTo(65) + ) + ``` + + Example: Filter for BMI >= 30 on a custom column + ```python + from phenex.filters import ValueFilter + from phenex.filters.value import GreaterThanOrEqualTo + + bmi_filter = ValueFilter( + min_value=GreaterThanOrEqualTo(30), + column_name="BMI" + ) + ``` """ def __init__( diff --git a/phenex/phenotypes/age_phenotype.py b/phenex/phenotypes/age_phenotype.py index 0d496e96..564e2183 100644 --- a/phenex/phenotypes/age_phenotype.py +++ b/phenex/phenotypes/age_phenotype.py @@ -14,9 +14,11 @@ class AgePhenotype(Phenotype): """ - AgePhenotype is a class that represents an age-based phenotype. It calculates the age of individuals - based on their date of birth and an optional anchor phenotype. The age is computed in years and can - be filtered within a specified range. + Use AgePhenotype to compute patient age (in years) at a given reference date or to include/exclude patients based on age criteria (e.g. "patients aged 18-65 at index date"). Age is calculated from date of birth to the index date (or a custom anchor phenotype if supplied). + + For patients passing all filters, this phenotype returns: + DATE: The patient's date of birth. + VALUE: Age in years at the anchor date (index date if not anchor not explicitly supplied). Parameters: name: Name of the phenotype, default is 'age'. diff --git a/phenex/phenotypes/bin_phenotype.py b/phenex/phenotypes/bin_phenotype.py index 31255928..90020489 100644 --- a/phenex/phenotypes/bin_phenotype.py +++ b/phenex/phenotypes/bin_phenotype.py @@ -16,14 +16,11 @@ class BinPhenotype(Phenotype): """ - BinPhenotype converts values into categorical bin labels. Supports both continuous numeric binning and discrete value mapping. + Use BinPhenotype to convert numeric or discrete values into categorical groups for reporting (e.g. age groups "18-30", "31-45", "46-65" or BMI categories "normal", "overweight", "obese"). Takes another phenotype as input and maps its VALUE column into named bins. Use bins for numeric ranges or value_mapping for discrete-to-category mapping. - For continuous values: Takes a phenotype that returns numeric values (like age, measurements, etc.) and converts the VALUE column into bin labels like "[10-20)", "[20-30)", etc. - - For discrete values: Takes a phenotype that returns discrete values (like codes from CodelistPhenotype) and maps them to categorical labels using a bin mapping dictionary. - - DATE: The event date selected from the input phenotype - VALUE: A categorical variable representing the bin label + This phenotype returns: + DATE: The event date from the input phenotype. + VALUE: A label representing the bin. Parameters: name: The name of the phenotype. diff --git a/phenex/phenotypes/categorical_phenotype.py b/phenex/phenotypes/categorical_phenotype.py index 3ec3e337..68d4f5ce 100644 --- a/phenex/phenotypes/categorical_phenotype.py +++ b/phenex/phenotypes/categorical_phenotype.py @@ -35,14 +35,11 @@ def check_categorical_filters_share_same_domain(filter, domain): class CategoricalPhenotype(Phenotype): """ - CategoricalPhenotype is used for discrete entities such for sex, race, or ethnicity, diagnosis position, or encounter type. CategoricalPhenotypes are especially helpful as a baseline characteristic from PERSON like tables to identify demographic information. - - CategoricalPhenotype can be used to filter patients by a category, or to pull relevant categorical information. - - - DATE: Often null; only populated if the categorical value is associated with a date e.g.a categorical phenotype identifying all inpatient encounters in an event table - VALUE: The identified category from the source column. + Use CategoricalPhenotype to extract or filter by discrete categorical values such as race, ethnicity, encounter type, or diagnosis position. Most commonly used as a baseline characteristic to capture demographic information from PERSON-like tables, or to filter events by category (e.g. "inpatient encounters only"). For sex specifically, use SexPhenotype instead (a convenience subclass). + This phenotype returns: + DATE: Often null; only populated if the categorical value is associated with a date (e.g. inpatient encounters). + VALUE: The identified category from the source column. Parameters: name: Name of the phenotype. diff --git a/phenex/phenotypes/codelist_phenotype.py b/phenex/phenotypes/codelist_phenotype.py index f8851a97..ff9df710 100644 --- a/phenex/phenotypes/codelist_phenotype.py +++ b/phenex/phenotypes/codelist_phenotype.py @@ -10,8 +10,11 @@ class CodelistPhenotype(EventPhenotype): """ - CodelistPhenotype extracts patients from a CodeTable based on a specified codelist and - other optional filters such as date range, relative time range and categorical filters. + Use CodelistPhenotype to identify patients who have a specific diagnosis, procedure, or drug exposure based on medical codes (e.g. ICD-10, CPT, RxNorm). This is the most commonly used phenotype — use it whenever the clinical concept you need can be defined by a set of codes in a codelist. Supports filtering by date range, relative time range, and categorical filters (e.g. inpatient / outpatient). + + For patients passing all filters, this phenotype returns: + DATE: The date of the matching event (first, last, nearest, or all depending on return_date). + VALUE: The matched medical code if return_value='all'; otherwise not populated. Parameters: domain: The domain of the phenotype. diff --git a/phenex/phenotypes/computation_graph_phenotypes.py b/phenex/phenotypes/computation_graph_phenotypes.py index a3478b4f..339b9738 100644 --- a/phenex/phenotypes/computation_graph_phenotypes.py +++ b/phenex/phenotypes/computation_graph_phenotypes.py @@ -220,9 +220,11 @@ def _perform_value_filtering(self, table: Table) -> Table: class ScorePhenotype(ComputationGraphPhenotype): """ - ScorePhenotype is a CompositePhenotype that performs arithmetic operations using the **boolean** column of its component phenotypes and populations the **value** column. It should be used for calculating medical scores such as CHADSVASC, HASBLED, etc. + Use ScorePhenotype to calculate clinical risk scores that are sums of yes/no criteria (e.g. CHA₂DS₂-VASc, HAS-BLED, Charlson Comorbidity Index). Each component phenotype contributes its boolean (present/absent) value, optionally multiplied by a weight. The expression defines the arithmetic: e.g. 2*age_gt_75 + hypertension + diabetes. - --> See the comparison table of CompositePhenotype classes + This phenotype returns: + DATE: The date from return_date (first, last, or from a specified phenotype). + VALUE: The computed score (sum of weighted boolean components). Parameters: expression: The arithmetic expression to be evaluated composed of phenotypes combined by python arithmetic operations. @@ -283,8 +285,11 @@ def repr_short(self, level=0): class ArithmeticPhenotype(ComputationGraphPhenotype): """ - ArithmeticPhenotype is a composite phenotype that performs arithmetic operations using the **value** column of its component phenotypes and populations the **value** column. It should be used for calculating values such as BMI, GFR or converting units. - --> See the comparison table of CompositePhenotype classes + Use ArithmeticPhenotype to compute derived numeric values from other phenotypes' values (e.g. BMI = weight / height², eGFR from creatinine, unit conversions). Each component phenotype contributes its numeric VALUE column, and the expression defines the arithmetic. + + This phenotype returns: + DATE: The date from return_date (first, last, or from a specified phenotype). + VALUE: The computed numeric result of the arithmetic expression. Parameters: expression: The arithmetic expression to be evaluated composed of phenotypes combined by python arithmetic operations. @@ -335,9 +340,11 @@ def __init__( class LogicPhenotype(ComputationGraphPhenotype): """ - LogicPhenotype is a composite phenotype that performs boolean operations using the **boolean** column of its component phenotypes and populations the **boolean** column of the resulting phenotype table. It should be used in any instance where multiple phenotypes are logically combined, for example, does a patient have diabetes AND hypertension, etc. + Use LogicPhenotype to combine multiple phenotypes with boolean logic (AND, OR, NOT). Use it when inclusion/exclusion criteria require compound conditions, e.g. "patients with diabetes AND hypertension", "patients with condition A OR condition B", "patients with diagnosis but NOT on treatment". - --> See the comparison table of CompositePhenotype classes + This phenotype returns: + DATE: The date from return_date (first, last, or from a specified phenotype). + VALUE: Not populated. The BOOLEAN column indicates the result of the logical expression. Parameters: expression: The logical expression to be evaluated composed of phenotypes combined by python arithmetic operations. diff --git a/phenex/phenotypes/death_phenotype.py b/phenex/phenotypes/death_phenotype.py index 6da38027..38d6cf10 100644 --- a/phenex/phenotypes/death_phenotype.py +++ b/phenex/phenotypes/death_phenotype.py @@ -11,8 +11,11 @@ class DeathPhenotype(Phenotype): """ - DeathPhenotype is a class that represents a death-based phenotype. It filters individuals - who have died and returns their date of death. + Use DeathPhenotype to identify patients who have died, as an outcome (e.g. "all-cause mortality") or as an inclusion/exclusion criterion. Returns the date of death. Can be combined with relative_time_range to restrict to deaths within a specific window (e.g. "death within 30 days of index"). + + This phenotype returns: + DATE: Date of death. + VALUE: Not populated (null). Parameters: name: Name of the phenotype, default is 'death'. @@ -35,7 +38,7 @@ def __init__( relative_time_range: Union[ RelativeTimeRangeFilter, List[RelativeTimeRangeFilter] ] = None, - **kwargs + **kwargs, ): super(DeathPhenotype, self).__init__(name=name, **kwargs) self.domain = domain diff --git a/phenex/phenotypes/event_count_phenotype.py b/phenex/phenotypes/event_count_phenotype.py index d53462f1..2a22447a 100644 --- a/phenex/phenotypes/event_count_phenotype.py +++ b/phenex/phenotypes/event_count_phenotype.py @@ -13,15 +13,11 @@ class EventCountPhenotype(Phenotype): """ - EventCountPhenotype counts the number of events that occur on distinct days. It is additionally able to filter patients based on: - 1. the number of distinct days an event occurred, by setting value_filter - 2. the number of days between pairs of events + Use EventCountPhenotype when you need to count how many times an event occurred (e.g. "≥2 diagnoses on distinct days", "≥3 prescriptions within 90 days") or to require a minimum number of events for inclusion. Takes another phenotype as input (which must have return_date='all') and counts events on distinct days. Use value_filter to set minimum/maximum event count thresholds. Use relative_time_range to require minimum days between event pairs. - EventCountPhenotype is a composite phenotype, meaning that it does not directly operate on source data and takes a phenotype as an argument. The phenotype passed to EventCountPhenotype must have return_date set to 'all' (if return_date on the provided phenotype is set to `first` or `last`, there will only be one event per patient...) - - - DATE: The event date selected based on `component_date_select` and `return_date` parameters. `return_date` returns multiple rows per patient for all events that fulfill criteria. `return_date` first is the first fulfilling event date, last the last. If component_date_select = 'first' the returned date is a pair of events, if component_date_select = 'second' we return the second of a pair of events. - VALUE: The number of days that the phenotype of interest has occurred i.e. if 4, that means the phenotype has occurred on 4 distinct days. + This phenotype returns: + DATE: The event date selected based on component_date_select and return_date parameters. + VALUE: The number of distinct days on which the event occurred. Parameters: name: The name of the phenotype. diff --git a/phenex/phenotypes/measurement_change_phenotype.py b/phenex/phenotypes/measurement_change_phenotype.py index 1c95f0eb..d2346652 100644 --- a/phenex/phenotypes/measurement_change_phenotype.py +++ b/phenex/phenotypes/measurement_change_phenotype.py @@ -10,7 +10,11 @@ class MeasurementChangePhenotype(Phenotype): """ - MeasurementChangePhenotype looks for changes in the value of a MeasurementPhenotype within a certain time period. Returns EVENT_DATE as either the date of the first or second MeasurementPhenotype event and VALUE as the observed change in the underlying MeasurementPhenotype's VALUE. + Use MeasurementChangePhenotype when you need to detect a change in a lab value or vital sign over time (e.g. "HbA1c decrease of ≥1% within 6 months", "weight gain of >5kg"). Takes a MeasurementPhenotype as input and compares pairs of measurements, returning the magnitude of change. Use direction to specify increase or decrease, and min_change/max_change to set thresholds. + + This phenotype returns: + DATE: The date of the first or second measurement event (based on component_date_select). + VALUE: The magnitude of change in the measurement value between the two events. Parameters: name: The name of the phenotype. diff --git a/phenex/phenotypes/measurement_phenotype.py b/phenex/phenotypes/measurement_phenotype.py index 29e11e3d..b38b0763 100644 --- a/phenex/phenotypes/measurement_phenotype.py +++ b/phenex/phenotypes/measurement_phenotype.py @@ -11,14 +11,16 @@ class MeasurementPhenotype(CodelistPhenotype): """ - # What is MeasurementPhenotype for? - The MeasurementPhenotype is for manipulating numerical data found in RWD data sources e.g. laboratory or observation results. These tables often contain numerical values (height, weight, blood pressure, lab results). As an event-based table, each row records a single measurement value for a single patient with a date. All numerical values are in a 'value' column. A medical code indicates the type of numerical measurement and the units of measurement are in an additional column. + Use MeasurementPhenotype when you need to work with numeric lab values, vitals, or observation results (e.g. HbA1c > 7%, systolic blood pressure, BMI, eGFR). It identifies measurements by medical code (like CodelistPhenotype) and additionally lets you filter by value range (value_filter) and return the numeric value — either a single value (first, last, nearest) or an aggregation (mean, median, max, min). Use this whenever the clinical concept involves a numeric threshold or you need the measurement value itself. - MeasurementPhenotype is a subclass of CodelistPhenotype, inheriting all of its functionality to identify patients by single or sets of medical codes (e.g. 'test type') within a specified time period. It can also : + MeasurementPhenotype is a subclass of CodelistPhenotype, inheriting all of its code-matching and time-filtering functionality. It additionally supports: - - identify patients with a measurement value within a value range and - - return a measurement value, either all measurements values within filter - criteria or perform simple aggregations (mean, median, max, min). + - filtering patients by measurement value within a range (value_filter) + - returning measurement values with optional aggregation (mean, median, max, min) + + This phenotype returns: + DATE: The date of the matching measurement event. + VALUE: The numeric measurement value (optionally aggregated via mean, median, max, min). # Example data: diff --git a/phenex/phenotypes/sex_phenotype.py b/phenex/phenotypes/sex_phenotype.py index 3a30c6db..6b2fab4a 100644 --- a/phenex/phenotypes/sex_phenotype.py +++ b/phenex/phenotypes/sex_phenotype.py @@ -5,7 +5,11 @@ class SexPhenotype(CategoricalPhenotype): """ - SexPhenotype represents a sex-based phenotype. It returns the sex of individuals in the VALUE column and optionally filters based on identified sex. DATE is not defined for SexPhenotype. + Use SexPhenotype to retrieve patient sex as a baseline characteristic or to filter a cohort by sex (e.g. "female patients only"). Returns the sex value from the PERSON table. Use categorical_filter to restrict to specific sex values. + + This phenotype returns: + DATE: Not populated (null). + VALUE: The sex value from the PERSON table. Parameters: name: Name of the phenotype, default is 'SEX'. @@ -42,7 +46,7 @@ def __init__( name: str = "SEX", domain: str = "PERSON", categorical_filter: "CategoricalFilter" = None, - **kwargs + **kwargs, ): if categorical_filter is None: categorical_filter = CategoricalFilter(column_name="SEX") diff --git a/phenex/phenotypes/time_range_count_phenotype.py b/phenex/phenotypes/time_range_count_phenotype.py index bad1b9af..05b2f21b 100644 --- a/phenex/phenotypes/time_range_count_phenotype.py +++ b/phenex/phenotypes/time_range_count_phenotype.py @@ -18,16 +18,11 @@ class TimeRangeCountPhenotype(Phenotype): """ - TimeRangeCountPhenotype works with time range tables i.e. the input table must have a START_DATE and END_DATE column (in addition to PERSON_ID). It counts the number of distinct time ranges for each person, either total or within a specified date range (relative or absolute). If no relative_time_range defined, it returns the number of time periods per person. If relative_time_range is defined, it counts the number of time periods before or after (depending on when keyword argument of relative_time_range), NOT including the time period defined by the relative_time_range anchor. + Use TimeRangeCountPhenotype to count the number of distinct episodes or periods for each patient from a table with START_DATE and END_DATE columns. Common use cases: count hospitalizations in the post-index period, count drug exposure episodes, or require a minimum number of coverage periods. Use value_filter to set thresholds (e.g. "≥2 hospitalizations"). - If min_days or max_days of the relative_time_range are defined, the entire time period must be included in the relative time range i.e. if before, the start date of all time periods must be contained within the time range. - - This can be used : - - given an admission discharge table, to count the number of hospitalizations that occurred e.g. in the post index period - - given a drug exposure table, to count the number of times a person has taken a medication - - DATE: Date is always null - VALUE: Number of distinct time periods in the specified time range. + This phenotype returns: + DATE: Not populated (null). + VALUE: Number of distinct time periods in the specified time range. Parameters: domain: The domain of the phenotype. diff --git a/phenex/phenotypes/time_range_day_count_phenotype.py b/phenex/phenotypes/time_range_day_count_phenotype.py index a18d051f..42b6b6be 100644 --- a/phenex/phenotypes/time_range_day_count_phenotype.py +++ b/phenex/phenotypes/time_range_day_count_phenotype.py @@ -18,14 +18,11 @@ class TimeRangeDayCountPhenotype(Phenotype): """ - TimeRangeDayCountPhenotype works with time range tables i.e. the input table must have a START_DATE and END_DATE column (in addition to PERSON_ID). It counts the **total number of days** within time ranges for each person, either total or within a specified date range (relative or absolute (TODO)). If no relative_time_range is defined, it returns the total number of days across all time periods per person. If relative_time_range is defined, it counts the number of days before or after (depending on when keyword argument of relative_time_range), INCLUDING the time period that contains the anchor date. + Use TimeRangeDayCountPhenotype to count the total number of days across time range episodes for each patient (e.g. "total days hospitalized in the year after index", "total days of drug exposure"). Works with tables that have START_DATE and END_DATE columns. Unlike TimeRangeCountPhenotype which counts episodes, this counts the sum of days across all episodes. - This can be used : - - given an admission discharge table, to count the total number of days hospitalized e.g. in the post index period - - given a drug exposure table, to count the total number of days of drug exposure - - DATE: Date is always null - VALUE: Total number of days across all time periods in the specified time range. + This phenotype returns: + DATE: Not populated (null). + VALUE: Total number of days across all time periods in the specified time range. Parameters: domain: The domain of the phenotype. diff --git a/phenex/phenotypes/time_range_days_to_next_range_phenotype.py b/phenex/phenotypes/time_range_days_to_next_range_phenotype.py index 415464bf..38be0389 100644 --- a/phenex/phenotypes/time_range_days_to_next_range_phenotype.py +++ b/phenex/phenotypes/time_range_days_to_next_range_phenotype.py @@ -10,18 +10,11 @@ class TimeRangeDaysToNextRange(Phenotype): """ - TimeRangeDaysToNextRange identifies the time range that contains the anchor phenotype, - then finds the adjacent time range (next or previous) and counts the days difference between them. + Use TimeRangeDaysToNextRange to measure the gap (in days) between consecutive time range episodes (e.g. "days from hospital discharge to next readmission", "gap between coverage periods"). Finds the time range containing the anchor date, then measures the gap to the next (or previous) adjacent time range. - If relative_time_range.when is 'after' (default): - Finds the next consecutive time range. - VALUE: Days difference between the end of anchored time range and start of next time range. - EVENT_DATE: The start date of the next consecutive time range. - - If relative_time_range.when is 'before': - Finds the previous consecutive time range. - VALUE: Days difference between the start of the anchored time range and the end of the previous time range. - EVENT_DATE: The end date of the previous consecutive time range. + This phenotype returns: + DATE: If when='after', start date of the next period. If when='before', end date of the previous period. + VALUE: Days between the current period and the adjacent period. Example: Count number days to next hospitalization after index date hospitalization @@ -45,7 +38,7 @@ def __init__( name: Optional[str] = None, relative_time_range: Optional[RelativeTimeRangeFilter] = None, value_filter: Optional[ValueFilter] = None, - **kwargs + **kwargs, ): super().__init__(name=name, **kwargs) self.domain = domain diff --git a/phenex/phenotypes/time_range_phenotype.py b/phenex/phenotypes/time_range_phenotype.py index 3f7a0b62..10fe1575 100644 --- a/phenex/phenotypes/time_range_phenotype.py +++ b/phenex/phenotypes/time_range_phenotype.py @@ -10,14 +10,11 @@ class TimeRangePhenotype(Phenotype): """ - As the name implies, TimeRangePhenotype is designed for working with time ranges. If the input data has a start and an end date, use TimeRangePhenotype to identify other events (or patients) that occur within this time range. The most common use case of this is working with 'health insurance coverage' data i.e. on 'OBSERVATION_PERIOD' table. These tables have one or many rows per patient with the start of coverage and end of coverage i.e. domains compatible with TimeRangePhenotype require a START_DATE and an END_DATE column. At it's simplest, TimeRangePhenotype identifies patients who have their INDEX_DATE (or other anchor date of interest) within this time range. Additionally, a minimum or maximum number of days from the anchor date to the beginning/end of the time range can be defined. The returned Phenotype has the following interpretation: + Use TimeRangePhenotype to work with data that has start and end dates — most commonly health insurance coverage (OBSERVATION_PERIOD). The two primary use cases are: (1) require minimum continuous enrollment/coverage before or after index date (e.g. "1 year of continuous coverage prior to index"), and (2) determine the date of loss to follow-up (right censoring). The input domain must have START_DATE and END_DATE columns. - DATE: If relative_time_range.when='before', then DATE is the beginning of the coverage period containing the anchor phenotype. If relative_time_range.when='after', then DATE is the end of the coverage period containing the anchor date. - VALUE: Coverage (in days) relative to the anchor date. By convention, always non-negative. - - There are two primary use cases for TimeRangePhenotype: - 1. Identify patients with some minimum duration of coverage prior to anchor_phenotype date e.g. "identify patients with 1 year of continuous coverage prior to index date" - 2. Determine the date of loss to followup (right censoring) i.e. the duration of coverage after the anchor_phenotype event + This phenotype returns: + DATE: If when='before', the start of the coverage period containing the anchor. If 'after', the end of that period. + VALUE: Coverage in days relative to the anchor date (always non-negative). ## Data for TimeRangePhenotype This phenotype requires a table with PersonID and a coverage start date and end date. Depending on the datasource used, this information is a separate ObservationPeriod table or found in the PersonTable. Use an PhenexObservationPeriodTable to map required coverage start and end date columns. For tables with overlapping time ranges, use the CombineOverlappingPeriods derived table to combine time ranges into a single time range. @@ -75,7 +72,7 @@ def __init__( date_range: Optional[DateFilter] = None, relative_time_range: Optional["RelativeTimeRangeFilter"] = None, allow_null_end_date: bool = True, - **kwargs + **kwargs, ): super(TimeRangePhenotype, self).__init__(name=name, **kwargs) self.domain = domain diff --git a/phenex/phenotypes/user_defined_phenotype.py b/phenex/phenotypes/user_defined_phenotype.py index 8813f675..a5213a28 100644 --- a/phenex/phenotypes/user_defined_phenotype.py +++ b/phenex/phenotypes/user_defined_phenotype.py @@ -21,17 +21,15 @@ def UserDefinedPhenotype( returns_value: bool = False, ): """ - UserDefinedPhenotype allows users of PhenEx to implement custom functionality within a single phenotype. To use, the user must pass a function that returns an ibis table. This means that the function must - 1. return an ibis table - 2. There are a minimum of one column : PERSON_ID. If no other columns are returned, it is assumed that all person_ids in the PERSON_ID column fulfill the UserDefinedPhenotype - 3. If additional columns are returned, they must be named BOOLEAN, EVENT_DATE, and VALUE. The BOOLEAN column indicates whether the person_id fulfills the UserDefinedPhenotype; patients with BOOLEAN = False will be removed. The EVENT_DATE column contains the date of the event, and the VALUE column contains a numeric value associated with the event. Any other columns are ignored. + Use UserDefinedPhenotype as an escape hatch when no built-in phenotype covers your use case. Pass a custom function that receives the mapped tables and returns an ibis table with at minimum a PERSON_ID column. Two main scenarios: (1) hybrid workflows where cohort extraction was done outside PhenEx (e.g. in R or SQL) and you want to inject those results as an entry criterion, and (2) complex custom event logic that cannot be expressed with built-in phenotypes. - UserDefinedPhenotype is especially useful for two use cases : - 1. Hybrid workflows: If you have performed cohort extraction outside of PhenEx (e.g. in R, SQL) but would like to use PhenEx to calculate baseline characteristics and outcomes, we can set the entry criterion to a UserDefinedPhenotype and read a dataframe of PERSON_IDS and INDEX_DATES. In this way, PhenEx flexibly allows us to use multiple tools in our analysis. - 2. Custom event definitions: If you need to define events based on complex logic that is not easily expressed using the built-in PhenEx functionality, you can use UserDefinedPhenotype to implement this logic in a custom function. + The function must return an ibis table with: + 1. PERSON_ID column (required). If no other columns, all person_ids are assumed to fulfill the phenotype. + 2. Optional BOOLEAN, EVENT_DATE, VALUE columns. BOOLEAN=False patients are excluded. Any other columns are ignored. - DATE: custom, as defined by user - VALUE: custom, as defined by user + This phenotype returns: + DATE: Custom, as defined by the user function. + VALUE: Custom, as defined by the user function. Parameters: name: The name of the phenotype. diff --git a/phenex/phenotypes/within_same_encounter_phenotype.py b/phenex/phenotypes/within_same_encounter_phenotype.py index cde0692a..be7bced2 100644 --- a/phenex/phenotypes/within_same_encounter_phenotype.py +++ b/phenex/phenotypes/within_same_encounter_phenotype.py @@ -4,7 +4,11 @@ class WithinSameEncounterPhenotype(Phenotype): """ - WithinSameEncounterPhenotype is a phenotype that filters a target phenotype based on the occurrence of an anchor phenotype within the same encounter. This phenotype can only be used with CodelistPhenotypes and MeasurementPhenotypes as anchor/phenotype. + Use WithinSameEncounterPhenotype when two events must co-occur within the same encounter/visit (e.g. "diagnosis of sepsis during a hospitalization where a blood culture was performed", "lab test during the same visit as a procedure"). Links a target phenotype to an anchor phenotype via a shared encounter identifier. Only works with CodelistPhenotype and MeasurementPhenotype inputs. + + This phenotype returns: + DATE: The event date from the target phenotype. + VALUE: The value from the target phenotype (if any). Parameters: name: The name of the phenotype. Optional. If not passed, name will be derived from the name of the codelist. diff --git a/phenex/reporting/cohort_explorer.py b/phenex/reporting/cohort_explorer.py index 85d3fe4e..8a216ec0 100644 --- a/phenex/reporting/cohort_explorer.py +++ b/phenex/reporting/cohort_explorer.py @@ -41,23 +41,28 @@ class CohortExplorer(Reporter): """ - Interactive dashboard for exploring cohort phenotypes and their distributions. - - This reporter creates an interactive Bokeh dashboard that allows users to: - - Select different phenotypes from cohort.phenotypes - - Explore VALUE column distributions with histograms - - View timeline patterns when EVENT_DATE is available - - Compare raw vs standardized values across phenotypes - - Examine event frequency per patient - - The implementation follows the working callback example pattern to ensure - JavaScript callbacks function properly in both Jupyter and exported HTML. + Use CohortExplorer to create an interactive Bokeh dashboard for exploring cohort phenotype distributions. It lets you browse phenotypes, view VALUE histograms, examine EVENT_DATE timelines, compare raw vs standardized values, and inspect per-patient event frequencies. Parameters: title: Dashboard title width: Dashboard width in pixels height: Plot height in pixels decimal_places: Number of decimal places for display (inherited) + + Examples: + + Example: Explore a cohort interactively + ```python + from phenex.reporting import CohortExplorer + + explorer = CohortExplorer(title="My Cohort") + cohort = Cohort( + ..., + custom_reporters=[explorer] + ) + cohort.execute(tables) + explorer.to_html("explorer.html") # open in browser + ``` """ def __init__( diff --git a/phenex/reporting/counts.py b/phenex/reporting/counts.py index 9fc040dc..d93b00b2 100644 --- a/phenex/reporting/counts.py +++ b/phenex/reporting/counts.py @@ -5,8 +5,17 @@ class InExCounts(Reporter): """ - Get counts of inclusion and exclusion criteria + Use InExCounts to get raw counts for each inclusion and exclusion criterion in a cohort. It produces a simple table with the number of patients satisfying each criterion independently. + Examples: + + Example: Get inclusion/exclusion counts + ```python + from phenex.reporting import InExCounts + + counts = InExCounts() + df = counts.execute(cohort) + ``` """ def execute(self, cohort: "Cohort") -> pd.DataFrame: diff --git a/phenex/reporting/report_drafter.py b/phenex/reporting/report_drafter.py index 86828ad5..d999447e 100644 --- a/phenex/reporting/report_drafter.py +++ b/phenex/reporting/report_drafter.py @@ -44,22 +44,13 @@ class ReportDrafter(Reporter): """ - The ReportDrafter creates comprehensive draft study reports including: - - Cohort definition description (entry, inclusion, exclusion criteria) - - Data analysis description and date ranges - - Waterfall table showing patient attrition - - Study variables (characteristics and outcomes) - - Table 1 (baseline characteristics) - - Table 2 (outcomes analysis) - - AI-generated descriptive text and figure captions (when AI is enabled) + Use ReportDrafter to generate a comprehensive draft study report from a cohort, including cohort definition, waterfall table, baseline characteristics (Table 1), outcomes analysis (Table 2), and optionally AI-generated descriptive text. Reports are exported in editable Markdown and Word formats and require human review before use. **IMPORTANT: Human-in-the-Loop Required** - The ReportDrafter generates DRAFT reports that require human review and editing before use. Reports are exported in editable formats (Markdown and Word) specifically to enable human oversight and refinement. AI-generated content should be verified for: - - Clinical accuracy and appropriateness - - Study-specific context and nuances - - Compliance with institutional guidelines - - Proper medical terminology and phrasing + The ReportDrafter generates DRAFT reports that require human review and editing. + AI-generated content should be verified for clinical accuracy, study-specific + context, and compliance with institutional guidelines. **Never use generated reports without thorough human review and approval.** diff --git a/phenex/reporting/table1.py b/phenex/reporting/table1.py index ab15b56a..9d3da446 100644 --- a/phenex/reporting/table1.py +++ b/phenex/reporting/table1.py @@ -34,9 +34,7 @@ def __getattr__(self, name: str): class Table1(Reporter): """ - Table1 is a common term used in epidemiology to describe a table that shows an overview of the baseline characteristics of a cohort. It contains the counts and percentages of the cohort that have each characteristic, for both boolean and value characteristics. In addition, summary statistics are provided for value characteristics (mean, std, median, min, max). - - Table1 by default reports on all phenotypes in the cohort's characteristics, but a custom list of phenotypes can be provided to the execute() method. When using the default cohort.characteristics, the section structure defined on the cohort is preserved in the Table1 output for better organization and display. + Use Table1 to generate a baseline characteristics summary table for a cohort, showing counts, percentages, and summary statistics (mean, std, median, min, max) for each characteristic. It reports on all phenotypes in the cohort's characteristics by default, preserving any section structure defined on the cohort. Parameters: decimal_places: Number of decimal places to round to. Default: 1 @@ -44,6 +42,26 @@ class Table1(Reporter): (child) phenotypes are expanded inline beneath each parent phenotype, indented according to their nesting depth. ``None`` (default) disables expansion. Set to a large number (e.g. 100) to include all levels. + + Examples: + + Example: Access Table1 after cohort execution + ```python + from phenex.reporting import Table1 + + # Table1 is generated automatically during cohort.execute() + df = cohort.table1 # DataFrame with characteristics summary + cohort.write_reports_to_excel("./output") # writes table1.xlsx + ``` + + Example: Use as a custom reporter with component phenotype expansion + ```python + table1 = Table1(include_component_phenotypes_level=100) + cohort = Cohort( + ..., + custom_reporters=[table1] + ) + ``` """ def __init__(self, include_component_phenotypes_level=None, **kwargs): diff --git a/phenex/reporting/table2.py b/phenex/reporting/table2.py index e88076f3..74c4026d 100644 --- a/phenex/reporting/table2.py +++ b/phenex/reporting/table2.py @@ -13,15 +13,7 @@ class Table2(Reporter): """ - Table2 generates outcome incidence rates and event counts for a cohort at specified time points. - - For each outcome, reports: - - N events in the cohort - - N censored patients (patients whose follow-up was cut short) - - Time under risk in patient-years (accounting for censoring) - - Incidence rate per 100 patient-years - - Time under risk accounts for censoring from competing events (e.g., death) and administrative censoring at end of study period. + Use Table2 to generate outcome incidence rates and event counts for a cohort at specified time points. It reports N events, N censored patients, time under risk in patient-years, and incidence rate per 100 patient-years, accounting for competing-event and administrative censoring. Parameters: time_points: List of days from index to evaluate outcomes (e.g., [90, 365]) diff --git a/phenex/reporting/time_to_event.py b/phenex/reporting/time_to_event.py index ecaae18a..a72555bb 100644 --- a/phenex/reporting/time_to_event.py +++ b/phenex/reporting/time_to_event.py @@ -20,20 +20,7 @@ class TimeToEvent(Reporter): """ - Perform a time to event analysis using Kaplan-Meier estimation. - - This reporter generates: - 1. A private patient-level time-to-event table (_tte_table) for intermediate processing - 2. Aggregated survival/risk data in self.df combining results from all outcomes - 3. Kaplan-Meier survival curves - - The patient-level table (_tte_table) contains one row per patient with: - - Index date for each patient - - Event dates for all outcomes (NULL if did not occur) - - Event dates for all right censoring events (NULL if did not occur) - - End of study period date (if provided) - - Days from index to each event - - Indicator variables for whether the first event was the outcome of interest + Use TimeToEvent to perform Kaplan-Meier survival analysis on cohort outcomes. It generates patient-level time-to-event data, aggregated survival/risk tables, and Kaplan-Meier survival curves accounting for right-censoring events and administrative censoring. The aggregated output (self.df) contains survival function estimates and event counts for each outcome, suitable for reporting and visualization. @@ -43,6 +30,24 @@ class TimeToEvent(Reporter): Suggested are death and end of followup. end_of_study_period: A datetime defining the end of study period. decimal_places: Number of decimal places for rounding survival probabilities. Default: 4 + + Examples: + + Example: Basic time-to-event analysis + ```python + from phenex.reporting import TimeToEvent + + tte = TimeToEvent( + right_censor_phenotypes=[death_phenotype], + end_of_study_period=datetime(2023, 12, 31) + ) + cohort = Cohort( + ..., + custom_reporters=[tte] + ) + cohort.execute(tables) + df = tte.df # aggregated survival data + ``` """ def __init__( diff --git a/phenex/reporting/treatment_pattern_analysis_sankey.py b/phenex/reporting/treatment_pattern_analysis_sankey.py index 2be98d27..f452bbb3 100644 --- a/phenex/reporting/treatment_pattern_analysis_sankey.py +++ b/phenex/reporting/treatment_pattern_analysis_sankey.py @@ -163,18 +163,7 @@ def build(self): class TreatmentPatternAnalysisSankeyReporter(_TreatmentPatternAnalysisMixin, Reporter): """ - Reporter that produces a d3-sankey diagram showing patient flow between - treatment regimen combinations across consecutive time periods. - - The reporter automatically discovers every TreatmentPatternAnalysis group - present in ``cohort.characteristics`` and ``cohort.outcomes`` by reading the - ``_tpa_name`` / ``_tpa_period_num`` / ``_tpa_period_label`` attributes set by - :class:`~phenex.phenotypes.factory.TreatmentPatternAnalysis`, with a regex - fallback for manually named phenotypes. - - One :class:`SankeyGenerator` is created per group. Each generator fetches - patient IDs from the already-executed phenotype tables and computes cross-period - flows by set intersection. + Use TreatmentPatternAnalysisSankeyReporter to produce a d3-sankey diagram showing patient flow between treatment regimen combinations across consecutive time periods. It automatically discovers TreatmentPatternAnalysis groups from cohort characteristics/outcomes and computes cross-period flows by set intersection. Outputs ------- diff --git a/phenex/reporting/waterfall.py b/phenex/reporting/waterfall.py index ebbc62a5..8d0a48d1 100644 --- a/phenex/reporting/waterfall.py +++ b/phenex/reporting/waterfall.py @@ -9,17 +9,44 @@ class Waterfall(Reporter): """ - A waterfall diagram, also known as an attrition table, shows how inclusion/exclusion criteria contribute to a final population size. Each inclusion/exclusion criteria is a row in the table, and the number of patients remaining after applying that criteria are shown on that row. + Use Waterfall to generate an attrition table showing how each inclusion/exclusion criterion contributes to the final cohort size. Each row shows a criterion with the absolute count fulfilling it, the cumulative remaining count, the percentage retained, and the delta from the previous step. | Column name | Description | | --- | --- | - | Type | The type of the phenotype, either entry, inclusion or exclusion | - | Name | The name of entry, inclusion or exclusion criteria | - | N | The absolute number of patients that fulfill that phenotype. For the entry criterium this is the absolute number in the dataset. For inclusion/exclusion criteria this is the number of patients that fulfill the entry criterium AND the phenotype and that row. | - | Remaining | The number of patients remaining in the cohort after sequentially applying the inclusion/exclusion criteria in the order that they are listed in this table. | - | % | The percentage of patients who fulfill the entry criterion who are remaining in the cohort after application of the phenotype on that row | - | Delta | The change in number of patients that occurs by applying the phenotype on that row. | - + | Type | The type of the phenotype: entry, inclusion, or exclusion | + | Name | The name of the entry, inclusion, or exclusion criterion | + | N | The absolute number of patients that fulfill that phenotype | + | Remaining | The number of patients remaining after sequentially applying criteria | + | % | The percentage of entry patients remaining after applying this criterion | + | Delta | The change in number of patients caused by applying this criterion | + + Parameters: + decimal_places: Number of decimal places to round to. Default: 1 + include_component_phenotypes_level: When set to an integer, component + (child) phenotypes are expanded inline beneath each parent. None + (default) disables expansion. + + Examples: + + Example: Generate waterfall for a cohort + ```python + from phenex.reporting import Waterfall + + waterfall = Waterfall() + # Waterfall is generated automatically during cohort.execute() + # Access via: + df = cohort.waterfall # DataFrame with attrition data + cohort.write_reports_to_excel("./output") # writes waterfall.xlsx + ``` + + Example: Use as a custom reporter with component phenotype expansion + ```python + waterfall = Waterfall(include_component_phenotypes_level=100) + cohort = Cohort( + ..., + custom_reporters=[waterfall] + ) + ``` """ def __init__(