11from dataclasses import dataclass
2- from typing import Iterable , Tuple , Optional , Sequence
2+ from typing import Iterable , Tuple , Optional , Sequence , List , cast
33
44import torch
55
66
77@dataclass
88class Path :
9- from_token : torch . Tensor # [max token parts]
10- path_node : torch . Tensor # [path length]
11- to_token : torch . Tensor # [max token parts]
9+ from_token : List [ int ] # [max token parts]
10+ path_node : List [ int ] # [path length]
11+ to_token : List [ int ] # [max token parts]
1212
1313
1414@dataclass
1515class LabeledPathContext :
16- label : torch . Tensor # [max label parts]
16+ label : List [ int ] # [max label parts]
1717 path_contexts : Sequence [Path ]
1818
1919
20+ def transpose (list_of_lists : List [List [int ]]) -> List [List [int ]]:
21+ return [cast (List [int ], it ) for it in zip (* list_of_lists )]
22+
23+
2024class BatchedLabeledPathContext :
2125 def __init__ (self , all_samples : Sequence [Optional [LabeledPathContext ]]):
2226 samples = [s for s in all_samples if s is not None ]
2327
2428 # [max label parts; batch size]
25- self .labels = torch .cat ( [s .label . unsqueeze ( 1 ) for s in samples ], dim = 1 )
29+ self .labels = torch .tensor ( transpose ( [s .label for s in samples ]), dtype = torch . long )
2630 # [batch size]
2731 self .contexts_per_label = torch .tensor ([len (s .path_contexts ) for s in samples ])
2832
2933 # [max token parts; n contexts]
30- self .from_token = torch .cat ([path .from_token .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
34+ self .from_token = torch .tensor (
35+ transpose ([path .from_token for s in samples for path in s .path_contexts ]), dtype = torch .long
36+ )
3137 # [path length; n contexts]
32- self .path_nodes = torch .cat ([path .path_node .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
38+ self .path_nodes = torch .tensor (
39+ transpose ([path .path_node for s in samples for path in s .path_contexts ]), dtype = torch .long
40+ )
3341 # [max token parts; n contexts]
34- self .to_token = torch .cat ([path .to_token .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
42+ self .to_token = torch .tensor (
43+ transpose ([path .to_token for s in samples for path in s .path_contexts ]), dtype = torch .long
44+ )
3545
3646 def __len__ (self ) -> int :
3747 return len (self .contexts_per_label )
@@ -53,8 +63,8 @@ def move_to_device(self, device: torch.device):
5363
5464@dataclass
5565class TypedPath (Path ):
56- from_type : torch . Tensor # [max type parts]
57- to_type : torch . Tensor # [max type parts]
66+ from_type : List [ int ] # [max type parts]
67+ to_type : List [ int ] # [max type parts]
5868
5969
6070@dataclass
@@ -67,6 +77,10 @@ def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]):
6777 super ().__init__ (all_samples )
6878 samples = [s for s in all_samples if s is not None ]
6979 # [max type parts; n contexts]
70- self .from_type = torch .cat ([path .from_type .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
80+ self .from_type = torch .tensor (
81+ transpose ([path .from_type for s in samples for path in s .path_contexts ]), dtype = torch .long
82+ )
7183 # [max type parts; n contexts]
72- self .to_type = torch .cat ([path .to_type .unsqueeze (1 ) for s in samples for path in s .path_contexts ], dim = 1 )
84+ self .to_type = torch .tensor (
85+ transpose ([path .to_type for s in samples for path in s .path_contexts ]), dtype = torch .long
86+ )
0 commit comments