diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java new file mode 100644 index 000000000..967bdaa48 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.types.RecordType; +import org.apache.wayang.core.plan.wayangplan.UnarySink; +import org.apache.wayang.core.types.DataSetType; + +import java.util.Properties; + +/** + * {@link UnarySink} that writes Records to a database table. + */ + +public class TableSink extends UnarySink { + private final String tableName; + + private String[] columnNames; + + private final Properties props; + + private String mode; + + /** + * Creates a new instance. + * + * @param props database connection properties + * @param tableName name of the table to be written + * @param columnNames names of the columns in the tables + */ + public TableSink(Properties props, String mode, String tableName, String... columnNames) { + this(props, mode, tableName, columnNames, DataSetType.createDefault(Record.class)); + } + + public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(type); + this.tableName = tableName; + this.columnNames = columnNames; + this.props = props; + this.mode = mode; + } + + /** + * Copies an instance (exclusive of broadcasts). + * + * @param that that should be copied + */ + public TableSink(TableSink that) { + super(that); + this.tableName = that.getTableName(); + this.columnNames = that.getColumnNames(); + this.props = that.getProperties(); + this.mode = that.getMode(); + } + + public String getTableName() { + return this.tableName; + } + + protected void setColumnNames(String[] columnNames) { + this.columnNames = columnNames; + } + + public String[] getColumnNames() { + return this.columnNames; + } + + public Properties getProperties() { + return this.props; + } + + public String getMode() { + return mode; + } + + public void setMode(String mode) { + this.mode = mode; + } + + /** + * Constructs an appropriate output {@link DataSetType} for the given column names. + * + * @param columnNames the column names or an empty array if unknown + * @return the output {@link DataSetType}, which will be based upon a {@link RecordType} unless no {@code columnNames} + * is empty + */ + private static DataSetType createOutputDataSetType(String[] columnNames) { + return columnNames.length == 0 ? + DataSetType.createDefault(Record.class) : + DataSetType.createDefault(new RecordType(columnNames)); + } +} diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index 9c58a78fb..b1ffca4c2 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -78,6 +78,12 @@ log4j-slf4j-impl 2.20.0 + + org.postgresql + postgresql + 42.7.2 + test + org.mockito diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java new file mode 100644 index 000000000..8c4949689 --- /dev/null +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.java.channels.CollectionChannel; +import org.apache.wayang.java.channels.JavaChannelInstance; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + + +public class JavaTableSink extends TableSink implements JavaExecutionOperator { + + private void setRecordValue(PreparedStatement ps, int index, Object value) throws SQLException { + if (value == null) { + ps.setNull(index, java.sql.Types.NULL); + } else if (value instanceof Integer) { + ps.setInt(index, (Integer) value); + } else if (value instanceof Long) { + ps.setLong(index, (Long) value); + } else if (value instanceof Double) { + ps.setDouble(index, (Double) value); + } else if (value instanceof Float) { + ps.setFloat(index, (Float) value); + } else if (value instanceof Boolean) { + ps.setBoolean(index, (Boolean) value); + } else { + ps.setString(index, value.toString()); + } + } + + public JavaTableSink(Properties props, String mode, String tableName) { + this(props, mode, tableName, null); + } + + public JavaTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + + } + + public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + + } + + public JavaTableSink(TableSink that) { + super(that); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + JavaExecutor javaExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + JavaChannelInstance input = (JavaChannelInstance) inputs[0]; + + // The stream is converted to an Iterator so that we can read the first element w/o consuming the entire stream. + Iterator recordIterator = input.provideStream().iterator(); + // We read the first element to derive the Record schema. + Record schemaRecord = recordIterator.next(); + + // We assume that all records have the same length and only check the first record. + int recordLength = schemaRecord.size(); + if (this.getColumnNames() != null) { + assert recordLength == this.getColumnNames().length; + } else { + String[] columnNames = new String[recordLength]; + for (int i = 0; i < recordLength; i++) { + columnNames[i] = "c_" + i; + } + this.setColumnNames(columnNames); + } + + // TODO: Check if we need this property. + this.getProperties().setProperty("streamingBatchInsert", "True"); + + Connection conn; + try { + Class.forName(this.getProperties().getProperty("driver")); + conn = DriverManager.getConnection(this.getProperties().getProperty("url"), this.getProperties()); + conn.setAutoCommit(false); + + Statement stmt = conn.createStatement(); + + // Drop existing table if the mode is 'overwrite'. + if (this.getMode().equals("overwrite")) { + stmt.execute("DROP TABLE IF EXISTS " + this.getTableName()); + } + + // Create a new table if the specified table name does not exist yet. + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE IF NOT EXISTS ").append(this.getTableName()).append(" ("); + String separator = ""; + for (int i = 0; i < recordLength; i++) { + sb.append(separator).append(this.getColumnNames()[i]).append(" VARCHAR(255)"); + separator = ", "; + } + sb.append(")"); + stmt.execute(sb.toString()); + + // Create a prepared statement to insert value from the recordIterator. + sb = new StringBuilder(); + sb.append("INSERT INTO ").append(this.getTableName()).append(" ("); + separator = ""; + for (int i = 0; i < recordLength; i++) { + sb.append(separator).append(this.getColumnNames()[i]); + separator = ", "; + } + sb.append(") VALUES ("); + separator = ""; + for (int i = 0; i < recordLength; i++) { + sb.append(separator).append("?"); + separator = ", "; + } + sb.append(")"); + PreparedStatement ps = conn.prepareStatement(sb.toString()); + + // The schema Record has to be pushed to the database too. + for (int i = 0; i < recordLength; i++) { + setRecordValue(ps, i + 1, schemaRecord.getField(i)); + } + ps.addBatch(); + + // Iterate through all remaining records and add them to the prepared statement + recordIterator.forEachRemaining( + r -> { + try { + for (int i = 0; i < recordLength; i++) { + setRecordValue(ps, i + 1, r.getField(i)); + } + ps.addBatch(); + } catch (SQLException e) { + e.printStackTrace(); + } + } + ); + + ps.executeBatch(); + conn.commit(); + conn.close(); + } catch (ClassNotFoundException e) { + System.out.println("Please specify a correct database driver."); + e.printStackTrace(); + } catch (SQLException e) { + e.printStackTrace(); + } + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "rheem.java.tablesink.load"; + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(CollectionChannel.DESCRIPTOR, StreamChannel.DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } +} diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java new file mode 100644 index 000000000..565cd9e20 --- /dev/null +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; +import org.apache.wayang.java.platform.JavaPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Properties; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link JavaTableSink}. + */ +class JavaTableSinkTest extends JavaExecutionOperatorTestBase { + + private static final String JDBC_URL = "jdbc:postgresql://localhost:5432/default"; + private static final String USERNAME = "postgres"; + private static final String PASSWORD = "123456"; + private static final String TABLE_NAME = "test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + // Load PostgreSQL driver + Class.forName("org.postgresql.Driver"); + + // Connect to database + connection = DriverManager.getConnection(JDBC_URL, USERNAME, PASSWORD); + + // Create test table + try (Statement stmt = connection.createStatement()) { + //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + //stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(100), value DOUBLE PRECISION)"); + } + } + + @AfterEach + void teardownTest() throws Exception { + // Clean up test table + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingToPostgres() throws Exception { + Configuration configuration = new Configuration(); + + // Configure database properties + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", USERNAME); + dbProps.setProperty("password", PASSWORD); + dbProps.setProperty("driver", "org.postgresql.Driver"); + + JavaTableSink sink = new JavaTableSink(dbProps, "overwrite", TABLE_NAME, + new String[]{"id", "name", "value"}, + DataSetType.createDefault(org.apache.wayang.basic.data.Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + // Create input channel with test data + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + // Create test records + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + Record record3 = new Record(3, "Charlie", 300.25); + + inputChannelInstance.accept(Stream.of(record1, record2, record3)); + + // Execute the sink + evaluate(sink, new ChannelInstance[]{inputChannelInstance}, new ChannelInstance[0]); + + // Verify data was written to database + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(3, rs.getInt(1), "Should have written 3 records"); + } + + // Verify specific record + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE id = 1")) { + rs.next(); + assertEquals("Alice", rs.getString("name")); + assertEquals(100.5, rs.getDouble("value"), 0.01); + } + } +} \ No newline at end of file diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index 1e89fd15e..3863f952c 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -121,5 +121,12 @@ 4.8 + + org.postgresql + postgresql + 42.7.2 + compile + + diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java new file mode 100644 index 000000000..ae25f3cb6 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Properties; + +public class SparkTableSink extends TableSink implements SparkExecutionOperator { + + private SaveMode mode; + + private org.apache.spark.sql.types.DataType getDataType(Object value) { + if (value == null) return DataTypes.StringType; + if (value instanceof Integer) return DataTypes.IntegerType; + if (value instanceof Long) return DataTypes.LongType; + if (value instanceof Double) return DataTypes.DoubleType; + if (value instanceof Float) return DataTypes.FloatType; + if (value instanceof Boolean) return DataTypes.BooleanType; + return DataTypes.StringType; + } + + public SparkTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + this.setMode(mode); + } + + public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + this.setMode(mode); + } + + public SparkTableSink(TableSink that) { + super(that); + this.setMode(that.getMode()); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + + JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); + + //nothing to write if rdd empty + recordRDD.cache(); + + boolean isEmpty = recordRDD.isEmpty(); + + if (!isEmpty) { + int recordLength = recordRDD.first().size(); + + JavaRDD rowRDD = recordRDD.map(record -> { + Object[] values = record.getValues(); + return RowFactory.create(values); + }); + + StructField[] fields = new StructField[recordLength]; + Record firstRecord = recordRDD.first(); + for (int i = 0; i < recordLength; i++) { + Object value = firstRecord.getField(i); + org.apache.spark.sql.types.DataType dataType = getDataType(value); + fields[i] = new StructField(this.getColumnNames()[i], dataType, true, Metadata.empty()); + } + StructType schema = new StructType(fields); + + SQLContext sqlcontext = new SQLContext(sparkExecutor.sc.sc()); + Dataset dataSet = sqlcontext.createDataFrame(rowRDD, schema); + this.getProperties().setProperty("batchSize", "250000"); + dataSet.write().mode(this.mode).jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + } else { + System.out.println("RDD is empty, nothing to write!"); + } + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + public void setMode(String mode) { + if (mode == null) { + throw new WayangException("Unspecified write mode for SparkTableSink."); + } else if (mode.equals("append")) { + this.mode = SaveMode.Append; + } else if (mode.equals("overwrite")) { + this.mode = SaveMode.Overwrite; + } else if (mode.equals("errorIfExists")) { + this.mode = SaveMode.ErrorIfExists; + } else if (mode.equals("ignore")) { + this.mode = SaveMode.Ignore; + } else { + throw new WayangException(String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); + } + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + + @Override + public boolean containsAction() { + return true; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "rheem.spark.tablesink.load"; + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java new file mode 100644 index 000000000..78665e90c --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.platform.SparkPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link SparkTableSink}. + */ +class SparkTableSinkTest extends SparkOperatorTestBase { + + private static final String JDBC_URL = "jdbc:postgresql://localhost:5432/default"; + private static final String USERNAME = "postgres"; + private static final String PASSWORD = "123456"; + private static final String TABLE_NAME = "spark_test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + // Load PostgreSQL driver + Class.forName("org.postgresql.Driver"); + + // Connect to database + connection = DriverManager.getConnection(JDBC_URL, USERNAME, PASSWORD); + } + + @AfterEach + void teardownTest() throws Exception { + // Clean up test table + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingToPostgres() throws Exception { + Configuration configuration = new Configuration(); + + // Configure database properties + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", USERNAME); + dbProps.setProperty("password", PASSWORD); + dbProps.setProperty("driver", "org.postgresql.Driver"); + + SparkTableSink sink = new SparkTableSink(dbProps, "overwrite", TABLE_NAME, + new String[]{"id", "name", "value"}, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + // Create input RDD with test data + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + Record record3 = new Record(3, "Charlie", 300.25); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(record1, record2, record3) + ); + + // Execute the sink + evaluate(sink, new ChannelInstance[]{inputChannelInstance}, new ChannelInstance[0]); + + // Verify data was written to database + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(3, rs.getInt(1), "Should have written 3 records"); + } + + // Verify specific record + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE id = 1")) { + rs.next(); + assertEquals("Alice", rs.getString("name")); + assertEquals(100.5, rs.getDouble("value"), 0.01); + } + } +} \ No newline at end of file