-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodelutils.py
More file actions
26 lines (23 loc) · 834 Bytes
/
modelutils.py
File metadata and controls
26 lines (23 loc) · 834 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from transformers import OPTForCausalLM
DEV = torch.device('cuda:0')
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def get_opt(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
# model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
model = OPTForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
model.seqlen = model.config.max_position_embeddings
return model