Skip to content

Commit 0d6c170

Browse files
committed
Review fixes
Signed-off-by: Marek Dabek <mdabek@nvidia.com>
1 parent 683bc6e commit 0d6c170

4 files changed

Lines changed: 9 additions & 13 deletions

File tree

dali/python/nvidia/dali/experimental/torchvision/v2/color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def verify(cls, data_input) -> None:
9999
layout = data_input.property("layout")[1]
100100

101101
# CHW
102-
if layout.cpu() == np.frombuffer(bytes("C", "utf-8"), dtype=np.uint8)[0]:
102+
if layout == np.frombuffer(bytes("C", "utf-8"), dtype=np.uint8)[0]:
103103
raise NotImplementedError(
104104
"NCHW and CHW layout are not supported for Grayscale, expecting HWC or NHWC"
105105
)

dali/python/nvidia/dali/experimental/torchvision/v2/functional/color.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ def _grayscale(
4444
return ndd.cat(inpt, inpt, inpt, axis_name="C")
4545
else:
4646
return ndd.hsv(inpt, saturation=0, device=device)
47-
else:
48-
return ndd.hsv(inpt, saturation=0, device=device)
4947

5048

5149
@adjust_input

dali/python/nvidia/dali/experimental/torchvision/v2/functional/normalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import nvidia.dali.experimental.dynamic as ndd
2020

21-
from ..operator import adjust_input, VerificationIsTensor # noqa: E402
21+
from ..operator import adjust_input, _ValidateIsTensor # noqa: E402
2222
from ..normalize import Normalize # noqa: E402
2323

2424

@@ -54,7 +54,7 @@ def normalize(
5454
std = np.asarray(std)[:, None, None]
5555

5656
Normalize.verify_args(std=std, mean=mean)
57-
VerificationIsTensor.verify(input_data)
57+
_ValidateIsTensor.verify(input_data)
5858

5959
if inplace:
6060
raise NotImplementedError("inplace is not implemented, yet")

dali/python/nvidia/dali/experimental/torchvision/v2/normalize.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from typing import Sequence, Literal
1818
import nvidia.dali.fn as fn
1919

20-
from .operator import Operator, ArgumentVerificationRule
20+
from .operator import Operator, _ArgumentValidateRule
2121

2222

23-
class VerifyStd(ArgumentVerificationRule):
23+
class _ValidateStd(_ArgumentValidateRule):
2424
"""
2525
Verify the standard deviation argument for the Normalize operator.
2626
@@ -32,17 +32,15 @@ class VerifyStd(ArgumentVerificationRule):
3232

3333
@classmethod
3434
def verify(cls, *, std, **_) -> None:
35-
if (
36-
not isinstance(std, (int, float, Sequence, torch.Tensor, np.ndarray))
37-
or isinstance(std, Sequence)
38-
and isinstance(std, str)
35+
if not isinstance(std, (int, float, Sequence, torch.Tensor, np.ndarray)) or (
36+
isinstance(std, Sequence) and isinstance(std, str)
3937
):
4038
raise TypeError(f"Std must be an int, a float or a Sequence, got {type(std)}")
4139
if np.any(np.array(std) == 0):
4240
raise ValueError("Std must not be 0")
4341

4442

45-
class VerifyMean(ArgumentVerificationRule):
43+
class _ValidateMean(_ArgumentValidateRule):
4644
"""
4745
Verify the mean argument for the Normalize operator.
4846
@@ -77,7 +75,7 @@ class Normalize(Operator):
7775
Bool to make this operation in-place. Not supported.
7876
"""
7977

80-
arg_rules = [VerifyStd, VerifyMean]
78+
arg_rules = [_ValidateStd, _ValidateMean]
8179
# TODO: currently not supported
8280
# input_rules = [VerificationIsTensor]
8381

0 commit comments

Comments
 (0)