diff --git a/docs/api/phenotypes/further_value_filter_phenotype.md b/docs/api/phenotypes/further_value_filter_phenotype.md new file mode 100644 index 00000000..0014fb8a --- /dev/null +++ b/docs/api/phenotypes/further_value_filter_phenotype.md @@ -0,0 +1,3 @@ +# FurtherValueFilterPhenotype + +::: phenex.phenotypes.further_value_filter_phenotype diff --git a/mkdocs.yml b/mkdocs.yml index 96d81b72..0ab1a5ae 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -20,6 +20,7 @@ nav: - EventPhenotype: api/phenotypes/event_phenotype.md - CodelistPhenotype: api/phenotypes/codelist_phenotype.md - MeasurementPhenotype: api/phenotypes/measurement_phenotype.md + - FurtherValueFilterPhenotype: api/phenotypes/further_value_filter_phenotype.md - AgePhenotype: api/phenotypes/age_phenotype.md - SexPhenotype: api/phenotypes/sex_phenotype.md - DeathPhenotype: api/phenotypes/death_phenotype.md diff --git a/phenex/__init__.py b/phenex/__init__.py index c4bdbb80..685c99c3 100644 --- a/phenex/__init__.py +++ b/phenex/__init__.py @@ -15,6 +15,7 @@ CODETYPE_INFO, MeasurementPhenotype, MeasurementChangePhenotype, + FurtherValueFilterPhenotype, AgePhenotype, SexPhenotype, BinPhenotype, @@ -134,6 +135,7 @@ "CODETYPE_INFO", "MeasurementPhenotype", "MeasurementChangePhenotype", + "FurtherValueFilterPhenotype", "AgePhenotype", "SexPhenotype", "BinPhenotype", diff --git a/phenex/aggregators/aggregator.py b/phenex/aggregators/aggregator.py index 2a0a8109..bf39f2e0 100644 --- a/phenex/aggregators/aggregator.py +++ b/phenex/aggregators/aggregator.py @@ -79,10 +79,14 @@ def __init__( self.preserve_nulls = preserve_nulls def aggregate(self, input_table: Table): + # Ensure INDEX_DATE is in the aggregation index if the table has it + agg_index = list(self.aggregation_index) + if "INDEX_DATE" in input_table.columns and "INDEX_DATE" not in agg_index: + agg_index.append("INDEX_DATE") # Define the window specification partition_cols = [ getattr(input_table, col) if isinstance(col, str) else col - for col in self.aggregation_index + for col in agg_index ] window_spec = ibis.window( group_by=partition_cols, order_by=input_table[self.event_date_column] @@ -124,7 +128,7 @@ def aggregate(self, input_table: Table): # Apply the distinct reduction if required if self.reduce: - selected_columns = self.aggregation_index + [self.event_date_column] + selected_columns = list(agg_index) + [self.event_date_column] input_table = input_table.select(selected_columns).distinct() input_table = input_table.mutate(VALUE=ibis.null().cast("int32")) @@ -255,10 +259,12 @@ def __init__( # otherwise, row count is preserved def aggregate(self, input_table: Table): + # Ensure INDEX_DATE is in the aggregation index if the table has it + agg_index = list(self.aggregation_index) + if "INDEX_DATE" in input_table.columns and "INDEX_DATE" not in agg_index: + agg_index.append("INDEX_DATE") # Get the aggregation index columns - _aggregation_index_cols = [ - getattr(input_table, col) for col in self.aggregation_index - ] + _aggregation_index_cols = [getattr(input_table, col) for col in agg_index] # Determine the aggregation column if self.aggregation_column is None: diff --git a/phenex/core/cohort.py b/phenex/core/cohort.py index efc7b083..1c623a39 100644 --- a/phenex/core/cohort.py +++ b/phenex/core/cohort.py @@ -67,10 +67,37 @@ def __init__( description: Optional[str] = None, database: Optional[Database] = None, custom_reporters: Optional[List] = None, + return_index: str = "first", + max_index_dates: Optional[int] = None, ): self.name = name self.description = description self.database = database + self.return_index = return_index + self.max_index_dates = max_index_dates + + assert return_index in ( + "first", + "last", + "all", + ), f"return_index must be 'first', 'last', or 'all', got '{return_index}'" + if max_index_dates is not None: + assert ( + isinstance(max_index_dates, int) and max_index_dates > 0 + ), f"max_index_dates must be a positive integer, got {max_index_dates}" + + # When return_index requires multiple candidate dates, auto-set entry criterion + if return_index in ("last", "all"): + if ( + hasattr(entry_criterion, "return_date") + and entry_criterion.return_date != "all" + ): + logger.info( + f"Cohort '{name}': return_index='{return_index}' requires entry criterion " + f"return_date='all'. Auto-setting from '{entry_criterion.return_date}'." + ) + entry_criterion.return_date = "all" + self.table = None # Will be set during execution to index table self.subset_tables_entry = None # Will be set during execution self.subset_tables_index = None # Will be set during execution @@ -340,6 +367,8 @@ def build_stages(self, tables: Dict[str, PhenexTable]): entry_phenotype=self.entry_criterion, inclusion_table_node=self.inclusions_table_node, exclusion_table_node=self.exclusions_table_node, + return_index=self.return_index, + max_index_dates=self.max_index_dates, ) index_nodes.append(self.index_table_node) diff --git a/phenex/core/exclusions_table_node.py b/phenex/core/exclusions_table_node.py index 6f29b843..b5186976 100644 --- a/phenex/core/exclusions_table_node.py +++ b/phenex/core/exclusions_table_node.py @@ -20,17 +20,27 @@ def __init__( self.index_phenotype = index_phenotype def _execute(self, tables: Dict[str, Table]): - exclusions_table = self.index_phenotype.table.select(["PERSON_ID"]) + # Build base from entry criterion; add INDEX_DATE if ALL phenotype tables have it + if self.phenotypes and all( + "INDEX_DATE" in pt.table.columns for pt in self.phenotypes + ): + join_keys = ["PERSON_ID", "INDEX_DATE"] + exclusions_table = self.index_phenotype.table.mutate( + INDEX_DATE=self.index_phenotype.table.EVENT_DATE + ).select(["PERSON_ID", "INDEX_DATE"]) + else: + join_keys = ["PERSON_ID"] + exclusions_table = self.index_phenotype.table.select(["PERSON_ID"]) for pt in self.phenotypes: - pt_table = pt.table.select(["PERSON_ID", "BOOLEAN"]).rename( + pt_table = pt.table.select([*join_keys, "BOOLEAN"]).rename( **{ f"{pt.name}_BOOLEAN": "BOOLEAN", } ) - exclusions_table = exclusions_table.left_join(pt_table, ["PERSON_ID"]) - columns = exclusions_table.columns - columns.remove("PERSON_ID_right") + exclusions_table = exclusions_table.left_join(pt_table, join_keys) + drop_cols = [f"{k}_right" for k in join_keys] + columns = [c for c in exclusions_table.columns if c not in drop_cols] exclusions_table = exclusions_table.select(columns) # fill all nones with False diff --git a/phenex/core/inclusions_table_node.py b/phenex/core/inclusions_table_node.py index e30b90e8..1d5df887 100644 --- a/phenex/core/inclusions_table_node.py +++ b/phenex/core/inclusions_table_node.py @@ -20,17 +20,28 @@ def __init__( self.index_phenotype = index_phenotype def _execute(self, tables: Dict[str, Table]): - inclusions_table = self.index_phenotype.table.select(["PERSON_ID"]) + # Build base from entry criterion; add INDEX_DATE if ALL phenotype tables have it + if self.phenotypes and all( + "INDEX_DATE" in pt.table.columns for pt in self.phenotypes + ): + join_keys = ["PERSON_ID", "INDEX_DATE"] + # Derive (PERSON_ID, INDEX_DATE) pairs from entry criterion + inclusions_table = self.index_phenotype.table.mutate( + INDEX_DATE=self.index_phenotype.table.EVENT_DATE + ).select(["PERSON_ID", "INDEX_DATE"]) + else: + join_keys = ["PERSON_ID"] + inclusions_table = self.index_phenotype.table.select(["PERSON_ID"]) for pt in self.phenotypes: - pt_table = pt.table.select(["PERSON_ID", "BOOLEAN"]).rename( + pt_table = pt.table.select([*join_keys, "BOOLEAN"]).rename( **{ f"{pt.name}_BOOLEAN": "BOOLEAN", } ) - inclusions_table = inclusions_table.left_join(pt_table, ["PERSON_ID"]) - columns = inclusions_table.columns - columns.remove("PERSON_ID_right") + inclusions_table = inclusions_table.left_join(pt_table, join_keys) + drop_cols = [f"{k}_right" for k in join_keys] + columns = [c for c in inclusions_table.columns if c not in drop_cols] inclusions_table = inclusions_table.select(columns) # fill all nones with False diff --git a/phenex/core/index_phenotype.py b/phenex/core/index_phenotype.py index d7442afd..dbcbdcba 100644 --- a/phenex/core/index_phenotype.py +++ b/phenex/core/index_phenotype.py @@ -1,12 +1,25 @@ -from typing import Dict +from typing import Dict, Optional from ibis.expr.types.relations import Table +import ibis from phenex.node import Node from phenex.phenotypes.phenotype import Phenotype +from phenex.util import create_logger + +logger = create_logger(__name__) class IndexPhenotype(Phenotype): """ - Compute the index table form the individual inclusions / exclusions phenotypes. + Compute the index table from the individual inclusions / exclusions phenotypes. + + Parameters: + return_index: Controls how multiple candidate index dates per patient are handled + after inclusion/exclusion filtering: + - "first": keep earliest passing INDEX_DATE per patient (default) + - "last": keep latest passing INDEX_DATE per patient + - "all": keep all passing INDEX_DATEs + max_index_dates: When set, cap the number of candidate entry dates per patient + to at most this many (keeping the earliest N) before applying inclusion/exclusion. """ def __init__( @@ -15,6 +28,8 @@ def __init__( entry_phenotype: Phenotype, inclusion_table_node: Node, exclusion_table_node: Node, + return_index: str = "first", + max_index_dates: Optional[int] = None, ): super(IndexPhenotype, self).__init__(name=name) self.add_children(entry_phenotype) @@ -26,20 +41,62 @@ def __init__( self.entry_phenotype = entry_phenotype self.inclusion_table_node = inclusion_table_node self.exclusion_table_node = exclusion_table_node + self.return_index = return_index + self.max_index_dates = max_index_dates def _execute(self, tables: Dict[str, Table]): index_table = self.entry_phenotype.table.mutate(INDEX_DATE="EVENT_DATE") + # Apply max_index_dates cap: keep only the N earliest candidate dates per patient + if self.max_index_dates is not None: + w = ibis.window(group_by="PERSON_ID", order_by="INDEX_DATE") + index_table = index_table.mutate(_rn=ibis.row_number().over(w)) + n_before = index_table.count() + index_table = index_table.filter(index_table._rn < self.max_index_dates) + index_table = index_table.drop("_rn") + logger.info( + f"IndexPhenotype '{self.name}': applied max_index_dates={self.max_index_dates}" + ) + if self.inclusion_table_node: + inc_keys = ["PERSON_ID"] + ( + ["INDEX_DATE"] + if "INDEX_DATE" in self.inclusion_table_node.table.columns + else [] + ) include = self.inclusion_table_node.table.filter( self.inclusion_table_node.table["BOOLEAN"] == True - ).select(["PERSON_ID"]) - index_table = index_table.inner_join(include, ["PERSON_ID"]) + ).select(inc_keys) + index_table = index_table.inner_join(include, inc_keys) if self.exclusion_table_node: + exc_keys = ["PERSON_ID"] + ( + ["INDEX_DATE"] + if "INDEX_DATE" in self.exclusion_table_node.table.columns + else [] + ) exclude = self.exclusion_table_node.table.filter( self.exclusion_table_node.table["BOOLEAN"] == False - ).select(["PERSON_ID"]) - index_table = index_table.inner_join(exclude, ["PERSON_ID"]) + ).select(exc_keys) + index_table = index_table.inner_join(exclude, exc_keys) + + # Apply return_index selection after inclusion/exclusion filtering + if self.return_index == "first": + w = ibis.window(group_by="PERSON_ID", order_by="INDEX_DATE") + index_table = index_table.mutate(_rn=ibis.row_number().over(w)) + index_table = index_table.filter(index_table._rn == 0).drop("_rn") + elif self.return_index == "last": + w = ibis.window(group_by="PERSON_ID", order_by=ibis.desc("INDEX_DATE")) + index_table = index_table.mutate(_rn=ibis.row_number().over(w)) + index_table = index_table.filter(index_table._rn == 0).drop("_rn") + # "all": keep everything + + # Deduplicate to at most one row per (PERSON_ID, INDEX_DATE) + dedup_keys = ["PERSON_ID"] + ( + ["INDEX_DATE"] if "INDEX_DATE" in index_table.columns else [] + ) + w = ibis.window(group_by=dedup_keys, order_by="EVENT_DATE") + index_table = index_table.mutate(_dedup_rn=ibis.row_number().over(w)) + index_table = index_table.filter(index_table._dedup_rn == 0).drop("_dedup_rn") return index_table diff --git a/phenex/core/subcohort.py b/phenex/core/subcohort.py index 06b43a4c..42abd17d 100644 --- a/phenex/core/subcohort.py +++ b/phenex/core/subcohort.py @@ -26,7 +26,13 @@ def __init__(self, phenotype: Phenotype, index_patient_ids): @property def table(self): - return self._phenotype.table.semi_join(self._index_patient_ids, "PERSON_ID") + join_keys = [ + k + for k in ["PERSON_ID", "INDEX_DATE"] + if k in self._phenotype.table.columns + and k in self._index_patient_ids.columns + ] + return self._phenotype.table.semi_join(self._index_patient_ids, join_keys) @property def children(self): @@ -59,8 +65,11 @@ def __init__( outcome_sections: dict = None, ): self.index_table = index_table + _id_cols = ["PERSON_ID"] + ( + ["INDEX_DATE"] if "INDEX_DATE" in index_table.columns else [] + ) index_patient_ids = index_table.filter(index_table.BOOLEAN == True).select( - "PERSON_ID" + *_id_cols ) self.characteristics = [ _FilteredPhenotypeView(p, index_patient_ids) @@ -79,8 +88,9 @@ def __init__( self.subset_tables_index = {} for domain, ptable in parent_subset.items(): if "PERSON_ID" in ptable._table.columns: + _sj_keys = [k for k in _id_cols if k in ptable._table.columns] self.subset_tables_index[domain] = type(ptable)( - ptable._table.semi_join(index_patient_ids, "PERSON_ID") + ptable._table.semi_join(index_patient_ids, _sj_keys) ) else: self.subset_tables_index[domain] = ptable @@ -276,18 +286,19 @@ def execute( # apply only the additional criteria. # ------------------------------------------------------------------ index_table = self.cohort.index_table + _ij_keys = ["PERSON_ID"] + ( + ["INDEX_DATE"] if "INDEX_DATE" in index_table.columns else [] + ) for inclusion in self.additional_inclusions: include_pids = inclusion.table.filter( inclusion.table["BOOLEAN"] == True - ).select("PERSON_ID") - index_table = index_table.inner_join(include_pids, "PERSON_ID") + ).select(*_ij_keys) + index_table = index_table.inner_join(include_pids, _ij_keys) for exclusion in self.additional_exclusions: - exclude_pids = exclusion.table.select("PERSON_ID") - index_table = index_table.filter( - ~index_table["PERSON_ID"].isin(exclude_pids["PERSON_ID"]) - ) + exclude_pids = exclusion.table.select(*_ij_keys) + index_table = index_table.anti_join(exclude_pids, _ij_keys) self.table = index_table diff --git a/phenex/filters/relative_time_range_filter.py b/phenex/filters/relative_time_range_filter.py index 91bdccde..430b36db 100644 --- a/phenex/filters/relative_time_range_filter.py +++ b/phenex/filters/relative_time_range_filter.py @@ -3,6 +3,7 @@ # from phenex.phenotypes.phenotype import Phenotype from phenex.filters.filter import Filter from phenex.filters.value_filter import ValueFilter +from phenex.phenotypes.functions import _get_join_keys from phenex.tables import EventTable, is_phenex_phenotype_table from phenex.filters.value import * @@ -64,8 +65,10 @@ def _filter(self, table: EventTable): else: anchor_table = self.anchor_phenotype.table reference_column = anchor_table.EVENT_DATE - # Note that joins can change column names if the tables have name collisions! - table = table.join(anchor_table, "PERSON_ID") + join_keys = [ + k for k in _get_join_keys(table) if k in anchor_table.columns + ] + table = table.join(anchor_table, join_keys) else: assert ( "INDEX_DATE" in table.columns diff --git a/phenex/phenotypes/__init__.py b/phenex/phenotypes/__init__.py index b78db53a..ef5145b9 100644 --- a/phenex/phenotypes/__init__.py +++ b/phenex/phenotypes/__init__.py @@ -8,6 +8,7 @@ from .event_count_phenotype import EventCountPhenotype from .measurement_phenotype import MeasurementPhenotype from .measurement_change_phenotype import MeasurementChangePhenotype +from .further_value_filter_phenotype import FurtherValueFilterPhenotype from .death_phenotype import DeathPhenotype from .categorical_phenotype import CategoricalPhenotype from .time_range_count_phenotype import TimeRangeCountPhenotype diff --git a/phenex/phenotypes/age_phenotype.py b/phenex/phenotypes/age_phenotype.py index 0d496e96..6bcc3d4b 100644 --- a/phenex/phenotypes/age_phenotype.py +++ b/phenex/phenotypes/age_phenotype.py @@ -4,6 +4,7 @@ from ibis.expr.types.relations import Table from phenex.phenotypes.phenotype import Phenotype +from phenex.phenotypes.functions import _get_join_keys from phenex.filters import ValueFilter, Value from phenex.tables import PhenotypeTable, is_phenex_person_table from phenex.filters.relative_time_range_filter import RelativeTimeRangeFilter @@ -128,8 +129,10 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: else: anchor_table = self.anchor_phenotype.table reference_column = anchor_table.EVENT_DATE - # Note that joins can change column names if the tables have name collisions! - table = table.join(anchor_table, "PERSON_ID") + join_keys = [ + k for k in _get_join_keys(table) if k in anchor_table.columns + ] + table = table.join(anchor_table, join_keys) else: assert ( "INDEX_DATE" in table.columns diff --git a/phenex/phenotypes/computation_graph_phenotypes.py b/phenex/phenotypes/computation_graph_phenotypes.py index a3478b4f..ec66f5ec 100644 --- a/phenex/phenotypes/computation_graph_phenotypes.py +++ b/phenex/phenotypes/computation_graph_phenotypes.py @@ -4,7 +4,7 @@ import ibis from phenex.tables import PhenotypeTable, PHENOTYPE_TABLE_COLUMNS from phenex.phenotypes.phenotype import Phenotype, ComputationGraph -from phenex.phenotypes.functions import hstack +from phenex.phenotypes.functions import hstack, _get_join_keys from phenex.phenotypes.functions import select_phenotype_columns from phenex.aggregators import First, Last @@ -70,7 +70,13 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: Returns: PhenotypeTable: The resulting phenotype table containing the required columns. """ - joined_table = hstack(self.children, tables["PERSON"].select("PERSON_ID")) + join_table = tables.get("PERSON") + if join_table is not None: + person_cols = ["PERSON_ID"] + if "INDEX_DATE" in join_table.columns: + person_cols.append("INDEX_DATE") + join_table = join_table.select(person_cols) + joined_table = hstack(self.children, join_table) if self.populate == "value" and self.operate_on == "boolean": for child in self.children: @@ -128,7 +134,7 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: if "BOOLEAN" not in schema.names: joined_table = joined_table.mutate(BOOLEAN=ibis.null().cast("boolean")) - return joined_table.distinct() + return select_phenotype_columns(joined_table).distinct() def _return_all_dates(self, table, date_columns): """ @@ -200,9 +206,15 @@ def _perform_date_selection(self, code_table): return code_table if self.return_date == "first": - aggregator = First(reduce=False, preserve_nulls=True) + agg_index = _get_join_keys(code_table) + aggregator = First( + reduce=False, preserve_nulls=True, aggregation_index=agg_index + ) elif self.return_date == "last": - aggregator = Last(reduce=False, preserve_nulls=True) + agg_index = _get_join_keys(code_table) + aggregator = Last( + reduce=False, preserve_nulls=True, aggregation_index=agg_index + ) elif self.return_date == "nearest": # Note: Nearest is not currently implemented in the aggregators # This would need to be added to the aggregator module @@ -409,7 +421,13 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: Returns: PhenotypeTable: The resulting phenotype table containing the required columns. """ - joined_table = hstack(self.children, tables["PERSON"].select("PERSON_ID")) + join_table = tables.get("PERSON") + if join_table is not None: + person_cols = ["PERSON_ID"] + if "INDEX_DATE" in join_table.columns: + person_cols.append("INDEX_DATE") + join_table = join_table.select(person_cols) + joined_table = hstack(self.children, join_table) # Convert boolean columns to integers for arithmetic operations if needed if self.populate == "value" and self.operate_on == "boolean": for child in self.children: diff --git a/phenex/phenotypes/death_phenotype.py b/phenex/phenotypes/death_phenotype.py index 803cbcae..2ff23564 100644 --- a/phenex/phenotypes/death_phenotype.py +++ b/phenex/phenotypes/death_phenotype.py @@ -84,4 +84,7 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: death_table = death_table.mutate(BOOLEAN=True) death_table = death_table.mutate(EVENT_DATE=death_table.DATE_OF_DEATH) - return death_table.select(["PERSON_ID", "EVENT_DATE", "VALUE", "BOOLEAN"]) + cols = ["PERSON_ID", "EVENT_DATE", "VALUE", "BOOLEAN"] + if "INDEX_DATE" in death_table.columns: + cols.append("INDEX_DATE") + return death_table.select(cols) diff --git a/phenex/phenotypes/event_count_phenotype.py b/phenex/phenotypes/event_count_phenotype.py index d53462f1..bc99453e 100644 --- a/phenex/phenotypes/event_count_phenotype.py +++ b/phenex/phenotypes/event_count_phenotype.py @@ -1,6 +1,7 @@ from ibis import _ from phenex.phenotypes.phenotype import Phenotype +from phenex.phenotypes.functions import _get_join_keys from phenex.filters.relative_time_range_filter import RelativeTimeRangeFilter from phenex.filters import DateFilter, ValueFilter from phenex.tables import is_phenex_code_table, PHENOTYPE_TABLE_COLUMNS, PhenotypeTable @@ -91,19 +92,21 @@ def _execute(self, tables) -> PhenotypeTable: table = self.phenotype.table # Select only distinct dates: - table = table.select(["PERSON_ID", "EVENT_DATE"]).distinct() + group_keys = _get_join_keys(table) + table = table.select([*group_keys, "EVENT_DATE"]).distinct() - # Count occurrences per PERSON_ID - occurrence_counts_table = table.group_by("PERSON_ID").aggregate(VALUE=_.count()) + # Count occurrences per (PERSON_ID, INDEX_DATE) + occurrence_counts_table = table.group_by(group_keys).aggregate(VALUE=_.count()) table, occurrence_counts_table = self._perform_value_filtering( table, occurrence_counts_table ) table = self._perform_relative_time_range_filtering(table) table = self._perform_date_selection(table) + select_cols = [*group_keys, "EVENT_DATE", "VALUE"] table = table.left_join( - occurrence_counts_table.select("PERSON_ID", "VALUE"), - table.PERSON_ID == occurrence_counts_table.PERSON_ID, - ).select("PERSON_ID", "EVENT_DATE", "VALUE") + occurrence_counts_table.select(*group_keys, "VALUE"), + group_keys, + ).select(select_cols) table = table.mutate(BOOLEAN=True).distinct() return table @@ -111,10 +114,12 @@ def _execute(self, tables) -> PhenotypeTable: def _perform_value_filtering(self, table, occurrence_counts_table): if self.value_filter is not None: occurrence_counts_table = self.value_filter.filter(occurrence_counts_table) + join_keys = _get_join_keys(table) + select_cols = [*join_keys, "EVENT_DATE", "VALUE"] table = table.right_join( occurrence_counts_table, - table.PERSON_ID == occurrence_counts_table.PERSON_ID, - ).select(["PERSON_ID", "EVENT_DATE", "VALUE"]) + join_keys, + ).select(select_cols) return table, occurrence_counts_table def _perform_relative_time_range_filtering(self, table): @@ -124,38 +129,50 @@ def _perform_relative_time_range_filtering(self, table): # Self join and rename event_date columns; # the first dates will be called INDEX_DATE # the second dates will be called EVENT_DATE - first_table = table.select( - "PERSON_ID", - table.EVENT_DATE.name("INDEX_DATE"), - ) - second_table = table.select( - "PERSON_ID", - table.EVENT_DATE.name("EVENT_DATE"), - ) - table = first_table.join( - second_table, first_table.PERSON_ID == second_table.PERSON_ID - ) + has_index = "INDEX_DATE" in table.columns + # Preserve original INDEX_DATE by renaming it temporarily + if has_index: + table = table.rename({"_ORIG_INDEX_DATE": "INDEX_DATE"}) + first_cols = ["PERSON_ID", table.EVENT_DATE.name("INDEX_DATE")] + second_cols = ["PERSON_ID", table.EVENT_DATE.name("EVENT_DATE")] + if has_index: + first_cols.append("_ORIG_INDEX_DATE") + second_cols.append("_ORIG_INDEX_DATE") + first_table = table.select(*first_cols) + second_table = table.select(*second_cols) + join_pred = [first_table.PERSON_ID == second_table.PERSON_ID] + if has_index: + join_pred.append( + first_table._ORIG_INDEX_DATE == second_table._ORIG_INDEX_DATE + ) + table = first_table.join(second_table, join_pred) table = table.filter(table.INDEX_DATE <= table.EVENT_DATE) # perform relative time range filtering; the first date is the anchor ('index_date') table = self.relative_time_range.filter(table) + select_base = ["PERSON_ID"] + if has_index: + select_base.append("_ORIG_INDEX_DATE") if self.component_date_select == "first": - table = table.select("PERSON_ID", "INDEX_DATE").rename( + table = table.select(*select_base, "INDEX_DATE").rename( {"EVENT_DATE": "INDEX_DATE"} ) elif self.component_date_select == "second": - table = table.select("PERSON_ID", "EVENT_DATE") + table = table.select(*select_base, "EVENT_DATE") + if has_index: + table = table.rename({"INDEX_DATE": "_ORIG_INDEX_DATE"}) return table def _perform_date_selection(self, table, reduce=True): if self.return_date is None or self.return_date == "all": return table + agg_index = _get_join_keys(table) if self.return_date == "first": - aggregator = First(reduce=reduce) + aggregator = First(reduce=reduce, aggregation_index=agg_index) elif self.return_date == "last": - aggregator = Last(reduce=reduce) + aggregator = Last(reduce=reduce, aggregation_index=agg_index) else: raise ValueError(f"Unknown return_date: {self.return_date}") table = aggregator.aggregate(table) - return table.select("PERSON_ID", "EVENT_DATE") + return table.select([*agg_index, "EVENT_DATE"]) diff --git a/phenex/phenotypes/event_phenotype.py b/phenex/phenotypes/event_phenotype.py index 13ddfead..b696d4af 100644 --- a/phenex/phenotypes/event_phenotype.py +++ b/phenex/phenotypes/event_phenotype.py @@ -5,7 +5,7 @@ from phenex.filters.categorical_filter import CategoricalFilter from phenex.aggregators import First, Last from phenex.tables import is_phenex_code_table, PhenotypeTable -from phenex.phenotypes.functions import select_phenotype_columns +from phenex.phenotypes.functions import select_phenotype_columns, _get_join_keys class EventPhenotype(Phenotype): @@ -101,10 +101,12 @@ def _perform_date_selection(self, code_table): reduce = self.return_value != "all" + agg_index = _get_join_keys(code_table) + if self.return_date == "first": - aggregator = First(reduce=reduce) + aggregator = First(reduce=reduce, aggregation_index=agg_index) elif self.return_date == "last": - aggregator = Last(reduce=reduce) + aggregator = Last(reduce=reduce, aggregation_index=agg_index) elif self.return_date == "nearest": raise NotImplementedError("Nearest aggregation not yet implemented") else: diff --git a/phenex/phenotypes/functions.py b/phenex/phenotypes/functions.py index b7994048..9411930f 100644 --- a/phenex/phenotypes/functions.py +++ b/phenex/phenotypes/functions.py @@ -8,6 +8,13 @@ logger = create_logger(__name__) +def _get_join_keys(table=None): + """Return the join keys. Always (PERSON_ID, INDEX_DATE) unless table lacks INDEX_DATE.""" + if table is not None and "INDEX_DATE" not in table.columns: + return ["PERSON_ID"] + return ["PERSON_ID", "INDEX_DATE"] + + def attach_anchor_and_get_reference_date(table, anchor_phenotype=None): # Unwrap PhenexTable so all joins are done at the raw ibis level. # PhenexTable.join would re-wrap the join result through __init__ which @@ -27,17 +34,24 @@ def attach_anchor_and_get_reference_date(table, anchor_phenotype=None): # skip the join and reuse the existing column. ref_col_name = f"_ref_date_{anchor_phenotype.name}" if ref_col_name not in raw.columns: - anchor_slim = anchor_table.select( - anchor_table.PERSON_ID, - anchor_table.EVENT_DATE.name(ref_col_name), + anchor_cols = [anchor_table.PERSON_ID] + has_index = ( + "INDEX_DATE" in raw.columns and "INDEX_DATE" in anchor_table.columns ) - # Join at raw ibis level using an explicit predicate so we can - # drop the duplicate PERSON_ID_right immediately afterwards. + if has_index: + anchor_cols.append(anchor_table.INDEX_DATE) + anchor_cols.append(anchor_table.EVENT_DATE.name(ref_col_name)) + anchor_slim = anchor_table.select(anchor_cols) + # Build join predicate + pred = raw.PERSON_ID == anchor_slim.PERSON_ID + if has_index: + pred = pred & (raw.INDEX_DATE == anchor_slim.INDEX_DATE) + keep_cols = [c for c in raw.columns] + [ref_col_name] raw = raw.join( anchor_slim, - raw.PERSON_ID == anchor_slim.PERSON_ID, + pred, how="left", - ).select([c for c in raw.columns] + [ref_col_name]) + ).select(keep_cols) reference_column = raw[ref_col_name] else: assert ( @@ -50,52 +64,58 @@ def attach_anchor_and_get_reference_date(table, anchor_phenotype=None): def hstack(phenotypes: List["Phenotype"], join_table: Table = None) -> Table: """ - Horizontally stacks multiple PhenotypeTable objects into a single table. The PERSON_ID columns are used to join the tables together. The resulting table will have three columns per phenotype: BOOLEAN, EVENT_DATE, and VALUE. The columns will be contain the phenotype name as a prefix. - # TODO: Add a test for this function. - Args: - phenotypes (List[Phenotype]): A list of Phenotype objects to stack. + Horizontally stacks multiple PhenotypeTable objects into a single table. + Joins on (PERSON_ID, INDEX_DATE) when INDEX_DATE is present, otherwise PERSON_ID only. """ - # TODO decide if phenotypes should be returning a phenextable t0 = datetime.now() if isinstance(join_table, PhenexTable): join_table = join_table.table + # Detect the broadest join keys from the join_table (if provided) or phenotypes + join_keys = _get_join_keys(phenotypes[0].table) + if join_table is None: - # UNION all phenotype PERSON_IDs as the base so every patient appears exactly once. - # We then LEFT JOIN each phenotype against this base. This replaces the previous - # chained FULL OUTER JOIN approach which inserted a .mutate()/.select() between - # every join, breaking ibis's JoinChain into N nested subqueries. logger.info( f"hstack: building flat LEFT JOIN chain for {len(phenotypes)} phenotypes (UNION base)" ) - person_id_tables = [ - pt.namespaced_table.select("PERSON_ID") for pt in phenotypes + key_tables = [ + pt.namespaced_table.select( + [k for k in join_keys if k in pt.namespaced_table.columns] + ) + for pt in phenotypes ] - join_table = ibis.union(*person_id_tables, distinct=True) + join_table = ibis.union(*key_tables, distinct=True) else: + # When a join_table is provided (e.g. PERSON table), restrict join keys + # to columns available in the join_table + join_keys = [k for k in join_keys if k in join_table.columns] logger.info( f"hstack: building flat LEFT JOIN chain for {len(phenotypes)} phenotypes" ) - # Chain all LEFT JOINs WITHOUT any intermediate .select()/.mutate() between iterations. - # Consecutive .join() calls on an ibis relation accumulate into a single JoinChain that - # compiles to a flat multi-way JOIN in SQL, rather than N layers of nested subqueries. - # Since join_table always has all PERSON_IDs (UNION base or caller-supplied), its - # PERSON_ID is always non-null so COALESCE is unnecessary. for pt in phenotypes: - join_table = join_table.join(pt.namespaced_table, "PERSON_ID", how="left") + # Per-phenotype join keys: fall back to PERSON_ID only when the + # phenotype table lacks a key column (e.g. no INDEX_DATE). + pt_join_keys = [k for k in join_keys if k in pt.namespaced_table.columns] + join_table = join_table.join(pt.namespaced_table, pt_join_keys, how="left") - # Remove all PERSON_ID_right* columns (key duplicates from each join's right side). - # Phenotype-specific columns are named {NAME}_BOOLEAN / _EVENT_DATE / _VALUE and are unaffected. + # Remove duplicate key columns from right side of joins columns = [ c for c in join_table.columns - if c == "PERSON_ID" or not c.startswith("PERSON_ID") + if c in join_keys + or (not c.startswith("PERSON_ID") and not c.startswith("INDEX_DATE")) ] - join_table = join_table.select(columns) + # Deduplicate (join_keys appear once) + seen = set() + unique_columns = [] + for c in columns: + if c not in seen: + seen.add(c) + unique_columns.append(c) + join_table = join_table.select(unique_columns) # Apply fill_null for all boolean columns in a single .mutate() call - # (one SQL projection instead of N separate projections). null_fills = {} for pt in phenotypes: bool_col_name = f"{pt.name}_BOOLEAN" @@ -117,26 +137,8 @@ def hstack(phenotypes: List["Phenotype"], join_table: Table = None) -> Table: def hstack_boolean(phenotypes: List["Phenotype"], join_table: Table = None) -> Table: """ Efficiently stacks only the BOOLEAN column from multiple phenotypes into a wide table - using UNION ALL + GROUP BY with filtered aggregation instead of N sequential JOINs. - - This produces SQL of the form: - SELECT PERSON_ID, - MAX(CASE WHEN _PHENOTYPE = 'pheno1' THEN BOOLEAN END) AS pheno1_BOOLEAN, - ... - FROM (SELECT PERSON_ID, BOOLEAN, 'pheno1' AS _PHENOTYPE FROM t1 - UNION ALL - SELECT PERSON_ID, BOOLEAN, 'pheno2' AS _PHENOTYPE FROM t2 ...) - GROUP BY PERSON_ID - - This is O(n) with a single data shuffle, compared to O(n * k) for k sequential joins. - - Args: - phenotypes: A list of Phenotype objects to stack. - join_table: Optional base table. If provided, the result is LEFT JOINed onto it - to preserve all rows (e.g. all patients in the index table). - - Returns: - Table with PERSON_ID and {name}_BOOLEAN columns for each phenotype. + using UNION ALL + GROUP BY with filtered aggregation. + Groups by (PERSON_ID, INDEX_DATE) when INDEX_DATE is present, otherwise PERSON_ID only. """ t0 = datetime.now() logger.info( @@ -146,10 +148,13 @@ def hstack_boolean(phenotypes: List["Phenotype"], join_table: Table = None) -> T if isinstance(join_table, PhenexTable): join_table = join_table.table - # Step 1: UNION ALL — stack each phenotype's (PERSON_ID, BOOLEAN) with a tag column + # Detect join keys from the first phenotype's table + join_keys = _get_join_keys(phenotypes[0].table) + + # Step 1: UNION ALL — stack each phenotype's (keys, BOOLEAN) with a tag column unioned_tables = [] for pt in phenotypes: - t = pt.table.select("PERSON_ID", "BOOLEAN").mutate( + t = pt.table.select(*join_keys, "BOOLEAN").mutate( _PHENOTYPE=ibis.literal(pt.name) ) unioned_tables.append(t) @@ -166,18 +171,25 @@ def hstack_boolean(phenotypes: List["Phenotype"], join_table: Table = None) -> T agg_exprs[col_name] = bool_col.max(where=is_match).fill_null(0.0) else: agg_exprs[col_name] = bool_col.max(where=is_match).fill_null(False) - wide_table = long_table.group_by("PERSON_ID").aggregate(**agg_exprs) + wide_table = long_table.group_by(join_keys).aggregate(**agg_exprs) # Step 3: If a base table was provided, LEFT JOIN to preserve all its rows if join_table is not None: - result = join_table.join(wide_table, "PERSON_ID", how="left") - # Drop duplicated PERSON_ID_right if present + result = join_table.join(wide_table, join_keys, how="left") + # Drop duplicated key columns from right side columns = [ c for c in result.columns - if c == "PERSON_ID" or not c.startswith("PERSON_ID") + if c in join_keys + or (not c.startswith("PERSON_ID") and not c.startswith("INDEX_DATE")) ] - result = result.select(columns) + seen = set() + unique_columns = [] + for c in columns: + if c not in seen: + seen.add(c) + unique_columns.append(c) + result = result.select(unique_columns) else: result = wide_table @@ -194,25 +206,7 @@ def hstack_pivot( """ Efficiently stacks BOOLEAN, EVENT_DATE, and VALUE from multiple phenotypes into a wide table using UNION ALL + GROUP BY with filtered aggregation. - - This is the full-column equivalent of hstack_boolean. VALUE is cast to string - before the UNION ALL so the schema is consistent across all phenotypes (numeric, - categorical, and null-typed values are all compatible). BOOLEAN and EVENT_DATE - are kept as-is. - - Args: - phenotypes: A list of Phenotype objects to stack. - join_table: Optional base table whose rows are all preserved via LEFT JOIN. - date_agg: Aggregation to apply to EVENT_DATE when a child phenotype has - multiple rows per person (e.g. return_date='all'). Use ``"min"`` when - the parent wants the first (earliest) date, and ``"max"`` (default) - when it wants the last (latest) date. The parent's own - ``ibis.least`` / ``ibis.greatest`` call then operates correctly on - per-child min/max dates. - - Returns: - Table with PERSON_ID and {name}_BOOLEAN / {name}_EVENT_DATE / {name}_VALUE - columns for each phenotype. Each {name}_VALUE retains its original type. + Groups by (PERSON_ID, INDEX_DATE) when INDEX_DATE is present, otherwise PERSON_ID only. """ t0 = datetime.now() logger.info( @@ -222,14 +216,14 @@ def hstack_pivot( if isinstance(join_table, PhenexTable): join_table = join_table.table - # Step 1: Build each per-phenotype slice as select → mutate so that every - # column reference inside .mutate() is bound to the SAME relation produced - # by .select(). Referencing pt.table.VALUE inside pt.table.select() risks - # ibis treating them as two distinct relation nodes and emitting a cross join. + # Detect join keys from the first phenotype's table + join_keys = _get_join_keys(phenotypes[0].table) + + # Step 1: Build per-phenotype slices original_value_types = {pt.name: pt.table.schema()["VALUE"] for pt in phenotypes} parts = [] for pt in phenotypes: - base = pt.table.select("PERSON_ID", "BOOLEAN", "EVENT_DATE", "VALUE") + base = pt.table.select(*join_keys, "BOOLEAN", "EVENT_DATE", "VALUE") t = base.mutate( EVENT_DATE=base.EVENT_DATE.cast("date"), VALUE=base.VALUE.cast("str"), @@ -237,11 +231,9 @@ def hstack_pivot( ) parts.append(t) - # UNION ALL — keep every row (distinct=False); the GROUP BY handles deduplication long_table = ibis.union(*parts, distinct=False) - # Step 2: Pivot via GROUP BY + filtered MAX — one aggregation pass, no joins. - # Cast each VALUE column back to its original type after the pivot. + # Step 2: Pivot via GROUP BY + filtered MAX bool_col = long_table.BOOLEAN date_col = long_table.EVENT_DATE val_col = long_table.VALUE @@ -258,20 +250,25 @@ def hstack_pivot( original_type = original_value_types[pt.name] agg_exprs[f"{pt.name}_VALUE"] = val_col.max(where=is_match).cast(original_type) - wide_table = long_table.group_by("PERSON_ID").aggregate(**agg_exprs) + wide_table = long_table.group_by(join_keys).aggregate(**agg_exprs) # Step 3: If a base table was provided, LEFT JOIN to preserve all its rows if join_table is not None: - join_table = join_table.select( - "PERSON_ID" - ).distinct() # Ensure join_table has only one PERSON_ID column - result = join_table.join(wide_table, "PERSON_ID", how="left") + join_table = join_table.select(join_keys).distinct() + result = join_table.join(wide_table, join_keys, how="left") columns = [ c for c in result.columns - if c == "PERSON_ID" or not c.startswith("PERSON_ID") + if c in join_keys + or (not c.startswith("PERSON_ID") and not c.startswith("INDEX_DATE")) ] - result = result.select(columns) + seen = set() + unique_columns = [] + for c in columns: + if c not in seen: + seen.add(c) + unique_columns.append(c) + result = result.select(unique_columns) else: logger.debug("No join table provided to hstack") result = wide_table @@ -297,4 +294,8 @@ def select_phenotype_columns( table = table.mutate(VALUE=fill_value) if "BOOLEAN" not in table.columns: table = table.mutate(BOOLEAN=fill_boolean) - return table.select([table.PERSON_ID, table.BOOLEAN, table.EVENT_DATE, table.VALUE]) + cols = [table.PERSON_ID] + if "INDEX_DATE" in table.columns: + cols.append(table.INDEX_DATE) + cols.extend([table.BOOLEAN, table.EVENT_DATE, table.VALUE]) + return table.select(cols) diff --git a/phenex/phenotypes/further_value_filter_phenotype.py b/phenex/phenotypes/further_value_filter_phenotype.py new file mode 100644 index 00000000..2a9dbb14 --- /dev/null +++ b/phenex/phenotypes/further_value_filter_phenotype.py @@ -0,0 +1,124 @@ +from typing import Union, List, Optional +from phenex.phenotypes.phenotype import Phenotype +from phenex.filters.relative_time_range_filter import RelativeTimeRangeFilter +from phenex.filters.date_filter import DateFilter +from phenex.aggregators import First, Last +from phenex.tables import PhenotypeTable +from phenex.phenotypes.functions import select_phenotype_columns +from phenex.util import create_logger + +logger = create_logger(__name__) + + +class FurtherValueFilterPhenotype(Phenotype): + """ + FurtherValueFilterPhenotype takes the output of an existing phenotype and applies + additional value filtering, value aggregation, and time-based filtering on top of it. + + This is useful when you want to chain filtering operations, e.g. first identify + measurements matching a codelist and value range with a MeasurementPhenotype, then + further filter those results by a different value range or time window. + + Parameters: + phenotype: The source phenotype whose output table will be further filtered. + This phenotype is added as a child dependency and must execute first. + value_filter (ValueFilter): A ValueFilter to apply to the source phenotype's + output values. Applied after value_aggregation. + value_aggregation (ValueAggregator): A ValueAggregator to apply to the source + phenotype's output values. Applied before value_filter. + date_range (DateFilter): A date range filter to apply. + relative_time_range (RelativeTimeRangeFilter): A relative time range filter + or list of filters to apply. + return_date (str): Specifies whether to return the 'first', 'last', or 'all' + event date(s). Default is 'all'. + """ + + output_display_type = "value" + + def __init__( + self, + phenotype: "Phenotype", + value_filter: Optional["ValueFilter"] = None, + value_aggregation: Optional["ValueAggregator"] = None, + date_range: Optional[DateFilter] = None, + relative_time_range: Optional[ + Union[RelativeTimeRangeFilter, List[RelativeTimeRangeFilter]] + ] = None, + return_date: str = "all", + **kwargs, + ): + super(FurtherValueFilterPhenotype, self).__init__(**kwargs) + + if not isinstance(phenotype, Phenotype): + raise TypeError( + f"'phenotype' must be a Phenotype instance, got {type(phenotype).__name__}." + ) + + self.source_phenotype = phenotype + self.add_children(phenotype) + + self.value_filter = value_filter + self.value_aggregation = value_aggregation + self.date_range = date_range + self.return_date = return_date + + assert self.return_date in [ + "first", + "last", + "nearest", + "all", + ], f"Unknown return_date: {return_date}" + + if isinstance(relative_time_range, RelativeTimeRangeFilter): + relative_time_range = [relative_time_range] + self.relative_time_range = relative_time_range + + if self.relative_time_range is not None: + for rtr in self.relative_time_range: + if rtr.anchor_phenotype is not None: + if not any(c is rtr.anchor_phenotype for c in self.children): + self.add_children(rtr.anchor_phenotype) + + def _execute(self, tables) -> PhenotypeTable: + table = self.source_phenotype.table + table = self._perform_time_filtering(table) + table = self._perform_date_selection(table) + table = self._perform_value_aggregation(table) + table = self._perform_value_filtering(table) + table = select_phenotype_columns(table) + return table + + def _perform_time_filtering(self, table): + if self.date_range is not None: + table = self.date_range.filter(table) + if self.relative_time_range is not None: + for rtr in self.relative_time_range: + table = rtr.filter(table) + return table + + def _perform_date_selection(self, table): + if self.return_date is None or self.return_date == "all": + return table + + reduce = False + + if self.return_date == "first": + aggregator = First(reduce=reduce) + elif self.return_date == "last": + aggregator = Last(reduce=reduce) + elif self.return_date == "nearest": + raise NotImplementedError("Nearest aggregation not yet implemented") + else: + raise ValueError(f"Unknown return_date: {self.return_date}") + + return aggregator.aggregate(table) + + def _perform_value_aggregation(self, table): + if self.value_aggregation is not None: + table = self.value_aggregation.aggregate(table) + return table + + def _perform_value_filtering(self, table): + if self.value_filter is not None: + table = self.value_filter.filter(table) + return table diff --git a/phenex/phenotypes/measurement_change_phenotype.py b/phenex/phenotypes/measurement_change_phenotype.py index 1c95f0eb..d0b9743f 100644 --- a/phenex/phenotypes/measurement_change_phenotype.py +++ b/phenex/phenotypes/measurement_change_phenotype.py @@ -2,7 +2,8 @@ from phenex.phenotypes import MeasurementPhenotype, Phenotype from phenex.filters.value import Value, GreaterThanOrEqualTo from phenex.filters.value_filter import ValueFilter -from phenex.tables import PHENOTYPE_TABLE_COLUMNS, PhenotypeTable +from phenex.tables import PhenotypeTable +from phenex.phenotypes.functions import select_phenotype_columns from phenex.aggregators.aggregator import First, Last, ValueAggregator, DailyMedian from ibis import _ @@ -81,13 +82,18 @@ def _execute(self, tables) -> PhenotypeTable: phenotype_table_1 = self.phenotype.table phenotype_table_2 = self.phenotype.table.view() # Create a self-join to compare each measurement with every other measurement + join_predicates = [ + phenotype_table_1.PERSON_ID == phenotype_table_2.PERSON_ID, + (phenotype_table_1.EVENT_DATE != phenotype_table_2.EVENT_DATE) + | (phenotype_table_1.VALUE != phenotype_table_2.VALUE), + ] + if "INDEX_DATE" in phenotype_table_1.columns: + join_predicates.append( + phenotype_table_1.INDEX_DATE == phenotype_table_2.INDEX_DATE + ) joined_table = phenotype_table_1.join( phenotype_table_2, - [ - phenotype_table_1.PERSON_ID == phenotype_table_2.PERSON_ID, - (phenotype_table_1.EVENT_DATE != phenotype_table_2.EVENT_DATE) - | (phenotype_table_1.VALUE != phenotype_table_2.VALUE), - ], + join_predicates, lname="{name}_1", rname="{name}_2", ).filter(_.EVENT_DATE_1 < _.EVENT_DATE_2) @@ -140,9 +146,12 @@ def _execute(self, tables) -> PhenotypeTable: ) # Select the required columns - filtered_table = filtered_table.mutate( + mutate_kwargs = dict( PERSON_ID="PERSON_ID_1", VALUE="VALUE_CHANGE", BOOLEAN=True ) + if "INDEX_DATE_1" in filtered_table.columns: + mutate_kwargs["INDEX_DATE"] = "INDEX_DATE_1" + filtered_table = filtered_table.mutate(**mutate_kwargs) # Handle the return_date attribute for each PERSON_ID using window functions if self.return_date == "first": @@ -153,10 +162,6 @@ def _execute(self, tables) -> PhenotypeTable: if self.return_value is not None: filtered_table = self.return_value.aggregate(filtered_table) - filtered_table = ( - filtered_table.mutate(BOOLEAN=True) - .select(PHENOTYPE_TABLE_COLUMNS) - .distinct() - ) - - return filtered_table + filtered_table = filtered_table.mutate(BOOLEAN=True) + filtered_table = select_phenotype_columns(filtered_table) + return filtered_table.distinct() diff --git a/phenex/phenotypes/phenotype.py b/phenex/phenotypes/phenotype.py index 9cff310c..e1df5a3d 100644 --- a/phenex/phenotypes/phenotype.py +++ b/phenex/phenotypes/phenotype.py @@ -4,6 +4,7 @@ from phenex.tables import ( PhenotypeTable, PHENOTYPE_TABLE_COLUMNS, + PHENOTYPE_TABLE_COLUMNS_WITH_INDEX, is_phenex_phenotype_table, ) from phenex.util import create_logger @@ -53,7 +54,12 @@ def _perform_final_processing(self, table: Table) -> Table: f"Phenotype {self.name} must return columns {PHENOTYPE_TABLE_COLUMNS}. Found {table.columns}." ) - self.table = table.select(PHENOTYPE_TABLE_COLUMNS) + # INDEX_DATE is optional; include in output only if present in the input table + if "INDEX_DATE" in table.columns: + self.table = table.select(PHENOTYPE_TABLE_COLUMNS_WITH_INDEX) + else: + self.table = table.select(PHENOTYPE_TABLE_COLUMNS) + # for some reason, having NULL datatype screws up writing the table to disk; here we make explicit cast if type(self.table.schema()["VALUE"]) == ibis.expr.datatypes.core.Null: self.table = self.table.cast({"VALUE": "float64"}) @@ -78,6 +84,8 @@ def namespaced_table(self) -> Table: f"{self.name}_EVENT_DATE": "EVENT_DATE", f"{self.name}_VALUE": "VALUE", } + if "INDEX_DATE" in self.table.columns: + new_column_names["INDEX_DATE"] = "INDEX_DATE" return self.table.rename(new_column_names) def _execute(self, tables: Dict[str, Table]): diff --git a/phenex/phenotypes/time_range_count_phenotype.py b/phenex/phenotypes/time_range_count_phenotype.py index 0191e9e6..25b0f1a7 100644 --- a/phenex/phenotypes/time_range_count_phenotype.py +++ b/phenex/phenotypes/time_range_count_phenotype.py @@ -10,6 +10,7 @@ from phenex.tables import is_phenex_code_table, PHENOTYPE_TABLE_COLUMNS, PhenotypeTable from phenex.phenotypes.functions import ( select_phenotype_columns, + _get_join_keys, ) from ibis.expr.types.relations import Table from ibis import _ @@ -148,8 +149,9 @@ def _perform_time_filtering(self, table): def _perform_count_aggregation(self, table): """Count the number of distinct time periods per person.""" - table = table.select(["PERSON_ID", "START_DATE", "END_DATE"]).distinct() - return table.group_by("PERSON_ID").aggregate(VALUE=_.count()) + group_keys = _get_join_keys(table) + table = table.select([*group_keys, "START_DATE", "END_DATE"]).distinct() + return table.group_by(group_keys).aggregate(VALUE=_.count()) def _perform_value_filtering(self, table): """Filter persons by period count using value_filter.""" @@ -161,8 +163,11 @@ def _perform_zero_fill(self, table, tables): """Left-join against the PERSON table to include persons with 0 periods (only when no value_filter is set).""" if self.value_filter is not None or "PERSON" not in tables: return table - persons = tables["PERSON"].select("PERSON_ID").distinct() - table = persons.join( - table, persons.PERSON_ID == table.PERSON_ID, how="left" - ).drop("PERSON_ID_right") + join_keys = _get_join_keys(table) + persons = ( + tables["PERSON"] + .select([c for c in join_keys if c in tables["PERSON"].columns]) + .distinct() + ) + table = persons.join(table, _get_join_keys(persons), how="left") return table.mutate(VALUE=table.VALUE.fillna(0)) diff --git a/phenex/phenotypes/time_range_day_count_phenotype.py b/phenex/phenotypes/time_range_day_count_phenotype.py index 50d145f0..7f868091 100644 --- a/phenex/phenotypes/time_range_day_count_phenotype.py +++ b/phenex/phenotypes/time_range_day_count_phenotype.py @@ -10,6 +10,7 @@ from phenex.tables import is_phenex_code_table, PHENOTYPE_TABLE_COLUMNS, PhenotypeTable from phenex.phenotypes.functions import ( select_phenotype_columns, + _get_join_keys, ) from ibis.expr.types.relations import Table from ibis import _ @@ -150,7 +151,10 @@ def _perform_time_filtering(self, table): def _perform_day_count_aggregation(self, table): """Count the total number of days across all distinct time periods per person.""" - table = table.select(["PERSON_ID", "START_DATE", "END_DATE"]).distinct() + cols = ["PERSON_ID", "START_DATE", "END_DATE"] + if "INDEX_DATE" in table.columns: + cols.append("INDEX_DATE") + table = table.select(cols).distinct() table = table.mutate( START_DATE=table.START_DATE.cast("date"), END_DATE=table.END_DATE.cast("date"), @@ -158,7 +162,9 @@ def _perform_day_count_aggregation(self, table): table = table.mutate( DAYS_IN_RANGE=table.END_DATE.delta(table.START_DATE, "day") + 1 ) - return table.group_by("PERSON_ID").aggregate(VALUE=_.DAYS_IN_RANGE.sum()) + return table.group_by(_get_join_keys(table)).aggregate( + VALUE=_.DAYS_IN_RANGE.sum() + ) def _perform_value_filtering(self, table): """Filter persons by total day count using value_filter.""" @@ -170,8 +176,11 @@ def _perform_zero_fill(self, table, tables): """Left-join against the PERSON table to include persons with 0 days (only when no value_filter is set).""" if self.value_filter is not None or "PERSON" not in tables: return table - persons = tables["PERSON"].select("PERSON_ID").distinct() - table = persons.join( - table, persons.PERSON_ID == table.PERSON_ID, how="left" - ).drop("PERSON_ID_right") + join_keys = _get_join_keys(table) + persons = ( + tables["PERSON"] + .select([c for c in join_keys if c in tables["PERSON"].columns]) + .distinct() + ) + table = persons.join(table, _get_join_keys(persons), how="left") return table.mutate(VALUE=table.VALUE.fillna(0)) diff --git a/phenex/phenotypes/time_range_days_to_next_range_phenotype.py b/phenex/phenotypes/time_range_days_to_next_range_phenotype.py index 415464bf..d1740866 100644 --- a/phenex/phenotypes/time_range_days_to_next_range_phenotype.py +++ b/phenex/phenotypes/time_range_days_to_next_range_phenotype.py @@ -2,7 +2,10 @@ from phenex.phenotypes.phenotype import Phenotype from phenex.filters import ValueFilter, RelativeTimeRangeFilter from phenex.tables import PhenotypeTable -from phenex.phenotypes.functions import attach_anchor_and_get_reference_date +from phenex.phenotypes.functions import ( + attach_anchor_and_get_reference_date, + _get_join_keys, +) import ibis from ibis import _ from ibis.expr.types.relations import Table @@ -89,8 +92,8 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: NEIGHBOR_START_DATE="START_DATE", NEIGHBOR_END_DATE="END_DATE" ) - # Join anchored_table with neighbor_table on PERSON_ID - joined = anchored_table.join(neighbor_table, "PERSON_ID") + # Join anchored_table with neighbor_table on (PERSON_ID, INDEX_DATE) + joined = anchored_table.join(neighbor_table, _get_join_keys(anchored_table)) # 3. Filter and Calculate Gap based on 'when' if when == "before": @@ -112,7 +115,9 @@ def _execute(self, tables: Dict[str, Table]) -> PhenotypeTable: # 4. Remove all time_ranges except the closest one (min value/gap) # We find the min VALUE for each anchor range - joined = joined.group_by(["PERSON_ID", group_key]).mutate(min_val=_.VALUE.min()) + joined = joined.group_by([*_get_join_keys(joined), group_key]).mutate( + min_val=_.VALUE.min() + ) joined = joined.filter(joined.VALUE == joined.min_val).drop("min_val") # 5. Apply Value Filter diff --git a/phenex/phenotypes/user_defined_phenotype.py b/phenex/phenotypes/user_defined_phenotype.py index 8813f675..c1c3a019 100644 --- a/phenex/phenotypes/user_defined_phenotype.py +++ b/phenex/phenotypes/user_defined_phenotype.py @@ -8,6 +8,7 @@ from phenex.filters.relative_time_range_filter import RelativeTimeRangeFilter from phenex.filters import DateFilter, ValueFilter from phenex.tables import is_phenex_code_table, PHENOTYPE_TABLE_COLUMNS, PhenotypeTable +from phenex.phenotypes.functions import select_phenotype_columns from phenex.aggregators import First, Last from phenex.util import create_logger @@ -71,6 +72,17 @@ def __init__( def _execute(self, tables) -> PhenotypeTable: table = function(tables) + # Propagate INDEX_DATE from PERSON table when function output lacks it + if ( + "INDEX_DATE" not in table.columns + and "PERSON" in tables + and "INDEX_DATE" in tables["PERSON"].columns + ): + person_index = ( + tables["PERSON"].select("PERSON_ID", "INDEX_DATE").distinct() + ) + table = table.join(person_index, "PERSON_ID") + if "BOOLEAN" not in table.columns: table = table.mutate(BOOLEAN=True).distinct() else: @@ -81,7 +93,7 @@ def _execute(self, tables) -> PhenotypeTable: if "VALUE" not in table.columns: table = table.mutate(VALUE=ibis.null().cast("int32")) - return table + return select_phenotype_columns(table) # Set output_display_type = as a class variable based on returns_value parameter _UserDefinedPhenotype.output_display_type = "value" if returns_value else "boolean" diff --git a/phenex/tables.py b/phenex/tables.py index 2d813827..163ed176 100644 --- a/phenex/tables.py +++ b/phenex/tables.py @@ -571,3 +571,10 @@ def is_phenex_index_table(table: PhenexTable) -> bool: PHENOTYPE_TABLE_COLUMNS = ["PERSON_ID", "BOOLEAN", "EVENT_DATE", "VALUE"] +PHENOTYPE_TABLE_COLUMNS_WITH_INDEX = [ + "PERSON_ID", + "INDEX_DATE", + "BOOLEAN", + "EVENT_DATE", + "VALUE", +] diff --git a/phenex/test/cohort/test_cohort_multi_index.py b/phenex/test/cohort/test_cohort_multi_index.py new file mode 100644 index 00000000..72d19f1d --- /dev/null +++ b/phenex/test/cohort/test_cohort_multi_index.py @@ -0,0 +1,262 @@ +""" +Tests for Cohort with return_index="first", "last", and "all". + +Scenario +-------- +Three patients, each with multiple entry events (code "d1") at different dates. +One exclusion criterion (code "e1", before index) selectively removes some +index dates but not others. + +Input data +~~~~~~~~~~ +Entry events (DRUG_EXPOSURE, code "d1"): + P1: 2020-01-01 (x2), 2020-07-01, 2021-01-01 + P2: 2020-03-01, 2020-09-01 + P3: 2020-05-01 + +Exclusion events (CONDITION_OCCURRENCE, code "e1"): + P1: 2020-04-01 → before 2020-07-01 ✓, before 2021-01-01 ✓, NOT before 2020-01-01 + P3: 2020-03-01 → before 2020-05-01 ✓ + +Surviving index dates after exclusion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + P1: 2020-01-01 + P2: 2020-03-01, 2020-09-01 + P3: (none) + +Expected index tables +~~~~~~~~~~~~~~~~~~~~~ + return_index="first": P1 @ 2020-01-01, P2 @ 2020-03-01 + return_index="last": P1 @ 2020-01-01, P2 @ 2020-09-01 + return_index="all": P1 @ 2020-01-01, P2 @ 2020-03-01, P2 @ 2020-09-01 +""" + +import datetime +import pandas as pd +from phenex.ibis_connect import DuckDBConnector +from phenex.test.cohort_test_generator import CohortTestGenerator +from phenex.codelists import Codelist +from phenex.core import Cohort +from phenex.phenotypes import CodelistPhenotype +from phenex.filters import RelativeTimeRangeFilter, GreaterThanOrEqualTo +from phenex.test.cohort.test_mappings import ( + PersonTableForTests, + DrugExposureTableForTests, + ConditionOccurenceTableForTests, +) + + +# --------------------------------------------------------------------------- +# Shared data setup +# --------------------------------------------------------------------------- + +ENTRY_DATE_1 = datetime.date(2020, 1, 1) +ENTRY_DATE_2 = datetime.date(2020, 7, 1) +ENTRY_DATE_3 = datetime.date(2021, 1, 1) +ENTRY_DATE_4 = datetime.date(2020, 3, 1) +ENTRY_DATE_5 = datetime.date(2020, 9, 1) +ENTRY_DATE_6 = datetime.date(2020, 5, 1) + +EXCLUSION_DATE_P1 = datetime.date(2020, 4, 1) +EXCLUSION_DATE_P3 = datetime.date(2020, 3, 1) + + +def _build_mapped_tables(con): + """Create shared input tables for all three tests.""" + df_person = pd.DataFrame( + { + "PATID": ["P1", "P2", "P3"], + "YOB": [1980, 1980, 1980], + "GENDER": [1, 2, 1], + "ACCEPTABLE": [1, 1, 1], + } + ) + person_table = PersonTableForTests( + con.dest_connection.create_table( + "PERSON", + df_person, + schema={"PATID": str, "YOB": int, "GENDER": int, "ACCEPTABLE": int}, + ) + ) + + # P1 has a duplicate d1 on ENTRY_DATE_1 to verify dedup to one row per date + df_drug = pd.DataFrame( + { + "PATID": ["P1", "P1", "P1", "P1", "P2", "P2", "P3"], + "PRODCODEID": ["d1"] * 7, + "ISSUEDATE": [ + ENTRY_DATE_1, + ENTRY_DATE_1, + ENTRY_DATE_2, + ENTRY_DATE_3, + ENTRY_DATE_4, + ENTRY_DATE_5, + ENTRY_DATE_6, + ], + } + ) + drug_table = DrugExposureTableForTests( + con.dest_connection.create_table( + "DRUG_EXPOSURE", + df_drug, + schema={"PATID": str, "PRODCODEID": str, "ISSUEDATE": datetime.date}, + ) + ) + + df_condition = pd.DataFrame( + { + "PATID": ["P1", "P3"], + "MEDCODEID": ["e1", "e1"], + "OBSDATE": [EXCLUSION_DATE_P1, EXCLUSION_DATE_P3], + } + ) + condition_table = ConditionOccurenceTableForTests( + con.dest_connection.create_table( + "CONDITION_OCCURRENCE", + df_condition, + schema={"PATID": str, "MEDCODEID": str, "OBSDATE": datetime.date}, + ) + ) + + return { + "PERSON": person_table, + "DRUG_EXPOSURE": drug_table, + "CONDITION_OCCURRENCE": condition_table, + } + + +def _make_exclusion(): + return CodelistPhenotype( + name="prior_event", + codelist=Codelist(["e1"]).copy(use_code_type=False), + domain="CONDITION_OCCURRENCE", + relative_time_range=RelativeTimeRangeFilter( + when="before", + min_days=GreaterThanOrEqualTo(0), + ), + ) + + +# --------------------------------------------------------------------------- +# return_index = "first" +# --------------------------------------------------------------------------- + + +class MultiIndexFirstTestGenerator(CohortTestGenerator): + test_date = True + + def define_cohort(self): + entry = CodelistPhenotype( + return_date="all", + codelist=Codelist(["d1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + ) + return Cohort( + name="test_cohort_multi_index_first", + entry_criterion=entry, + exclusions=[_make_exclusion()], + return_index="first", + ) + + def define_mapped_tables(self): + self.con = DuckDBConnector() + return _build_mapped_tables(self.con) + + def define_expected_output(self): + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_4], + } + ) + return {"index": df} + + +# --------------------------------------------------------------------------- +# return_index = "last" +# --------------------------------------------------------------------------- + + +class MultiIndexLastTestGenerator(CohortTestGenerator): + test_date = True + + def define_cohort(self): + entry = CodelistPhenotype( + return_date="all", + codelist=Codelist(["d1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + ) + return Cohort( + name="test_cohort_multi_index_last", + entry_criterion=entry, + exclusions=[_make_exclusion()], + return_index="last", + ) + + def define_mapped_tables(self): + self.con = DuckDBConnector() + return _build_mapped_tables(self.con) + + def define_expected_output(self): + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_5], + } + ) + return {"index": df} + + +# --------------------------------------------------------------------------- +# return_index = "all" +# --------------------------------------------------------------------------- + + +class MultiIndexAllTestGenerator(CohortTestGenerator): + test_date = True + + def define_cohort(self): + entry = CodelistPhenotype( + return_date="all", + codelist=Codelist(["d1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + ) + return Cohort( + name="test_cohort_multi_index_all", + entry_criterion=entry, + exclusions=[_make_exclusion()], + return_index="all", + ) + + def define_mapped_tables(self): + self.con = DuckDBConnector() + return _build_mapped_tables(self.con) + + def define_expected_output(self): + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_4, ENTRY_DATE_5], + } + ) + return {"index": df} + + +# --------------------------------------------------------------------------- +# pytest entry points +# --------------------------------------------------------------------------- + + +def test_cohort_multi_index_first(): + g = MultiIndexFirstTestGenerator() + g.run_tests() + + +def test_cohort_multi_index_last(): + g = MultiIndexLastTestGenerator() + g.run_tests() + + +def test_cohort_multi_index_all(): + g = MultiIndexAllTestGenerator() + g.run_tests() diff --git a/phenex/test/cohort/test_subcohort_multi_index.py b/phenex/test/cohort/test_subcohort_multi_index.py new file mode 100644 index 00000000..6c885ae7 --- /dev/null +++ b/phenex/test/cohort/test_subcohort_multi_index.py @@ -0,0 +1,348 @@ +""" +Tests for Subcohort with multi-index dates (return_index="first", "last", "all"). + +Scenario +-------- +Three patients, each with multiple entry events (code "d1") at different dates. +The parent cohort applies an exclusion (code "e1", before index) that removes +some index dates. The subcohort applies an additional exclusion (code "s1", +before index) that further removes specific index dates. + +Input data +~~~~~~~~~~ +Entry events (DRUG_EXPOSURE, code "d1"): + P1: 2020-01-01, 2020-07-01, 2021-01-01 + P2: 2020-03-01, 2020-09-01 + P3: 2020-05-01 + +Cohort exclusion events (CONDITION_OCCURRENCE, code "e1"): + P1: 2020-04-01 → before 2020-07-01 ✓, before 2021-01-01 ✓, NOT before 2020-01-01 + P3: 2020-03-01 → before 2020-05-01 ✓ + +Surviving index dates after cohort exclusion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + P1: 2020-01-01 + P2: 2020-03-01, 2020-09-01 + P3: (none) + +Subcohort additional exclusion (DRUG_EXPOSURE, code "s1"): + P2: 2020-06-01 → before 2020-09-01 ✓, NOT before 2020-03-01 + +Surviving index dates after subcohort exclusion +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + P1: 2020-01-01 + P2: 2020-03-01 (2020-09-01 removed by s1) + +Expected index tables +~~~~~~~~~~~~~~~~~~~~~ + return_index="first": + Cohort: P1 @ 2020-01-01, P2 @ 2020-03-01 + Subcohort: P1 @ 2020-01-01, P2 @ 2020-03-01 (s1 not before Mar → P2 stays) + + return_index="last": + Cohort: P1 @ 2020-01-01, P2 @ 2020-09-01 + Subcohort: P1 @ 2020-01-01 (s1 before Sep → P2 removed) + + return_index="all": + Cohort: P1 @ 2020-01-01, P2 @ 2020-03-01, P2 @ 2020-09-01 + Subcohort: P1 @ 2020-01-01, P2 @ 2020-03-01 (P2@Sep removed by s1) +""" + +import datetime +import pandas as pd +from phenex.ibis_connect import DuckDBConnector +from phenex.test.cohort.test_subcohort import SubcohortTestGenerator +from phenex.codelists import Codelist +from phenex.core import Cohort, Subcohort +from phenex.phenotypes import CodelistPhenotype +from phenex.filters import RelativeTimeRangeFilter, GreaterThanOrEqualTo +from phenex.test.cohort.test_mappings import ( + PersonTableForTests, + DrugExposureTableForTests, + ConditionOccurenceTableForTests, +) + + +# --------------------------------------------------------------------------- +# Shared constants +# --------------------------------------------------------------------------- + +ENTRY_DATE_1 = datetime.date(2020, 1, 1) +ENTRY_DATE_2 = datetime.date(2020, 7, 1) +ENTRY_DATE_3 = datetime.date(2021, 1, 1) +ENTRY_DATE_4 = datetime.date(2020, 3, 1) +ENTRY_DATE_5 = datetime.date(2020, 9, 1) +ENTRY_DATE_6 = datetime.date(2020, 5, 1) + +EXCLUSION_DATE_P1 = datetime.date(2020, 4, 1) +EXCLUSION_DATE_P3 = datetime.date(2020, 3, 1) + +SUBCOHORT_EXCL_DATE_P2 = datetime.date(2020, 6, 1) + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _build_mapped_tables(con): + """Create input tables shared by all three test modes.""" + df_person = pd.DataFrame( + { + "PATID": ["P1", "P2", "P3"], + "YOB": [1980, 1980, 1980], + "GENDER": [1, 2, 1], + "ACCEPTABLE": [1, 1, 1], + } + ) + person_table = PersonTableForTests( + con.dest_connection.create_table( + "PERSON", + df_person, + schema={"PATID": str, "YOB": int, "GENDER": int, "ACCEPTABLE": int}, + ) + ) + + df_drug = pd.DataFrame( + { + "PATID": ["P1", "P1", "P1", "P2", "P2", "P3", "P2"], + "PRODCODEID": ["d1", "d1", "d1", "d1", "d1", "d1", "s1"], + "ISSUEDATE": [ + ENTRY_DATE_1, + ENTRY_DATE_2, + ENTRY_DATE_3, + ENTRY_DATE_4, + ENTRY_DATE_5, + ENTRY_DATE_6, + SUBCOHORT_EXCL_DATE_P2, + ], + } + ) + drug_table = DrugExposureTableForTests( + con.dest_connection.create_table( + "DRUG_EXPOSURE", + df_drug, + schema={"PATID": str, "PRODCODEID": str, "ISSUEDATE": datetime.date}, + ) + ) + + df_condition = pd.DataFrame( + { + "PATID": ["P1", "P3"], + "MEDCODEID": ["e1", "e1"], + "OBSDATE": [EXCLUSION_DATE_P1, EXCLUSION_DATE_P3], + } + ) + condition_table = ConditionOccurenceTableForTests( + con.dest_connection.create_table( + "CONDITION_OCCURRENCE", + df_condition, + schema={"PATID": str, "MEDCODEID": str, "OBSDATE": datetime.date}, + ) + ) + + return { + "PERSON": person_table, + "DRUG_EXPOSURE": drug_table, + "CONDITION_OCCURRENCE": condition_table, + } + + +def _make_cohort_exclusion(): + return CodelistPhenotype( + name="prior_event", + codelist=Codelist(["e1"]).copy(use_code_type=False), + domain="CONDITION_OCCURRENCE", + relative_time_range=RelativeTimeRangeFilter( + when="before", + min_days=GreaterThanOrEqualTo(0), + ), + ) + + +def _make_subcohort_exclusion(): + return CodelistPhenotype( + name="subcohort_excl_s1", + codelist=Codelist(["s1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + relative_time_range=RelativeTimeRangeFilter( + when="before", + min_days=GreaterThanOrEqualTo(0), + ), + ) + + +# --------------------------------------------------------------------------- +# return_index = "first" +# --------------------------------------------------------------------------- + + +class MultiIndexFirstSubcohortTestGenerator(SubcohortTestGenerator): + test_date = True + + def define_cohort(self): + entry = CodelistPhenotype( + return_date="all", + codelist=Codelist(["d1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + ) + return Cohort( + name="test_subcohort_multi_index_first", + entry_criterion=entry, + exclusions=[_make_cohort_exclusion()], + return_index="first", + ) + + def define_subcohort(self): + return Subcohort( + name="subcohort", + cohort=self.cohort, + exclusions=[_make_subcohort_exclusion()], + ) + + def define_mapped_tables(self): + self.con = DuckDBConnector() + return _build_mapped_tables(self.con) + + def define_expected_output(self): + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_4], + } + ) + return {"index": df} + + def define_expected_subcohort_output(self): + # s1@June NOT before March → P2 stays + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_4], + } + ) + return {"subcohort_index": df} + + +# --------------------------------------------------------------------------- +# return_index = "last" +# --------------------------------------------------------------------------- + + +class MultiIndexLastSubcohortTestGenerator(SubcohortTestGenerator): + test_date = True + + def define_cohort(self): + entry = CodelistPhenotype( + return_date="all", + codelist=Codelist(["d1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + ) + return Cohort( + name="test_subcohort_multi_index_last", + entry_criterion=entry, + exclusions=[_make_cohort_exclusion()], + return_index="last", + ) + + def define_subcohort(self): + return Subcohort( + name="subcohort", + cohort=self.cohort, + exclusions=[_make_subcohort_exclusion()], + ) + + def define_mapped_tables(self): + self.con = DuckDBConnector() + return _build_mapped_tables(self.con) + + def define_expected_output(self): + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_5], + } + ) + return {"index": df} + + def define_expected_subcohort_output(self): + # s1@June IS before September → P2 removed + df = pd.DataFrame( + { + "PERSON_ID": ["P1"], + "EVENT_DATE": [ENTRY_DATE_1], + } + ) + return {"subcohort_index": df} + + +# --------------------------------------------------------------------------- +# return_index = "all" +# --------------------------------------------------------------------------- + + +class MultiIndexAllSubcohortTestGenerator(SubcohortTestGenerator): + test_date = True + + def define_cohort(self): + entry = CodelistPhenotype( + return_date="all", + codelist=Codelist(["d1"]).copy(use_code_type=False), + domain="DRUG_EXPOSURE", + ) + return Cohort( + name="test_subcohort_multi_index_all", + entry_criterion=entry, + exclusions=[_make_cohort_exclusion()], + return_index="all", + ) + + def define_subcohort(self): + return Subcohort( + name="subcohort", + cohort=self.cohort, + exclusions=[_make_subcohort_exclusion()], + ) + + def define_mapped_tables(self): + self.con = DuckDBConnector() + return _build_mapped_tables(self.con) + + def define_expected_output(self): + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_4, ENTRY_DATE_5], + } + ) + return {"index": df} + + def define_expected_subcohort_output(self): + # P2@Sep removed by s1; P2@Mar stays + df = pd.DataFrame( + { + "PERSON_ID": ["P1", "P2"], + "EVENT_DATE": [ENTRY_DATE_1, ENTRY_DATE_4], + } + ) + return {"subcohort_index": df} + + +# --------------------------------------------------------------------------- +# pytest entry points +# --------------------------------------------------------------------------- + + +def test_subcohort_multi_index_first(): + g = MultiIndexFirstSubcohortTestGenerator() + g.run_tests() + + +def test_subcohort_multi_index_last(): + g = MultiIndexLastSubcohortTestGenerator() + g.run_tests() + + +def test_subcohort_multi_index_all(): + g = MultiIndexAllSubcohortTestGenerator() + g.run_tests() diff --git a/phenex/test/phenotypes/multi_index/__init__.py b/phenex/test/phenotypes/multi_index/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/phenex/test/phenotypes/multi_index/test_age_phenotype.py b/phenex/test/phenotypes/multi_index/test_age_phenotype.py new file mode 100644 index 00000000..60f7c850 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_age_phenotype.py @@ -0,0 +1,70 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_age_phenotype import AgePhenotypeTestGenerator + + +class MultiIndexAgePhenotypeTestGenerator(MultiIndexMixin, AgePhenotypeTestGenerator): + name_space = "mi_agpt" + _index_date = datetime.date(2022, 1, 1) + _shift = datetime.timedelta(days=730) + + def define_input_tables(self): + tables = AgePhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = AgePhenotypeTestGenerator.define_phenotype_tests(self) + age_shift = self.shift.days // 365 + index_date_2 = self._index_date + self.shift + + for info in tests: + orig_persons = list(info["persons"]) + orig_values = list(info["values"]) + + # At index_date_2, person P{x} has age x + age_shift. + # Recompute which persons pass the value filter at the shifted age. + shifted_persons = [] + shifted_values = [] + for x in range(self.n_persons): + age = x + age_shift + if self._check_filter(info.get("value_filter"), age): + shifted_persons.append(f"P{x}") + shifted_values.append(age) + + info["persons"] = orig_persons + shifted_persons + info["values"] = orig_values + shifted_values + info["index_dates"] = [self._index_date] * len(orig_persons) + [ + index_date_2 + ] * len(shifted_persons) + + return tests + + @staticmethod + def _check_filter(vf, age): + if vf is None: + return True + if vf.min_value is not None: + op = vf.min_value.operator + val = vf.min_value.value + if op == ">=" and age < val: + return False + if op == ">" and age <= val: + return False + if vf.max_value is not None: + op = vf.max_value.operator + val = vf.max_value.value + if op == "<=" and age > val: + return False + if op == "<" and age >= val: + return False + return True + + +def test_multiindex_age_phenotype(): + tg = MultiIndexAgePhenotypeTestGenerator() + tg.run_tests() + + +if __name__ == "__main__": + test_multiindex_age_phenotype() diff --git a/phenex/test/phenotypes/multi_index/test_arithmetic_phenotype.py b/phenex/test/phenotypes/multi_index/test_arithmetic_phenotype.py new file mode 100644 index 00000000..d2718157 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_arithmetic_phenotype.py @@ -0,0 +1,34 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_arithmetic_phenotype import ( + ArithmeticPhenotypeArithmeticPhenotypeTestGenerator, +) + + +class MultiIndexArithmeticPhenotypeTestGenerator( + MultiIndexMixin, ArithmeticPhenotypeArithmeticPhenotypeTestGenerator +): + name_space = "mi_arpt" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = ( + ArithmeticPhenotypeArithmeticPhenotypeTestGenerator.define_input_tables( + self + ) + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = ( + ArithmeticPhenotypeArithmeticPhenotypeTestGenerator.define_phenotype_tests( + self + ) + ) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_arithmetic_phenotype(): + tg = MultiIndexArithmeticPhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_bin_phenotype.py b/phenex/test/phenotypes/multi_index/test_bin_phenotype.py new file mode 100644 index 00000000..8e30a88f --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_bin_phenotype.py @@ -0,0 +1,24 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_bin_phenotype import BinnedAgePhenotypeTestGenerator + + +class MultiIndexBinnedAgePhenotypeTestGenerator( + MultiIndexMixin, BinnedAgePhenotypeTestGenerator +): + name_space = "mi_bnpt" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = BinnedAgePhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = BinnedAgePhenotypeTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_binned_age_phenotype(): + tg = MultiIndexBinnedAgePhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_categoric_phenotype.py b/phenex/test/phenotypes/multi_index/test_categoric_phenotype.py new file mode 100644 index 00000000..7121951d --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_categoric_phenotype.py @@ -0,0 +1,44 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_categoric_phenotype import ( + CategoricalPhenotypeWithDateTestGenerator, +) + + +class MultiIndexCategoricalPhenotypeWithDateTestGenerator( + MultiIndexMixin, CategoricalPhenotypeWithDateTestGenerator +): + name_space = "mi_ctpt_date" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = CategoricalPhenotypeWithDateTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = CategoricalPhenotypeWithDateTestGenerator.define_phenotype_tests(self) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # With 90-day shift (→ 2022-04-01), P3's event (2022-01-03) is now + # before the shifted index, so it passes "before" filters. + # For "after" filter (c3), P3's event is now before shifted → excluded. + shifted_persons = { + "single_flag": ["P0", "P1", "P2", "P3"], + "two_categorical_filter_or": ["P0", "P1", "P2", "P3", "P6", "P7"], + "two_categorical_filter_and": [], + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +def test_multiindex_categorical_phenotype_with_date(): + tg = MultiIndexCategoricalPhenotypeWithDateTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_codelist_phenotype_autojoin.py b/phenex/test/phenotypes/multi_index/test_codelist_phenotype_autojoin.py new file mode 100644 index 00000000..ee0dad04 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_codelist_phenotype_autojoin.py @@ -0,0 +1,51 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_codelist_phenotype_autojoin import ( + CodelistPhenotypeAutojoinTimeRangeTestGenerator, +) + + +class MultiIndexCodelistAutojoinTimeRangeTestGenerator( + MultiIndexMixin, CodelistPhenotypeAutojoinTimeRangeTestGenerator +): + name_space = "mi_clpt_autojoin_timerange" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = CodelistPhenotypeAutojoinTimeRangeTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = CodelistPhenotypeAutojoinTimeRangeTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # With a 90-day shift, relative distances change so different + # persons pass each filter at the shifted INDEX_DATE. + shifted_persons = { + "max_days_leq_180": ["P1", "P2", "P6", "P7", "P8", "P10", "P11"], + "max_days_lt_180": ["P2", "P6", "P7", "P8", "P10", "P11"], + "min_days_geq_90_max_days_leq_180": ["P1", "P2", "P6", "P7"], + "after_max_days_leq_180": ["P9", "P10", "P12", "P13", "P14"], + "after_max_days_g_90_max_days_leq_180": ["P12"], + "range_min_gn90_max_g90": ["P8", "P9", "P10", "P11", "P14"], + "range_min_gn90_max_ge180": ["P8", "P9", "P10", "P11", "P12", "P13", "P14"], + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +def test_multiindex_codelist_autojoin_time_range(): + tg = MultiIndexCodelistAutojoinTimeRangeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_codelist_phenotype_multiindex.py b/phenex/test/phenotypes/multi_index/test_codelist_phenotype_multiindex.py new file mode 100644 index 00000000..053a63a7 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_codelist_phenotype_multiindex.py @@ -0,0 +1,222 @@ +""" +Multi-index date variants of codelist phenotype tests. + +Each test duplicates input data with a second INDEX_DATE (shifted by 90 days), +verifying that phenotype logic correctly partitions results by (PERSON_ID, INDEX_DATE). +Only includes tests whose input data contained an INDEX_DATE column. +""" + +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_codelist_phenotype import ( + CodelistPhenotypeRelativeTimeRangeFilterTestGenerator, + CodelistPhenotypeAnchorPhenotypeRelativeTimeRangeFilterTestGenerator, + CodelistPhenotypeReturnDateFilterTestGenerator, +) + + +# ── Relative time-range filter (INDEX_DATE as anchor) ──────────────────── + + +class MultiIndexRelativeTimeRangeFilterTestGenerator( + MultiIndexMixin, CodelistPhenotypeRelativeTimeRangeFilterTestGenerator +): + name_space = "mi_clpt_timerangefilter" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = ( + CodelistPhenotypeRelativeTimeRangeFilterTestGenerator.define_input_tables( + self + ) + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = CodelistPhenotypeRelativeTimeRangeFilterTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # With a 90-day shift, relative distances change so different + # persons pass each filter at the shifted INDEX_DATE. + shifted_persons = { + "max_days_leq_180": ["P1", "P2", "P6", "P7", "P8", "P10", "P11"], + "max_days_lt_180": ["P2", "P6", "P7", "P8", "P10", "P11"], + "min_days_geq_90_max_days_leq_180": ["P1", "P2", "P6", "P7"], + "after_max_days_leq_180": ["P9", "P10", "P12", "P13", "P14"], + "after_max_days_g_90_max_days_leq_180": ["P12"], + "range_min_gn90_max_g90": ["P8", "P9", "P10", "P11", "P14"], + "range_min_gn90_max_ge180": ["P8", "P9", "P10", "P11", "P12", "P13", "P14"], + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +# ── Anchor phenotype relative time-range filter ────────────────────────── + + +class MultiIndexAnchorPhenotypeTestGenerator( + MultiIndexMixin, + CodelistPhenotypeAnchorPhenotypeRelativeTimeRangeFilterTestGenerator, +): + name_space = "mi_clpt_anchor_phenotype" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = CodelistPhenotypeAnchorPhenotypeRelativeTimeRangeFilterTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = CodelistPhenotypeAnchorPhenotypeRelativeTimeRangeFilterTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # At shifted INDEX_DATE (2022-04-01): + # - phenotypeindex1 (>0, ≤90 before): matches P10-P14 (anchor=2022-01-01) + # - phenotypeindex2 (≥0, ≤180 before): matches P5-P9, P10-P14, P15-P19 + # Then dependent phenotypes filter c2 events relative to anchors. + shifted_persons = { + "p1": ["P10"], + "p2": ["P6", "P7", "P11", "P12", "P16", "P17"], + "p4": ["P6", "P7", "P8", "P11", "P12", "P13", "P16", "P17", "P18"], + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +# ── Return-date filter ─────────────────────────────────────────────────── + + +class MultiIndexReturnDateTestGenerator( + MultiIndexMixin, CodelistPhenotypeReturnDateFilterTestGenerator +): + name_space = "mi_clpt_return_date" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = CodelistPhenotypeReturnDateFilterTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = CodelistPhenotypeReturnDateFilterTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # At shifted INDEX_DATE (2022-04-01), c1 events for P0 have new + # relative distances. Compute correct expected output per test. + for info in tests: + orig_persons = list(info["persons"]) + orig_dates = list(info["dates"]) + n = len(orig_persons) + name = info["name"] + + if name == "returndate": + # return_date="all", no filter → all 6 c1 events at both INDEX_DATEs + info["persons"] = orig_persons + orig_persons + info["dates"] = orig_dates + orig_dates + info["index_dates"] = [idx1] * n + [idx2] * n + + elif name == "l90": + # return_date="all", before, max_days < 90 + # Shifted: events[6]=2022-03-31 (1d), events[7]=2022-04-01 (0d) + sp = ["P0", "P0"] + sd = [self.event_dates[6], self.event_dates[7]] + info["persons"] = orig_persons + sp + info["dates"] = orig_dates + sd + info["index_dates"] = [idx1] * n + [idx2] * len(sp) + + elif name == "leq90": + # return_date="all", before, max_days ≤ 90 + # Shifted: events[6] (1d), events[7] (0d) + sp = ["P0", "P0"] + sd = [self.event_dates[6], self.event_dates[7]] + info["persons"] = orig_persons + sp + info["dates"] = orig_dates + sd + info["index_dates"] = [idx1] * n + [idx2] * len(sp) + + elif name == "first_preindex": + # return_date="first", no filter → earliest c1 = events[0] + info["persons"] = orig_persons + ["P0"] + info["dates"] = orig_dates + [self.event_dates[0]] + info["index_dates"] = [idx1] * n + [idx2] + + elif name == "last_preindex": + # return_date="last", before → events[7] (on shifted index) + info["persons"] = orig_persons + ["P0"] + info["dates"] = orig_dates + [self.event_dates[7]] + info["index_dates"] = [idx1] * n + [idx2] + + elif name == "first_leq90": + # return_date="first", before, max_days ≤ 90 → events[6] + info["persons"] = orig_persons + ["P0"] + info["dates"] = orig_dates + [self.event_dates[6]] + info["index_dates"] = [idx1] * n + [idx2] + + elif name == "last_postindex": + # return_date="last", after → events[8] + info["persons"] = orig_persons + ["P0"] + info["dates"] = orig_dates + [self.event_dates[8]] + info["index_dates"] = [idx1] * n + [idx2] + + elif name == "first_postindex": + # return_date="first", after → events[7] (on shifted index) + info["persons"] = orig_persons + ["P0"] + info["dates"] = orig_dates + [self.event_dates[7]] + info["index_dates"] = [idx1] * n + [idx2] + + elif name == "postindex_leq90": + # return_date="all", after, max_days ≤ 90 → events[7], events[8] + sp = ["P0", "P0"] + sd = [self.event_dates[7], self.event_dates[8]] + info["persons"] = orig_persons + sp + info["dates"] = orig_dates + sd + info["index_dates"] = [idx1] * n + [idx2] * len(sp) + + return tests + + +# ── Test functions ──────────────────────────────────────────────────────── + + +def test_multiindex_relative_time_range_filter(): + tg = MultiIndexRelativeTimeRangeFilterTestGenerator() + tg.run_tests() + + +def test_multiindex_anchor_phenotype(): + tg = MultiIndexAnchorPhenotypeTestGenerator() + tg.run_tests() + + +def test_multiindex_return_date(): + tg = MultiIndexReturnDateTestGenerator() + tg.run_tests() + + +if __name__ == "__main__": + test_multiindex_relative_time_range_filter() + test_multiindex_anchor_phenotype() + test_multiindex_return_date() diff --git a/phenex/test/phenotypes/multi_index/test_death_phenotype.py b/phenex/test/phenotypes/multi_index/test_death_phenotype.py new file mode 100644 index 00000000..ce53c8d0 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_death_phenotype.py @@ -0,0 +1,182 @@ +import datetime + +import pandas as pd + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_death_phenotype import ( + DeathPhenotypeTestGenerator, + DeathPhenotypeDateRangeTestGenerator, +) + + +class MultiIndexDeathPhenotypeTestGenerator( + MultiIndexMixin, DeathPhenotypeTestGenerator +): + name_space = "mi_dthpt" + _index_date = datetime.date(2022, 1, 1) + _shift = datetime.timedelta(days=730) + + def _duplicate_input_tables(self, input_tables): + """Shift INDEX_DATE backward instead of forward for death phenotype.""" + result = [] + for info in input_tables: + df = info["df"].copy() + if "INDEX_DATE" not in df.columns: + result.append(info) + continue + df2 = df.copy() + df2["INDEX_DATE"] = pd.to_datetime(df2["INDEX_DATE"]) - self.shift + combined = pd.concat([df, df2], ignore_index=True) + result.append({**info, "df": combined}) + return result + + def define_input_tables(self): + tables = DeathPhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = DeathPhenotypeTestGenerator.define_phenotype_tests(self) + index_date_1 = self._index_date + index_date_2 = datetime.date( + *(pd.Timestamp(index_date_1) - self.shift).timetuple()[:3] + ) + + # Death dates (same for both INDEX_DATEs): + # P0: None, P1: idx, P2: idx-20d, P3: idx-40d, P4: idx-60d, + # P5: idx+0d, P6: idx+20d, P7: idx+40d, P8: idx+60d + # At index_date_2 (~2020-01-02) ALL deaths are >670 days in the future. + # Only unbounded "after" filters match; all "before" and bounded filters + # produce no matches at index_date_2. + death_dates = list(self.input_table["DATE_OF_DEATH"].values) + + for info in tests: + orig_persons = list(info["persons"]) + orig_dates = list(info["dates"]) + + rtr = info.get("time_range_filter") + shifted_persons = [] + shifted_dates = [] + + if rtr is not None and rtr.when == "after": + has_min_days = rtr.min_days is not None + has_max_days = rtr.max_days is not None + for i in range(self.n_persons): + dd = death_dates[i] + if dd is None or pd.isna(dd): + continue + dd_date = pd.Timestamp(dd) + idx2_ts = pd.Timestamp(index_date_2) + days_diff = (dd_date - idx2_ts).days + if days_diff < 0: + continue + if has_min_days: + op = rtr.min_days.operator + val = rtr.min_days.value + if op == ">" and days_diff <= val: + continue + if op == ">=" and days_diff < val: + continue + if has_max_days: + op = rtr.max_days.operator + val = rtr.max_days.value + if op == "<=" and days_diff > val: + continue + if op == "<" and days_diff >= val: + continue + shifted_persons.append(f"P{i}") + shifted_dates.append(dd) + + info["persons"] = orig_persons + shifted_persons + info["dates"] = orig_dates + shifted_dates + info["index_dates"] = [index_date_1] * len(orig_persons) + [ + index_date_2 + ] * len(shifted_persons) + + return tests + + +class MultiIndexDeathPhenotypeDateRangeTestGenerator( + MultiIndexMixin, DeathPhenotypeDateRangeTestGenerator +): + name_space = "mi_dthpt_daterange" + _index_date = datetime.date(2022, 1, 1) + _shift = datetime.timedelta(days=730) + + def _duplicate_input_tables(self, input_tables): + """Shift INDEX_DATE backward instead of forward for death phenotype.""" + result = [] + for info in input_tables: + df = info["df"].copy() + if "INDEX_DATE" not in df.columns: + result.append(info) + continue + df2 = df.copy() + df2["INDEX_DATE"] = pd.to_datetime(df2["INDEX_DATE"]) - self.shift + combined = pd.concat([df, df2], ignore_index=True) + result.append({**info, "df": combined}) + return result + + def define_input_tables(self): + tables = DeathPhenotypeDateRangeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = DeathPhenotypeDateRangeTestGenerator.define_phenotype_tests(self) + index_date_1 = self._index_date + index_date_2 = datetime.date( + *(pd.Timestamp(index_date_1) - self.shift).timetuple()[:3] + ) + death_dates = list(self.input_table["DATE_OF_DEATH"].values) + + # At index_date_2 (~2020-01-02) all deaths are far in the future. + # date_range is absolute (2021-12-01..2022-01-31) so it still applies. + # "after" filter: deaths in date_range AND after index_date_2 → P1,P2,P3 + # "before" filter: no deaths are before index_date_2 → none + for info in tests: + orig_persons = list(info["persons"]) + orig_dates = list(info["dates"]) + orig_values = list(info["values"]) + + phenotype = info["phenotype"] + rtr = phenotype.relative_time_range + if isinstance(rtr, list): + rtr = rtr[0] + + shifted_persons = [] + shifted_dates = [] + shifted_values = [] + + if rtr.when == "after": + # Deaths in date_range AND after index_date_2 + dt_start = pd.Timestamp("2021-12-01") + dt_end = pd.Timestamp("2022-01-31") + idx2_ts = pd.Timestamp(index_date_2) + for i in range(len(death_dates)): + dd = death_dates[i] + if dd is None or pd.isna(dd): + continue + dd_ts = pd.Timestamp(dd) + if dd_ts >= dt_start and dd_ts <= dt_end and dd_ts >= idx2_ts: + days_diff = (dd_ts - idx2_ts).days + shifted_persons.append(f"P{i}") + shifted_dates.append(dd) + shifted_values.append(days_diff) + + info["persons"] = orig_persons + shifted_persons + info["dates"] = orig_dates + shifted_dates + info["values"] = orig_values + shifted_values + info["index_dates"] = [index_date_1] * len(orig_persons) + [ + index_date_2 + ] * len(shifted_persons) + + return tests + + +def test_multiindex_death_phenotype(): + tg = MultiIndexDeathPhenotypeTestGenerator() + tg.run_tests() + + +def test_multiindex_death_phenotype_date_range(): + tg = MultiIndexDeathPhenotypeDateRangeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_event_count_phenotype.py b/phenex/test/phenotypes/multi_index/test_event_count_phenotype.py new file mode 100644 index 00000000..64de797f --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_event_count_phenotype.py @@ -0,0 +1,22 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_event_count_phenotype import EventCountTestGenerator + + +class MultiIndexEventCountTestGenerator(MultiIndexMixin, EventCountTestGenerator): + name_space = "mi_ecpt" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = EventCountTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = EventCountTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_event_count(): + tg = MultiIndexEventCountTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_event_phenotype.py b/phenex/test/phenotypes/multi_index/test_event_phenotype.py new file mode 100644 index 00000000..8a00bd28 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_event_phenotype.py @@ -0,0 +1,82 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_event_phenotype import ( + EventPhenotypeBasicTestGenerator, + EventPhenotypeRelativeTimeRangeFilterTestGenerator, +) + + +class MultiIndexEventPhenotypeBasicTestGenerator( + MultiIndexMixin, EventPhenotypeBasicTestGenerator +): + name_space = "mi_evpt_basic" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = EventPhenotypeBasicTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = EventPhenotypeBasicTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexEventPhenotypeRelativeTimeRangeFilterTestGenerator( + MultiIndexMixin, EventPhenotypeRelativeTimeRangeFilterTestGenerator +): + name_space = "mi_evpt_timerange" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = EventPhenotypeRelativeTimeRangeFilterTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = ( + EventPhenotypeRelativeTimeRangeFilterTestGenerator.define_phenotype_tests( + self + ) + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # With a 90-day shift, the relative distances change so different + # persons pass each filter at the shifted INDEX_DATE. + shifted_persons = { + "max_days_leq_180": ["P1", "P2", "P6", "P7", "P8", "P10", "P11"], + "max_days_lt_180": ["P2", "P6", "P7", "P8", "P10", "P11"], + "min_days_geq_90_max_days_leq_180": ["P1", "P2", "P6", "P7"], + "after_max_days_leq_180": ["P9", "P10", "P12", "P13", "P14"], + "after_min_gt_90_max_leq_180": ["P12"], + "range_min_gn90_max_l90": ["P8", "P9", "P10", "P11", "P14"], + "range_min_gn90_max_leq180": [ + "P8", + "P9", + "P10", + "P11", + "P12", + "P13", + "P14", + ], + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +def test_multiindex_event_phenotype_basic(): + tg = MultiIndexEventPhenotypeBasicTestGenerator() + tg.run_tests() + + +def test_multiindex_event_phenotype_relative_time_range(): + tg = MultiIndexEventPhenotypeRelativeTimeRangeFilterTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_further_value_filter_phenotype.py b/phenex/test/phenotypes/multi_index/test_further_value_filter_phenotype.py new file mode 100644 index 00000000..df079623 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_further_value_filter_phenotype.py @@ -0,0 +1,114 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_further_value_filter_phenotype import ( + FurtherValueFilterBasicTestGenerator, + FurtherValueFilterAggregationTestGenerator, + FurtherValueFilterDateRangeTestGenerator, + FurtherValueFilterRelativeTimeRangeTestGenerator, + FurtherValueFilterReturnDateTestGenerator, +) + + +class MultiIndexFurtherValueFilterBasicTestGenerator( + MultiIndexMixin, FurtherValueFilterBasicTestGenerator +): + name_space = "mi_fvf_basic" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = FurtherValueFilterBasicTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = FurtherValueFilterBasicTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexFurtherValueFilterAggregationTestGenerator( + MultiIndexMixin, FurtherValueFilterAggregationTestGenerator +): + name_space = "mi_fvf_aggregation" + _index_date = datetime.date(2022, 1, 3) + + def define_input_tables(self): + tables = FurtherValueFilterAggregationTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = FurtherValueFilterAggregationTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexFurtherValueFilterDateRangeTestGenerator( + MultiIndexMixin, FurtherValueFilterDateRangeTestGenerator +): + name_space = "mi_fvf_daterange" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = FurtherValueFilterDateRangeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = FurtherValueFilterDateRangeTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexFurtherValueFilterRelativeTimeRangeTestGenerator( + MultiIndexMixin, FurtherValueFilterRelativeTimeRangeTestGenerator +): + name_space = "mi_fvf_relativetimerange" + _index_date = datetime.date(2022, 6, 1) + + def define_input_tables(self): + tables = FurtherValueFilterRelativeTimeRangeTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = FurtherValueFilterRelativeTimeRangeTestGenerator.define_phenotype_tests( + self + ) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexFurtherValueFilterReturnDateTestGenerator( + MultiIndexMixin, FurtherValueFilterReturnDateTestGenerator +): + name_space = "mi_fvf_returndate" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = FurtherValueFilterReturnDateTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = FurtherValueFilterReturnDateTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_further_value_filter_basic(): + tg = MultiIndexFurtherValueFilterBasicTestGenerator() + tg.run_tests() + + +def test_multiindex_further_value_filter_aggregation(): + tg = MultiIndexFurtherValueFilterAggregationTestGenerator() + tg.run_tests() + + +def test_multiindex_further_value_filter_date_range(): + tg = MultiIndexFurtherValueFilterDateRangeTestGenerator() + tg.run_tests() + + +def test_multiindex_further_value_filter_relative_time_range(): + tg = MultiIndexFurtherValueFilterRelativeTimeRangeTestGenerator() + tg.run_tests() + + +def test_multiindex_further_value_filter_return_date(): + tg = MultiIndexFurtherValueFilterReturnDateTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_logic_phenotype.py b/phenex/test/phenotypes/multi_index/test_logic_phenotype.py new file mode 100644 index 00000000..7caab54c --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_logic_phenotype.py @@ -0,0 +1,55 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_logic_phenotype import ( + LogicPhenotypeValueTestGenerator, + LogicPhenotypeMixedComponentValueTypesTestGenerator, +) + + +class MultiIndexLogicPhenotypeValueTestGenerator( + MultiIndexMixin, LogicPhenotypeValueTestGenerator +): + name_space = "mi_lgpt_value" + _index_date = datetime.date(2020, 1, 1) + + def define_input_tables(self): + tables = LogicPhenotypeValueTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = LogicPhenotypeValueTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexLogicPhenotypeMixedComponentValueTypesTestGenerator( + MultiIndexMixin, LogicPhenotypeMixedComponentValueTypesTestGenerator +): + name_space = "mi_lgpt_mixed" + _index_date = datetime.date(2020, 1, 1) + + def define_input_tables(self): + tables = ( + LogicPhenotypeMixedComponentValueTypesTestGenerator.define_input_tables( + self + ) + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = ( + LogicPhenotypeMixedComponentValueTypesTestGenerator.define_phenotype_tests( + self + ) + ) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_logic_phenotype_value(): + tg = MultiIndexLogicPhenotypeValueTestGenerator() + tg.run_tests() + + +def test_multiindex_logic_phenotype_mixed_value_types(): + tg = MultiIndexLogicPhenotypeMixedComponentValueTypesTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_logic_phenotype_complex_entry.py b/phenex/test/phenotypes/multi_index/test_logic_phenotype_complex_entry.py new file mode 100644 index 00000000..6b143a94 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_logic_phenotype_complex_entry.py @@ -0,0 +1,26 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_logic_phenotype_complex_entry import ( + LogicPhenotypeComplexEntryTestGenerator, +) + + +class MultiIndexLogicPhenotypeComplexEntryTestGenerator( + MultiIndexMixin, LogicPhenotypeComplexEntryTestGenerator +): + name_space = "mi_lgpt_complex_entry" + _index_date = datetime.date(2022, 9, 20) + + def define_input_tables(self): + tables = LogicPhenotypeComplexEntryTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = LogicPhenotypeComplexEntryTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_logic_phenotype_complex_entry(): + tg = MultiIndexLogicPhenotypeComplexEntryTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_measurement_change_phenotype.py b/phenex/test/phenotypes/multi_index/test_measurement_change_phenotype.py new file mode 100644 index 00000000..1c22da58 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_measurement_change_phenotype.py @@ -0,0 +1,76 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_measurement_change_phenotype import ( + MeasurementChangeIncreaseDecreasePhenotypeTestGenerator, + MeasurementChangePhenotypeRelativeTimeRangeTestGenerator, +) + + +class MultiIndexMeasurementChangeIncreaseDecreaseTestGenerator( + MultiIndexMixin, MeasurementChangeIncreaseDecreasePhenotypeTestGenerator +): + name_space = "mi_mcpt_increasedecrease" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = ( + MeasurementChangeIncreaseDecreasePhenotypeTestGenerator.define_input_tables( + self + ) + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = MeasurementChangeIncreaseDecreasePhenotypeTestGenerator.define_phenotype_tests( + self + ) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexMeasurementChangeRelativeTimeRangeTestGenerator( + MultiIndexMixin, MeasurementChangePhenotypeRelativeTimeRangeTestGenerator +): + name_space = "mi_mcpt_relativetimerange" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = MeasurementChangePhenotypeRelativeTimeRangeTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = MeasurementChangePhenotypeRelativeTimeRangeTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # With 90-day shift (→ 2022-04-01), all events (Dec 27 – Jan 9) are + # before the shifted index. Post-index tests find nothing; pre-index + # tests gain the persons whose events were originally after index. + shifted_persons = { + "mmcpt": [], # post-index: no events after shifted + "mmcpt_2": [], # post-index: no events after shifted + "mmcpt_3": ["P0", "P3"], # pre-index: P0 (1d apart) + P3 (original) + "mmcpt_4": ["P0", "P1", "P3", "P4", "P6"], # pre-index: ≤2d apart + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +def test_multiindex_measurement_change_increase_decrease(): + tg = MultiIndexMeasurementChangeIncreaseDecreaseTestGenerator() + tg.run_tests() + + +def test_multiindex_measurement_change_relative_time_range(): + tg = MultiIndexMeasurementChangeRelativeTimeRangeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_measurement_phenotype.py b/phenex/test/phenotypes/multi_index/test_measurement_phenotype.py new file mode 100644 index 00000000..705c3e78 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_measurement_phenotype.py @@ -0,0 +1,30 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_measurement_phenotype import ( + MeasurementPhenotypeRelativeTimeRangeFilterTestGenerator, +) + + +class MultiIndexMeasurementPhenotypeRelativeTimeRangeTestGenerator( + MultiIndexMixin, MeasurementPhenotypeRelativeTimeRangeFilterTestGenerator +): + name_space = "mi_mmpt_relativetimerange" + _index_date = datetime.date(2022, 1, 2) + + def define_input_tables(self): + tables = MeasurementPhenotypeRelativeTimeRangeFilterTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = MeasurementPhenotypeRelativeTimeRangeFilterTestGenerator.define_phenotype_tests( + self + ) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_measurement_phenotype_relative_time_range(): + tg = MultiIndexMeasurementPhenotypeRelativeTimeRangeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_relative_time_range_filter.py b/phenex/test/phenotypes/multi_index/test_relative_time_range_filter.py new file mode 100644 index 00000000..8b194d4b --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_relative_time_range_filter.py @@ -0,0 +1,51 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_relative_time_range_filter import ( + CodelistPhenotypeRelativeTimeRangeFilterTestGenerator, +) + + +class MultiIndexRelativeTimeRangeFilterTestGenerator( + MultiIndexMixin, CodelistPhenotypeRelativeTimeRangeFilterTestGenerator +): + name_space = "mi_rtrf" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = ( + CodelistPhenotypeRelativeTimeRangeFilterTestGenerator.define_input_tables( + self + ) + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = CodelistPhenotypeRelativeTimeRangeFilterTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + shifted_persons = { + "max_days_leq_180": ["P1", "P2", "P6", "P7", "P8", "P10", "P11"], + "max_days_lt_180": ["P2", "P6", "P7", "P8", "P10", "P11"], + "min_days_geq_90_max_days_leq_180": ["P1", "P2", "P6", "P7"], + "after_max_days_leq_180": ["P9", "P10", "P12", "P13", "P14"], + "after_max_days_g_90_max_days_leq_180": ["P12"], + "range_min_gn90_max_g90": ["P8", "P9", "P10", "P11", "P14"], + "range_min_gn90_max_ge180": ["P8", "P9", "P10", "P11", "P12", "P13", "P14"], + } + + for test in tests: + orig = list(test["persons"]) + shifted = shifted_persons[test["name"]] + test["persons"] = orig + shifted + test["index_dates"] = [idx1] * len(orig) + [idx2] * len(shifted) + + return tests + + +def test_multiindex_relative_time_range_filter(): + tg = MultiIndexRelativeTimeRangeFilterTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_score_phenotype.py b/phenex/test/phenotypes/multi_index/test_score_phenotype.py new file mode 100644 index 00000000..464783f2 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_score_phenotype.py @@ -0,0 +1,24 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_score_phenotype import ScorePhenotypeTestGenerator + + +class MultiIndexScorePhenotypeTestGenerator( + MultiIndexMixin, ScorePhenotypeTestGenerator +): + name_space = "mi_scpt" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = ScorePhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = ScorePhenotypeTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_score_phenotype(): + tg = MultiIndexScorePhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_time_range_count_phenotype.py b/phenex/test/phenotypes/multi_index/test_time_range_count_phenotype.py new file mode 100644 index 00000000..678bc2ec --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_time_range_count_phenotype.py @@ -0,0 +1,49 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_time_range_count_phenotype import ( + TimeRangeCountPhenotypeTestGenerator, +) + + +class MultiIndexTimeRangeCountPhenotypeTestGenerator( + MultiIndexMixin, TimeRangeCountPhenotypeTestGenerator +): + name_space = "mi_trcpt" + _index_date = datetime.date(2020, 5, 15) + + def define_input_tables(self): + tables = TimeRangeCountPhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangeCountPhenotypeTestGenerator.define_phenotype_tests(self) + idx1 = self._index_date + idx2 = idx1 + self._shift + + # At shifted index (2020-08-13), p1-p4 are all "before", only p5 is "after" (19 days after). + shifted = { + "count_all_visits": {"persons": ["P1", "P2"], "values": [5, 1]}, + "count_visits_after_index": {"persons": ["P1", "P2"], "values": [1, 0]}, + "count_visits_before_index": {"persons": ["P1", "P2"], "values": [4, 1]}, + "max_days_set": {"persons": ["P1", "P2"], "values": [1, 0]}, + "min_days_set": {"persons": ["P1", "P2"], "values": [0, 0]}, + "min_days_set_with_value_filter": {"persons": [], "values": []}, + "value_filter": {"persons": ["P1"], "values": [5]}, + "date_range_filter": {"persons": ["P1", "P2"], "values": [3, 0]}, + } + + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + s = shifted[test["name"]] + test["persons"] = orig_p + s["persons"] + test["values"] = orig_v + s["values"] + test["index_dates"] = [idx1] * len(orig_p) + [idx2] * len(s["persons"]) + + return tests + + +def test_multiindex_time_range_count_phenotype(): + tg = MultiIndexTimeRangeCountPhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_time_range_day_count_phenotype.py b/phenex/test/phenotypes/multi_index/test_time_range_day_count_phenotype.py new file mode 100644 index 00000000..70b46abb --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_time_range_day_count_phenotype.py @@ -0,0 +1,57 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_time_range_day_count_phenotype import ( + TimeRangeDayCountPhenotypeTestGenerator, +) + + +class MultiIndexTimeRangeDayCountPhenotypeTestGenerator( + MultiIndexMixin, TimeRangeDayCountPhenotypeTestGenerator +): + name_space = "mi_trdcpt" + _index_date = datetime.date(2020, 5, 15) + + def define_input_tables(self): + tables = TimeRangeDayCountPhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangeDayCountPhenotypeTestGenerator.define_phenotype_tests(self) + idx1 = self._index_date + idx2 = idx1 + self._shift + + # At shifted index (2020-08-13), p1-p4 are entirely before shifted, + # only p5 (Sep 1-30) is after (starts 19 days after shifted). + # Day-level clipping applies for min/max_days constraints. + shifted = { + "count_all_days": {"persons": ["P1", "P2"], "values": [150, 30]}, + "count_days_before_index": {"persons": ["P1", "P2"], "values": [120, 30]}, + "count_days_after_index": {"persons": ["P1", "P2"], "values": [30, 0]}, + "count_days_after_min30": {"persons": ["P1", "P2"], "values": [19, 0]}, + "count_days_after_max90": {"persons": ["P1", "P2"], "values": [30, 0]}, + "count_days_after_30to90": {"persons": ["P1", "P2"], "values": [19, 0]}, + "count_days_min100": {"persons": ["P1"], "values": [150]}, + "count_days_before_min30": {"persons": ["P1", "P2"], "values": [104, 30]}, + "date_range_max_end_date": {"persons": ["P1", "P2"], "values": [45, 30]}, + "date_range_min_start_date": {"persons": ["P1", "P2"], "values": [46, 0]}, + "date_range_combined_start_and_end": { + "persons": ["P1", "P2"], + "values": [61, 0], + }, + } + + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + s = shifted[test["name"]] + test["persons"] = orig_p + s["persons"] + test["values"] = orig_v + s["values"] + test["index_dates"] = [idx1] * len(orig_p) + [idx2] * len(s["persons"]) + + return tests + + +def test_multiindex_time_range_day_count_phenotype(): + tg = MultiIndexTimeRangeDayCountPhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_time_range_days_to_next_range_phenotype.py b/phenex/test/phenotypes/multi_index/test_time_range_days_to_next_range_phenotype.py new file mode 100644 index 00000000..d61d8e79 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_time_range_days_to_next_range_phenotype.py @@ -0,0 +1,44 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_time_range_days_to_next_range_phenotype import ( + TimeRangeDaysToNextRangePhenotypeTestGenerator, +) + + +class MultiIndexTimeRangeDaysToNextRangePhenotypeTestGenerator( + MultiIndexMixin, TimeRangeDaysToNextRangePhenotypeTestGenerator +): + name_space = "mi_trdtnrp" + _index_date = datetime.date(2022, 1, 15) + + def define_input_tables(self): + tables = TimeRangeDaysToNextRangePhenotypeTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangeDaysToNextRangePhenotypeTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + + # At shifted index (2022-04-15), no visit ranges contain the shifted + # index date, so no anchor range is found and all tests produce empty + # results at the shifted index. + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + orig_d = list(test["dates"]) + test["persons"] = orig_p + test["values"] = orig_v + test["dates"] = orig_d + test["index_dates"] = [idx1] * len(orig_p) + + return tests + + +def test_multiindex_time_range_days_to_next_range_phenotype(): + tg = MultiIndexTimeRangeDaysToNextRangePhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_time_range_phenotype.py b/phenex/test/phenotypes/multi_index/test_time_range_phenotype.py new file mode 100644 index 00000000..650e03e1 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_time_range_phenotype.py @@ -0,0 +1,271 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_time_range_phenotype import ( + TimeRangePhenotypeTestGenerator, + ContinuousCoverageReturnLastPhenotypeTestGenerator, + TimeRangePhenotypeWithDateRangeBeforeAllExcludedTestGenerator, + TimeRangePhenotypeWithDateRangeBeforeReducedDaysTestGenerator, + TimeRangePhenotypeWithDateRangeAfterAllExcludedTestGenerator, + TimeRangePhenotypeWithDateRangeAfterReducedDaysTestGenerator, +) + + +class MultiIndexTimeRangePhenotypeTestGenerator( + MultiIndexMixin, TimeRangePhenotypeTestGenerator +): + name_space = "mi_ccpt" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = TimeRangePhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangePhenotypeTestGenerator.define_phenotype_tests(self) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # At shifted index (2022-04-01), different observation periods contain + # the index and have different coverage days before it. + shifted = { + "coverage_min_geq_90": { + "persons": ["P15", "P19", "P20", "P22", "P23"], + "values": [180, 179, 90, 90, 90], + }, + "coverage_min_gt_90": { + "persons": ["P15", "P19"], + "values": [180, 179], + }, + } + + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + s = shifted[test["name"]] + test["persons"] = orig_p + s["persons"] + test["values"] = orig_v + s["values"] + test["index_dates"] = [idx1] * len(orig_p) + [idx2] * len(s["persons"]) + + return tests + + +class MultiIndexContinuousCoverageReturnLastTestGenerator( + MultiIndexMixin, ContinuousCoverageReturnLastPhenotypeTestGenerator +): + name_space = "mi_ccpt_returnlast" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = ContinuousCoverageReturnLastPhenotypeTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = ( + ContinuousCoverageReturnLastPhenotypeTestGenerator.define_phenotype_tests( + self + ) + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + def end_date_for(person_id): + return self.df_input[self.df_input["PERSON_ID"] == person_id][ + "END_DATE" + ].values[0] + + # At shifted index (2022-04-01), only periods extending ≥90 days + # after Apr 1 qualify. + shifted = { + "coverage_min_geq_90": { + "persons": ["P23", "P27"], + "values": [90, 91], + "dates": [end_date_for("P23"), end_date_for("P27")], + }, + "coverage_min_gt_90": { + "persons": ["P27"], + "values": [91], + "dates": [end_date_for("P27")], + }, + } + + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + orig_d = list(test["dates"]) + s = shifted[test["name"]] + test["persons"] = orig_p + s["persons"] + test["values"] = orig_v + s["values"] + test["dates"] = orig_d + s["dates"] + test["index_dates"] = [idx1] * len(orig_p) + [idx2] * len(s["persons"]) + + return tests + + +class MultiIndexTimeRangeBeforeAllExcludedTestGenerator( + MultiIndexMixin, TimeRangePhenotypeWithDateRangeBeforeAllExcludedTestGenerator +): + name_space = "mi_ccpt_daterange_before_all_excluded" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = TimeRangePhenotypeWithDateRangeBeforeAllExcludedTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangePhenotypeWithDateRangeBeforeAllExcludedTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # At shifted index (2022-04-01), min_date clips start to 2022-01-02. + # Periods containing Apr 1 with clipped start ≤ Apr 1 all qualify. + shifted_persons = [ + "P15", + "P19", + "P20", + "P22", + "P23", + "P24", + "P25", + "P26", + "P27", + ] + + for test in tests: + orig_p = list(test["persons"]) + n = len(orig_p) + test["persons"] = orig_p + shifted_persons + test["index_dates"] = [idx1] * n + [idx2] * len(shifted_persons) + + return tests + + +class MultiIndexTimeRangeBeforeReducedDaysTestGenerator( + MultiIndexMixin, TimeRangePhenotypeWithDateRangeBeforeReducedDaysTestGenerator +): + name_space = "mi_ccpt_daterange_before_reduced_days" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = TimeRangePhenotypeWithDateRangeBeforeReducedDaysTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangePhenotypeWithDateRangeBeforeReducedDaysTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # At shifted index (2022-04-01), min_date clips start to max(orig, 2021-11-01). + # P15/P19 clipped to Nov 1 → 151 days. P20-P23 start Jan 1 → 90 days. + # P24-P27 start Jan 2 → 89 days. + shifted_persons = [ + "P15", + "P19", + "P20", + "P22", + "P23", + "P24", + "P25", + "P26", + "P27", + ] + shifted_values = [151, 151, 90, 90, 90, 89, 89, 89, 89] + + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + n = len(orig_p) + test["persons"] = orig_p + shifted_persons + test["values"] = orig_v + shifted_values + test["index_dates"] = [idx1] * n + [idx2] * len(shifted_persons) + + return tests + + +class MultiIndexTimeRangeAfterAllExcludedTestGenerator( + MultiIndexMixin, TimeRangePhenotypeWithDateRangeAfterAllExcludedTestGenerator +): + name_space = "mi_ccpt_daterange_after_all_excluded" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = TimeRangePhenotypeWithDateRangeAfterAllExcludedTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangePhenotypeWithDateRangeAfterAllExcludedTestGenerator.define_phenotype_tests( + self + ) + return self._duplicate_expected(tests, self._index_date) + + +class MultiIndexTimeRangeAfterReducedDaysTestGenerator( + MultiIndexMixin, TimeRangePhenotypeWithDateRangeAfterReducedDaysTestGenerator +): + name_space = "mi_ccpt_daterange_after_reduced_days" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = TimeRangePhenotypeWithDateRangeAfterReducedDaysTestGenerator.define_input_tables( + self + ) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = TimeRangePhenotypeWithDateRangeAfterReducedDaysTestGenerator.define_phenotype_tests( + self + ) + idx1 = self._index_date + idx2 = self._index_date + self.shift + + # At shifted index (2022-04-01), max_date clips end to 2022-02-15. + # Clipped end < shifted index → no periods contain shifted index → empty. + for test in tests: + orig_p = list(test["persons"]) + orig_v = list(test["values"]) + n = len(orig_p) + test["index_dates"] = [idx1] * n + + return tests + + +def test_multiindex_time_range_phenotype(): + tg = MultiIndexTimeRangePhenotypeTestGenerator() + tg.run_tests() + + +def test_multiindex_continuous_coverage_return_last(): + tg = MultiIndexContinuousCoverageReturnLastTestGenerator() + tg.run_tests() + + +def test_multiindex_time_range_before_all_excluded(): + tg = MultiIndexTimeRangeBeforeAllExcludedTestGenerator() + tg.run_tests() + + +def test_multiindex_time_range_before_reduced_days(): + tg = MultiIndexTimeRangeBeforeReducedDaysTestGenerator() + tg.run_tests() + + +def test_multiindex_time_range_after_all_excluded(): + tg = MultiIndexTimeRangeAfterAllExcludedTestGenerator() + tg.run_tests() + + +def test_multiindex_time_range_after_reduced_days(): + tg = MultiIndexTimeRangeAfterReducedDaysTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_user_defined_phenotype.py b/phenex/test/phenotypes/multi_index/test_user_defined_phenotype.py new file mode 100644 index 00000000..d519565f --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_user_defined_phenotype.py @@ -0,0 +1,26 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_user_defined_phenotype import ( + UserDefinedPhenotypeTestGenerator, +) + + +class MultiIndexUserDefinedPhenotypeTestGenerator( + MultiIndexMixin, UserDefinedPhenotypeTestGenerator +): + name_space = "mi_udpt" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = UserDefinedPhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = UserDefinedPhenotypeTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_user_defined_phenotype(): + tg = MultiIndexUserDefinedPhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index/test_within_same_encounter_phenotype.py b/phenex/test/phenotypes/multi_index/test_within_same_encounter_phenotype.py new file mode 100644 index 00000000..b3eca671 --- /dev/null +++ b/phenex/test/phenotypes/multi_index/test_within_same_encounter_phenotype.py @@ -0,0 +1,26 @@ +import datetime + +from phenex.test.phenotypes.multi_index_mixin import MultiIndexMixin +from phenex.test.phenotypes.test_within_same_encounter_phenotype import ( + WithinSameEncounterPhenotypeTestGenerator, +) + + +class MultiIndexWithinSameEncounterPhenotypeTestGenerator( + MultiIndexMixin, WithinSameEncounterPhenotypeTestGenerator +): + name_space = "mi_wsept" + _index_date = datetime.date(2022, 1, 1) + + def define_input_tables(self): + tables = WithinSameEncounterPhenotypeTestGenerator.define_input_tables(self) + return self._duplicate_input_tables(tables) + + def define_phenotype_tests(self): + tests = WithinSameEncounterPhenotypeTestGenerator.define_phenotype_tests(self) + return self._duplicate_expected(tests, self._index_date) + + +def test_multiindex_within_same_encounter_phenotype(): + tg = MultiIndexWithinSameEncounterPhenotypeTestGenerator() + tg.run_tests() diff --git a/phenex/test/phenotypes/multi_index_mixin.py b/phenex/test/phenotypes/multi_index_mixin.py new file mode 100644 index 00000000..6fe4dcc6 --- /dev/null +++ b/phenex/test/phenotypes/multi_index_mixin.py @@ -0,0 +1,131 @@ +""" +MultiIndexMixin – duplicates test input data with a second INDEX_DATE +and adapts the test runner to verify results per (PERSON_ID, INDEX_DATE). + +Each test class can set ``_shift`` to control the INDEX_DATE offset. +The module-level ``SHIFT`` (90 days) is the default. +""" + +import datetime +import os + +import pandas as pd +import ibis + +from phenex.test.util.check_equality import check_equality + + +SHIFT = datetime.timedelta(days=90) + + +class MultiIndexMixin: + """Mixin that duplicates input data with a second INDEX_DATE and adapts + the test runner to verify results per (PERSON_ID, INDEX_DATE). + + Set ``_shift`` on the subclass to override the default 90-day shift. + """ + + _shift = SHIFT # subclasses may override + + @property + def shift(self): + return self._shift + + def _duplicate_input_tables(self, input_tables): + """Duplicate all rows with a second INDEX_DATE shifted by self.shift. + + Tables without INDEX_DATE get the column added (using self._index_date) + before duplication so that every domain table carries the index. + """ + result = [] + for info in input_tables: + df = info["df"].copy() + if "INDEX_DATE" not in df.columns: + df["INDEX_DATE"] = self._index_date + df2 = df.copy() + df2["INDEX_DATE"] = pd.to_datetime(df2["INDEX_DATE"]) + self.shift + combined = pd.concat([df, df2], ignore_index=True) + result.append({**info, "df": combined}) + return result + + def _duplicate_expected(self, test_infos, index_date): + """Duplicate expected persons / dates / values for the second INDEX_DATE.""" + for info in test_infos: + n = len(info["persons"]) + info["index_dates"] = [index_date] * n + [index_date + self.shift] * n + info["persons"] = list(info["persons"]) * 2 + if info.get("dates") is not None: + info["dates"] = list(info["dates"]) * 2 + if info.get("values") is not None: + info["values"] = list(info["values"]) * 2 + return test_infos + + def _run_tests(self): + """Override: includes INDEX_DATE in expected output and join predicates.""" + + def df_from_test_info(test_info): + df = pd.DataFrame() + df["PERSON_ID"] = test_info["persons"] + df["INDEX_DATE"] = test_info["index_dates"] + df["boolean"] = True + if test_info.get("dates") is not None: + df["EVENT_DATE"] = test_info["dates"] + else: + df["EVENT_DATE"] = None + if test_info.get("values") is not None: + df["VALUE"] = test_info["values"] + else: + df["VALUE"] = None + return df + + self.test_infos = self.define_phenotype_tests() + + for test_info in self.test_infos: + df = df_from_test_info(test_info) + filename = self.name_output_file(test_info) + ".csv" + path = os.path.join(self.dirpaths["expected"], filename) + df.sort_values(by=["PERSON_ID", "INDEX_DATE"]).to_csv( + path, index=False, date_format=self.date_format + ) + + result_table = test_info["phenotype"].execute(self.domains) + + if self.verbose: + ibis.options.interactive = True + print(f"Running test: {test_info['name']}") + print(f"Expected:\n{df}") + print(f"Result:\n{result_table.to_pandas()}") + + path = os.path.join(self.dirpaths["result"], filename) + result_table.to_pandas().sort_values(by=["PERSON_ID", "INDEX_DATE"]).to_csv( + path, index=False, date_format=self.date_format + ) + + schema = {} + for col in df.columns: + if "date" in col.lower(): + schema[col] = datetime.date + elif "value" in col.lower(): + schema[col] = self.value_datatype + elif "boolean" in col.lower(): + schema[col] = bool + else: + schema[col] = str + + expected_output_table = self.con.create_table( + self.name_output_file(test_info), df, schema=schema + ) + + join_on = ["PERSON_ID", "INDEX_DATE"] + if self.test_values: + join_on.append("VALUE") + if self.test_date: + join_on.append("EVENT_DATE") + check_equality( + result_table, + expected_output_table, + test_name=test_info["name"], + test_values=self.test_values, + test_date=self.test_date, + join_on=join_on, + ) diff --git a/phenex/test/phenotypes/test_arithmetic_phenotype.py b/phenex/test/phenotypes/test_arithmetic_phenotype.py index 76486b85..e35b4da5 100644 --- a/phenex/test/phenotypes/test_arithmetic_phenotype.py +++ b/phenex/test/phenotypes/test_arithmetic_phenotype.py @@ -50,9 +50,12 @@ def define_input_tables(self): event_date_columnname="EVENT_DATE", ) df["VALUE"] = range(df.shape[0]) + index_date = datetime.date(2022, 1, 1) + df["INDEX_DATE"] = index_date df_person = pd.DataFrame() df_person["PERSON_ID"] = list(df["PERSON_ID"].unique()) + df_person["INDEX_DATE"] = index_date return [ {"name": "measurement", "df": df}, @@ -185,9 +188,13 @@ def define_input_tables(self): ["c1"] * n_p1_c1 + ["c2"] * n_p1_c2 + ["c1"] * n_p2_c1 + ["c2"] * n_p2_c2 ) df["CODE_TYPE"] = ["ICD10CM"] * df.shape[0] + index_date = datetime.date(2022, 1, 1) + df["EVENT_DATE"] = index_date + df["INDEX_DATE"] = index_date df_person = pd.DataFrame() df_person["PERSON_ID"] = list(df["PERSON_ID"].unique()) + df_person["INDEX_DATE"] = index_date return [ {"name": "condition_occurrence", "df": df}, @@ -247,9 +254,12 @@ def define_input_tables(self): event_date_columnname="EVENT_DATE", ) df["VALUE"] = range(df.shape[0]) + index_date = datetime.date(2022, 1, 1) + df["INDEX_DATE"] = index_date df_person = pd.DataFrame() df_person["PERSON_ID"] = list(df["PERSON_ID"].unique()) + df_person["INDEX_DATE"] = index_date return [ {"name": "measurement", "df": df}, diff --git a/phenex/test/phenotypes/test_further_value_filter_phenotype.py b/phenex/test/phenotypes/test_further_value_filter_phenotype.py new file mode 100644 index 00000000..4fc8b19d --- /dev/null +++ b/phenex/test/phenotypes/test_further_value_filter_phenotype.py @@ -0,0 +1,444 @@ +import datetime, os +import pandas as pd +import copy + +from phenex.filters.value import ( + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, +) +from phenex.phenotypes.measurement_phenotype import MeasurementPhenotype +from phenex.phenotypes.further_value_filter_phenotype import FurtherValueFilterPhenotype +from phenex.codelists import LocalCSVCodelistFactory +from phenex.filters.value_filter import ValueFilter +from phenex.filters.date_filter import DateFilter, After, Before +from phenex.aggregators import * +from phenex.filters.relative_time_range_filter import RelativeTimeRangeFilter +from phenex.test.phenotype_test_generator import PhenotypeTestGenerator + + +class FurtherValueFilterBasicTestGenerator(PhenotypeTestGenerator): + """Test basic value filtering on the output of a MeasurementPhenotype.""" + + name_space = "fvf_basic" + test_values = True + + def define_input_tables(self): + df = pd.DataFrame() + N = 10 + df["VALUE"] = list(range(N)) + df["PERSON_ID"] = [f"P{x}" for x in range(N)] + df["CODE"] = "c1" + df["CODE_TYPE"] = "ICD10CM" + df["EVENT_DATE"] = None + + return [ + { + "name": "MEASUREMENT", + "df": df, + } + ] + + def define_phenotype_tests(self): + codelist_factory = LocalCSVCodelistFactory( + path=os.path.join(os.path.dirname(__file__), "../util/dummy/codelists.csv") + ) + source_phenotype = MeasurementPhenotype( + name="leq9", + codelist=codelist_factory.get_codelist("c1"), + domain="MEASUREMENT", + value_filter=ValueFilter(max_value=LessThanOrEqualTo(9)), + ) + + c1 = { + "name": "further_filter_l2", + "persons": [f"P{x}" for x in range(2)], + "values": [x for x in range(2)], + "phenotype": FurtherValueFilterPhenotype( + name="further_filter_l2", + phenotype=source_phenotype, + value_filter=ValueFilter(max_value=LessThan(2)), + ), + } + + c2 = { + "name": "further_filter_geq3_leq6", + "persons": [f"P{x}" for x in range(3, 7)], + "values": [x for x in range(3, 7)], + "phenotype": FurtherValueFilterPhenotype( + name="further_filter_geq3_leq6", + phenotype=source_phenotype, + value_filter=ValueFilter( + min_value=GreaterThanOrEqualTo(3), + max_value=LessThanOrEqualTo(6), + ), + ), + } + + test_infos = [c1, c2] + return test_infos + + +class FurtherValueFilterAggregationTestGenerator(PhenotypeTestGenerator): + """Test value aggregation on the output of a MeasurementPhenotype.""" + + name_space = "fvf_aggregation" + test_values = True + + def define_input_tables(self): + def create_copy_with_changes(df, lab, date): + _df = copy.deepcopy(df) + _df["VALUE"] = lab + _df["EVENT_DATE"] = date + return _df + + d1 = datetime.date(2022, 1, 1) + d2 = datetime.date(2022, 1, 2) + + df_base = pd.DataFrame() + self.N = 3 + df_base["PERSON_ID"] = [f"P{x}" for x in range(self.N)] + df_base["CODE"] = "c1" + df_base["CODE_TYPE"] = "ICD10CM" + df_base["VALUE"] = 2 + df_base["EVENT_DATE"] = d1 + + # d1: values 2, 4 d2: values 6, 8 + df_d1_2 = create_copy_with_changes(df_base, 4, d1) + df_d2_1 = create_copy_with_changes(df_base, 6, d2) + df_d2_2 = create_copy_with_changes(df_base, 8, d2) + + df_final = pd.concat([df_base, df_d1_2, df_d2_1, df_d2_2]) + df_final["INDEX_DATE"] = datetime.date(2022, 1, 3) + + return [ + { + "name": "MEASUREMENT", + "df": df_final, + } + ] + + def define_phenotype_tests(self): + codelist_factory = LocalCSVCodelistFactory( + path=os.path.join(os.path.dirname(__file__), "../util/dummy/codelists.csv") + ) + source_phenotype = MeasurementPhenotype( + name="all_values", + codelist=codelist_factory.get_codelist("c1"), + domain="MEASUREMENT", + ) + + # Mean of all values: (2+4+6+8)/4 = 5.0 + c_mean = { + "name": "further_mean", + "persons": [f"P{x}" for x in range(self.N)], + "values": [5.0] * self.N, + "phenotype": FurtherValueFilterPhenotype( + name="further_mean", + phenotype=source_phenotype, + value_aggregation=Mean(), + ), + } + + # Max of all values: 8 + c_max = { + "name": "further_max", + "persons": [f"P{x}" for x in range(self.N)], + "values": [8] * self.N, + "phenotype": FurtherValueFilterPhenotype( + name="further_max", + phenotype=source_phenotype, + value_aggregation=Max(), + ), + } + + # Min of all values: 2 + c_min = { + "name": "further_min", + "persons": [f"P{x}" for x in range(self.N)], + "values": [2] * self.N, + "phenotype": FurtherValueFilterPhenotype( + name="further_min", + phenotype=source_phenotype, + value_aggregation=Min(), + ), + } + + # DailyMean then filter > 5: d1 mean=3, d2 mean=7 → only d2 mean=7 passes + c_daily_mean_then_filter = { + "name": "further_daily_mean_gt5", + "persons": [f"P{x}" for x in range(self.N)], + "values": [7.0] * self.N, + "phenotype": FurtherValueFilterPhenotype( + name="further_daily_mean_gt5", + phenotype=source_phenotype, + value_aggregation=DailyMean(), + value_filter=ValueFilter(min_value=GreaterThan(5)), + ), + } + + test_infos = [c_mean, c_max, c_min, c_daily_mean_then_filter] + return test_infos + + +class FurtherValueFilterDateRangeTestGenerator(PhenotypeTestGenerator): + """Test date_range filtering on the output of a MeasurementPhenotype.""" + + name_space = "fvf_daterange" + test_values = True + + def define_input_tables(self): + d1 = datetime.date(2022, 1, 1) + d2 = datetime.date(2022, 6, 1) + d3 = datetime.date(2022, 12, 1) + + df = pd.DataFrame() + N = 3 + df["PERSON_ID"] = [f"P{x}" for x in range(N)] + df["CODE"] = "c1" + df["CODE_TYPE"] = "ICD10CM" + df["VALUE"] = [10, 20, 30] + df["EVENT_DATE"] = [d1, d2, d3] + + return [ + { + "name": "MEASUREMENT", + "df": df, + } + ] + + def define_phenotype_tests(self): + codelist_factory = LocalCSVCodelistFactory( + path=os.path.join(os.path.dirname(__file__), "../util/dummy/codelists.csv") + ) + source_phenotype = MeasurementPhenotype( + name="all_measurements", + codelist=codelist_factory.get_codelist("c1"), + domain="MEASUREMENT", + ) + + # Filter to only events after 2022-03-01 → P1 (June) and P2 (Dec) + c1 = { + "name": "after_march", + "persons": ["P1", "P2"], + "values": [20, 30], + "phenotype": FurtherValueFilterPhenotype( + name="after_march", + phenotype=source_phenotype, + date_range=DateFilter(min_date=After("2022-03-01")), + ), + } + + # Filter to only events before 2022-07-01 → P0 (Jan) and P1 (June) + c2 = { + "name": "before_july", + "persons": ["P0", "P1"], + "values": [10, 20], + "phenotype": FurtherValueFilterPhenotype( + name="before_july", + phenotype=source_phenotype, + date_range=DateFilter(max_date=Before("2022-07-01")), + ), + } + + # Date range + value filter: after March AND value > 25 → only P2 (Dec, value=30) + c3 = { + "name": "after_march_gt25", + "persons": ["P2"], + "values": [30], + "phenotype": FurtherValueFilterPhenotype( + name="after_march_gt25", + phenotype=source_phenotype, + date_range=DateFilter(min_date=After("2022-03-01")), + value_filter=ValueFilter(min_value=GreaterThan(25)), + ), + } + + test_infos = [c1, c2, c3] + return test_infos + + +class FurtherValueFilterRelativeTimeRangeTestGenerator(PhenotypeTestGenerator): + """Test relative_time_range filtering on the output of a MeasurementPhenotype.""" + + name_space = "fvf_relativetimerange" + test_values = True + + def define_input_tables(self): + d_pre = datetime.date(2022, 1, 1) + d_post = datetime.date(2022, 12, 1) + + df = pd.DataFrame() + N = 4 + # Each person has a pre and post measurement + df["PERSON_ID"] = [f"P{x}" for x in range(N)] + [f"P{x}" for x in range(N)] + df["CODE"] = "c1" + df["CODE_TYPE"] = "ICD10CM" + df["VALUE"] = list(range(1, N + 1)) + list(range(11, N + 11)) + df["EVENT_DATE"] = [d_pre] * N + [d_post] * N + df["INDEX_DATE"] = datetime.date(2022, 6, 1) + + return [ + { + "name": "MEASUREMENT", + "df": df, + } + ] + + def define_phenotype_tests(self): + codelist_factory = LocalCSVCodelistFactory( + path=os.path.join(os.path.dirname(__file__), "../util/dummy/codelists.csv") + ) + source_phenotype = MeasurementPhenotype( + name="all_values", + codelist=codelist_factory.get_codelist("c1"), + domain="MEASUREMENT", + ) + + # Only post-index values (after index date) + c1 = { + "name": "post_index_only", + "persons": [f"P{x}" for x in range(4)], + "values": [11, 12, 13, 14], + "phenotype": FurtherValueFilterPhenotype( + name="post_index_only", + phenotype=source_phenotype, + relative_time_range=RelativeTimeRangeFilter( + min_days=GreaterThanOrEqualTo(0), when="after" + ), + ), + } + + # Only pre-index values (before index date) + c2 = { + "name": "pre_index_only", + "persons": [f"P{x}" for x in range(4)], + "values": [1, 2, 3, 4], + "phenotype": FurtherValueFilterPhenotype( + name="pre_index_only", + phenotype=source_phenotype, + relative_time_range=RelativeTimeRangeFilter( + min_days=GreaterThanOrEqualTo(0), when="before" + ), + ), + } + + # Post-index + value filter > 12 + c3 = { + "name": "post_index_gt12", + "persons": ["P2", "P3"], + "values": [13, 14], + "phenotype": FurtherValueFilterPhenotype( + name="post_index_gt12", + phenotype=source_phenotype, + relative_time_range=RelativeTimeRangeFilter( + min_days=GreaterThanOrEqualTo(0), when="after" + ), + value_filter=ValueFilter(min_value=GreaterThan(12)), + ), + } + + test_infos = [c1, c2, c3] + return test_infos + + +class FurtherValueFilterReturnDateTestGenerator(PhenotypeTestGenerator): + """Test return_date selection on the output of a MeasurementPhenotype.""" + + name_space = "fvf_returndate" + test_values = True + test_date = True + + def define_input_tables(self): + d1 = datetime.date(2022, 1, 1) + d2 = datetime.date(2022, 6, 1) + d3 = datetime.date(2022, 12, 1) + + df = pd.DataFrame() + N = 3 + # Each person has 3 measurements on 3 different dates + df["PERSON_ID"] = [f"P{x}" for x in range(N)] * 3 + df["CODE"] = "c1" + df["CODE_TYPE"] = "ICD10CM" + df["VALUE"] = [10, 20, 30] + [40, 50, 60] + [70, 80, 90] + df["EVENT_DATE"] = [d1] * N + [d2] * N + [d3] * N + + return [ + { + "name": "MEASUREMENT", + "df": df, + } + ] + + def define_phenotype_tests(self): + codelist_factory = LocalCSVCodelistFactory( + path=os.path.join(os.path.dirname(__file__), "../util/dummy/codelists.csv") + ) + source_phenotype = MeasurementPhenotype( + name="all_values", + codelist=codelist_factory.get_codelist("c1"), + domain="MEASUREMENT", + ) + + # return_date='first' → earliest date for each person + c_first = { + "name": "return_first", + "persons": [f"P{x}" for x in range(3)], + "dates": [datetime.date(2022, 1, 1)] * 3, + "values": [10, 20, 30], + "phenotype": FurtherValueFilterPhenotype( + name="return_first", + phenotype=source_phenotype, + return_date="first", + ), + } + + # return_date='last' → latest date for each person + c_last = { + "name": "return_last", + "persons": [f"P{x}" for x in range(3)], + "dates": [datetime.date(2022, 12, 1)] * 3, + "values": [70, 80, 90], + "phenotype": FurtherValueFilterPhenotype( + name="return_last", + phenotype=source_phenotype, + return_date="last", + ), + } + + test_infos = [c_first, c_last] + return test_infos + + +def test_further_value_filter_basic(): + spg = FurtherValueFilterBasicTestGenerator() + spg.run_tests() + + +def test_further_value_filter_aggregation(): + spg = FurtherValueFilterAggregationTestGenerator() + spg.run_tests() + + +def test_further_value_filter_date_range(): + spg = FurtherValueFilterDateRangeTestGenerator() + spg.run_tests() + + +def test_further_value_filter_relative_time_range(): + spg = FurtherValueFilterRelativeTimeRangeTestGenerator() + spg.run_tests() + + +def test_further_value_filter_return_date(): + spg = FurtherValueFilterReturnDateTestGenerator() + spg.run_tests() + + +if __name__ == "__main__": + test_further_value_filter_basic() + test_further_value_filter_aggregation() + test_further_value_filter_date_range() + test_further_value_filter_relative_time_range() + test_further_value_filter_return_date() diff --git a/phenex/test/phenotypes/test_measurement_phenotype.py b/phenex/test/phenotypes/test_measurement_phenotype.py index 9d7e0394..84ba0852 100644 --- a/phenex/test/phenotypes/test_measurement_phenotype.py +++ b/phenex/test/phenotypes/test_measurement_phenotype.py @@ -875,57 +875,6 @@ def define_phenotype_tests(self): return test_infos -class MeasurementPhenotypeFurtherFilterTestGenerator(PhenotypeTestGenerator): - name_space = "mmpt_furtherfilter" - test_values = True - - def define_input_tables(self): - df = pd.DataFrame() - N = 10 - df["VALUE"] = list(range(N)) - df["PERSON_ID"] = [f"P{x}" for x in range(N)] - df["CODE"] = "c1" - df["CODE_TYPE"] = "ICD10CM" - df["EVENT_DATE"] = None - df["flag"] = ["inpatient"] * 5 + [""] * (10 - 5) - - return [ - { - "name": "MEASUREMENT", - "df": df, - } - ] - - def define_phenotype_tests(self): - codelist_factory = LocalCSVCodelistFactory( - path=os.path.join(os.path.dirname(__file__), "../util/dummy/codelists.csv") - ) - phenotype_to_filter_further = MeasurementPhenotype( - name="leq9", - codelist=codelist_factory.get_codelist("c1"), - domain="MEASUREMENT", - value_filter=ValueFilter(value=9, operator="<="), - ) - - c2 = { - "name": "further_filter_l2", - "persons": [f"P{x}" for x in range(2)], - "values": [x for x in range(2)], - "phenotype": MeasurementPhenotype( - name="further_filter_l2", - value_filter=ValueFilter(value=2, operator="<"), - further_value_filter_phenotype=phenotype_to_filter_further, - ), - } - - test_infos = [c2] - for test_info in test_infos: - test_info["refactor"] = True # TODO remove once refactored - test_info["extra_tests"] = ["unique"] - - return test_infos - - def test_measurement_phenotype(): spg = MeasurementPhenotypeValueFilterTestGenerator() spg.run_tests() diff --git a/phenex/test/phenotypes/test_score_phenotype.py b/phenex/test/phenotypes/test_score_phenotype.py index b81192ca..6b9af5ae 100644 --- a/phenex/test/phenotypes/test_score_phenotype.py +++ b/phenex/test/phenotypes/test_score_phenotype.py @@ -26,8 +26,11 @@ def define_input_tables(self): event_date_columnname="EVENT_DATE", ) + index_date = datetime.date(2022, 1, 1) + df["INDEX_DATE"] = index_date df_person = pd.DataFrame() df_person["PERSON_ID"] = df["PERSON_ID"].unique() + df_person["INDEX_DATE"] = index_date return [ { "name": "CONDITION_OCCURRENCE", @@ -143,8 +146,11 @@ def define_input_tables(self): event_date_columnname="EVENT_DATE", ) + index_date = datetime.date(2022, 1, 1) + df["INDEX_DATE"] = index_date df_person = pd.DataFrame() df_person["PERSON_ID"] = df["PERSON_ID"].unique() + df_person["INDEX_DATE"] = index_date return [ { "name": "CONDITION_OCCURRENCE", diff --git a/phenex/test/phenotypes/test_within_same_encounter_phenotype.py b/phenex/test/phenotypes/test_within_same_encounter_phenotype.py index 53ca48a4..31aa2b2e 100644 --- a/phenex/test/phenotypes/test_within_same_encounter_phenotype.py +++ b/phenex/test/phenotypes/test_within_same_encounter_phenotype.py @@ -40,6 +40,7 @@ def define_input_tables(self): None, "v1", ] + df_proc["INDEX_DATE"] = datetime.date(2022, 1, 1) df_proc["EVENT_DATE"] = [ index_date - one_day, index_date - one_day,