Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/phenotypes/further_value_filter_phenotype.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# FurtherValueFilterPhenotype

::: phenex.phenotypes.further_value_filter_phenotype
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ nav:
- EventPhenotype: api/phenotypes/event_phenotype.md
- CodelistPhenotype: api/phenotypes/codelist_phenotype.md
- MeasurementPhenotype: api/phenotypes/measurement_phenotype.md
- FurtherValueFilterPhenotype: api/phenotypes/further_value_filter_phenotype.md
- AgePhenotype: api/phenotypes/age_phenotype.md
- SexPhenotype: api/phenotypes/sex_phenotype.md
- DeathPhenotype: api/phenotypes/death_phenotype.md
Expand Down
2 changes: 2 additions & 0 deletions phenex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CODETYPE_INFO,
MeasurementPhenotype,
MeasurementChangePhenotype,
FurtherValueFilterPhenotype,
AgePhenotype,
SexPhenotype,
BinPhenotype,
Expand Down Expand Up @@ -134,6 +135,7 @@
"CODETYPE_INFO",
"MeasurementPhenotype",
"MeasurementChangePhenotype",
"FurtherValueFilterPhenotype",
"AgePhenotype",
"SexPhenotype",
"BinPhenotype",
Expand Down
16 changes: 11 additions & 5 deletions phenex/aggregators/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ def __init__(
self.preserve_nulls = preserve_nulls

def aggregate(self, input_table: Table):
# Ensure INDEX_DATE is in the aggregation index if the table has it
agg_index = list(self.aggregation_index)
if "INDEX_DATE" in input_table.columns and "INDEX_DATE" not in agg_index:
agg_index.append("INDEX_DATE")
# Define the window specification
partition_cols = [
getattr(input_table, col) if isinstance(col, str) else col
for col in self.aggregation_index
for col in agg_index
]
window_spec = ibis.window(
group_by=partition_cols, order_by=input_table[self.event_date_column]
Expand Down Expand Up @@ -124,7 +128,7 @@ def aggregate(self, input_table: Table):

# Apply the distinct reduction if required
if self.reduce:
selected_columns = self.aggregation_index + [self.event_date_column]
selected_columns = list(agg_index) + [self.event_date_column]
input_table = input_table.select(selected_columns).distinct()
input_table = input_table.mutate(VALUE=ibis.null().cast("int32"))

Expand Down Expand Up @@ -255,10 +259,12 @@ def __init__(
# otherwise, row count is preserved

def aggregate(self, input_table: Table):
# Ensure INDEX_DATE is in the aggregation index if the table has it
agg_index = list(self.aggregation_index)
if "INDEX_DATE" in input_table.columns and "INDEX_DATE" not in agg_index:
agg_index.append("INDEX_DATE")
# Get the aggregation index columns
_aggregation_index_cols = [
getattr(input_table, col) for col in self.aggregation_index
]
_aggregation_index_cols = [getattr(input_table, col) for col in agg_index]

# Determine the aggregation column
if self.aggregation_column is None:
Expand Down
29 changes: 29 additions & 0 deletions phenex/core/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,37 @@ def __init__(
description: Optional[str] = None,
database: Optional[Database] = None,
custom_reporters: Optional[List] = None,
return_index: str = "first",
max_index_dates: Optional[int] = None,
):
self.name = name
self.description = description
self.database = database
self.return_index = return_index
self.max_index_dates = max_index_dates

assert return_index in (
"first",
"last",
"all",
), f"return_index must be 'first', 'last', or 'all', got '{return_index}'"
if max_index_dates is not None:
assert (
isinstance(max_index_dates, int) and max_index_dates > 0
), f"max_index_dates must be a positive integer, got {max_index_dates}"

# When return_index requires multiple candidate dates, auto-set entry criterion
if return_index in ("last", "all"):
if (
hasattr(entry_criterion, "return_date")
and entry_criterion.return_date != "all"
):
logger.info(
f"Cohort '{name}': return_index='{return_index}' requires entry criterion "
f"return_date='all'. Auto-setting from '{entry_criterion.return_date}'."
)
entry_criterion.return_date = "all"

self.table = None # Will be set during execution to index table
self.subset_tables_entry = None # Will be set during execution
self.subset_tables_index = None # Will be set during execution
Expand Down Expand Up @@ -340,6 +367,8 @@ def build_stages(self, tables: Dict[str, PhenexTable]):
entry_phenotype=self.entry_criterion,
inclusion_table_node=self.inclusions_table_node,
exclusion_table_node=self.exclusions_table_node,
return_index=self.return_index,
max_index_dates=self.max_index_dates,
)
index_nodes.append(self.index_table_node)

Expand Down
20 changes: 15 additions & 5 deletions phenex/core/exclusions_table_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,27 @@ def __init__(
self.index_phenotype = index_phenotype

def _execute(self, tables: Dict[str, Table]):
exclusions_table = self.index_phenotype.table.select(["PERSON_ID"])
# Build base from entry criterion; add INDEX_DATE if ALL phenotype tables have it
if self.phenotypes and all(
"INDEX_DATE" in pt.table.columns for pt in self.phenotypes
):
join_keys = ["PERSON_ID", "INDEX_DATE"]
exclusions_table = self.index_phenotype.table.mutate(
INDEX_DATE=self.index_phenotype.table.EVENT_DATE
).select(["PERSON_ID", "INDEX_DATE"])
else:
join_keys = ["PERSON_ID"]
exclusions_table = self.index_phenotype.table.select(["PERSON_ID"])

for pt in self.phenotypes:
pt_table = pt.table.select(["PERSON_ID", "BOOLEAN"]).rename(
pt_table = pt.table.select([*join_keys, "BOOLEAN"]).rename(
**{
f"{pt.name}_BOOLEAN": "BOOLEAN",
}
)
exclusions_table = exclusions_table.left_join(pt_table, ["PERSON_ID"])
columns = exclusions_table.columns
columns.remove("PERSON_ID_right")
exclusions_table = exclusions_table.left_join(pt_table, join_keys)
drop_cols = [f"{k}_right" for k in join_keys]
columns = [c for c in exclusions_table.columns if c not in drop_cols]
exclusions_table = exclusions_table.select(columns)

# fill all nones with False
Expand Down
21 changes: 16 additions & 5 deletions phenex/core/inclusions_table_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,28 @@ def __init__(
self.index_phenotype = index_phenotype

def _execute(self, tables: Dict[str, Table]):
inclusions_table = self.index_phenotype.table.select(["PERSON_ID"])
# Build base from entry criterion; add INDEX_DATE if ALL phenotype tables have it
if self.phenotypes and all(
"INDEX_DATE" in pt.table.columns for pt in self.phenotypes
):
join_keys = ["PERSON_ID", "INDEX_DATE"]
# Derive (PERSON_ID, INDEX_DATE) pairs from entry criterion
inclusions_table = self.index_phenotype.table.mutate(
INDEX_DATE=self.index_phenotype.table.EVENT_DATE
).select(["PERSON_ID", "INDEX_DATE"])
else:
join_keys = ["PERSON_ID"]
inclusions_table = self.index_phenotype.table.select(["PERSON_ID"])

for pt in self.phenotypes:
pt_table = pt.table.select(["PERSON_ID", "BOOLEAN"]).rename(
pt_table = pt.table.select([*join_keys, "BOOLEAN"]).rename(
**{
f"{pt.name}_BOOLEAN": "BOOLEAN",
}
)
inclusions_table = inclusions_table.left_join(pt_table, ["PERSON_ID"])
columns = inclusions_table.columns
columns.remove("PERSON_ID_right")
inclusions_table = inclusions_table.left_join(pt_table, join_keys)
drop_cols = [f"{k}_right" for k in join_keys]
columns = [c for c in inclusions_table.columns if c not in drop_cols]
inclusions_table = inclusions_table.select(columns)

# fill all nones with False
Expand Down
69 changes: 63 additions & 6 deletions phenex/core/index_phenotype.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from typing import Dict
from typing import Dict, Optional
from ibis.expr.types.relations import Table
import ibis
from phenex.node import Node
from phenex.phenotypes.phenotype import Phenotype
from phenex.util import create_logger

logger = create_logger(__name__)


class IndexPhenotype(Phenotype):
"""
Compute the index table form the individual inclusions / exclusions phenotypes.
Compute the index table from the individual inclusions / exclusions phenotypes.

Parameters:
return_index: Controls how multiple candidate index dates per patient are handled
after inclusion/exclusion filtering:
- "first": keep earliest passing INDEX_DATE per patient (default)
- "last": keep latest passing INDEX_DATE per patient
- "all": keep all passing INDEX_DATEs
max_index_dates: When set, cap the number of candidate entry dates per patient
to at most this many (keeping the earliest N) before applying inclusion/exclusion.
"""

def __init__(
Expand All @@ -15,6 +28,8 @@ def __init__(
entry_phenotype: Phenotype,
inclusion_table_node: Node,
exclusion_table_node: Node,
return_index: str = "first",
max_index_dates: Optional[int] = None,
):
super(IndexPhenotype, self).__init__(name=name)
self.add_children(entry_phenotype)
Expand All @@ -26,20 +41,62 @@ def __init__(
self.entry_phenotype = entry_phenotype
self.inclusion_table_node = inclusion_table_node
self.exclusion_table_node = exclusion_table_node
self.return_index = return_index
self.max_index_dates = max_index_dates

def _execute(self, tables: Dict[str, Table]):
index_table = self.entry_phenotype.table.mutate(INDEX_DATE="EVENT_DATE")

# Apply max_index_dates cap: keep only the N earliest candidate dates per patient
if self.max_index_dates is not None:
w = ibis.window(group_by="PERSON_ID", order_by="INDEX_DATE")
index_table = index_table.mutate(_rn=ibis.row_number().over(w))
n_before = index_table.count()
index_table = index_table.filter(index_table._rn < self.max_index_dates)
index_table = index_table.drop("_rn")
logger.info(
f"IndexPhenotype '{self.name}': applied max_index_dates={self.max_index_dates}"
)

if self.inclusion_table_node:
inc_keys = ["PERSON_ID"] + (
["INDEX_DATE"]
if "INDEX_DATE" in self.inclusion_table_node.table.columns
else []
)
include = self.inclusion_table_node.table.filter(
self.inclusion_table_node.table["BOOLEAN"] == True
).select(["PERSON_ID"])
index_table = index_table.inner_join(include, ["PERSON_ID"])
).select(inc_keys)
index_table = index_table.inner_join(include, inc_keys)

if self.exclusion_table_node:
exc_keys = ["PERSON_ID"] + (
["INDEX_DATE"]
if "INDEX_DATE" in self.exclusion_table_node.table.columns
else []
)
exclude = self.exclusion_table_node.table.filter(
self.exclusion_table_node.table["BOOLEAN"] == False
).select(["PERSON_ID"])
index_table = index_table.inner_join(exclude, ["PERSON_ID"])
).select(exc_keys)
index_table = index_table.inner_join(exclude, exc_keys)

# Apply return_index selection after inclusion/exclusion filtering
if self.return_index == "first":
w = ibis.window(group_by="PERSON_ID", order_by="INDEX_DATE")
index_table = index_table.mutate(_rn=ibis.row_number().over(w))
index_table = index_table.filter(index_table._rn == 0).drop("_rn")
elif self.return_index == "last":
w = ibis.window(group_by="PERSON_ID", order_by=ibis.desc("INDEX_DATE"))
index_table = index_table.mutate(_rn=ibis.row_number().over(w))
index_table = index_table.filter(index_table._rn == 0).drop("_rn")
# "all": keep everything

# Deduplicate to at most one row per (PERSON_ID, INDEX_DATE)
dedup_keys = ["PERSON_ID"] + (
["INDEX_DATE"] if "INDEX_DATE" in index_table.columns else []
)
w = ibis.window(group_by=dedup_keys, order_by="EVENT_DATE")
index_table = index_table.mutate(_dedup_rn=ibis.row_number().over(w))
index_table = index_table.filter(index_table._dedup_rn == 0).drop("_dedup_rn")

return index_table
29 changes: 20 additions & 9 deletions phenex/core/subcohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def __init__(self, phenotype: Phenotype, index_patient_ids):

@property
def table(self):
return self._phenotype.table.semi_join(self._index_patient_ids, "PERSON_ID")
join_keys = [
k
for k in ["PERSON_ID", "INDEX_DATE"]
if k in self._phenotype.table.columns
and k in self._index_patient_ids.columns
]
return self._phenotype.table.semi_join(self._index_patient_ids, join_keys)

@property
def children(self):
Expand Down Expand Up @@ -59,8 +65,11 @@ def __init__(
outcome_sections: dict = None,
):
self.index_table = index_table
_id_cols = ["PERSON_ID"] + (
["INDEX_DATE"] if "INDEX_DATE" in index_table.columns else []
)
index_patient_ids = index_table.filter(index_table.BOOLEAN == True).select(
"PERSON_ID"
*_id_cols
)
self.characteristics = [
_FilteredPhenotypeView(p, index_patient_ids)
Expand All @@ -79,8 +88,9 @@ def __init__(
self.subset_tables_index = {}
for domain, ptable in parent_subset.items():
if "PERSON_ID" in ptable._table.columns:
_sj_keys = [k for k in _id_cols if k in ptable._table.columns]
self.subset_tables_index[domain] = type(ptable)(
ptable._table.semi_join(index_patient_ids, "PERSON_ID")
ptable._table.semi_join(index_patient_ids, _sj_keys)
)
else:
self.subset_tables_index[domain] = ptable
Expand Down Expand Up @@ -276,18 +286,19 @@ def execute(
# apply only the additional criteria.
# ------------------------------------------------------------------
index_table = self.cohort.index_table
_ij_keys = ["PERSON_ID"] + (
["INDEX_DATE"] if "INDEX_DATE" in index_table.columns else []
)

for inclusion in self.additional_inclusions:
include_pids = inclusion.table.filter(
inclusion.table["BOOLEAN"] == True
).select("PERSON_ID")
index_table = index_table.inner_join(include_pids, "PERSON_ID")
).select(*_ij_keys)
index_table = index_table.inner_join(include_pids, _ij_keys)

for exclusion in self.additional_exclusions:
exclude_pids = exclusion.table.select("PERSON_ID")
index_table = index_table.filter(
~index_table["PERSON_ID"].isin(exclude_pids["PERSON_ID"])
)
exclude_pids = exclusion.table.select(*_ij_keys)
index_table = index_table.anti_join(exclude_pids, _ij_keys)

self.table = index_table

Expand Down
7 changes: 5 additions & 2 deletions phenex/filters/relative_time_range_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# from phenex.phenotypes.phenotype import Phenotype
from phenex.filters.filter import Filter
from phenex.filters.value_filter import ValueFilter
from phenex.phenotypes.functions import _get_join_keys
from phenex.tables import EventTable, is_phenex_phenotype_table
from phenex.filters.value import *

Expand Down Expand Up @@ -64,8 +65,10 @@ def _filter(self, table: EventTable):
else:
anchor_table = self.anchor_phenotype.table
reference_column = anchor_table.EVENT_DATE
# Note that joins can change column names if the tables have name collisions!
table = table.join(anchor_table, "PERSON_ID")
join_keys = [
k for k in _get_join_keys(table) if k in anchor_table.columns
]
table = table.join(anchor_table, join_keys)
else:
assert (
"INDEX_DATE" in table.columns
Expand Down
1 change: 1 addition & 0 deletions phenex/phenotypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .event_count_phenotype import EventCountPhenotype
from .measurement_phenotype import MeasurementPhenotype
from .measurement_change_phenotype import MeasurementChangePhenotype
from .further_value_filter_phenotype import FurtherValueFilterPhenotype
from .death_phenotype import DeathPhenotype
from .categorical_phenotype import CategoricalPhenotype
from .time_range_count_phenotype import TimeRangeCountPhenotype
Expand Down
Loading
Loading