diff --git a/core/src/duckdb/creator.rs b/core/src/duckdb/creator.rs index 12b0a9a5..6fba730b 100644 --- a/core/src/duckdb/creator.rs +++ b/core/src/duckdb/creator.rs @@ -232,6 +232,24 @@ impl TableDefinition { .map(|(name, time_created)| (RelationName(name), time_created)) .collect()) } + + /// Resolve the actual table name for DML operations (DELETE, UPDATE). + /// + /// If the table is backed by a view over an internal + /// `__data_*` table, returns the latest internal table name. + /// Otherwise returns the base table definition name. + /// + /// # Errors + /// + /// Returns an error if the internal tables cannot be listed. + pub fn resolve_dml_table_name(&self, tx: &Transaction<'_>) -> super::Result { + let internal_tables = self.list_internal_tables(tx)?; + if let Some((latest_internal_table_name, _)) = internal_tables.last() { + Ok(latest_internal_table_name.to_string()) + } else { + Ok(self.name.to_string()) + } + } } /// A table creator, which is used to create, delete, and manage tables based on a `TableDefinition`. diff --git a/core/src/duckdb/write.rs b/core/src/duckdb/write.rs index fd1da6e0..667e6184 100644 --- a/core/src/duckdb/write.rs +++ b/core/src/duckdb/write.rs @@ -260,12 +260,12 @@ impl TableProvider for DuckDBTableWriter { } else { Some(filters_to_sql(&filters, Some(expr::Engine::DuckDB))?) }; - let table_name = self.table_definition.name().to_string(); + let table_definition = Arc::clone(&self.table_definition); let pool = Arc::clone(&self.pool); Ok(Arc::new(DeletionExec::new(Arc::new(DuckDBDeletionSink { pool, - table_name, + table_definition, sql_where, })))) } @@ -281,26 +281,26 @@ impl TableProvider for DuckDBTableWriter { } let set_clause = assignments_to_sql(&assignments, Some(expr::Engine::DuckDB))?; - let table_name = self.table_definition.name().to_string(); - let pool = Arc::clone(&self.pool); - - let sql = if filters.is_empty() { - format!(r#"UPDATE "{table_name}" SET {set_clause}"#) + let sql_where = if filters.is_empty() { + None } else { - let sql_where = filters_to_sql(&filters, Some(expr::Engine::DuckDB))?; - format!(r#"UPDATE "{table_name}" SET {set_clause} WHERE {sql_where}"#) + Some(filters_to_sql(&filters, Some(expr::Engine::DuckDB))?) }; + let table_definition = Arc::clone(&self.table_definition); + let pool = Arc::clone(&self.pool); Ok(Arc::new(UpdateExec::new(Arc::new(DuckDBUpdateSink { pool, - sql, + table_definition, + set_clause, + sql_where, })))) } } struct DuckDBDeletionSink { pool: Arc, - table_name: String, + table_definition: Arc, sql_where: Option, } @@ -308,7 +308,7 @@ struct DuckDBDeletionSink { impl DeletionSink for DuckDBDeletionSink { async fn delete_from(&self) -> Result> { let pool = Arc::clone(&self.pool); - let table_name = self.table_name.clone(); + let table_definition = Arc::clone(&self.table_definition); let sql_where = self.sql_where.clone(); tokio::task::spawn_blocking( @@ -317,6 +317,8 @@ impl DeletionSink for DuckDBDeletionSink { let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn)?; let tx = duckdb_conn.conn.transaction()?; + let table_name = table_definition.resolve_dml_table_name(&tx)?; + let delete_sql = if let Some(sql_where) = &sql_where { format!(r#"DELETE FROM "{table_name}" WHERE {sql_where}"#) } else { @@ -335,14 +337,18 @@ impl DeletionSink for DuckDBDeletionSink { struct DuckDBUpdateSink { pool: Arc, - sql: String, + table_definition: Arc, + set_clause: String, + sql_where: Option, } #[async_trait] impl UpdateSink for DuckDBUpdateSink { async fn execute_update(&self) -> Result> { let pool = Arc::clone(&self.pool); - let sql = self.sql.clone(); + let table_definition = Arc::clone(&self.table_definition); + let set_clause = self.set_clause.clone(); + let sql_where = self.sql_where.clone(); tokio::task::spawn_blocking( move || -> Result> { @@ -350,6 +356,13 @@ impl UpdateSink for DuckDBUpdateSink { let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn)?; let tx = duckdb_conn.conn.transaction()?; + let table_name = table_definition.resolve_dml_table_name(&tx)?; + + let sql = if let Some(sql_where) = &sql_where { + format!(r#"UPDATE "{table_name}" SET {set_clause} WHERE {sql_where}"#) + } else { + format!(r#"UPDATE "{table_name}" SET {set_clause}"#) + }; let count = tx.execute(&sql, [])?; tx.commit()?; @@ -1871,4 +1884,129 @@ mod test { assert_eq!(name, "all"); } } + + /// Helper: set up a DuckDB table via Overwrite (which creates a view over + /// an internal `__data_*` table), then return `(DuckDBTableWriter, pool)`. + async fn setup_writer_with_overwrite_data( + ids: Vec, + names: Vec<&str>, + ) -> (DuckDBTableWriter, Arc) { + let pool = get_mem_duckdb(); + let table_definition = get_basic_table_definition(); + + // Insert seed data via DuckDBDataSink with Overwrite mode. + // This creates an internal __data_* table and a view with the table definition name. + let schema = table_definition.schema(); + let duckdb_sink = DuckDBDataSink::new( + Arc::clone(&pool), + Arc::clone(&table_definition), + InsertOp::Overwrite, + None, + Arc::clone(&schema), + ); + let batches = vec![RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(StringArray::from(names)), + ], + ) + .expect("should create a record batch")]; + let stream = Box::pin( + MemoryStream::try_new(batches, Arc::clone(&schema), None).expect("to get stream"), + ); + Arc::new(duckdb_sink) + .write_all(stream, &Arc::new(TaskContext::default())) + .await + .expect("to write all"); + + let mem_table: Arc = Arc::new( + datafusion::datasource::MemTable::try_new(schema, vec![vec![]]) + .expect("to create mem table"), + ); + let writer = DuckDBTableWriterBuilder::new() + .with_read_provider(mem_table) + .with_pool(Arc::clone(&pool)) + .with_table_definition((*table_definition).clone()) + .build() + .expect("to build writer"); + + (writer, pool) + } + + #[tokio::test] + async fn test_delete_from_view_backed_table_with_filter() { + let _guard = init_tracing(None); + let (writer, pool) = + setup_writer_with_overwrite_data(vec![1, 2, 3], vec!["a", "b", "c"]).await; + + let ctx = datafusion::prelude::SessionContext::new(); + + // DELETE WHERE id = 2 + let filters = + vec![datafusion::logical_expr::col("id").eq(datafusion::logical_expr::lit(2i64))]; + let plan = writer + .delete_from(&ctx.state(), filters) + .await + .expect("delete_from should succeed"); + + let count = extract_count(plan).await; + assert_eq!(count, 1, "should have deleted exactly 1 row"); + + // Verify remaining rows via the view + let (ids, names) = query_all_rows(&pool); + assert_eq!(ids.len(), 2, "should have 2 rows remaining"); + assert_eq!(ids, vec![1, 3]); + assert_eq!(names, vec!["a", "c"]); + } + + #[tokio::test] + async fn test_delete_from_view_backed_table_empty_filters() { + let _guard = init_tracing(None); + let (writer, pool) = + setup_writer_with_overwrite_data(vec![1, 2, 3], vec!["a", "b", "c"]).await; + + let ctx = datafusion::prelude::SessionContext::new(); + let plan = writer + .delete_from(&ctx.state(), vec![]) + .await + .expect("delete_from should succeed"); + + let count = extract_count(plan).await; + assert_eq!(count, 3, "should have deleted all 3 rows"); + + let (ids, _) = query_all_rows(&pool); + assert!(ids.is_empty(), "table should be empty after delete-all"); + } + + #[tokio::test] + async fn test_update_view_backed_table_with_filter() { + let _guard = init_tracing(None); + let (writer, pool) = + setup_writer_with_overwrite_data(vec![1, 2, 3], vec!["a", "b", "c"]).await; + + let ctx = datafusion::prelude::SessionContext::new(); + + let assignments = vec![("name".to_string(), datafusion::logical_expr::lit("updated"))]; + let filters = + vec![datafusion::logical_expr::col("id").eq(datafusion::logical_expr::lit(2i64))]; + let plan = writer + .update(&ctx.state(), assignments, filters) + .await + .expect("update should succeed"); + + let count = extract_count(plan).await; + assert_eq!(count, 1, "should have updated exactly 1 row"); + + let (ids, names) = query_all_rows(&pool); + assert_eq!(ids.len(), 3, "should still have 3 rows"); + for (id, name) in ids.iter().zip(names.iter()) { + match *id { + 1 => assert_eq!(name, "a"), + 2 => assert_eq!(name, "updated"), + 3 => assert_eq!(name, "c"), + other => panic!("unexpected id {other}"), + } + } + } }