|
23 | 23 | from smdebug.core.utils import SagemakerSimulator, ScriptSimulator |
24 | 24 |
|
25 | 25 |
|
| 26 | +class CustomCrossEntropyLoss(nn.modules.loss._WeightedLoss): |
| 27 | + __constants__ = ["weight", "ignore_index", "reduction"] |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean" |
| 31 | + ): |
| 32 | + super(CustomCrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) |
| 33 | + self.ignore_index = ignore_index |
| 34 | + |
| 35 | + def forward(self, input, target): |
| 36 | + return F.cross_entropy( |
| 37 | + input, |
| 38 | + target, |
| 39 | + weight=self.weight, |
| 40 | + ignore_index=self.ignore_index, |
| 41 | + reduction=self.reduction, |
| 42 | + ) |
| 43 | + |
| 44 | + |
26 | 45 | @pytest.mark.skipif( |
27 | 46 | torch.__version__ == "1.7.0", |
28 | 47 | reason="Disabling the test temporarily until we root cause the version incompatibility", |
29 | 48 | ) |
30 | 49 | @pytest.mark.parametrize("script_mode", [False]) |
31 | 50 | @pytest.mark.parametrize("use_loss_module", [True, False]) |
32 | | -def test_pytorch(script_mode, use_loss_module): |
| 51 | +@pytest.mark.parametrize("use_custom_loss_module", [True, False]) |
| 52 | +def test_pytorch(script_mode, use_loss_module, use_custom_loss_module): |
33 | 53 | smd.del_hook() |
34 | 54 |
|
35 | 55 | sim_class = ScriptSimulator if script_mode else SagemakerSimulator |
36 | 56 | with sim_class() as sim: |
37 | 57 | trainloader, testloader = get_dataloaders() |
38 | 58 | net = Net() |
39 | | - criterion = nn.CrossEntropyLoss() |
| 59 | + if use_custom_loss_module: |
| 60 | + criterion = CustomCrossEntropyLoss() |
| 61 | + else: |
| 62 | + criterion = nn.CrossEntropyLoss() |
40 | 63 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) |
41 | 64 |
|
42 | 65 | if script_mode: |
|
0 commit comments