Skip to content

Commit b6521c6

Browse files
authored
[API Compatibility] Improve compat.nn.Linear init and param reset (#76196)
1 parent 23d449e commit b6521c6

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

python/paddle/compat/nn/__init__.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,20 +250,11 @@ def reset_parameters(self) -> None:
250250
Resets parameters based on their initialization used in ``__init__``.
251251
"""
252252

253-
# KaimingUniform initializer should be more flexible: user should be able to specify place
254-
expected_place = paddle.base.framework._current_expected_place()
255-
original_place = self.weight.place
256253
nn.init.kaiming_uniform_(self.weight, a=sqrt(5))
257-
258-
place_mismatch = expected_place != original_place
259-
if place_mismatch and in_dynamic_mode():
260-
self.weight = self.weight.to(original_place)
261254
if self.bias is not None:
262255
# nn.init._calculate_fan_in_and_fan_out(self.weight) for 2D array
263256
# is equivalent to returning (weight.shape[1], weight.shape[0])
257+
# TODO(heqianyue): use _calculate_fan_in_and_fan_out when available
264258
fan_in = self.weight.shape[1]
265259
bound = 1 / sqrt(fan_in) if fan_in > 0 else 0
266260
nn.init.uniform_(self.bias, -bound, bound)
267-
268-
if place_mismatch and in_dynamic_mode():
269-
self.bias = self.bias.to(original_place)

python/paddle/nn/initializer/kaiming.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,19 @@ def forward(
173173
-limit,
174174
limit,
175175
self._seed,
176-
_current_expected_place(),
176+
var.place
177+
if var.place._type()
178+
else _current_expected_place(),
177179
)
178180
else:
179181
gain = calculate_gain(self._nonlinearity, self._negative_slope)
180182
std = gain / math.sqrt(float(fan_in))
181-
place = _current_expected_place()
183+
# var.place._type() means undefined, happens when initializer is specified in ParamAttr
184+
place = (
185+
var.place
186+
if var.place._type()
187+
else _current_expected_place()
188+
)
182189
out_var = _C_ops.gaussian(
183190
out_var.shape, 0.0, std, self._seed, out_dtype, place
184191
)

python/paddle/nn/initializer/uniform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def forward(
126126
self._low,
127127
self._high,
128128
self._seed,
129-
_current_expected_place(),
129+
var.place if var.place._type() else _current_expected_place(),
130130
)
131131
if var.dtype == core.VarDesc.VarType.FP16:
132132
var_tmp = _C_ops.cast(out_var, var.dtype)

test/legacy_test/test_compat_nn_linear.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,18 @@ def test_weight_transpose_behavior(self):
347347

348348
np.testing.assert_allclose(y_pd.numpy(), expected, rtol=1e-5, atol=1e-8)
349349

350+
def test_reset_parameters(self):
351+
if not paddle.base.is_compiled_with_cuda():
352+
return
353+
devices = ['cpu', None] # None means the default device
354+
for device_ in devices:
355+
dummy_tensor = paddle.zeros(1, device=device_)
356+
lin = paddle.compat.nn.Linear(4, 8, bias=True, device=device_)
357+
expected_device = dummy_tensor.place
358+
lin.reset_parameters()
359+
self.assertEqual(lin.weight.place, expected_device)
360+
self.assertEqual(lin.bias.place, expected_device)
361+
350362
def test_error_handling(self):
351363
"""Test error handling for invalid inputs"""
352364
# Shape mismatch between input and weight

0 commit comments

Comments
 (0)