2222class ShardConfigTest (parameterized .TestCase ):
2323
2424 @parameterized .named_parameters (
25- ('imagenet train, 137 GiB' , 137 << 30 , 1281167 , True , 1024 ),
26- ('imagenet evaluation, 6.3 GiB' , 6300 * (1 << 20 ), 50000 , True , 64 ),
27- ('very large, but few examples, 52 GiB' , 52 << 30 , 512 , True , 512 ),
28- ('xxl, 10 TiB' , 10 << 40 , 10 ** 9 , True , 11264 ),
29- ('xxl, 10 PiB, 100B examples' , 10 << 50 , 10 ** 11 , True , 10487808 ),
30- ('xs, 100 MiB, 100K records' , 10 << 20 , 100 * 10 ** 3 , True , 1 ),
31- ('m, 499 MiB, 200K examples' , 400 << 20 , 200 * 10 ** 3 , True , 4 ),
25+ dict (
26+ testcase_name = 'imagenet train, 137 GiB' ,
27+ total_size = 137 << 30 ,
28+ num_examples = 1281167 ,
29+ uses_precise_sharding = True ,
30+ max_size = None ,
31+ expected_num_shards = 1024 ,
32+ ),
33+ dict (
34+ testcase_name = 'imagenet evaluation, 6.3 GiB' ,
35+ total_size = 6300 * (1 << 20 ),
36+ num_examples = 50000 ,
37+ uses_precise_sharding = True ,
38+ max_size = None ,
39+ expected_num_shards = 64 ,
40+ ),
41+ dict (
42+ testcase_name = 'very large, but few examples, 52 GiB' ,
43+ total_size = 52 << 30 ,
44+ num_examples = 512 ,
45+ uses_precise_sharding = True ,
46+ max_size = None ,
47+ expected_num_shards = 512 ,
48+ ),
49+ dict (
50+ testcase_name = 'xxl, 10 TiB' ,
51+ total_size = 10 << 40 ,
52+ num_examples = 10 ** 9 ,
53+ uses_precise_sharding = True ,
54+ max_size = None ,
55+ expected_num_shards = 11264 ,
56+ ),
57+ dict (
58+ testcase_name = 'xxl, 10 PiB, 100B examples' ,
59+ total_size = 10 << 50 ,
60+ num_examples = 10 ** 11 ,
61+ uses_precise_sharding = True ,
62+ max_size = None ,
63+ expected_num_shards = 10487808 ,
64+ ),
65+ dict (
66+ testcase_name = 'xs, 100 MiB, 100K records' ,
67+ total_size = 10 << 20 ,
68+ num_examples = 100 * 10 ** 3 ,
69+ uses_precise_sharding = True ,
70+ max_size = None ,
71+ expected_num_shards = 1 ,
72+ ),
73+ dict (
74+ testcase_name = 'm, 499 MiB, 200K examples' ,
75+ total_size = 400 << 20 ,
76+ num_examples = 200 * 10 ** 3 ,
77+ uses_precise_sharding = True ,
78+ max_size = None ,
79+ expected_num_shards = 4 ,
80+ ),
81+ dict (
82+ testcase_name = '100GiB, even example sizes' ,
83+ num_examples = 1e9 , # 1B examples
84+ total_size = 1e9 * 1000 , # On average 1000 bytes per example
85+ max_size = 1000 , # Max example size is 4000 bytes
86+ uses_precise_sharding = True ,
87+ expected_num_shards = 1024 ,
88+ ),
89+ dict (
90+ testcase_name = '100GiB, uneven example sizes' ,
91+ num_examples = 1e9 , # 1B examples
92+ total_size = 1e9 * 1000 , # On average 1000 bytes per example
93+ max_size = 4 * 1000 , # Max example size is 4000 bytes
94+ uses_precise_sharding = True ,
95+ expected_num_shards = 4096 ,
96+ ),
97+ dict (
98+ testcase_name = '100GiB, very uneven example sizes' ,
99+ num_examples = 1e9 , # 1B examples
100+ total_size = 1e9 * 1000 , # On average 1000 bytes per example
101+ max_size = 16 * 1000 , # Max example size is 16x the average bytes
102+ uses_precise_sharding = True ,
103+ expected_num_shards = 15360 ,
104+ ),
32105 )
33106 def test_get_number_shards_default_config (
34- self , total_size , num_examples , uses_precise_sharding , expected_num_shards
107+ self ,
108+ total_size : int ,
109+ num_examples : int ,
110+ uses_precise_sharding : bool ,
111+ max_size : int ,
112+ expected_num_shards : int ,
35113 ):
36114 shard_config = shard_utils .ShardConfig ()
37115 self .assertEqual (
38116 expected_num_shards ,
39117 shard_config .get_number_shards (
40118 total_size = total_size ,
41119 num_examples = num_examples ,
120+ max_example_size = max_size , # max(1, total_size // num_examples),
42121 uses_precise_sharding = uses_precise_sharding ,
43122 ),
44123 )
@@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
48127 self .assertEqual (
49128 42 ,
50129 shard_config .get_number_shards (
51- total_size = 100 , num_examples = 1 , uses_precise_sharding = True
130+ total_size = 100 ,
131+ max_example_size = 100 ,
132+ num_examples = 1 ,
133+ uses_precise_sharding = True ,
52134 ),
53135 )
54136
0 commit comments