@@ -41,7 +41,7 @@ def get_scores(self, ignore_first_label=False, ignore_last_label=False):
4141 return iou , miou
4242
4343
44- def evaluate (model , testloader , criterion , num_classes : int , device , amp_enabled : bool ):
44+ def evaluate (model , testloader , criterion , num_classes : int , amp_enabled : bool , device , eval_fps = True ):
4545 model .eval ()
4646
4747 # Evaluate
@@ -55,12 +55,16 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
5555
5656 # 예측
5757 with torch .cuda .amp .autocast (enabled = amp_enabled ):
58- torch .cuda .synchronize ()
59- start_time = time .time ()
60- with torch .no_grad ():
61- output = model (image )
62- torch .cuda .synchronize ()
63- inference_time += time .time () - start_time
58+ if eval_fps :
59+ torch .cuda .synchronize ()
60+ start_time = time .time ()
61+ with torch .no_grad ():
62+ output = model (image )
63+ torch .cuda .synchronize ()
64+ inference_time += time .time () - start_time
65+ else :
66+ with torch .no_grad ():
67+ output = model (image )
6468
6569 # validation loss를 모두 합침
6670 val_loss += criterion (output , target ).item ()
@@ -79,8 +83,11 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
7983 val_loss /= len (testloader )
8084
8185 # 추론 시간과 fps를 계산 (추론 시간 단위: sec)
82- inference_time /= len (testloader .dataset )
83- fps = 1 / inference_time
86+ if eval_fps :
87+ inference_time /= len (testloader .dataset )
88+ fps = 1 / inference_time
89+ else :
90+ fps = 0
8491
8592 return val_loss , iou , miou , fps
8693
@@ -103,7 +110,8 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
103110 criterion = nn .CrossEntropyLoss ()
104111
105112 # 모델 평가
106- val_loss , iou , miou , fps = evaluate (model , testloader , criterion , config [config ['model' ]]['num_classes' ], device )
113+ val_loss , iou , miou , fps = evaluate (model , testloader , criterion , config ['dataset' ]['num_classes' ],
114+ config ['amp_enabled' ], device )
107115
108116 # 평가 결과를 csv 파일로 저장
109117 os .makedirs ('result' , exist_ok = True )
0 commit comments