@@ -22,7 +22,8 @@ import org.apache.spark.sql.jdbc.JdbcDialects
2222import org .apache .spark .sql .execution .datasources .jdbc .JdbcUtils .{createConnectionFactory , getSchema , schemaString }
2323import com .microsoft .sqlserver .jdbc .{SQLServerBulkCopy , SQLServerBulkCopyOptions }
2424
25- import scala .collection .mutable .ListBuffer
25+ import scala .collection .mutable .ArrayBuffer
26+ import scala .util .control .Breaks .{breakable ,break }
2627
2728/**
2829* BulkCopyUtils Object implements common utility function used by both datapool and
@@ -179,47 +180,6 @@ object BulkCopyUtils extends Logging {
179180 conn.createStatement.executeQuery(queryStr)
180181 }
181182
182- /**
183- * getComputedCols
184- * utility function to get computed columns.
185- * Use computed column names to exclude computed column when matching schema.
186- */
187- private [spark] def getComputedCols (
188- conn : Connection ,
189- table : String ): List [String ] = {
190- val queryStr = s " SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID(' ${table}'); "
191- val computedColRs = conn.createStatement.executeQuery(queryStr)
192- val computedCols = ListBuffer [String ]()
193- while (computedColRs.next()) {
194- val colName = computedColRs.getString(" name" )
195- computedCols.append(colName)
196- }
197- computedCols.toList
198- }
199-
200- /**
201- * dfComputedColCount
202- * utility function to get number of computed columns in dataframe.
203- * Use number of computed columns in dataframe to get number of non computed column in df,
204- * and compare with the number of non computed column in sql table
205- */
206- private [spark] def dfComputedColCount (
207- dfColNames : List [String ],
208- computedCols : List [String ],
209- dfColCaseMap : Map [String , String ],
210- isCaseSensitive : Boolean ): Int = {
211- var dfComputedColCt = 0
212- for (j <- 0 to computedCols.length- 1 ){
213- if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
214- ! isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
215- && dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
216- dfComputedColCt += 1
217- }
218- }
219- dfComputedColCt
220- }
221-
222-
223183 /**
224184 * getColMetadataMap
225185 * Utility function convert result set meta data to array.
@@ -303,37 +263,32 @@ object BulkCopyUtils extends Logging {
303263 val dfCols = df.schema
304264
305265 val tableCols = getSchema(rs, JdbcDialects .get(url))
306- val computedCols = getComputedCols(conn, dbtable)
307-
308266 val prefix = " Spark Dataframe and SQL Server table have differing"
309267
310- if (computedCols.length == 0 ) {
311- assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
312- s " ${prefix} numbers of columns " )
313- } else if (strictSchemaCheck) {
314- val dfColNames = df.schema.fieldNames.toList
315- val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
316- // if df has computed column(s), check column length using non computed column in df and table.
317- // non computed column number in df: dfCols.length - dfComputedColCt
318- // non computed column number in table: tableCols.length - computedCols.length
319- assertIfCheckEnabled(dfCols.length- dfComputedColCt == tableCols.length- computedCols.length, strictSchemaCheck,
320- s " ${prefix} numbers of columns " )
321- }
322-
268+ assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
269+ s " ${prefix} numbers of columns " )
323270
324- val result = new Array [ColumnMetadata ](tableCols.length - computedCols.length)
325- var nonAutoColIndex = 0
271+ val result = new ArrayBuffer [ColumnMetadata ]()
326272
327273 for (i <- 0 to tableCols.length- 1 ) {
328- val tableColName = tableCols(i).name
329- var dfFieldIndex = - 1
330- // set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
331- if (computedCols.contains(tableColName)) {
332- logDebug(s " skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex" )
333- }else {
274+ breakable {
275+ val tableColName = tableCols(i).name
276+ var dfFieldIndex = 0
334277 var dfColName : String = " "
335278 if (isCaseSensitive) {
336- dfFieldIndex = dfCols.fieldIndex(tableColName)
279+ // skip mapping / metadata if table col not in df col (strictSchema check disabled)
280+ logDebug(s " df contains ${tableColName}: ${dfCols.fieldNames.contains(tableColName)}" )
281+ if (! strictSchemaCheck && ! dfCols.fieldNames.contains(tableColName)) {
282+ logDebug(s " skipping index $i sql col name $tableColName dfFieldIndex $dfFieldIndex" )
283+ break
284+ }
285+ try {
286+ dfFieldIndex = dfCols.fieldIndex(tableColName)
287+ } catch {
288+ case ex : IllegalArgumentException => {
289+ throw new SQLException (s " SQL table column ${tableColName} not exist in df columns " )
290+ }
291+ }
337292 dfColName = dfCols(dfFieldIndex).name
338293 assertIfCheckEnabled(
339294 tableColName == dfColName, strictSchemaCheck,
@@ -362,28 +317,29 @@ object BulkCopyUtils extends Logging {
362317 dfCols(dfFieldIndex).dataType == tableCols(i).dataType,
363318 strictSchemaCheck,
364319 s " ${prefix} column data types at column index ${i}. " +
365- s " DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
366- s " Table col ${tableColName} dataType ${tableCols(i).dataType} " )
320+ s " DF col ${dfColName} dataType ${dfCols(dfFieldIndex).dataType} " +
321+ s " Table col ${tableColName} dataType ${tableCols(i).dataType} " )
367322 }
368323 assertIfCheckEnabled(
369324 dfCols(dfFieldIndex).nullable == tableCols(i).nullable,
370325 strictSchemaCheck,
371326 s " ${prefix} column nullable configurations at column index ${i}" +
372- s " DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
373- s " Table col ${tableColName} nullable config is ${tableCols(i).nullable}" )
327+ s " DF col ${dfColName} nullable config is ${dfCols(dfFieldIndex).nullable} " +
328+ s " Table col ${tableColName} nullable config is ${tableCols(i).nullable}" )
374329
375- // Schema check passed for element, Create ColMetaData only for non auto generated column
376- result(nonAutoColIndex) = new ColumnMetadata (
330+ // Schema check passed for element, Create ColMetaData
331+ result + = new ColumnMetadata (
377332 rs.getMetaData().getColumnName(i+ 1 ),
378333 rs.getMetaData().getColumnType(i+ 1 ),
379334 rs.getMetaData().getPrecision(i+ 1 ),
380335 rs.getMetaData().getScale(i+ 1 ),
381336 dfFieldIndex
382337 )
383- nonAutoColIndex += 1
338+ logDebug( s " one col metadata name: ${rs.getMetaData().getColumnName(i + 1 )} " )
384339 }
385340 }
386- result
341+ logDebug(s " metadata: ${result.toArray}, column length: ${result.length}" )
342+ result.toArray
387343 }
388344
389345 /**
0 commit comments