diff --git a/README.md b/README.md index 5a69385..7407e89 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,9 @@ See the [documentation](https://docs.rs/orc-rust/latest/orc_rust/) for examples ## Supported features -This crate currently only supports reading ORC files into Arrow arrays. Write support is planned -(see [Roadmap](#roadmap)). The below features listed relate only to reading ORC files. -At this time, we aim to support the [ORCv1](https://orc.apache.org/specification/ORCv1/) specification only. +This crate supports reading ORC files into Arrow arrays and writing flat Arrow +`RecordBatch`es to ORC files. At this time, we aim to support the +[ORCv1](https://orc.apache.org/specification/ORCv1/) specification only. - Read synchronously & asynchronously (using Tokio) - All compression types (Zlib, Snappy, Lzo, Lz4, Zstd) @@ -22,6 +22,8 @@ At this time, we aim to support the [ORCv1](https://orc.apache.org/specification - All encodings - Rudimentary support for retrieving statistics - Retrieving user metadata into Arrow schema metadata +- Write Arrow arrays synchronously, with an async writer API for async sinks +- Writer compression, row group indexes, bloom filters, and column statistics ## Roadmap @@ -32,9 +34,8 @@ The following lists the rough roadmap for features to be implemented, from highe - Performance enhancements - Predicate pushdown -- Row indices -- Bloom filters -- Write from Arrow arrays +- Complete row index seek positions +- Nested type write support - Encryption A non-Arrow API interface is not planned at the moment. Feel free to raise an issue if there is such @@ -120,4 +121,3 @@ To regenerate/update the [proto.rs](src/proto.rs) file, execute the [regen.sh](r ```shell ./regen.sh ``` - diff --git a/src/arrow_writer.rs b/src/arrow_writer.rs index e322493..6bb2d4e 100644 --- a/src/arrow_writer.rs +++ b/src/arrow_writer.rs @@ -25,9 +25,11 @@ use prost::Message; use snafu::{ensure, ResultExt}; use crate::{ + compression::{compress_stream, WriterCompression, DEFAULT_COMPRESSION_BLOCK_SIZE}, error::{IoSnafu, Result, UnexpectedSnafu}, memory::EstimateMemory, proto, + writer::index::{root_column_statistics, ColumnStatsBuilder}, writer::stripe::{StripeInformation, StripeWriter}, }; @@ -38,6 +40,10 @@ pub struct ArrowWriterBuilder { schema: SchemaRef, batch_size: usize, stripe_byte_size: usize, + compression: WriterCompression, + compression_block_size: usize, + row_index_stride: Option, + bloom_filters: bool, } impl ArrowWriterBuilder { @@ -50,6 +56,10 @@ impl ArrowWriterBuilder { batch_size: 1024, // 64 MiB stripe_byte_size: 64 * 1024 * 1024, + compression: WriterCompression::None, + compression_block_size: DEFAULT_COMPRESSION_BLOCK_SIZE as usize, + row_index_stride: None, + bloom_filters: false, } } @@ -66,17 +76,83 @@ impl ArrowWriterBuilder { self } + /// Compress ORC streams and metadata with the provided writer codec. + pub fn with_compression(mut self, compression: WriterCompression) -> Self { + self.compression = compression; + self + } + + /// Enable ORC `ZLIB` compression. + /// + /// ORC does not have a separate protobuf value for `GZIP`; common "gzip" + /// writer options map to ORC `ZLIB`, so this is a naming convenience. + pub fn with_gzip_compression(mut self) -> Self { + self.compression = WriterCompression::Zlib; + self + } + + /// The maximum uncompressed size of each ORC compression block. + pub fn with_compression_block_size(mut self, compression_block_size: usize) -> Self { + self.compression_block_size = compression_block_size; + self + } + + /// Enable writer row indexes with `rows_per_group` rows per row group. + pub fn with_row_index_stride(mut self, rows_per_group: usize) -> Self { + self.row_index_stride = Some(rows_per_group); + self + } + + /// Enable writer bloom filter streams for supported primitive columns. + pub fn with_bloom_filters(mut self) -> Self { + self.bloom_filters = true; + self + } + /// Construct an [`ArrowWriter`] ready to encode [`RecordBatch`]es into /// an ORC file. pub fn try_build(mut self) -> Result> { + ensure!( + self.compression_block_size > 0, + UnexpectedSnafu { + msg: "compression block size must be greater than zero" + } + ); + ensure!( + self.row_index_stride.map_or(true, |stride| stride > 0), + UnexpectedSnafu { + msg: "row index stride must be greater than zero" + } + ); + ensure!( + !self.bloom_filters || self.row_index_stride.is_some(), + UnexpectedSnafu { + msg: "bloom filters require row indexes to define row group boundaries" + } + ); + serialize_schema(&self.schema)?; + // Required magic "ORC" bytes at start of file self.writer.write_all(b"ORC").context(IoSnafu)?; - let writer = StripeWriter::new(self.writer, &self.schema); + let writer = StripeWriter::new( + self.writer, + &self.schema, + self.compression, + self.compression_block_size, + self.row_index_stride, + self.bloom_filters, + ); + let file_stats_builders = schema_field_stats_builders(&self.schema); Ok(ArrowWriter { writer, schema: self.schema, batch_size: self.batch_size, stripe_byte_size: self.stripe_byte_size, + compression: self.compression, + compression_block_size: self.compression_block_size, + row_index_stride: self.row_index_stride, + file_stats_builders, + total_rows_written: 0, written_stripes: vec![], // Accounting for the 3 magic bytes above total_bytes_written: 3, @@ -92,6 +168,11 @@ pub struct ArrowWriter { schema: SchemaRef, batch_size: usize, stripe_byte_size: usize, + compression: WriterCompression, + compression_block_size: usize, + row_index_stride: Option, + file_stats_builders: Vec, + total_rows_written: u64, written_stripes: Vec, /// Used to keep track of progress in file so far (instead of needing Seek on the writer) total_bytes_written: u64, @@ -108,6 +189,16 @@ impl ArrowWriter { } ); + self.total_rows_written += batch.num_rows() as u64; + for ((array, field), stats_builder) in batch + .columns() + .iter() + .zip(self.schema.fields().iter()) + .zip(self.file_stats_builders.iter_mut()) + { + stats_builder.update_array(field.data_type(), array); + } + for offset in (0..batch.num_rows()).step_by(self.batch_size) { let length = self.batch_size.min(batch.num_rows() - offset); let batch = batch.slice(offset, length); @@ -133,29 +224,81 @@ impl ArrowWriter { /// Flush the current stripe if it is still in progress, and write the tail /// metadata and close the writer. - pub fn close(mut self) -> Result<()> { + pub fn close(self) -> Result<()> { + self.finish().map(|_| ()) + } + + /// Flush buffered data, write file tail metadata, and return the inner writer. + pub fn finish(mut self) -> Result { // Flush in-progress stripe if self.writer.row_count > 0 { self.flush_stripe()?; } - let footer = serialize_footer(&self.written_stripes, &self.schema); + + let metadata = serialize_metadata(&self.written_stripes); + let metadata = metadata.encode_to_vec(); + let metadata = compress_stream( + bytes::Bytes::from(metadata), + self.compression, + self.compression_block_size, + )?; + let metadata_length = metadata.len() as u64; + + let file_statistics = self.file_column_statistics(); + let footer = serialize_footer( + &self.written_stripes, + &self.schema, + self.row_index_stride, + file_statistics, + )?; let footer = footer.encode_to_vec(); - let postscript = serialize_postscript(footer.len() as u64); + let footer = compress_stream( + bytes::Bytes::from(footer), + self.compression, + self.compression_block_size, + )?; + let postscript = serialize_postscript( + footer.len() as u64, + metadata_length, + self.compression, + self.compression_block_size, + ); let postscript = postscript.encode_to_vec(); let postscript_len = postscript.len() as u8; let mut writer = self.writer.finish(); + writer.write_all(&metadata).context(IoSnafu)?; writer.write_all(&footer).context(IoSnafu)?; writer.write_all(&postscript).context(IoSnafu)?; // Postscript length as last byte writer.write_all(&[postscript_len]).context(IoSnafu)?; - // TODO: return file metadata - Ok(()) + Ok(writer) + } + + fn file_column_statistics(&self) -> Vec { + let mut statistics = Vec::with_capacity(self.file_stats_builders.len() + 1); + statistics.push(root_column_statistics(self.total_rows_written)); + statistics.extend( + self.schema + .fields() + .iter() + .zip(self.file_stats_builders.iter()) + .map(|(field, builder)| builder.finish(field.data_type())), + ); + statistics } } -fn serialize_schema(schema: &SchemaRef) -> Vec { +fn schema_field_stats_builders(schema: &SchemaRef) -> Vec { + schema + .fields() + .iter() + .map(|field| ColumnStatsBuilder::new(field.data_type())) + .collect() +} + +fn serialize_schema(schema: &SchemaRef) -> Result> { let mut types = vec![]; let field_names = schema @@ -213,45 +356,101 @@ fn serialize_schema(schema: &SchemaRef) -> Vec { kind: Some(proto::r#type::Kind::Boolean.into()), ..Default::default() }, + ArrowDataType::Decimal128(precision, scale) => { + ensure!( + *scale >= 0, + UnexpectedSnafu { + msg: "negative decimal scales are not supported by the ORC writer" + } + ); + proto::Type { + kind: Some(proto::r#type::Kind::Decimal.into()), + precision: Some(*precision as u32), + scale: Some(*scale as u32), + ..Default::default() + } + } + ArrowDataType::Date32 => proto::Type { + kind: Some(proto::r#type::Kind::Date.into()), + ..Default::default() + }, + ArrowDataType::Timestamp(_, None) => proto::Type { + kind: Some(proto::r#type::Kind::Timestamp.into()), + ..Default::default() + }, + ArrowDataType::Timestamp(_, Some(tz)) if tz.as_ref() == "UTC" => proto::Type { + kind: Some(proto::r#type::Kind::TimestampInstant.into()), + ..Default::default() + }, + ArrowDataType::Timestamp(_, Some(_)) => { + ensure!( + false, + UnexpectedSnafu { + msg: "only UTC timestamp timezones are supported by the ORC writer" + } + ); + unreachable!() + } // TODO: support more types _ => unimplemented!("unsupported datatype"), }; types.push(t); } - types + Ok(types) } -fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto::Footer { +fn serialize_footer( + stripes: &[StripeInformation], + schema: &SchemaRef, + row_index_stride: Option, + statistics: Vec, +) -> Result { let body_length = stripes .iter() .map(|s| s.index_length + s.data_length + s.footer_length) .sum::(); let number_of_rows = stripes.iter().map(|s| s.row_count as u64).sum::(); let stripes = stripes.iter().map(From::from).collect(); - let types = serialize_schema(schema); - proto::Footer { + let types = serialize_schema(schema)?; + Ok(proto::Footer { header_length: Some(3), content_length: Some(body_length + 3), stripes, types, metadata: vec![], number_of_rows: Some(number_of_rows), - statistics: vec![], - row_index_stride: None, + statistics, + row_index_stride: row_index_stride.map(|stride| stride as u32), writer: Some(u32::MAX), encryption: None, calendar: None, software_version: None, + }) +} + +fn serialize_metadata(stripes: &[StripeInformation]) -> proto::Metadata { + proto::Metadata { + stripe_stats: stripes + .iter() + .map(|stripe| proto::StripeStatistics { + col_stats: stripe.column_statistics.clone(), + }) + .collect(), } } -fn serialize_postscript(footer_length: u64) -> proto::PostScript { +fn serialize_postscript( + footer_length: u64, + metadata_length: u64, + compression: WriterCompression, + compression_block_size: usize, +) -> proto::PostScript { proto::PostScript { footer_length: Some(footer_length), - compression: Some(proto::CompressionKind::None.into()), // TODO: support compression - compression_block_size: None, + compression: Some(compression.to_proto().into()), + compression_block_size: (!compression.is_none()).then_some(compression_block_size as u64), version: vec![0, 12], - metadata_length: Some(0), // TODO: statistics + metadata_length: Some(metadata_length), writer_version: Some(u32::MAX), // TODO: check which version to use stripe_statistics_length: None, magic: Some("ORC".to_string()), @@ -264,17 +463,17 @@ mod tests { use arrow::{ array::{ - Array, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatchReader, - StringArray, + Array, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, + LargeStringArray, RecordBatchReader, StringArray, TimestampNanosecondArray, }, buffer::NullBuffer, compute::concat_batches, - datatypes::{DataType as ArrowDataType, Field, Schema}, + datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit}, }; use bytes::Bytes; - use crate::{stripe::Stripe, ArrowReaderBuilder}; + use crate::{statistics::TypeStatistics, stripe::Stripe, ArrowReaderBuilder}; use super::*; @@ -293,6 +492,29 @@ mod tests { reader.collect::, _>>().unwrap() } + fn write_to_bytes( + batch: &RecordBatch, + gzip_compression: bool, + row_index_stride: Option, + bloom_filters: bool, + ) -> Bytes { + let mut f = vec![]; + let mut builder = ArrowWriterBuilder::new(&mut f, batch.schema()); + if gzip_compression { + builder = builder.with_gzip_compression(); + } + if let Some(row_index_stride) = row_index_stride { + builder = builder.with_row_index_stride(row_index_stride); + } + if bloom_filters { + builder = builder.with_bloom_filters(); + } + let mut writer = builder.try_build().unwrap(); + writer.write(batch).unwrap(); + writer.close().unwrap(); + Bytes::from(f) + } + #[test] fn test_roundtrip_write() { let f32_array = Arc::new(Float32Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); @@ -354,6 +576,161 @@ mod tests { assert_eq!(batch, rows[0]); } + #[test] + fn test_roundtrip_write_gzip_compression() { + let array = Arc::new(Int64Array::from((0..1024).collect::>())); + let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap(); + + let f = write_to_bytes(&batch, true, None, false); + let builder = ArrowReaderBuilder::try_new(f).unwrap(); + assert!(builder.file_metadata().compression().is_some()); + + let rows = builder.build().collect::, _>>().unwrap(); + assert_eq!(batch, rows[0]); + } + + #[test] + fn test_write_row_indexes() { + let array = Arc::new(Int64Array::from((0..12).collect::>())); + let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap(); + + let mut f = write_to_bytes(&batch, false, Some(5), false); + let builder = ArrowReaderBuilder::try_new(f.clone()).unwrap(); + assert_eq!(builder.file_metadata().row_index_stride(), Some(5)); + + let stripe = Stripe::new( + &mut f, + builder.file_metadata(), + builder.file_metadata().root_data_type(), + &builder.file_metadata().stripe_metadatas()[0], + ) + .unwrap(); + let row_index = stripe.read_row_indexes(builder.file_metadata()).unwrap(); + let column_index = row_index.column(1).unwrap(); + + assert_eq!(column_index.num_row_groups(), 3); + assert_eq!(row_index.total_rows(), 12); + assert_eq!(row_index.rows_per_group(), 5); + + let stats = column_index.row_group_stats(0).unwrap(); + assert_eq!(stats.number_of_values(), 5); + assert!(!stats.has_null()); + match stats.type_statistics().unwrap() { + TypeStatistics::Integer { min, max, sum } => { + assert_eq!((*min, *max, *sum), (0, 4, Some(10))); + } + other => panic!("expected integer stats, got {other:?}"), + } + } + + #[test] + fn test_write_bloom_filters() { + let array = Arc::new(StringArray::from(vec!["alpha", "beta", "gamma", "delta"])); + let schema = Schema::new(vec![Field::new("name", ArrowDataType::Utf8, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap(); + + let mut f = write_to_bytes(&batch, false, Some(2), true); + let builder = ArrowReaderBuilder::try_new(f.clone()).unwrap(); + let stripe = Stripe::new( + &mut f, + builder.file_metadata(), + builder.file_metadata().root_data_type(), + &builder.file_metadata().stripe_metadatas()[0], + ) + .unwrap(); + let row_index = stripe.read_row_indexes(builder.file_metadata()).unwrap(); + let column_index = row_index.column(1).unwrap(); + let first_group = column_index.entry(0).unwrap(); + let bloom = first_group.bloom_filter.as_ref().unwrap(); + + assert!(bloom.might_contain(b"alpha")); + assert!(!bloom.might_contain(b"definitely-not-present")); + } + + #[test] + fn test_write_file_and_stripe_statistics() { + let array = Arc::new(Int64Array::from((0..12).collect::>())); + let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap(); + + let f = write_to_bytes(&batch, false, None, false); + let builder = ArrowReaderBuilder::try_new(f).unwrap(); + let file_stats = builder.file_metadata().column_file_statistics(); + assert_eq!(file_stats.len(), 2); + assert_eq!(file_stats[0].number_of_values(), 12); + assert_eq!(file_stats[1].number_of_values(), 12); + match file_stats[1].type_statistics().unwrap() { + TypeStatistics::Integer { min, max, sum } => { + assert_eq!((*min, *max, *sum), (0, 11, Some(66))); + } + other => panic!("expected integer stats, got {other:?}"), + } + + let stripe_stats = builder.file_metadata().stripe_metadatas()[0].column_statistics(); + assert_eq!(stripe_stats.len(), 2); + assert_eq!(stripe_stats[0].number_of_values(), 12); + assert_eq!(stripe_stats[1].number_of_values(), 12); + } + + #[test] + fn test_roundtrip_write_decimal_date_timestamp() { + let decimal_array = Arc::new( + Decimal128Array::from(vec![Some(12345), None, Some(-678), Some(0)]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + let date_array = Arc::new(Date32Array::from(vec![ + Some(19_358), + None, + Some(0), + Some(-1), + ])); + let timestamp_array = Arc::new(TimestampNanosecondArray::from(vec![ + Some(0), + None, + Some(1_672_531_200_123_456_789), + Some(-1_000_000_000), + ])); + let timestamp_utc_array = Arc::new( + TimestampNanosecondArray::from(vec![ + Some(0), + None, + Some(1_672_531_200_000_000_000), + Some(1_000_000_000), + ]) + .with_timezone("UTC"), + ); + let schema = Schema::new(vec![ + Field::new("decimal", ArrowDataType::Decimal128(10, 2), true), + Field::new("date", ArrowDataType::Date32, true), + Field::new( + "timestamp", + ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "timestamp_utc", + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + decimal_array, + date_array, + timestamp_array, + timestamp_utc_array, + ], + ) + .unwrap(); + + let rows = roundtrip(std::slice::from_ref(&batch)); + assert_eq!(batch, rows[0]); + } + #[test] fn test_roundtrip_write_large_type() { let large_utf8_array = Arc::new(LargeStringArray::from(vec![ diff --git a/src/async_arrow_writer.rs b/src/async_arrow_writer.rs new file mode 100644 index 0000000..ba7e266 --- /dev/null +++ b/src/async_arrow_writer.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Async API for writing Arrow [`RecordBatch`]es into ORC files. + +use arrow::{array::RecordBatch, datatypes::SchemaRef}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::{ + arrow_writer::{ArrowWriter, ArrowWriterBuilder}, + compression::{WriterCompression, DEFAULT_COMPRESSION_BLOCK_SIZE}, + error::{IoSnafu, Result, UnexpectedSnafu}, +}; +use snafu::{ensure, ResultExt}; + +/// Construct an [`AsyncArrowWriter`] backed by an async writer. +pub struct AsyncArrowWriterBuilder { + writer: W, + schema: SchemaRef, + batch_size: usize, + stripe_byte_size: usize, + compression: WriterCompression, + compression_block_size: usize, + row_index_stride: Option, + bloom_filters: bool, +} + +impl AsyncArrowWriterBuilder { + /// Create a new builder for writing ORC bytes to `writer`. + pub fn new(writer: W, schema: SchemaRef) -> Self { + Self { + writer, + schema, + batch_size: 1024, + stripe_byte_size: 64 * 1024 * 1024, + compression: WriterCompression::None, + compression_block_size: DEFAULT_COMPRESSION_BLOCK_SIZE as usize, + row_index_stride: None, + bloom_filters: false, + } + } + + /// Batch size controls the encoding behaviour. Default is `1024`. + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// The approximate size of stripes. Default is `64MiB`. + pub fn with_stripe_byte_size(mut self, stripe_byte_size: usize) -> Self { + self.stripe_byte_size = stripe_byte_size; + self + } + + /// Compress ORC streams and metadata with the provided writer codec. + pub fn with_compression(mut self, compression: WriterCompression) -> Self { + self.compression = compression; + self + } + + /// Enable ORC `ZLIB` compression. + pub fn with_gzip_compression(mut self) -> Self { + self.compression = WriterCompression::Zlib; + self + } + + /// The maximum uncompressed size of each ORC compression block. + pub fn with_compression_block_size(mut self, compression_block_size: usize) -> Self { + self.compression_block_size = compression_block_size; + self + } + + /// Enable writer row indexes with `rows_per_group` rows per row group. + pub fn with_row_index_stride(mut self, rows_per_group: usize) -> Self { + self.row_index_stride = Some(rows_per_group); + self + } + + /// Enable writer bloom filter streams for supported primitive columns. + pub fn with_bloom_filters(mut self) -> Self { + self.bloom_filters = true; + self + } + + /// Construct an [`AsyncArrowWriter`]. + pub fn try_build(self) -> Result> { + ensure!( + self.compression_block_size > 0, + UnexpectedSnafu { + msg: "compression block size must be greater than zero" + } + ); + ensure!( + self.row_index_stride.map_or(true, |stride| stride > 0), + UnexpectedSnafu { + msg: "row index stride must be greater than zero" + } + ); + ensure!( + !self.bloom_filters || self.row_index_stride.is_some(), + UnexpectedSnafu { + msg: "bloom filters require row indexes to define row group boundaries" + } + ); + + let mut builder = ArrowWriterBuilder::new(Vec::new(), self.schema) + .with_batch_size(self.batch_size) + .with_stripe_byte_size(self.stripe_byte_size) + .with_compression(self.compression) + .with_compression_block_size(self.compression_block_size); + if let Some(row_index_stride) = self.row_index_stride { + builder = builder.with_row_index_stride(row_index_stride); + } + if self.bloom_filters { + builder = builder.with_bloom_filters(); + } + let inner = builder.try_build()?; + Ok(AsyncArrowWriter { + writer: self.writer, + inner, + }) + } +} + +/// Encodes [`RecordBatch`]es into ORC and writes the final bytes asynchronously. +pub struct AsyncArrowWriter { + writer: W, + inner: ArrowWriter>, +} + +impl AsyncArrowWriter { + /// Encode the provided batch into the ORC stream. + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.inner.write(batch) + } + + /// Flush buffered ORC bytes to the async writer and close this writer. + pub async fn close(self) -> Result<()> { + self.finish().await.map(|_| ()) + } + + /// Flush buffered ORC bytes to the async writer and return the inner writer. + pub async fn finish(mut self) -> Result { + let buffer = self.inner.finish()?; + self.writer.write_all(&buffer).await.context(IoSnafu)?; + self.writer.flush().await.context(IoSnafu)?; + Ok(self.writer) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{Int64Array, RecordBatch}, + datatypes::{DataType as ArrowDataType, Field, Schema}, + }; + use bytes::Bytes; + use tokio::io::AsyncReadExt; + + use crate::{ArrowReaderBuilder, AsyncArrowWriterBuilder}; + + #[tokio::test] + async fn test_async_writer_roundtrip() { + let array = Arc::new(Int64Array::from(vec![1, 2, 3, 4])); + let schema = Arc::new(Schema::new(vec![Field::new( + "int64", + ArrowDataType::Int64, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + + let (client, mut server) = tokio::io::duplex(4096); + let read_handle = tokio::spawn(async move { + let mut out = Vec::new(); + server.read_to_end(&mut out).await.unwrap(); + out + }); + + let mut writer = AsyncArrowWriterBuilder::new(client, batch.schema()) + .try_build() + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().await.unwrap(); + + let bytes = Bytes::from(read_handle.await.unwrap()); + let rows = ArrowReaderBuilder::try_new(bytes) + .unwrap() + .build() + .collect::, _>>() + .unwrap(); + assert_eq!(batch, rows[0]); + } +} diff --git a/src/bloom_filter.rs b/src/bloom_filter.rs index 25e7e68..71c2804 100644 --- a/src/bloom_filter.rs +++ b/src/bloom_filter.rs @@ -84,15 +84,27 @@ impl BloomFilter { }) } - #[cfg(test)] /// Create a Bloom filter from raw parts (mainly for tests) - pub fn from_parts(num_hash_functions: u32, bitset: Vec) -> Self { + pub(crate) fn from_parts(num_hash_functions: u32, bitset: Vec) -> Self { Self { num_hash_functions: num_hash_functions.max(1), bitset, } } + pub(crate) fn to_proto_utf8(&self) -> proto::BloomFilter { + let utf8bitset = self + .bitset + .iter() + .flat_map(|word| word.to_le_bytes()) + .collect(); + proto::BloomFilter { + num_hash_functions: Some(self.num_hash_functions), + bitset: vec![], + utf8bitset: Some(utf8bitset), + } + } + /// Set bits for the provided 64-bit hash using ORC's double-hash scheme. pub fn add_hash(&mut self, hash64: u64) { let bit_count = self.bitset.len() * 64; diff --git a/src/compression.rs b/src/compression.rs index 012a3f7..4d53098 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -16,9 +16,9 @@ // under the License. // Modified from https://github.com/DataEngineeringLabs/orc-format/blob/416490db0214fc51d53289253c0ee91f7fc9bc17/src/read/decompress/mod.rs -//! Related code for handling decompression of ORC files. +//! Related code for handling compression and decompression of ORC files. -use std::io::Read; +use std::io::{Read, Write}; use bytes::{Bytes, BytesMut}; use fallible_streaming_iterator::FallibleStreamingIterator; @@ -28,7 +28,8 @@ use crate::error::{self, OrcError, Result}; use crate::proto::{self, CompressionKind}; // Spec states default is 256K -const DEFAULT_COMPRESSION_BLOCK_SIZE: u64 = 256 * 1024; +pub(crate) const DEFAULT_COMPRESSION_BLOCK_SIZE: u64 = 256 * 1024; +const MAX_COMPRESSION_BLOCK_SIZE: usize = (1 << 23) - 1; #[derive(Clone, Copy, Debug)] pub struct Compression { @@ -100,6 +101,32 @@ impl std::fmt::Display for CompressionType { } } +/// Compression codec used by the ORC writer. +/// +/// ORC's protobuf format does not define a separate `GZIP` compression kind. +/// Ecosystems that expose a "gzip" writer option typically store it as ORC +/// `ZLIB`, so [`WriterCompression::Zlib`] is also used by +/// [`ArrowWriterBuilder::with_gzip_compression`](crate::ArrowWriterBuilder::with_gzip_compression). +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub enum WriterCompression { + #[default] + None, + Zlib, +} + +impl WriterCompression { + pub(crate) fn to_proto(self) -> CompressionKind { + match self { + Self::None => CompressionKind::None, + Self::Zlib => CompressionKind::Zlib, + } + } + + pub(crate) fn is_none(self) -> bool { + matches!(self, Self::None) + } +} + /// Indicates length of block and whether it's compressed or not. #[derive(Debug, PartialEq, Eq)] enum CompressionHeader { @@ -141,6 +168,17 @@ struct Lz4 { impl DecompressorVariant for Zlib { fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()> { + if let Ok(decoded) = read_zlib_wrapped(compressed_bytes) { + scratch.clear(); + scratch.extend(decoded); + return Ok(()); + } + if let Ok(decoded) = read_gzip_wrapped(compressed_bytes) { + scratch.clear(); + scratch.extend(decoded); + return Ok(()); + } + let mut gz = flate2::read::DeflateDecoder::new(compressed_bytes); scratch.clear(); gz.read_to_end(scratch).context(error::IoSnafu)?; @@ -212,6 +250,70 @@ fn get_decompressor_variant( } } +fn read_zlib_wrapped(compressed_bytes: &[u8]) -> std::io::Result> { + let mut decoder = flate2::read::ZlibDecoder::new(compressed_bytes); + let mut decoded = Vec::new(); + decoder.read_to_end(&mut decoded)?; + Ok(decoded) +} + +fn read_gzip_wrapped(compressed_bytes: &[u8]) -> std::io::Result> { + let mut decoder = flate2::read::GzDecoder::new(compressed_bytes); + let mut decoded = Vec::new(); + decoder.read_to_end(&mut decoded)?; + Ok(decoded) +} + +pub(crate) fn compress_stream( + bytes: Bytes, + compression: WriterCompression, + block_size: usize, +) -> Result { + if compression.is_none() || bytes.is_empty() { + return Ok(bytes); + } + + if block_size == 0 || block_size > MAX_COMPRESSION_BLOCK_SIZE { + return error::UnexpectedSnafu { + msg: format!( + "compression block size must be between 1 and {MAX_COMPRESSION_BLOCK_SIZE}, got {block_size}" + ), + } + .fail(); + } + + let mut output = BytesMut::new(); + for chunk in bytes.chunks(block_size) { + let compressed = match compression { + WriterCompression::None => unreachable!("None handled above"), + WriterCompression::Zlib => zlib_compress(chunk)?, + }; + + if compressed.len() >= chunk.len() { + output.extend_from_slice(&encode_header(chunk.len(), true)); + output.extend_from_slice(chunk); + } else { + output.extend_from_slice(&encode_header(compressed.len(), false)); + output.extend_from_slice(&compressed); + } + } + + Ok(output.freeze()) +} + +fn zlib_compress(bytes: &[u8]) -> Result> { + let mut encoder = flate2::write::ZlibEncoder::new(Vec::new(), flate2::Compression::default()); + encoder.write_all(bytes).context(error::IoSnafu)?; + encoder.finish().context(error::IoSnafu) +} + +fn encode_header(length: usize, is_original: bool) -> [u8; 3] { + debug_assert!(length <= MAX_COMPRESSION_BLOCK_SIZE); + let length_and_flag = ((length as u32) << 1) | u32::from(is_original); + let bytes = length_and_flag.to_le_bytes(); + [bytes[0], bytes[1], bytes[2]] +} + enum State { Original(Bytes), Compressed(Vec), @@ -349,6 +451,7 @@ impl std::io::Read for Decompressor { #[cfg(test)] mod tests { use super::*; + use std::io::{Read, Write}; #[test] fn decode_uncompressed() { @@ -368,4 +471,47 @@ mod tests { let actual = decode_header(bytes); assert_eq!(expected, actual); } + + #[test] + fn zlib_decompressor_accepts_zlib_gzip_and_raw_deflate_wrappers() { + let input = b"orc zlib compatibility"; + + let mut zlib_encoder = + flate2::write::ZlibEncoder::new(Vec::new(), flate2::Compression::default()); + zlib_encoder.write_all(input).unwrap(); + let zlib = zlib_encoder.finish().unwrap(); + + let mut gzip_encoder = + flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default()); + gzip_encoder.write_all(input).unwrap(); + let gzip = gzip_encoder.finish().unwrap(); + + let mut deflate_encoder = + flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default()); + deflate_encoder.write_all(input).unwrap(); + let deflate = deflate_encoder.finish().unwrap(); + + for compressed in [zlib, gzip, deflate] { + let mut scratch = Vec::new(); + Zlib.decompress_block(&compressed, &mut scratch).unwrap(); + assert_eq!(input, scratch.as_slice()); + } + } + + #[test] + fn zlib_compress_stream_roundtrips() { + let input = Bytes::from_static(b"abcdefghijklmnopqrstuvwxyz"); + let compressed = compress_stream(input.clone(), WriterCompression::Zlib, 8).unwrap(); + assert_ne!(input, compressed); + + let compression = Compression { + compression_type: CompressionType::Zlib, + max_decompressed_block_size: 8, + }; + let mut decompressor = Decompressor::new(compressed, Some(compression), vec![]); + let mut output = Vec::new(); + decompressor.read_to_end(&mut output).unwrap(); + + assert_eq!(input.as_ref(), output.as_slice()); + } } diff --git a/src/encoding/integer/mod.rs b/src/encoding/integer/mod.rs index 80457e2..9116e27 100644 --- a/src/encoding/integer/mod.rs +++ b/src/encoding/integer/mod.rs @@ -41,7 +41,7 @@ pub mod rle_v2; mod util; // TODO: consider having a separate varint.rs -pub use util::read_varint_zigzagged; +pub use util::{read_varint_zigzagged, write_varint_zigzagged}; #[derive(Debug, Clone, Copy)] pub enum RleVersion { diff --git a/src/lib.rs b/src/lib.rs index e639421..50c1fd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,8 @@ pub mod arrow_reader; pub mod arrow_writer; #[cfg(feature = "async")] pub mod async_arrow_reader; +#[cfg(feature = "async")] +pub mod async_arrow_writer; pub mod bloom_filter; mod column; pub mod compression; @@ -75,6 +77,8 @@ pub use arrow_reader::{ArrowReader, ArrowReaderBuilder}; pub use arrow_writer::{ArrowWriter, ArrowWriterBuilder}; #[cfg(feature = "async")] pub use async_arrow_reader::ArrowStreamReader; +#[cfg(feature = "async")] +pub use async_arrow_writer::{AsyncArrowWriter, AsyncArrowWriterBuilder}; pub use predicate::{ComparisonOp, Predicate, PredicateValue}; pub use row_selection::{RowSelection, RowSelector}; pub use schema::{ArrowSchemaOptions, TimestampPrecision}; diff --git a/src/writer/column.rs b/src/writer/column.rs index f1ec72f..25b9f1b 100644 --- a/src/writer/column.rs +++ b/src/writer/column.rs @@ -20,21 +20,26 @@ use std::marker::PhantomData; use arrow::{ array::{Array, ArrayRef, AsArray}, datatypes::{ - ArrowPrimitiveType, ByteArrayType, Float32Type, Float64Type, GenericBinaryType, - GenericStringType, Int16Type, Int32Type, Int64Type, Int8Type, + ArrowPrimitiveType, ArrowTimestampType, ByteArrayType, Date32Type, Decimal128Type, + Float32Type, Float64Type, GenericBinaryType, GenericStringType, Int16Type, Int32Type, + Int64Type, Int8Type, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, }, }; use bytes::{BufMut, BytesMut}; +use snafu::ensure; use crate::{ encoding::{ boolean::BooleanEncoder, byte::ByteRleEncoder, float::FloatEncoder, - integer::{rle_v2::RleV2Encoder, NInt, SignedEncoding, UnsignedEncoding}, + integer::{ + rle_v2::RleV2Encoder, write_varint_zigzagged, NInt, SignedEncoding, UnsignedEncoding, + }, PrimitiveValueEncoder, }, - error::Result, + error::{Result, UnexpectedSnafu}, memory::EstimateMemory, writer::StreamType, }; @@ -390,12 +395,256 @@ where } } +pub struct DecimalColumnEncoder { + data: BytesMut, + scale: RleV2Encoder, + present: Option, + encoded_count: usize, + fixed_scale: i8, +} + +impl DecimalColumnEncoder { + pub fn new(fixed_scale: i8) -> Self { + Self { + data: BytesMut::new(), + scale: RleV2Encoder::new(), + present: None, + encoded_count: 0, + fixed_scale, + } + } + + fn write_value(&mut self, value: i128) { + write_varint_zigzagged::(&mut self.data, value); + self.scale.write_one(self.fixed_scale as i32); + } +} + +impl EstimateMemory for DecimalColumnEncoder { + fn estimate_memory_size(&self) -> usize { + self.data.len() + + self.scale.estimate_memory_size() + + self + .present + .as_ref() + .map(|p| p.estimate_memory_size()) + .unwrap_or(0) + } +} + +impl ColumnStripeEncoder for DecimalColumnEncoder { + fn encode_array(&mut self, array: &ArrayRef) -> Result<()> { + let array = array.as_primitive::(); + match (array.nulls(), &mut self.present) { + (Some(null_buffer), Some(present)) => { + present.extend(null_buffer); + for index in null_buffer.valid_indices() { + self.write_value(array.value(index)); + } + } + (Some(null_buffer), None) => { + let mut present = BooleanEncoder::new(); + present.extend_present(self.encoded_count); + present.extend(null_buffer); + self.present = Some(present); + for index in null_buffer.valid_indices() { + self.write_value(array.value(index)); + } + } + (None, _) => { + for value in array.values() { + self.write_value(*value); + } + if let Some(present) = self.present.as_mut() { + present.extend_present(array.len()) + } + } + } + self.encoded_count += array.len() - array.null_count(); + Ok(()) + } + + fn column_encoding(&self) -> ColumnEncoding { + ColumnEncoding::DirectV2 + } + + fn finish(&mut self) -> Vec { + let data = Stream { + kind: StreamType::Data, + bytes: std::mem::take(&mut self.data).into(), + }; + let secondary = Stream { + kind: StreamType::Secondary, + bytes: self.scale.take_inner(), + }; + self.encoded_count = 0; + match &mut self.present { + Some(present) => { + let present = Stream { + kind: StreamType::Present, + bytes: present.finish(), + }; + vec![data, secondary, present] + } + None => vec![data, secondary], + } + } +} + +pub struct TimestampColumnEncoder { + data: RleV2Encoder, + secondary: RleV2Encoder, + present: Option, + encoded_count: usize, + _phantom: PhantomData, +} + +impl TimestampColumnEncoder { + pub fn new() -> Self { + Self { + data: RleV2Encoder::new(), + secondary: RleV2Encoder::new(), + present: None, + encoded_count: 0, + _phantom: PhantomData, + } + } + + fn write_value(&mut self, value: i64) -> Result<()> { + let nanos_since_epoch = timestamp_value_to_nanos::(value); + let (seconds_since_base, encoded_nanos) = encode_timestamp(nanos_since_epoch)?; + self.data.write_one(seconds_since_base); + self.secondary.write_one(encoded_nanos); + Ok(()) + } +} + +impl EstimateMemory for TimestampColumnEncoder { + fn estimate_memory_size(&self) -> usize { + self.data.estimate_memory_size() + + self.secondary.estimate_memory_size() + + self + .present + .as_ref() + .map(|p| p.estimate_memory_size()) + .unwrap_or(0) + } +} + +impl ColumnStripeEncoder for TimestampColumnEncoder { + fn encode_array(&mut self, array: &ArrayRef) -> Result<()> { + let array = array.as_primitive::(); + match (array.nulls(), &mut self.present) { + (Some(null_buffer), Some(present)) => { + present.extend(null_buffer); + for index in null_buffer.valid_indices() { + self.write_value(array.value(index))?; + } + } + (Some(null_buffer), None) => { + let mut present = BooleanEncoder::new(); + present.extend_present(self.encoded_count); + present.extend(null_buffer); + self.present = Some(present); + for index in null_buffer.valid_indices() { + self.write_value(array.value(index))?; + } + } + (None, _) => { + for value in array.values() { + self.write_value(*value)?; + } + if let Some(present) = self.present.as_mut() { + present.extend_present(array.len()) + } + } + } + self.encoded_count += array.len() - array.null_count(); + Ok(()) + } + + fn column_encoding(&self) -> ColumnEncoding { + ColumnEncoding::DirectV2 + } + + fn finish(&mut self) -> Vec { + let data = Stream { + kind: StreamType::Data, + bytes: self.data.take_inner(), + }; + let secondary = Stream { + kind: StreamType::Secondary, + bytes: self.secondary.take_inner(), + }; + self.encoded_count = 0; + match &mut self.present { + Some(present) => { + let present = Stream { + kind: StreamType::Present, + bytes: present.finish(), + }; + vec![data, secondary, present] + } + None => vec![data, secondary], + } + } +} + +const ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH: i128 = 1_420_070_400; +const NANOSECONDS_IN_SECOND: i128 = 1_000_000_000; + +fn timestamp_value_to_nanos(value: i64) -> i128 { + match T::UNIT { + TimeUnit::Second => (value as i128) * 1_000_000_000, + TimeUnit::Millisecond => (value as i128) * 1_000_000, + TimeUnit::Microsecond => (value as i128) * 1_000, + TimeUnit::Nanosecond => value as i128, + } +} + +fn encode_timestamp(nanos_since_epoch: i128) -> Result<(i64, i64)> { + let seconds_since_epoch = nanos_since_epoch.div_euclid(NANOSECONDS_IN_SECOND); + let nanos = nanos_since_epoch.rem_euclid(NANOSECONDS_IN_SECOND) as u32; + let seconds_since_base = seconds_since_epoch - ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH; + ensure!( + seconds_since_base >= i64::MIN as i128 && seconds_since_base <= i64::MAX as i128, + UnexpectedSnafu { + msg: "timestamp seconds are out of ORC writer range" + } + ); + Ok((seconds_since_base as i64, encode_timestamp_nanos(nanos))) +} + +fn encode_timestamp_nanos(nanos: u32) -> i64 { + if nanos == 0 { + return 0; + } + + let mut stripped = nanos; + let mut zeros = 0; + while zeros < 8 && stripped % 10 == 0 { + stripped /= 10; + zeros += 1; + } + + if zeros > 1 { + ((stripped as i64) << 3) | ((zeros - 1) as i64) + } else { + (nanos as i64) << 3 + } +} + pub type FloatColumnEncoder = PrimitiveColumnEncoder>; pub type DoubleColumnEncoder = PrimitiveColumnEncoder>; pub type ByteColumnEncoder = PrimitiveColumnEncoder; pub type Int16ColumnEncoder = PrimitiveColumnEncoder>; pub type Int32ColumnEncoder = PrimitiveColumnEncoder>; pub type Int64ColumnEncoder = PrimitiveColumnEncoder>; +pub type DateColumnEncoder = PrimitiveColumnEncoder>; +pub type TimestampSecondColumnEncoder = TimestampColumnEncoder; +pub type TimestampMillisecondColumnEncoder = TimestampColumnEncoder; +pub type TimestampMicrosecondColumnEncoder = TimestampColumnEncoder; +pub type TimestampNanosecondColumnEncoder = TimestampColumnEncoder; pub type StringColumnEncoder = GenericBinaryColumnEncoder>; pub type LargeStringColumnEncoder = GenericBinaryColumnEncoder>; pub type BinaryColumnEncoder = GenericBinaryColumnEncoder>; diff --git a/src/writer/index.rs b/src/writer/index.rs new file mode 100644 index 0000000..e4648b9 --- /dev/null +++ b/src/writer/index.rs @@ -0,0 +1,680 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Array, ArrayRef, AsArray}, + datatypes::{ + DataType as ArrowDataType, Date32Type, Decimal128Type, Float32Type, Float64Type, + GenericBinaryType, GenericStringType, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, + }, +}; + +use crate::{bloom_filter::BloomFilter, proto}; + +#[derive(Debug)] +pub(crate) struct RowIndexBuilder { + data_type: ArrowDataType, + rows_per_group: usize, + row_groups: Vec, + current: ColumnStatsBuilder, + rows_in_current_group: usize, +} + +impl RowIndexBuilder { + pub(crate) fn new(data_type: ArrowDataType, rows_per_group: usize) -> Self { + Self { + current: ColumnStatsBuilder::new(&data_type), + data_type, + rows_per_group, + row_groups: Vec::new(), + rows_in_current_group: 0, + } + } + + pub(crate) fn update(&mut self, array: &ArrayRef) { + for row_index in 0..array.len() { + if self.rows_in_current_group == self.rows_per_group { + self.finish_current_group(); + } + + self.current.update(&self.data_type, array, row_index); + self.rows_in_current_group += 1; + } + } + + pub(crate) fn finish(&mut self) -> proto::RowIndex { + if self.rows_in_current_group > 0 { + self.finish_current_group(); + } + + proto::RowIndex { + entry: std::mem::take(&mut self.row_groups), + } + } + + fn finish_current_group(&mut self) { + let statistics = self.current.finish(&self.data_type); + self.row_groups.push(proto::RowIndexEntry { + // The current reader uses row-group statistics for predicate pruning. + // Stream seek positions require encoder-level position recorders. + positions: vec![], + statistics: Some(statistics), + }); + self.current = ColumnStatsBuilder::new(&self.data_type); + self.rows_in_current_group = 0; + } +} + +#[derive(Debug)] +pub(crate) struct BloomFilterIndexBuilder { + data_type: ArrowDataType, + rows_per_group: usize, + row_groups: Vec, + current: BloomFilter, + rows_in_current_group: usize, + enabled: bool, +} + +impl BloomFilterIndexBuilder { + pub(crate) fn new(data_type: ArrowDataType, rows_per_group: usize) -> Self { + let enabled = bloom_supported(&data_type); + Self { + current: new_bloom_filter(rows_per_group), + data_type, + rows_per_group, + row_groups: Vec::new(), + rows_in_current_group: 0, + enabled, + } + } + + pub(crate) fn update(&mut self, array: &ArrayRef) { + if !self.enabled { + return; + } + + for row_index in 0..array.len() { + if self.rows_in_current_group == self.rows_per_group { + self.finish_current_group(); + } + + if !array.is_null(row_index) { + if let Some(hash) = bloom_hash(&self.data_type, array, row_index) { + self.current.add_hash(hash); + } + } + self.rows_in_current_group += 1; + } + } + + pub(crate) fn finish(&mut self) -> proto::BloomFilterIndex { + if self.enabled && self.rows_in_current_group > 0 { + self.finish_current_group(); + } + + proto::BloomFilterIndex { + bloom_filter: std::mem::take(&mut self.row_groups), + } + } + + fn finish_current_group(&mut self) { + self.row_groups.push(self.current.to_proto_utf8()); + self.current = new_bloom_filter(self.rows_per_group); + self.rows_in_current_group = 0; + } +} + +fn new_bloom_filter(rows_per_group: usize) -> BloomFilter { + let bit_count = rows_per_group.saturating_mul(16).max(1024); + let word_count = bit_count.div_ceil(64).min(8192); + BloomFilter::from_parts(6, vec![0; word_count]) +} + +fn bloom_supported(data_type: &ArrowDataType) -> bool { + matches!( + data_type, + ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 + | ArrowDataType::Float32 + | ArrowDataType::Float64 + | ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 + | ArrowDataType::Boolean + ) +} + +fn bloom_hash(data_type: &ArrowDataType, array: &ArrayRef, row_index: usize) -> Option { + match data_type { + ArrowDataType::Int8 => Some(BloomFilter::hash_long( + array.as_primitive::().value(row_index) as i64, + )), + ArrowDataType::Int16 => Some(BloomFilter::hash_long( + array.as_primitive::().value(row_index) as i64, + )), + ArrowDataType::Int32 => Some(BloomFilter::hash_long( + array.as_primitive::().value(row_index) as i64, + )), + ArrowDataType::Int64 => Some(BloomFilter::hash_long( + array.as_primitive::().value(row_index), + )), + ArrowDataType::Float32 => Some(BloomFilter::hash_long( + (array.as_primitive::().value(row_index) as f64).to_bits() as i64, + )), + ArrowDataType::Float64 => Some(BloomFilter::hash_long( + array + .as_primitive::() + .value(row_index) + .to_bits() as i64, + )), + ArrowDataType::Utf8 => Some(BloomFilter::hash_bytes( + array + .as_bytes::>() + .value(row_index) + .as_bytes(), + )), + ArrowDataType::LargeUtf8 => Some(BloomFilter::hash_bytes( + array + .as_bytes::>() + .value(row_index) + .as_bytes(), + )), + ArrowDataType::Boolean => Some(BloomFilter::hash_long( + if array.as_boolean().value(row_index) { + 1 + } else { + 0 + }, + )), + _ => None, + } +} + +#[derive(Debug)] +pub(crate) struct ColumnStatsBuilder { + number_of_values: u64, + has_null: bool, + type_stats: TypeStatsBuilder, +} + +impl ColumnStatsBuilder { + pub(crate) fn new(data_type: &ArrowDataType) -> Self { + Self { + number_of_values: 0, + has_null: false, + type_stats: TypeStatsBuilder::new(data_type), + } + } + + pub(crate) fn update_array(&mut self, data_type: &ArrowDataType, array: &ArrayRef) { + for row_index in 0..array.len() { + self.update(data_type, array, row_index); + } + } + + fn update(&mut self, data_type: &ArrowDataType, array: &ArrayRef, row_index: usize) { + if array.is_null(row_index) { + self.has_null = true; + return; + } + + self.number_of_values += 1; + + match data_type { + ArrowDataType::Int8 => { + self.type_stats + .update_integer(array.as_primitive::().value(row_index) as i64); + } + ArrowDataType::Int16 => { + self.type_stats + .update_integer(array.as_primitive::().value(row_index) as i64); + } + ArrowDataType::Int32 => { + self.type_stats + .update_integer(array.as_primitive::().value(row_index) as i64); + } + ArrowDataType::Int64 => { + self.type_stats + .update_integer(array.as_primitive::().value(row_index)); + } + ArrowDataType::Float32 => { + self.type_stats + .update_double(array.as_primitive::().value(row_index) as f64); + } + ArrowDataType::Float64 => { + self.type_stats + .update_double(array.as_primitive::().value(row_index)); + } + ArrowDataType::Utf8 => { + self.type_stats + .update_string(array.as_bytes::>().value(row_index)); + } + ArrowDataType::LargeUtf8 => { + self.type_stats + .update_string(array.as_bytes::>().value(row_index)); + } + ArrowDataType::Binary => { + self.type_stats.update_binary( + array + .as_bytes::>() + .value(row_index) + .len(), + ); + } + ArrowDataType::LargeBinary => { + self.type_stats.update_binary( + array + .as_bytes::>() + .value(row_index) + .len(), + ); + } + ArrowDataType::Boolean => { + self.type_stats + .update_boolean(array.as_boolean().value(row_index)); + } + ArrowDataType::Decimal128(_, _) => { + self.type_stats + .update_decimal(array.as_primitive::().value(row_index)); + } + ArrowDataType::Date32 => { + self.type_stats + .update_date(array.as_primitive::().value(row_index)); + } + ArrowDataType::Timestamp(TimeUnit::Second, _) => { + self.type_stats.update_timestamp_nanos( + (array.as_primitive::().value(row_index) as i128) + * 1_000_000_000, + ); + } + ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => { + self.type_stats.update_timestamp_nanos( + (array + .as_primitive::() + .value(row_index) as i128) + * 1_000_000, + ); + } + ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => { + self.type_stats.update_timestamp_nanos( + (array + .as_primitive::() + .value(row_index) as i128) + * 1_000, + ); + } + ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => { + self.type_stats.update_timestamp_nanos( + array + .as_primitive::() + .value(row_index) as i128, + ); + } + _ => {} + } + } + + pub(crate) fn finish(&self, data_type: &ArrowDataType) -> proto::ColumnStatistics { + let mut statistics = proto::ColumnStatistics { + number_of_values: Some(self.number_of_values), + has_null: Some(self.has_null), + ..Default::default() + }; + + if self.number_of_values == 0 { + return statistics; + } + + match (&self.type_stats, data_type) { + (TypeStatsBuilder::Integer(stats), _) => { + if let (Some(minimum), Some(maximum)) = (stats.minimum, stats.maximum) { + statistics.int_statistics = Some(proto::IntegerStatistics { + minimum: Some(minimum), + maximum: Some(maximum), + sum: stats.sum, + }); + } + } + (TypeStatsBuilder::Double(stats), _) => { + if let (Some(minimum), Some(maximum)) = (stats.minimum, stats.maximum) { + statistics.double_statistics = Some(proto::DoubleStatistics { + minimum: Some(minimum), + maximum: Some(maximum), + sum: stats.sum, + }); + } + } + (TypeStatsBuilder::String(stats), ArrowDataType::Utf8 | ArrowDataType::LargeUtf8) => { + if let (Some(minimum), Some(maximum)) = (&stats.minimum, &stats.maximum) { + statistics.string_statistics = Some(proto::StringStatistics { + minimum: Some(minimum.clone()), + maximum: Some(maximum.clone()), + sum: Some(stats.sum), + lower_bound: None, + upper_bound: None, + }); + } + } + ( + TypeStatsBuilder::Binary(stats), + ArrowDataType::Binary | ArrowDataType::LargeBinary, + ) => { + statistics.binary_statistics = Some(proto::BinaryStatistics { + sum: Some(stats.sum), + }); + } + (TypeStatsBuilder::Boolean(stats), ArrowDataType::Boolean) => { + statistics.bucket_statistics = Some(proto::BucketStatistics { + count: vec![stats.true_count], + }); + } + (TypeStatsBuilder::Decimal(stats), ArrowDataType::Decimal128(_, scale)) => { + if let (Some(minimum), Some(maximum)) = (stats.minimum, stats.maximum) { + statistics.decimal_statistics = Some(proto::DecimalStatistics { + minimum: Some(format_decimal(minimum, *scale)), + maximum: Some(format_decimal(maximum, *scale)), + sum: stats.sum.map(|sum| format_decimal(sum, *scale)), + }); + } + } + (TypeStatsBuilder::Date(stats), ArrowDataType::Date32) => { + if let (Some(minimum), Some(maximum)) = (stats.minimum, stats.maximum) { + statistics.date_statistics = Some(proto::DateStatistics { + minimum: Some(minimum), + maximum: Some(maximum), + }); + } + } + (TypeStatsBuilder::Timestamp(stats), ArrowDataType::Timestamp(_, _)) => { + if let (Some(minimum), Some(maximum)) = (stats.minimum_nanos, stats.maximum_nanos) { + statistics.timestamp_statistics = Some(proto::TimestampStatistics { + minimum: Some(timestamp_millis(minimum)), + maximum: Some(timestamp_millis(maximum)), + minimum_utc: Some(timestamp_millis(minimum)), + maximum_utc: Some(timestamp_millis(maximum)), + minimum_nanos: Some(timestamp_sub_millis_nanos(minimum)), + maximum_nanos: Some(timestamp_sub_millis_nanos(maximum)), + }); + } + } + _ => {} + } + + statistics + } +} + +pub(crate) fn root_column_statistics(number_of_rows: u64) -> proto::ColumnStatistics { + proto::ColumnStatistics { + number_of_values: Some(number_of_rows), + has_null: Some(false), + ..Default::default() + } +} + +#[derive(Debug)] +enum TypeStatsBuilder { + Integer(IntegerStatsBuilder), + Double(DoubleStatsBuilder), + String(StringStatsBuilder), + Binary(BinaryStatsBuilder), + Boolean(BooleanStatsBuilder), + Decimal(DecimalStatsBuilder), + Date(DateStatsBuilder), + Timestamp(TimestampStatsBuilder), + Unsupported, +} + +impl TypeStatsBuilder { + fn new(data_type: &ArrowDataType) -> Self { + match data_type { + ArrowDataType::Int8 + | ArrowDataType::Int16 + | ArrowDataType::Int32 + | ArrowDataType::Int64 => Self::Integer(IntegerStatsBuilder::default()), + ArrowDataType::Float32 | ArrowDataType::Float64 => { + Self::Double(DoubleStatsBuilder::default()) + } + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { + Self::String(StringStatsBuilder::default()) + } + ArrowDataType::Binary | ArrowDataType::LargeBinary => { + Self::Binary(BinaryStatsBuilder::default()) + } + ArrowDataType::Boolean => Self::Boolean(BooleanStatsBuilder::default()), + ArrowDataType::Decimal128(_, _) => Self::Decimal(DecimalStatsBuilder::default()), + ArrowDataType::Date32 => Self::Date(DateStatsBuilder::default()), + ArrowDataType::Timestamp(_, _) => Self::Timestamp(TimestampStatsBuilder::default()), + _ => Self::Unsupported, + } + } + + fn update_integer(&mut self, value: i64) { + if let Self::Integer(stats) = self { + stats.update(value); + } + } + + fn update_double(&mut self, value: f64) { + if let Self::Double(stats) = self { + stats.update(value); + } + } + + fn update_string(&mut self, value: &str) { + if let Self::String(stats) = self { + stats.update(value); + } + } + + fn update_binary(&mut self, length: usize) { + if let Self::Binary(stats) = self { + stats.update(length); + } + } + + fn update_boolean(&mut self, value: bool) { + if let Self::Boolean(stats) = self { + stats.update(value); + } + } + + fn update_decimal(&mut self, value: i128) { + if let Self::Decimal(stats) = self { + stats.update(value); + } + } + + fn update_date(&mut self, value: i32) { + if let Self::Date(stats) = self { + stats.update(value); + } + } + + fn update_timestamp_nanos(&mut self, value: i128) { + if let Self::Timestamp(stats) = self { + stats.update(value); + } + } +} + +#[derive(Debug, Default)] +struct IntegerStatsBuilder { + minimum: Option, + maximum: Option, + sum: Option, +} + +impl IntegerStatsBuilder { + fn update(&mut self, value: i64) { + self.minimum = Some(self.minimum.map_or(value, |minimum| minimum.min(value))); + self.maximum = Some(self.maximum.map_or(value, |maximum| maximum.max(value))); + self.sum = match self.sum { + Some(sum) => sum.checked_add(value), + None => Some(value), + }; + } +} + +#[derive(Debug, Default)] +struct DoubleStatsBuilder { + minimum: Option, + maximum: Option, + sum: Option, +} + +impl DoubleStatsBuilder { + fn update(&mut self, value: f64) { + if value.is_nan() { + return; + } + + self.minimum = Some(self.minimum.map_or(value, |minimum| minimum.min(value))); + self.maximum = Some(self.maximum.map_or(value, |maximum| maximum.max(value))); + self.sum = Some(self.sum.unwrap_or(0.0) + value); + } +} + +#[derive(Debug, Default)] +struct StringStatsBuilder { + minimum: Option, + maximum: Option, + sum: i64, +} + +impl StringStatsBuilder { + fn update(&mut self, value: &str) { + self.minimum = Some(self.minimum.as_deref().map_or_else( + || value.to_string(), + |minimum| minimum.min(value).to_string(), + )); + self.maximum = Some(self.maximum.as_deref().map_or_else( + || value.to_string(), + |maximum| maximum.max(value).to_string(), + )); + self.sum += value.len() as i64; + } +} + +#[derive(Debug, Default)] +struct BinaryStatsBuilder { + sum: i64, +} + +impl BinaryStatsBuilder { + fn update(&mut self, length: usize) { + self.sum += length as i64; + } +} + +#[derive(Debug, Default)] +struct BooleanStatsBuilder { + true_count: u64, +} + +impl BooleanStatsBuilder { + fn update(&mut self, value: bool) { + if value { + self.true_count += 1; + } + } +} + +#[derive(Debug, Default)] +struct DecimalStatsBuilder { + minimum: Option, + maximum: Option, + sum: Option, +} + +impl DecimalStatsBuilder { + fn update(&mut self, value: i128) { + self.minimum = Some(self.minimum.map_or(value, |minimum| minimum.min(value))); + self.maximum = Some(self.maximum.map_or(value, |maximum| maximum.max(value))); + self.sum = match self.sum { + Some(sum) => sum.checked_add(value), + None => Some(value), + }; + } +} + +#[derive(Debug, Default)] +struct DateStatsBuilder { + minimum: Option, + maximum: Option, +} + +impl DateStatsBuilder { + fn update(&mut self, value: i32) { + self.minimum = Some(self.minimum.map_or(value, |minimum| minimum.min(value))); + self.maximum = Some(self.maximum.map_or(value, |maximum| maximum.max(value))); + } +} + +#[derive(Debug, Default)] +struct TimestampStatsBuilder { + minimum_nanos: Option, + maximum_nanos: Option, +} + +impl TimestampStatsBuilder { + fn update(&mut self, value: i128) { + self.minimum_nanos = Some( + self.minimum_nanos + .map_or(value, |minimum| minimum.min(value)), + ); + self.maximum_nanos = Some( + self.maximum_nanos + .map_or(value, |maximum| maximum.max(value)), + ); + } +} + +fn timestamp_millis(nanos_since_epoch: i128) -> i64 { + nanos_since_epoch.div_euclid(1_000_000) as i64 +} + +fn timestamp_sub_millis_nanos(nanos_since_epoch: i128) -> i32 { + nanos_since_epoch.rem_euclid(1_000_000) as i32 +} + +fn format_decimal(value: i128, scale: i8) -> String { + if scale <= 0 { + let multiplier = 10_i128.pow((-scale) as u32); + return (value * multiplier).to_string(); + } + + let scale = scale as usize; + let sign = if value < 0 { "-" } else { "" }; + let digits = if value == i128::MIN { + "170141183460469231731687303715884105728".to_string() + } else { + value.abs().to_string() + }; + if digits.len() <= scale { + let padding = "0".repeat(scale + 1 - digits.len()); + let digits = format!("{padding}{digits}"); + let split = digits.len() - scale; + return format!("{sign}{}.{}", &digits[..split], &digits[split..]); + } + + let split = digits.len() - scale; + format!("{sign}{}.{}", &digits[..split], &digits[split..]) +} diff --git a/src/writer/mod.rs b/src/writer/mod.rs index 0fc8f72..fef9145 100644 --- a/src/writer/mod.rs +++ b/src/writer/mod.rs @@ -22,10 +22,13 @@ use bytes::Bytes; use crate::proto; pub mod column; +pub(crate) mod index; pub mod stripe; #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum StreamType { + RowIndex, + BloomFilterUtf8, Present, Data, Length, @@ -36,6 +39,8 @@ pub enum StreamType { impl From for proto::stream::Kind { fn from(value: StreamType) -> Self { match value { + StreamType::RowIndex => proto::stream::Kind::RowIndex, + StreamType::BloomFilterUtf8 => proto::stream::Kind::BloomFilterUtf8, StreamType::Present => proto::stream::Kind::Present, StreamType::Data => proto::stream::Kind::Data, StreamType::Length => proto::stream::Kind::Length, diff --git a/src/writer/stripe.rs b/src/writer/stripe.rs index e16ee8e..1d9382c 100644 --- a/src/writer/stripe.rs +++ b/src/writer/stripe.rs @@ -19,27 +19,36 @@ use std::io::Write; use arrow::array::RecordBatch; use arrow::datatypes::{DataType as ArrowDataType, FieldRef, SchemaRef}; +use bytes::Bytes; use prost::Message; use snafu::ResultExt; +use crate::compression::{compress_stream, WriterCompression}; use crate::error::{IoSnafu, Result}; use crate::memory::EstimateMemory; use crate::proto; use super::column::{ BinaryColumnEncoder, BooleanColumnEncoder, ByteColumnEncoder, ColumnStripeEncoder, - DoubleColumnEncoder, FloatColumnEncoder, Int16ColumnEncoder, Int32ColumnEncoder, - Int64ColumnEncoder, LargeBinaryColumnEncoder, LargeStringColumnEncoder, StringColumnEncoder, + DateColumnEncoder, DecimalColumnEncoder, DoubleColumnEncoder, FloatColumnEncoder, + Int16ColumnEncoder, Int32ColumnEncoder, Int64ColumnEncoder, LargeBinaryColumnEncoder, + LargeStringColumnEncoder, StringColumnEncoder, TimestampMicrosecondColumnEncoder, + TimestampMillisecondColumnEncoder, TimestampNanosecondColumnEncoder, + TimestampSecondColumnEncoder, +}; +use super::index::{ + root_column_statistics, BloomFilterIndexBuilder, ColumnStatsBuilder, RowIndexBuilder, }; use super::{ColumnEncoding, StreamType}; -#[derive(Copy, Clone, Eq, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct StripeInformation { pub start_offset: u64, pub index_length: u64, pub data_length: u64, pub footer_length: u64, pub row_count: usize, + pub column_statistics: Vec, } impl StripeInformation { @@ -68,6 +77,12 @@ pub struct StripeWriter { writer: W, /// Flattened columns, in order of their column ID. columns: Vec>, + data_types: Vec, + row_index_builders: Option>, + bloom_filter_builders: Option>, + stripe_stats_builders: Vec, + compression: WriterCompression, + compression_block_size: usize, pub row_count: usize, } @@ -80,11 +95,48 @@ impl EstimateMemory for StripeWriter { } impl StripeWriter { - pub fn new(writer: W, schema: &SchemaRef) -> Self { + pub fn new( + writer: W, + schema: &SchemaRef, + compression: WriterCompression, + compression_block_size: usize, + row_index_stride: Option, + bloom_filters: bool, + ) -> Self { let columns = schema.fields().iter().map(create_encoder).collect(); + let data_types = schema + .fields() + .iter() + .map(|field| field.data_type().clone()) + .collect::>(); + let row_index_builders = row_index_stride.map(|rows_per_group| { + schema + .fields() + .iter() + .map(|field| RowIndexBuilder::new(field.data_type().clone(), rows_per_group)) + .collect() + }); + let bloom_filter_builders = row_index_stride.and_then(|rows_per_group| { + bloom_filters.then(|| { + schema + .fields() + .iter() + .map(|field| { + BloomFilterIndexBuilder::new(field.data_type().clone(), rows_per_group) + }) + .collect() + }) + }); + let stripe_stats_builders = data_types.iter().map(ColumnStatsBuilder::new).collect(); Self { writer, columns, + data_types, + row_index_builders, + bloom_filter_builders, + stripe_stats_builders, + compression, + compression_block_size, row_count: 0, } } @@ -93,6 +145,26 @@ impl StripeWriter { /// to required batch size. pub fn encode_batch(&mut self, batch: &RecordBatch) -> Result<()> { // TODO: consider how to handle nested types (including parent nullability) + if let Some(row_index_builders) = self.row_index_builders.as_mut() { + for (array, row_index_builder) in batch.columns().iter().zip(row_index_builders) { + row_index_builder.update(array); + } + } + if let Some(bloom_filter_builders) = self.bloom_filter_builders.as_mut() { + for (array, bloom_filter_builder) in batch.columns().iter().zip(bloom_filter_builders) { + bloom_filter_builder.update(array); + } + } + + for ((array, data_type), stats_builder) in batch + .columns() + .iter() + .zip(self.data_types.iter()) + .zip(self.stripe_stats_builders.iter_mut()) + { + stats_builder.update_array(data_type, array); + } + for (array, encoder) in batch.columns().iter().zip(self.columns.iter_mut()) { encoder.encode_array(array)?; } @@ -119,8 +191,56 @@ impl StripeWriter { column_encodings.extend(child_column_encodings); let column_encodings = column_encodings.iter().map(From::from).collect(); - // Root type won't have any streams + // Root type won't have any streams. Row index streams must be written before data streams. let mut written_streams = vec![]; + let mut index_length = 0; + if let Some(row_index_builders) = self.row_index_builders.as_mut() { + for (index, row_index_builder) in row_index_builders.iter_mut().enumerate() { + let row_index = row_index_builder.finish(); + if row_index.entry.is_empty() { + continue; + } + + let bytes = row_index.encode_to_vec(); + let bytes = compress_stream( + Bytes::from(bytes), + self.compression, + self.compression_block_size, + )?; + let length = bytes.len(); + self.writer.write_all(&bytes).context(IoSnafu)?; + index_length += length as u64; + written_streams.push(WrittenStream { + kind: StreamType::RowIndex, + column: index + 1, + length, + }); + } + } + if let Some(bloom_filter_builders) = self.bloom_filter_builders.as_mut() { + for (index, bloom_filter_builder) in bloom_filter_builders.iter_mut().enumerate() { + let bloom_filter_index = bloom_filter_builder.finish(); + if bloom_filter_index.bloom_filter.is_empty() { + continue; + } + + let bytes = bloom_filter_index.encode_to_vec(); + let bytes = compress_stream( + Bytes::from(bytes), + self.compression, + self.compression_block_size, + )?; + let length = bytes.len(); + self.writer.write_all(&bytes).context(IoSnafu)?; + index_length += length as u64; + written_streams.push(WrittenStream { + kind: StreamType::BloomFilterUtf8, + column: index + 1, + length, + }); + } + } + let mut data_length = 0; for (index, c) in self.columns.iter_mut().enumerate() { // Offset by 1 to account for root of 0 @@ -129,6 +249,7 @@ impl StripeWriter { // Flush the streams to the writer for s in streams { let (kind, bytes) = s.into_parts(); + let bytes = compress_stream(bytes, self.compression, self.compression_block_size)?; let length = bytes.len(); self.writer.write_all(&bytes).context(IoSnafu)?; data_length += length as u64; @@ -148,8 +269,14 @@ impl StripeWriter { }; let footer_bytes = stripe_footer.encode_to_vec(); + let footer_bytes = compress_stream( + Bytes::from(footer_bytes), + self.compression, + self.compression_block_size, + )?; let footer_length = footer_bytes.len() as u64; let row_count = self.row_count; + let column_statistics = self.finish_column_statistics(row_count as u64); self.writer.write_all(&footer_bytes).context(IoSnafu)?; // Reset state for next stripe @@ -157,10 +284,11 @@ impl StripeWriter { Ok(StripeInformation { start_offset, - index_length: 0, + index_length, data_length, footer_length, row_count, + column_statistics, }) } @@ -168,6 +296,23 @@ impl StripeWriter { pub fn finish(self) -> W { self.writer } + + fn finish_column_statistics(&mut self, row_count: u64) -> Vec { + let mut column_statistics = Vec::with_capacity(self.stripe_stats_builders.len() + 1); + column_statistics.push(root_column_statistics(row_count)); + column_statistics.extend( + self.data_types + .iter() + .zip(self.stripe_stats_builders.iter()) + .map(|(data_type, builder)| builder.finish(data_type)), + ); + self.stripe_stats_builders = self + .data_types + .iter() + .map(ColumnStatsBuilder::new) + .collect(); + column_statistics + } } fn create_encoder(field: &FieldRef) -> Box { @@ -183,6 +328,20 @@ fn create_encoder(field: &FieldRef) -> Box { ArrowDataType::Binary => Box::new(BinaryColumnEncoder::new()), ArrowDataType::LargeBinary => Box::new(LargeBinaryColumnEncoder::new()), ArrowDataType::Boolean => Box::new(BooleanColumnEncoder::new()), + ArrowDataType::Decimal128(_, scale) => Box::new(DecimalColumnEncoder::new(*scale)), + ArrowDataType::Date32 => Box::new(DateColumnEncoder::new(ColumnEncoding::DirectV2)), + ArrowDataType::Timestamp(arrow::datatypes::TimeUnit::Second, _) => { + Box::new(TimestampSecondColumnEncoder::new()) + } + ArrowDataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => { + Box::new(TimestampMillisecondColumnEncoder::new()) + } + ArrowDataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { + Box::new(TimestampMicrosecondColumnEncoder::new()) + } + ArrowDataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, _) => { + Box::new(TimestampNanosecondColumnEncoder::new()) + } // TODO: support more datatypes _ => unimplemented!("unsupported datatype"), }