Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 33 additions & 47 deletions api/tests_v2/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ def __init__(self):
def init_from_json(self, filename, config_id=0, unknown_dim=16):
super(BatchNormConfig, self).init_from_json(filename, config_id,
unknown_dim)
# tf's batch_norm does not have data_format param, it only support NHWC format.
if self.data_format == "NCHW":
print(
"Warning:\n"
" 1. tf's batch_norm does not have data_format param, it only support NHWC format.\n"
)
self.run_tf = False

if len(self.x_shape) == 4:
if self.data_format == "NCHW":
self.num_channels = self.x_shape[1]
Expand All @@ -41,48 +33,46 @@ def init_from_json(self, filename, config_id=0, unknown_dim=16):
def to_tensorflow(self):
tf_config = super(BatchNormConfig, self).to_tensorflow()
if len(tf_config.x_shape) == 4:
tf_config.axes = [0, 1, 2]
tf_config.axis = 1 if self.data_format == "NCHW" else 3
else:
tf_config.axes = [0]
tf_config.axis = 1
return tf_config


class PDBatchNorm(PaddleAPIBenchmarkBase):
def build_program(self, config):
def _create_parameter(name, value, stop_gradient):
param = paddle.create_parameter(
name=name,
shape=[config.num_channels],
dtype=config.x_dtype,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value)))
param.stop_gradient = stop_gradient
return param

x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)

running_mean = paddle.create_parameter(
name='running_mean',
shape=[config.num_channels],
dtype=config.x_dtype,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.5)))
running_mean.stop_gradient = True
running_var = paddle.create_parameter(
name='running_var',
shape=[config.num_channels],
dtype=config.x_dtype,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.1)))
running_var.stop_gradient = True

scale = self.variable(
name='scale', shape=[config.num_channels], dtype=config.x_dtype)
bias = self.variable(
name='bias', shape=[config.num_channels], dtype=config.x_dtype)
running_mean = _create_parameter(
name='running_mean', value=0.5, stop_gradient=True)
running_var = _create_parameter(
name='running_var', value=0.1, stop_gradient=True)

scale = _create_parameter(name='scale', value=0.5, stop_gradient=False)
bias = _create_parameter(name='bias', value=0.1, stop_gradient=False)

result = paddle.nn.functional.batch_norm(
x=x,
running_mean=running_mean,
running_var=running_var,
weight=scale, # scale
bias=bias, # bias
weight=scale,
bias=bias,
epsilon=config.epsilon,
momentum=config.momentum,
training=config.training,
data_format=config.data_format)

self.feed_vars = [x, scale, bias]
self.feed_vars = [x]
self.fetch_vars = [result]
if config.backward:
self.append_gradients(result, [x, scale, bias])
Expand All @@ -91,24 +81,20 @@ def build_program(self, config):
class TFBatchNorm(TensorflowAPIBenchmarkBase):
def build_graph(self, config):
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
scale = self.variable(
name='scale', shape=[config.num_channels], dtype=config.x_dtype)
bias = self.variable(
name='bias', shape=[config.num_channels], dtype=config.x_dtype)
mean, var = tf.nn.moments(
x=x, axes=config.axes, shift=None, keepdims=False)
result = tf.nn.batch_normalization(
x=x,
mean=mean,
variance=var,
offset=bias,
scale=scale,
variance_epsilon=config.epsilon)
bn = tf.keras.layers.BatchNormalization(
axis=config.axis,
momentum=config.momentum,
epsilon=config.epsilon,
beta_initializer=tf.constant_initializer(0.1),
gamma_initializer=tf.constant_initializer(0.5),
moving_mean_initializer=tf.constant_initializer(0.5),
moving_variance_initializer=tf.constant_initializer(0.1))
result = bn(x, training=config.training)

self.feed_list = [x, scale, bias]
self.feed_list = [x]
self.fetch_list = [result]
if config.backward:
self.append_gradients(result, [x, scale, bias])
self.append_gradients(result, [x, bn.gamma, bn.beta])


if __name__ == '__main__':
Expand Down
6 changes: 4 additions & 2 deletions api/tests_v2/configs/batch_norm.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
"type": "float",
"value": "0.9"
}
}
},
"atol": 1E-5
}, {
"config_id": 1,
"op": "batch_norm",
Expand All @@ -49,7 +50,8 @@
"type": "float",
"value": "0.9"
}
}
},
"atol": 1E-4
}, {
"config_id": 2,
"op": "batch_norm",
Expand Down
22 changes: 7 additions & 15 deletions api/tests_v2/while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,13 @@ def cond(i, loop_len, input, result):
return i < loop_len

def body(i, loop_len, input, result):
if tf.__version__ <= "1.15.0":
result = tf.contrib.layers.fully_connected(
inputs=input,
num_outputs=config.size,
weights_initializer=tf.constant_initializer(0.5),
biases_initializer=tf.constant_initializer(0.1),
activation_fn=None)
else:
result = tf.compat.v1.layers.dense(
inputs=input,
units=config.size,
activation=None,
use_bias=True,
kernel_initializer=tf.constant_initializer(0.5),
bias_initializer=tf.constant_initializer(0.1))
result = tf.compat.v1.layers.dense(
inputs=input,
units=config.size,
activation=None,
use_bias=True,
kernel_initializer=tf.constant_initializer(0.5),
bias_initializer=tf.constant_initializer(0.1))
return [i + 1, loop_len, input, result]

input = self.variable(
Expand Down