Skip to content

Commit c2393f4

Browse files
authored
Fix: init_hidden always returns None (Mamba manages state internally)
The fallback init_hidden was returning a zero tensor instead of None, causing test_hidden_state_api to fail. The fallback's _selective_scan already handles None by initializing zeros internally, so init_hidden should always return None for consistent API behavior.
1 parent e4a7883 commit c2393f4

1 file changed

Lines changed: 1 addition & 7 deletions

File tree

core/ssm_mamba.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,8 @@ def init_hidden(self, batch_size: int = 1) -> Optional[torch.Tensor]:
338338
batch_size: Number of sequences in batch
339339
340340
Returns:
341-
Zero tensor of shape (batch_size, d_inner, state_dim)
342-
if using the fallback.
341+
None (Mamba manages state internally via selective scan).
343342
"""
344-
if not self._using_official_mamba:
345-
d_inner = int(self.expand * self.d_model)
346-
return torch.zeros(
347-
batch_size, d_inner, self.state_dim, device=self.device
348-
)
349343
return None
350344

351345
def forward(

0 commit comments

Comments
 (0)