diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java index b32f84f57..6066bb55a 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java @@ -19,6 +19,7 @@ package org.apache.wayang.basic.util; import org.apache.wayang.basic.data.Record; +import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; import java.time.LocalDate; @@ -42,10 +43,15 @@ public class SqlTypeUtils { defaultMap.put(int.class, "INT"); defaultMap.put(Long.class, "BIGINT"); defaultMap.put(long.class, "BIGINT"); + defaultMap.put(Short.class, "SMALLINT"); + defaultMap.put(short.class, "SMALLINT"); defaultMap.put(Double.class, "DOUBLE"); defaultMap.put(double.class, "DOUBLE"); defaultMap.put(Float.class, "FLOAT"); defaultMap.put(float.class, "FLOAT"); + // Explicit precision/scale (matches Spark's DecimalType default); a bare + // NUMERIC defaults to scale 0 on most engines and would truncate decimals. + defaultMap.put(BigDecimal.class, "NUMERIC(38,18)"); defaultMap.put(Boolean.class, "BOOLEAN"); defaultMap.put(boolean.class, "BOOLEAN"); defaultMap.put(String.class, "VARCHAR(255)"); diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java index dfa5a8e5f..378a87816 100644 --- a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java @@ -51,10 +51,21 @@ public void testGetSqlTypeDefault() { assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, DatabaseProduct.UNKNOWN)); assertEquals("INT", SqlTypeUtils.getSqlType(int.class, DatabaseProduct.UNKNOWN)); assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, DatabaseProduct.UNKNOWN)); + assertEquals("BIGINT", SqlTypeUtils.getSqlType(long.class, DatabaseProduct.UNKNOWN)); + assertEquals("SMALLINT", SqlTypeUtils.getSqlType(Short.class, DatabaseProduct.UNKNOWN)); + assertEquals("SMALLINT", SqlTypeUtils.getSqlType(short.class, DatabaseProduct.UNKNOWN)); assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, DatabaseProduct.UNKNOWN)); + assertEquals("DOUBLE", SqlTypeUtils.getSqlType(double.class, DatabaseProduct.UNKNOWN)); + assertEquals("FLOAT", SqlTypeUtils.getSqlType(Float.class, DatabaseProduct.UNKNOWN)); + assertEquals("FLOAT", SqlTypeUtils.getSqlType(float.class, DatabaseProduct.UNKNOWN)); + assertEquals("NUMERIC(38,18)", SqlTypeUtils.getSqlType(java.math.BigDecimal.class, DatabaseProduct.UNKNOWN)); + assertEquals("BOOLEAN", SqlTypeUtils.getSqlType(Boolean.class, DatabaseProduct.UNKNOWN)); + assertEquals("BOOLEAN", SqlTypeUtils.getSqlType(boolean.class, DatabaseProduct.UNKNOWN)); assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, DatabaseProduct.UNKNOWN)); assertEquals("DATE", SqlTypeUtils.getSqlType(java.sql.Date.class, DatabaseProduct.UNKNOWN)); + assertEquals("DATE", SqlTypeUtils.getSqlType(java.time.LocalDate.class, DatabaseProduct.UNKNOWN)); assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(java.sql.Timestamp.class, DatabaseProduct.UNKNOWN)); + assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(java.time.LocalDateTime.class, DatabaseProduct.UNKNOWN)); } @Test @@ -63,6 +74,10 @@ public void testPostgresqlOverrides() { assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, DatabaseProduct.POSTGRESQL)); assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, DatabaseProduct.POSTGRESQL)); assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, DatabaseProduct.POSTGRESQL)); + // Short and BigDecimal are not overridden for PostgreSQL, they inherit from the default map. + assertEquals("SMALLINT", SqlTypeUtils.getSqlType(Short.class, DatabaseProduct.POSTGRESQL)); + assertEquals("SMALLINT", SqlTypeUtils.getSqlType(short.class, DatabaseProduct.POSTGRESQL)); + assertEquals("NUMERIC(38,18)", SqlTypeUtils.getSqlType(java.math.BigDecimal.class, DatabaseProduct.POSTGRESQL)); } @Test diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java index 7e12b1403..5d5086897 100644 --- a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -59,6 +59,10 @@ private void setRecordValue(PreparedStatement ps, int index, Object value) throw ps.setDouble(index, (Double) value); } else if (value instanceof Float) { ps.setFloat(index, (Float) value); + } else if (value instanceof Short) { + ps.setShort(index, (Short) value); + } else if (value instanceof java.math.BigDecimal) { + ps.setBigDecimal(index, (java.math.BigDecimal) value); } else if (value instanceof Boolean) { ps.setBoolean(index, (Boolean) value); } else if (value instanceof java.sql.Date) { diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java index d56b8d838..f8e1443a5 100644 --- a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -32,10 +32,13 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.Statement; +import java.sql.Types; import java.util.Properties; import java.util.stream.Stream; @@ -305,6 +308,60 @@ void testSupportedTypes() throws Exception { } } + @Test + void testWritingAllSupportedTypesToDatabase() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "int_col", "long_col", "short_col", "double_col", "float_col", + "decimal_col", "bool_col", "string_col" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + BigDecimal decimalValue = new BigDecimal("12.345"); + input.accept(Stream.of(new Record( + 42, 9_000_000_000L, (short) 7, 3.14d, 1.5f, decimalValue, true, "hello"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + + // values written through the Java sink are read back from the database unchanged + assertEquals(42, rs.getInt("int_col")); + assertEquals(9_000_000_000L, rs.getLong("long_col")); + assertEquals((short) 7, rs.getShort("short_col")); + assertEquals(3.14d, rs.getDouble("double_col"), 1e-9); + assertEquals(1.5f, rs.getFloat("float_col"), 1e-6f); + assertEquals(0, decimalValue.compareTo(rs.getBigDecimal("decimal_col"))); + assertTrue(rs.getBoolean("bool_col")); + assertEquals("hello", rs.getString("string_col")); + + // the two new types were created with the right SQL column types + ResultSetMetaData md = rs.getMetaData(); + + int shortIdx = rs.findColumn("short_col"); + assertEquals(Types.SMALLINT, md.getColumnType(shortIdx), "Short -> SMALLINT"); + + int decimalIdx = rs.findColumn("decimal_col"); + assertEquals(Types.NUMERIC, md.getColumnType(decimalIdx), "BigDecimal -> NUMERIC"); + assertEquals(38, md.getPrecision(decimalIdx), "BigDecimal precision must be 38"); + assertEquals(18, md.getScale(decimalIdx), "BigDecimal scale must be 18"); + } + } + public static class TestPojo { private int id; private String name; diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java index c2bcc143a..466cca381 100644 --- a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -141,15 +141,21 @@ public Tuple, Collection> eval return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); } - private org.apache.spark.sql.types.DataType getSparkDataType(Class cls) { + // Package-private (not private) so SparkTableSinkTest can unit-test this + // Java-class -> Spark DataType mapping directly; it uses no instance state. + org.apache.spark.sql.types.DataType getSparkDataType(Class cls) { if (cls == Integer.class || cls == int.class) return DataTypes.IntegerType; if (cls == Long.class || cls == long.class) return DataTypes.LongType; + if (cls == Short.class || cls == short.class) + return DataTypes.ShortType; if (cls == Double.class || cls == double.class) return DataTypes.DoubleType; if (cls == Float.class || cls == float.class) return DataTypes.FloatType; + if (cls == java.math.BigDecimal.class) + return DataTypes.createDecimalType(38, 18); if (cls == Boolean.class || cls == boolean.class) return DataTypes.BooleanType; if (cls == java.sql.Date.class || cls == java.time.LocalDate.class) diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java index bd9ca40de..6b68ff1e3 100644 --- a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -21,14 +21,20 @@ import org.apache.wayang.core.platform.ChannelInstance; import org.apache.wayang.core.types.DataSetType; import org.apache.wayang.spark.channels.RddChannel; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.Statement; +import java.sql.Types; import java.util.Arrays; import java.util.Properties; @@ -239,6 +245,109 @@ void testSupportedTypes() throws Exception { } } + @Test + void testWritingAllSupportedTypesToDatabase() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + String[] columns = { "int_col", "long_col", "short_col", "double_col", "float_col", + "decimal_col", "bool_col", "string_col" }; + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, columns, + DataSetType.createDefault(Record.class)); + + BigDecimal decimalValue = new BigDecimal("12.345"); + + Record record = new Record( + 42, // int_col + 9_000_000_000L, // long_col + (short) 7, // short_col + 3.14d, // double_col + 1.5f, // float_col + decimalValue, // decimal_col + true, // bool_col + "hello"); // string_col + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(record)); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM \"" + TABLE_NAME + "\"")) { + rs.next(); + + // values written through Spark are read back from the database unchanged + assertEquals(42, rs.getInt("int_col")); + assertEquals(9_000_000_000L, rs.getLong("long_col")); + assertEquals((short) 7, rs.getShort("short_col")); + assertEquals(3.14d, rs.getDouble("double_col"), 1e-9); + assertEquals(1.5f, rs.getFloat("float_col"), 1e-6f); + assertEquals(0, decimalValue.compareTo(rs.getBigDecimal("decimal_col"))); + assertTrue(rs.getBoolean("bool_col")); + assertEquals("hello", rs.getString("string_col")); + + // the two new types were created with the right SQL column types in H2 + ResultSetMetaData md = rs.getMetaData(); + + int shortIdx = rs.findColumn("short_col"); + assertEquals(Types.SMALLINT, md.getColumnType(shortIdx), "Short -> SMALLINT"); + + int decimalIdx = rs.findColumn("decimal_col"); + assertEquals(Types.NUMERIC, md.getColumnType(decimalIdx), "BigDecimal -> NUMERIC"); + assertEquals(38, md.getPrecision(decimalIdx), "BigDecimal precision must be 38"); + assertEquals(18, md.getScale(decimalIdx), "BigDecimal scale must be 18"); + } + } + + // Type-mapping checks (no database). + // Unlike the database tests above, these tests call getSparkDataType(...) directly to pin the Java-class -> Spark DataType + // contract used to build the write schema. + + // Throwaway sink instance, getSparkDataType uses no instance state. + private SparkTableSink mappingProbe() { + return new SparkTableSink<>(new Properties(), "overwrite", "probe", + new String[] { "c" }, DataSetType.createDefault(Record.class)); + } + + @Test + void getSparkDataType_mapsBigDecimalToDecimal38_18() { + DataType type = mappingProbe().getSparkDataType(BigDecimal.class); + assertTrue(type instanceof DecimalType, "BigDecimal must map to a DecimalType"); + DecimalType decimal = (DecimalType) type; + assertEquals(38, decimal.precision(), "BigDecimal precision must be 38"); + assertEquals(18, decimal.scale(), "BigDecimal scale must be 18"); + } + + @Test + void getSparkDataType_mapsAllSupportedTypes() { + SparkTableSink s = mappingProbe(); + assertEquals(DataTypes.IntegerType, s.getSparkDataType(Integer.class)); + assertEquals(DataTypes.IntegerType, s.getSparkDataType(int.class)); + assertEquals(DataTypes.LongType, s.getSparkDataType(Long.class)); + assertEquals(DataTypes.LongType, s.getSparkDataType(long.class)); + assertEquals(DataTypes.ShortType, s.getSparkDataType(Short.class)); + assertEquals(DataTypes.ShortType, s.getSparkDataType(short.class)); + assertEquals(DataTypes.DoubleType, s.getSparkDataType(Double.class)); + assertEquals(DataTypes.DoubleType, s.getSparkDataType(double.class)); + assertEquals(DataTypes.FloatType, s.getSparkDataType(Float.class)); + assertEquals(DataTypes.FloatType, s.getSparkDataType(float.class)); + assertEquals(DataTypes.BooleanType, s.getSparkDataType(Boolean.class)); + assertEquals(DataTypes.BooleanType, s.getSparkDataType(boolean.class)); + assertEquals(DataTypes.DateType, s.getSparkDataType(java.sql.Date.class)); + assertEquals(DataTypes.DateType, s.getSparkDataType(java.time.LocalDate.class)); + assertEquals(DataTypes.TimestampType, s.getSparkDataType(java.sql.Timestamp.class)); + assertEquals(DataTypes.TimestampType, s.getSparkDataType(java.time.LocalDateTime.class)); + } + + @Test + void getSparkDataType_fallsBackToStringForUnsupportedTypes() { + SparkTableSink s = mappingProbe(); + assertEquals(DataTypes.StringType, s.getSparkDataType(Object.class)); + assertEquals(DataTypes.StringType, s.getSparkDataType(java.util.UUID.class)); + } + public static class TestPojo implements java.io.Serializable { private int id; private String name;