-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Stroll gives the following error for documents with 'empty' sentences (this example document only contains punctuation):
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-3f6234a81ebe> in <module>
----> 1 doc = nlp( '.')
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/stanza/pipeline/core.py in __call__(self, doc)
164 assert any([isinstance(doc, str), isinstance(doc, list),
165 isinstance(doc, Document)]), 'input should be either str, list or Document'
--> 166 doc = self.process(doc)
167 return doc
168
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/stanza/pipeline/core.py in process(self, doc)
158 for processor_name in PIPELINE_NAMES:
159 if self.processors.get(processor_name):
--> 160 doc = self.processors[processor_name].process(doc)
161 return doc
162
~/filter-bubble/stroll/stroll/stanza.py in process(self, doc)
112
113 frame_labels, role_labels, \
--> 114 frame_chance, role_chance = self.net.label(gs)
115
116 word_offset = 0
~/filter-bubble/stroll/stroll/model.py in label(self, gs)
354
355 def label(self, gs):
--> 356 logitsf, logitsr = self(gs)
357 logitsf = torch.softmax(logitsf, dim=1)
358 logitsr = torch.softmax(logitsr, dim=1)
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/filter-bubble/stroll/stroll/model.py in forward(self, g)
345
346 # Hidden layers, each of h_dims to h_dims
--> 347 g = self.kernel(g)
348
349 # MLP output
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/filter-bubble/stroll/stroll/model.py in forward(self, graph)
283 # and create a new output and hidden state for step t
284 for l in range(self.num_layers):
--> 285 graph.update_all(rgcn_msg, rgcn_reduce, rgcn_apply)
286
287 # Batchnorm
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/dgl-0.6a210127-py3.8-linux-x86_64.egg/dgl/heterograph.py in update_all(self, message_func, reduce_func, apply_node_func, etype)
4660 _, dtid = self._graph.metagraph.find_edge(etid)
4661 g = self if etype is None else self[etype]
-> 4662 ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
4663 self._set_n_repr(dtid, ALL, ndata)
4664
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/dgl-0.6a210127-py3.8-linux-x86_64.egg/dgl/core.py in message_passing(g, mfunc, rfunc, afunc)
286 else:
287 orig_eid = g.edata.get(EID, None)
--> 288 msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
289 # reduce phase
290 if is_builtin(rfunc):
~/anaconda3/envs/filtbubtest/lib/python3.8/site-packages/dgl-0.6a210127-py3.8-linux-x86_64.egg/dgl/core.py in invoke_edge_udf(graph, eid, etype, func, orig_eid)
80 ebatch = EdgeBatch(graph, eid if orig_eid is None else orig_eid,
81 etype, srcdata, edata, dstdata)
---> 82 return func(ebatch)
83
84 def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
~/filter-bubble/stroll/stroll/model.py in rgcn_msg(edges)
252 n = edges.data['norm']
253 msg = torch.bmm(edges.src['output'].unsqueeze(1), w).squeeze()
--> 254 msg = torch.bmm(n.reshape(-1, 1, 1), msg.unsqueeze(1)).squeeze()
255
256 return {'m': msg}
RuntimeError: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #2 'batch2' (while checking arguments for bmm)
I would say stroll should simply tag no SRLs for this sentence, but not throw an error
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels