Skip to content
This repository was archived by the owner on Feb 27, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.{createConnectionFactory, getSchema, schemaString}
import com.microsoft.sqlserver.jdbc.{SQLServerBulkCopy, SQLServerBulkCopyOptions}

import scala.collection.mutable.ListBuffer
import scala.collection.mutable.ArrayBuffer
import scala.util.control.Breaks.{breakable,break}

/**
* BulkCopyUtils Object implements common utility function used by both datapool and
Expand Down Expand Up @@ -179,47 +180,6 @@ object BulkCopyUtils extends Logging {
conn.createStatement.executeQuery(queryStr)
}

/**
* getComputedCols
* utility function to get computed columns.
* Use computed column names to exclude computed column when matching schema.
*/
private[spark] def getComputedCols(
conn: Connection,
table: String): List[String] = {
val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"
val computedColRs = conn.createStatement.executeQuery(queryStr)
val computedCols = ListBuffer[String]()
while (computedColRs.next()) {
val colName = computedColRs.getString("name")
computedCols.append(colName)
}
computedCols.toList
}

/**
* dfComputedColCount
* utility function to get number of computed columns in dataframe.
* Use number of computed columns in dataframe to get number of non computed column in df,
* and compare with the number of non computed column in sql table
*/
private[spark] def dfComputedColCount(
dfColNames: List[String],
computedCols: List[String],
dfColCaseMap: Map[String, String],
isCaseSensitive: Boolean): Int ={
var dfComputedColCt = 0
for (j <- 0 to computedCols.length-1){
if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
!isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
&& dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
dfComputedColCt += 1
}
}
dfComputedColCt
}


/**
* getColMetadataMap
* Utility function convert result set meta data to array.
Expand Down Expand Up @@ -303,43 +263,36 @@ object BulkCopyUtils extends Logging {
val dfCols = df.schema

val tableCols = getSchema(rs, JdbcDialects.get(url))
val computedCols = getComputedCols(conn, dbtable)

val prefix = "Spark Dataframe and SQL Server table have differing"

if (computedCols.length == 0) {
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")
} else if (strictSchemaCheck) {
val dfColNames = df.schema.fieldNames.toList
val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
// if df has computed column(s), check column length using non computed column in df and table.
// non computed column number in df: dfCols.length - dfComputedColCt
// non computed column number in table: tableCols.length - computedCols.length
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")
}

assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
s"${prefix} numbers of columns")

val result = new Array[ColumnMetadata](tableCols.length - computedCols.length)
var nonAutoColIndex = 0
val result = new ArrayBuffer[ColumnMetadata]()

for (i <- 0 to tableCols.length-1) {
val tableColName = tableCols(i).name
var dfFieldIndex = -1
// set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
if (computedCols.contains(tableColName)) {
logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
}else{
breakable {
val tableColName = tableCols(i).name
var dfFieldIndex = 0
var dfColName:String = ""
if (isCaseSensitive) {
if (!strictSchemaCheck && !dfCols.fieldNames.contains(tableColName)) {
// when strictSchema check disabled, skip mapping / metadata if table col not in df col
logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata")
break
}
dfFieldIndex = dfCols.fieldIndex(tableColName)
dfColName = dfCols(dfFieldIndex).name
assertIfCheckEnabled(
tableColName == dfColName, strictSchemaCheck,
s"""${prefix} column names '${tableColName}' and
'${dfColName}' at column index ${i} (case sensitive)""")
} else {
if (!strictSchemaCheck && !dfColCaseMap.contains(tableColName.toLowerCase())) {
// when strictSchema check disabled, skip mapping / metadata if table col not in df col
logDebug(s"skipping index $i sql col name $tableColName mapping and column metadata")
break
}
dfFieldIndex = dfCols.fieldIndex(dfColCaseMap(tableColName.toLowerCase()))
dfColName = dfCols(dfFieldIndex).name
assertIfCheckEnabled(
Expand All @@ -362,28 +315,28 @@ object BulkCopyUtils extends Logging {
dfCols(dfFieldIndex).dataType == tableCols(i).dataType,
strictSchemaCheck,
s"${prefix} column data types at column index ${i}." +
s" DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
s" Table col ${tableColName} dataType ${tableCols(i).dataType} ")
s" DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
s" Table col ${tableColName} dataType ${tableCols(i).dataType} ")
}
assertIfCheckEnabled(
dfCols(dfFieldIndex).nullable == tableCols(i).nullable,
strictSchemaCheck,
s"${prefix} column nullable configurations at column index ${i}" +
s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}")
s" DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
s" Table col ${tableColName} nullable config is ${tableCols(i).nullable}")

// Schema check passed for element, Create ColMetaData only for non auto generated column
result(nonAutoColIndex) = new ColumnMetadata(
// Schema check passed for element, Create ColMetaData
result += new ColumnMetadata(
rs.getMetaData().getColumnName(i+1),
rs.getMetaData().getColumnType(i+1),
rs.getMetaData().getPrecision(i+1),
rs.getMetaData().getScale(i+1),
dfFieldIndex
)
nonAutoColIndex += 1
}
}
result
logDebug(s"metadata includes ${result.length} columns")
result.toArray
}

/**
Expand Down