Skip to content

Commit 22a6bd9

Browse files
committed
fixed
1 parent 5520427 commit 22a6bd9

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

minimal_llama/hyper/finetune_peft_acc.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,6 @@ def run():
9898
# optimizer = bitsandbytes.optim.AdamW(prefix_maker.parameters(), lr=args.lr, is_paged=True, optim_bits=32)
9999
optimizer = torch.optim.Adam(prefix_maker.parameters(), lr=args.lr, betas=(0.9, 0.99))
100100
save_model = prefix_maker
101-
102-
if accelerator.process_index == 0:
103-
for k, v in prefix_maker.state_dict().items():
104-
print(k, tuple(v.shape))
105-
import time
106-
time.sleep(10)
107101
elif args.peft_type == "lora":
108102
config = lora_llama.LLAMA_CONFIG_DICT[args.model_size]
109103
config.dtype = torch.bfloat16
@@ -139,7 +133,7 @@ def run():
139133
completed_steps = train_state["completed_steps"]
140134
load_path = os.path.join(args.save_dir, f"checkpoint_{completed_steps:05d}.pt")
141135
print0("Resuming from", load_path)
142-
loaded = torch.load(load_path)
136+
loaded = torch.load(load_path, map_location="cpu")
143137
if args.peft_type == "prefix":
144138
# noinspection PyUnboundLocalVariable
145139
prefix_maker.load_state_dict(loaded["model"])

0 commit comments

Comments
 (0)