Skip to content

Commit 198d49f

Browse files
committed
change to jnp.bfloat16
1 parent cc78a3a commit 198d49f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _precompile_pooling(self) -> None:
115115

116116
for num_tokens in self.runner.num_tokens_paddings:
117117
hidden_states = self._create_dummy_tensor(
118-
(num_tokens, hidden_size), t2j_dtype(dtype), sharding=hidden_sharding)
118+
(num_tokens, hidden_size), jnp.bfloat16, sharding=hidden_sharding)
119119

120120
for num_reqs in self.runner.num_reqs_paddings:
121121
if num_reqs == 0 or num_reqs > num_tokens:

0 commit comments

Comments
 (0)