diff --git a/crates/alertd/src/alert.rs b/crates/alertd/src/alert.rs index bdc3986b..5146cebf 100644 --- a/crates/alertd/src/alert.rs +++ b/crates/alertd/src/alert.rs @@ -222,21 +222,40 @@ impl AlertDefinition { return Ok(ControlFlow::Break(())); } TicketSource::Sql { sql, numerical } => { - let client = pool + // Hard cap on rows pulled from the database, bounded at the portal so + // excess rows never reach memory. Numerical thresholds and the `rows` + // template context therefore operate on at most ROW_LIMIT rows. + const ROW_LIMIT: i32 = 100; + + let mut client = pool .get() .await .map_err(|e| miette!("getting connection from pool: {e}"))?; - let statement = client.prepare(sql).await.into_diagnostic()?; + let transaction = client.transaction().await.into_diagnostic()?; + let statement = transaction.prepare(sql).await.into_diagnostic()?; let interval = bestool_postgres::pg_interval::Interval(self.interval_duration); let all_params: Vec<&(dyn ToSql + Sync)> = vec![¬_before, &interval]; - let rows = client - .query(&statement, &all_params[..statement.params().len()]) + let portal = transaction + .bind(&statement, &all_params[..statement.params().len()]) + .await + .into_diagnostic() + .wrap_err("binding query")?; + let rows = transaction + .query_portal(&portal, ROW_LIMIT) .await .into_diagnostic() .wrap_err("querying database")?; + if rows.len() == ROW_LIMIT as usize { + warn!( + ?self.file, + limit = ROW_LIMIT, + "alert SQL result capped at the row limit; excess rows were not loaded" + ); + } + if rows.is_empty() { debug!(?self.file, "no rows returned, skipping"); return Ok(ControlFlow::Break(())); @@ -885,4 +904,32 @@ send: // Match is case-sensitive. assert!(!server_kind_matches(Some("Central"), Some("central"))); } + + #[tokio::test] + async fn sql_source_caps_rows_at_the_limit() { + let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set for tests"); + let pool = bestool_postgres::pool::create_pool(&db_url, "bestool-alertd-test") + .await + .unwrap(); + + let alert = AlertDefinition { + file: "row-limit.yml".into(), + interval_duration: Duration::from_secs(60), + source: TicketSource::Sql { + sql: "SELECT n FROM generate_series(1, 1000) AS n".into(), + numerical: Vec::new(), + }, + ..Default::default() + }; + + let mut context = TeraCtx::new(); + let flow = alert + .read_sources(&pool, Timestamp::now(), &mut context, false) + .await + .unwrap(); + + assert!(matches!(flow, ControlFlow::Continue(()))); + let rows = context.get("rows").expect("rows inserted into context"); + assert_eq!(rows.as_array().map(|a| a.len()), Some(100)); + } }