Skip to content

Commit 11830b9

Browse files
authored
feat: refactor to use only singular classifcation model for color (#533)
* feat: refactor to use only singular classifcation model for color
1 parent 4f68701 commit 11830b9

File tree

6 files changed

+12
-37
lines changed

6 files changed

+12
-37
lines changed

sketch_map_tool/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
"wms-read-timeout": 600,
2323
"max-nr-simultaneous-uploads": 100,
2424
"max_pixel_per_image": 10e8, # 10.000*10.000
25-
"yolo_osm_cls": "SMT-OSM-CLS",
26-
"yolo_esri_cls": "SMT-ESRI-CLS",
25+
"yolo_cls": "SMT-CLS",
2726
"yolo_osm_obj": "SMT-OSM",
2827
"yolo_esri_obj": "SMT-ESRI",
2928
"model_type_sam": "vit_b",

sketch_map_tool/routes.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,6 @@ def weights_smt_osm(lang="en") -> Response: # pyright: ignore
9999
return send_from_directory(dir, name, as_attachment=True)
100100

101101

102-
@app.get("/weights/SMT-OSM-CLS.pt")
103-
@app.get("/<lang>/weights/SMT-OSM-CLS.pt")
104-
def weights_smt_osm_cls(lang="en") -> Response: # pyright: ignore
105-
dir = Path(config.get_config_value("weights-dir"))
106-
name = "SMT-OSM-CLS.pt"
107-
return send_from_directory(dir, name, as_attachment=True)
108-
109-
110102
@app.get("/weights/SMT-ESRI.pt")
111103
@app.get("/<lang>/weights/SMT-ESRI.pt")
112104
def weights_smt_esri(lang="en") -> Response: # pyright: ignore
@@ -115,11 +107,11 @@ def weights_smt_esri(lang="en") -> Response: # pyright: ignore
115107
return send_from_directory(dir, name, as_attachment=True)
116108

117109

118-
@app.get("/weights/SMT-ESRI-CLS.pt")
119-
@app.get("/<lang>/weights/SMT-ESRI-CLS.pt")
120-
def weights_smt_esri_cls(lang="en") -> Response: # pyright: ignore
110+
@app.get("/weights/SMT-CLS.pt")
111+
@app.get("/<lang>/weights/SMT-CLS.pt")
112+
def weights_smt_cls(lang="en") -> Response: # pyright: ignore
121113
dir = Path(config.get_config_value("weights-dir"))
122-
name = "SMT-ESRI-CLS.pt"
114+
name = "SMT-CLS.pt"
123115
return send_from_directory(dir, name, as_attachment=True)
124116

125117

sketch_map_tool/tasks.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ def init_worker_ml_models(**_):
5151
logging.info("Initialize ml-models.")
5252
global sam_predictor
5353
global yolo_obj_osm
54-
global yolo_cls_osm
5554
global yolo_obj_esri
56-
global yolo_cls_esri
55+
global yolo_cls
5756

5857
path = init_sam2()
5958
device = select_computation_device()
@@ -65,9 +64,8 @@ def init_worker_ml_models(**_):
6564
sam_predictor = SAM2ImagePredictor(sam2_model)
6665

6766
yolo_obj_osm = YOLO_MB(init_model(get_config_value("yolo_osm_obj")))
68-
yolo_cls_osm = YOLO(init_model(get_config_value("yolo_osm_cls")))
6967
yolo_obj_esri = YOLO_MB(init_model(get_config_value("yolo_esri_obj")))
70-
yolo_cls_esri = YOLO(init_model(get_config_value("yolo_esri_cls")))
68+
yolo_cls = YOLO(init_model(get_config_value("yolo_cls")))
7169

7270

7371
@worker_process_shutdown.connect
@@ -140,10 +138,8 @@ def digitize_sketches(
140138
) -> FeatureCollection:
141139
if layer == "osm":
142140
yolo_obj = yolo_obj_osm
143-
yolo_cls = yolo_cls_osm
144141
elif layer == "esri-world-imagery":
145142
yolo_obj = yolo_obj_esri
146-
yolo_cls = yolo_cls_esri
147143
else:
148144
raise ValueError("Unexpected layer: " + layer)
149145

tests/integration/upload_processing/test_detect_markings.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,6 @@ def yolo_osm_obj() -> YOLO_MB:
4141
return YOLO_MB(path)
4242

4343

44-
@pytest.fixture
45-
def yolo_osm_cls() -> YOLO:
46-
"""YOLO Classification"""
47-
path = init_model(get_config_value("yolo_osm_cls"))
48-
return YOLO(path)
49-
50-
5144
@pytest.fixture
5245
def yolo_esri_obj() -> YOLO_MB:
5346
"""YOLO Object Detection"""
@@ -56,9 +49,9 @@ def yolo_esri_obj() -> YOLO_MB:
5649

5750

5851
@pytest.fixture
59-
def yolo_esri_cls() -> YOLO:
52+
def yolo_cls() -> YOLO:
6053
"""YOLO Classification"""
61-
path = init_model(get_config_value("yolo_osm_cls"))
54+
path = init_model(get_config_value("yolo_cls"))
6255
return YOLO(path)
6356

6457

@@ -68,17 +61,14 @@ def test_detect_markings(
6861
map_frame_marked,
6962
map_frame,
7063
yolo_osm_obj,
71-
yolo_osm_cls,
7264
yolo_esri_obj,
73-
yolo_esri_cls,
65+
yolo_cls,
7466
sam_predictor,
7567
):
7668
if layer.value == "osm":
7769
yolo_obj = yolo_osm_obj
78-
yolo_cls = yolo_osm_cls
7970
else:
8071
yolo_obj = yolo_esri_obj
81-
yolo_cls = yolo_esri_cls
8272
markings = detect_markings(
8373
map_frame_marked,
8474
np.asarray(Image.open(map_frame)),

tests/integration/upload_processing/test_ml_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
@pytest.mark.parametrize(
1010
"id",
1111
(
12-
get_config_value("yolo_osm_cls"),
13-
get_config_value("yolo_esri_cls"),
1412
get_config_value("yolo_osm_obj"),
1513
get_config_value("yolo_esri_obj"),
14+
get_config_value("yolo_cls"),
1615
),
1716
)
1817
def test_init_model(id):

tests/unit/test_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def config_keys():
3434
"wms-read-timeout",
3535
"max-nr-simultaneous-uploads",
3636
"max_pixel_per_image",
37-
"yolo_osm_cls",
38-
"yolo_esri_cls",
37+
"yolo_cls",
3938
"yolo_osm_obj",
4039
"yolo_esri_obj",
4140
"model_type_sam",

0 commit comments

Comments
 (0)