Skip to content

Commit 4b37151

Browse files
authored
[NPU] replace ce loss with nll loss (#3759) (#3782)
1 parent 6758049 commit 4b37151

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

paddleseg/models/losses/cross_entropy_loss.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from paddleseg.cvlibs import manager
2020

21+
_IS_NPU = "npu" in paddle.get_device()
22+
2123

2224
@manager.LOSSES.add_component
2325
class CrossEntropyLoss(nn.Layer):
@@ -81,11 +83,20 @@ def forward(self, logit, label, semantic_weights=None):
8183
logit = paddle.transpose(logit, [0, 2, 3, 1])
8284
label = label.astype('int64')
8385

84-
loss = F.cross_entropy(logit,
85-
label,
86-
ignore_index=self.ignore_index,
87-
reduction='none',
88-
weight=self.weight)
86+
if _IS_NPU:
87+
logit = logit.transpose([0, 3, 1, 2])
88+
logit = F.log_softmax(logit, axis=1)
89+
loss = F.nll_loss(logit,
90+
label,
91+
weight=self.weight,
92+
ignore_index=self.ignore_index,
93+
reduction='none')
94+
else:
95+
loss = F.cross_entropy(logit,
96+
label,
97+
ignore_index=self.ignore_index,
98+
reduction='none',
99+
weight=self.weight)
89100

90101
return self._post_process_loss(logit, label, semantic_weights, loss)
91102

0 commit comments

Comments
 (0)