From 95088b29fbf885cfcbaf7c0c2c0c1b0491e0efb1 Mon Sep 17 00:00:00 2001 From: luxu1-ms Date: Wed, 20 Oct 2021 16:19:06 -0700 Subject: [PATCH 1/2] skip sql cols if not in df when strictSchema check disabled --- .../jdbc/spark/utils/BulkCopyUtils.scala | 109 ++++++------------ 1 file changed, 34 insertions(+), 75 deletions(-) diff --git a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala index 8b06e8e..1bd5894 100644 --- a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala +++ b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.{createConnectionFactory, getSchema, schemaString} import com.microsoft.sqlserver.jdbc.{SQLServerBulkCopy, SQLServerBulkCopyOptions} -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.ArrayBuffer +import scala.util.control.Breaks.{breakable,break} /** * BulkCopyUtils Object implements common utility function used by both datapool and @@ -179,47 +180,6 @@ object BulkCopyUtils extends Logging { conn.createStatement.executeQuery(queryStr) } - /** - * getComputedCols - * utility function to get computed columns. - * Use computed column names to exclude computed column when matching schema. - */ - private[spark] def getComputedCols( - conn: Connection, - table: String): List[String] = { - val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');" - val computedColRs = conn.createStatement.executeQuery(queryStr) - val computedCols = ListBuffer[String]() - while (computedColRs.next()) { - val colName = computedColRs.getString("name") - computedCols.append(colName) - } - computedCols.toList - } - - /** - * dfComputedColCount - * utility function to get number of computed columns in dataframe. - * Use number of computed columns in dataframe to get number of non computed column in df, - * and compare with the number of non computed column in sql table - */ - private[spark] def dfComputedColCount( - dfColNames: List[String], - computedCols: List[String], - dfColCaseMap: Map[String, String], - isCaseSensitive: Boolean): Int ={ - var dfComputedColCt = 0 - for (j <- 0 to computedCols.length-1){ - if (isCaseSensitive && dfColNames.contains(computedCols(j)) || - !isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase()) - && dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) { - dfComputedColCt += 1 - } - } - dfComputedColCt - } - - /** * getColMetadataMap * Utility function convert result set meta data to array. @@ -303,44 +263,43 @@ object BulkCopyUtils extends Logging { val dfCols = df.schema val tableCols = getSchema(rs, JdbcDialects.get(url)) - val computedCols = getComputedCols(conn, dbtable) - val prefix = "Spark Dataframe and SQL Server table have differing" - if (computedCols.length == 0) { - assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck, - s"${prefix} numbers of columns") - } else if (strictSchemaCheck) { - val dfColNames = df.schema.fieldNames.toList - val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive) - // if df has computed column(s), check column length using non computed column in df and table. - // non computed column number in df: dfCols.length - dfComputedColCt - // non computed column number in table: tableCols.length - computedCols.length - assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck, - s"${prefix} numbers of columns") - } - + assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck, + s"${prefix} numbers of columns") - val result = new Array[ColumnMetadata](tableCols.length - computedCols.length) - var nonAutoColIndex = 0 + val result = new ArrayBuffer[ColumnMetadata]() for (i <- 0 to tableCols.length-1) { - val tableColName = tableCols(i).name - var dfFieldIndex = -1 - // set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata - if (computedCols.contains(tableColName)) { - logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex") - }else{ + breakable { + val tableColName = tableCols(i).name + var dfFieldIndex = 0 var dfColName:String = "" if (isCaseSensitive) { - dfFieldIndex = dfCols.fieldIndex(tableColName) + if (!dfCols.fieldNames.contains(tableColName)) { + if (strictSchemaCheck) { + throw new SQLException(s"SQL table column ${tableColName} not exist in df columns") + } else { + // when strictSchema check disabled, skip mapping / metadata if table col not in df col + logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata") + break + } + } dfColName = dfCols(dfFieldIndex).name assertIfCheckEnabled( tableColName == dfColName, strictSchemaCheck, s"""${prefix} column names '${tableColName}' and '${dfColName}' at column index ${i} (case sensitive)""") } else { - dfFieldIndex = dfCols.fieldIndex(dfColCaseMap(tableColName.toLowerCase())) + if (!dfColCaseMap.contains(tableColName.toLowerCase())) { + if (strictSchemaCheck) { + throw new SQLException(s"SQL table column ${tableColName} not exist in df columns") + } else { + // when strictSchema check disabled, skip mapping / metadata if table col not in df col + logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata") + break + } + } dfColName = dfCols(dfFieldIndex).name assertIfCheckEnabled( tableColName.toLowerCase() == dfColName.toLowerCase(), @@ -362,28 +321,28 @@ object BulkCopyUtils extends Logging { dfCols(dfFieldIndex).dataType == tableCols(i).dataType, strictSchemaCheck, s"${prefix} column data types at column index ${i}." + - s" DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " + - s" Table col ${tableColName} dataType ${tableCols(i).dataType} ") + s" DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " + + s" Table col ${tableColName} dataType ${tableCols(i).dataType} ") } assertIfCheckEnabled( dfCols(dfFieldIndex).nullable == tableCols(i).nullable, strictSchemaCheck, s"${prefix} column nullable configurations at column index ${i}" + - s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " + - s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}") + s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " + + s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}") - // Schema check passed for element, Create ColMetaData only for non auto generated column - result(nonAutoColIndex) = new ColumnMetadata( + // Schema check passed for element, Create ColMetaData + result += new ColumnMetadata( rs.getMetaData().getColumnName(i+1), rs.getMetaData().getColumnType(i+1), rs.getMetaData().getPrecision(i+1), rs.getMetaData().getScale(i+1), dfFieldIndex ) - nonAutoColIndex += 1 } } - result + logDebug(s"metadata includes ${result.length} columns") + result.toArray } /** From 5d467d982faa68da79349fc5c9f2951a355e455a Mon Sep 17 00:00:00 2001 From: luxu1-ms Date: Fri, 22 Oct 2021 00:57:44 -0700 Subject: [PATCH 2/2] remove exception catch --- .../jdbc/spark/utils/BulkCopyUtils.scala | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala index 1bd5894..fc9cc0c 100644 --- a/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala +++ b/src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala @@ -276,30 +276,24 @@ object BulkCopyUtils extends Logging { var dfFieldIndex = 0 var dfColName:String = "" if (isCaseSensitive) { - if (!dfCols.fieldNames.contains(tableColName)) { - if (strictSchemaCheck) { - throw new SQLException(s"SQL table column ${tableColName} not exist in df columns") - } else { - // when strictSchema check disabled, skip mapping / metadata if table col not in df col - logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata") - break - } + if (!strictSchemaCheck && !dfCols.fieldNames.contains(tableColName)) { + // when strictSchema check disabled, skip mapping / metadata if table col not in df col + logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata") + break } + dfFieldIndex = dfCols.fieldIndex(tableColName) dfColName = dfCols(dfFieldIndex).name assertIfCheckEnabled( tableColName == dfColName, strictSchemaCheck, s"""${prefix} column names '${tableColName}' and '${dfColName}' at column index ${i} (case sensitive)""") } else { - if (!dfColCaseMap.contains(tableColName.toLowerCase())) { - if (strictSchemaCheck) { - throw new SQLException(s"SQL table column ${tableColName} not exist in df columns") - } else { - // when strictSchema check disabled, skip mapping / metadata if table col not in df col - logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata") - break - } + if (!strictSchemaCheck && !dfColCaseMap.contains(tableColName.toLowerCase())) { + // when strictSchema check disabled, skip mapping / metadata if table col not in df col + logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata") + break } + dfFieldIndex = dfCols.fieldIndex(dfColCaseMap(tableColName.toLowerCase())) dfColName = dfCols(dfFieldIndex).name assertIfCheckEnabled( tableColName.toLowerCase() == dfColName.toLowerCase(),