Skip to content

Commit 4f950d8

Browse files
committed
fix window attention
1 parent afc9dc5 commit 4f950d8

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

eqnet/models/phasenet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def __init__(
298298
event_time_loss_weight=1.0,
299299
polarity_loss_weight=1.0,
300300
prompt_loss_weight=1.0,
301+
window_attention=False,
301302
**kwargs,
302303
) -> None:
303304
super().__init__(**kwargs)
@@ -341,6 +342,7 @@ def __init__(
341342
add_polarity=add_polarity,
342343
add_event=add_event,
343344
add_prompt=add_prompt,
345+
window_attention=window_attention,
344346
**kwargs,
345347
)
346348
else:
@@ -455,12 +457,12 @@ def forward(self, batched_inputs: Tensor) -> Dict[str, Tensor]:
455457
def build_model(
456458
backbone="unet",
457459
log_scale=True,
458-
shift_window=False,
460+
window_attention=False,
459461
*args,
460462
**kwargs,
461463
) -> PhaseNet:
462464
return PhaseNet(
463465
backbone=backbone,
464466
log_scale=log_scale,
465-
shift_window=shift_window,
467+
window_attention=window_attention,
466468
)

eqnet/models/phasenet_plus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def build_model(
1010
event_center_loss_weight=1.0,
1111
event_time_loss_weight=1.0,
1212
polarity_loss_weight=1.0,
13-
shift_window=False,
13+
window_attention=False,
1414
*args,
1515
**kwargs,
1616
) -> PhaseNet:
@@ -23,5 +23,5 @@ def build_model(
2323
event_center_loss_weight=event_center_loss_weight,
2424
event_time_loss_weight=event_time_loss_weight,
2525
polarity_loss_weight=polarity_loss_weight,
26-
shift_window=shift_window,
26+
window_attention=window_attention,
2727
)

eqnet/models/x_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def time_window_mod(b, h, q_idx, kv_idx):
343343

344344
return prefix_mask | suffix_mask | mid_mask
345345

346-
block_mask = create_block_mask(time_window_mod, B, H, q_len, kv_len, device=device, BLOCK_SIZE=window_size, _compile=False)
346+
block_mask = create_block_mask(time_window_mod, B, H, q_len, kv_len, device=device, BLOCK_SIZE=128, _compile=False)
347347
return block_mask
348348

349349
@lru_cache

predict.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,12 +443,15 @@ def main(args):
443443
# checkpoint = torch.load(glob(os.path.join(artifact_dir, "*.pth"))[0], map_location="cpu")
444444
# model.load_state_dict(checkpoint["model"], strict=True)
445445

446+
model.load_state_dict(checkpoint["model"], strict=True)
447+
if args.window_attention:
448+
model = torch.compile(model)
446449
model_without_ddp = model
447450
if args.distributed:
448451
torch.distributed.barrier()
449452
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
450453
model_without_ddp = model.module
451-
model_without_ddp.load_state_dict(checkpoint["model"], strict=True)
454+
#model_without_ddp.load_state_dict(checkpoint["model"], strict=True)
452455

453456
if args.model == "phasenet_das":
454457
pred_phasenet_das(args, model, data_loader, pick_path, figure_path)
@@ -494,7 +497,7 @@ def get_args_parser(add_help=True):
494497
parser.add_argument("--result_path", type=str, default="results", help="path to result directory")
495498
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
496499
parser.add_argument("--min_prob", default=0.3, type=float, help="minimum probability for picking")
497-
parser.add_argument("--shift_window", action="store_true", help="If use shift window for transformer")
500+
parser.add_argument("--window-attention", action="store_true", help="If use shift window for transformer")
498501

499502
## Seismic
500503
parser.add_argument("--add_polarity", action="store_true", help="If use polarity information")

0 commit comments

Comments
 (0)