Skip to content

Commit 93b48eb

Browse files
authored
Merge pull request #325 from JdeRobot/issue-323
Issue 323 - Changes made in preprocessing images for torch detection models.
2 parents 3b1aaab + ada0a05 commit 93b48eb

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

app.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def browse_folder():
187187
"Confidence Threshold",
188188
min_value=0.0,
189189
max_value=1.0,
190-
value=st.session_state.get("confidence_threshold", 0.5),
191190
step=0.01,
192191
key="confidence_threshold",
193192
help="Minimum confidence score for detections",
@@ -196,7 +195,6 @@ def browse_folder():
196195
"NMS Threshold",
197196
min_value=0.0,
198197
max_value=1.0,
199-
value=st.session_state.get("nms_threshold", 0.5),
200198
step=0.01,
201199
key="nms_threshold",
202200
help="Non-maximum suppression threshold",
@@ -205,15 +203,22 @@ def browse_folder():
205203
"Max Detections/Image",
206204
min_value=1,
207205
max_value=1000,
208-
value=st.session_state.get("max_detections", 100),
209206
step=1,
210207
key="max_detections",
211208
)
209+
st.number_input(
210+
"Image Resize Height",
211+
min_value=1,
212+
max_value=4096,
213+
value=640,
214+
step=1,
215+
key="resize_height",
216+
help="Height to resize images for inference",
217+
)
212218
with col2:
213219
st.selectbox(
214220
"Device",
215221
["cpu", "cuda", "mps"],
216-
index=0 if st.session_state.get("device", "cpu") == "cpu" else 1,
217222
key="device",
218223
)
219224
st.selectbox(
@@ -231,20 +236,26 @@ def browse_folder():
231236
"Batch Size",
232237
min_value=1,
233238
max_value=256,
234-
value=st.session_state.get("batch_size", 1),
235239
step=1,
236240
key="batch_size",
237241
)
238242
st.number_input(
239243
"Evaluation Step",
240244
min_value=0,
241245
max_value=1000,
242-
value=st.session_state.get("evaluation_step", 10),
243246
step=1,
244247
key="evaluation_step",
245248
help="Update UI with intermediate metrics every N images (0 = disable intermediate updates)",
246249
)
247-
250+
st.number_input(
251+
"Image Resize Width",
252+
min_value=1,
253+
max_value=4096,
254+
value=640,
255+
step=1,
256+
key="resize_width",
257+
help="Width to resize images for inference",
258+
)
248259
# Load model action in sidebar
249260
from detectionmetrics.models.torch_detection import TorchImageDetectionModel
250261
import json, tempfile
@@ -292,6 +303,8 @@ def browse_folder():
292303
device = st.session_state.get("device", "cpu")
293304
batch_size = int(st.session_state.get("batch_size", 1))
294305
evaluation_step = int(st.session_state.get("evaluation_step", 5))
306+
resize_height = int(st.session_state.get("resize_height", 640))
307+
resize_width = int(st.session_state.get("resize_width", 640))
295308
model_format = st.session_state.get("model_format", "torchvision")
296309
config_data = {
297310
"confidence_threshold": confidence_threshold,
@@ -300,6 +313,8 @@ def browse_folder():
300313
"device": device,
301314
"batch_size": batch_size,
302315
"evaluation_step": evaluation_step,
316+
"resize_height": resize_height,
317+
"resize_width": resize_width,
303318
"model_format": model_format.lower(),
304319
}
305320
with tempfile.NamedTemporaryFile(

detectionmetrics/models/torch_detection.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,21 @@ def __init__(
260260
# Build input transforms (resize, normalize, etc.)
261261
self.transform_input = []
262262

263+
# Default resize to 640x640 if not specified
263264
if "resize" in self.model_cfg:
264-
self.transform_input += [
265-
transforms.Resize(
266-
size=(
267-
self.model_cfg["resize"].get("height", None),
268-
self.model_cfg["resize"].get("width", None),
269-
),
270-
interpolation=transforms.InterpolationMode.BILINEAR,
271-
)
272-
]
265+
resize_height = self.model_cfg["resize"].get("height", 640)
266+
resize_width = self.model_cfg["resize"].get("width", 640)
267+
else:
268+
# Default to 640x640 when no resize is specified
269+
resize_height = 640
270+
resize_width = 640
271+
272+
self.transform_input += [
273+
transforms.Resize(
274+
size=(resize_height, resize_width),
275+
interpolation=transforms.InterpolationMode.BILINEAR,
276+
)
277+
]
273278

274279
if "crop" in self.model_cfg:
275280
crop_size = (

0 commit comments

Comments
 (0)