Skip to content

Commit 08befe2

Browse files
committed
tftraininghelper evaluate
1 parent 6c2ea24 commit 08befe2

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

pyzoo/zoo/pipeline/api/net/tf_optimizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from bigdl.nn.criterion import Criterion
2525
from bigdl.nn.layer import Layer
26-
from bigdl.util.common import to_list, JavaValue
26+
from bigdl.util.common import to_list, JavaValue, callBigDlFunc
2727
from bigdl.optim.optimizer import EveryEpoch, MaxEpoch, SeveralIteration
2828
from zoo.pipeline.api.keras.engine.topology import to_bigdl_metric
2929
from zoo.pipeline.api.keras.optimizers import DistriOptimizer
@@ -54,6 +54,12 @@ def __init__(self, path, configProto):
5454
byte_arr = None
5555
super(TFTrainingHelper, self).__init__(None, "float", path, byte_arr)
5656

57+
def evaluate(self, dataset, batch_size, val_methods):
58+
return callBigDlFunc(self.bigdl_type,
59+
"tfEvaluate",
60+
self.value,
61+
dataset, batch_size, val_methods)
62+
5763

5864
class TFOptimizer:
5965
def __init__(self, loss, optim_method, sess=None, dataset=None, inputs=None,
@@ -182,10 +188,12 @@ def to_floats(vs):
182188

183189
if val_outputs is not None and val_labels is not None:
184190
val_rdd = self.dataset.get_validation_data()
191+
self.val_rdd = val_rdd
185192
if val_rdd is not None:
186193
val_method = [TFValidationMethod(m, len(val_outputs), len(val_labels))
187194
for m in to_list(val_method)]
188195
training_rdd = sample_rdd
196+
self.val_method = val_method
189197

190198
elif val_split != 0.0:
191199
training_rdd, val_rdd = sample_rdd.randomSplit([1 - val_split, val_split])

zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/keras/python/PythonZooKeras.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,4 +1414,23 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ
14141414
def createEpochStep(stepSize: Int, gamma: Double): SGD.EpochStep = {
14151415
SGD.EpochStep(stepSize, gamma)
14161416
}
1417+
1418+
1419+
def tfEvaluate(model: AbstractModule[Activity, Activity, T],
1420+
valRDD: JavaRDD[Sample],
1421+
batchSize: Int,
1422+
valMethods: JList[ValidationMethod[T]])
1423+
: JList[EvaluatedResult] = {
1424+
val sampleRDD = toJSample(valRDD)
1425+
val featureSize = sampleRDD.first().numFeature()
1426+
val dataSet = batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize)
1427+
val rdd = dataSet.toDistributed().data(train = false)
1428+
val resultArray = model.evaluate(rdd,
1429+
valMethods.asScala.toArray)
1430+
val testResultArray = resultArray.map { result =>
1431+
EvaluatedResult(result._1.result()._1, result._1.result()._2,
1432+
result._2.toString())
1433+
}
1434+
testResultArray.toList.asJava
1435+
}
14171436
}

0 commit comments

Comments
 (0)