diff --git a/README.md b/README.md
index 96e31b7..6ed9dd6 100644
--- a/README.md
+++ b/README.md
@@ -1,93 +1 @@
-
-YOLOv8 Segmentation with DeepSORT Object Tracking(ID + Trails)
-
-## Google Colab File Link (A Single Click Solution)
-The google colab file link for yolov8 segmentation and tracking is provided below, you can check the implementation in Google Colab, and its a single click implementation
-,you just need to select the Run Time as GPU, and click on Run All.
-
-[`Google Colab File`](https://colab.research.google.com/drive/1wRkrquf_HMV7tyKy2zDAuqqK9G4zZub5?usp=sharing)
-
-## YouTube Video Tutorial Link
-
-[`YouTube Link`](https://www.youtube.com/watch?v=0JIPNk21ivU)
-
-
-## YOLOv8 with DeepSORT Object Tracking
-
-[`Github Repo Link`](https://github.com/MuhammadMoinFaisal/YOLOv8-DeepSORT-Object-Tracking.git)
-
-## Object Segmentation and Tracking (ID + Trails) using YOLOv8 on Custom Data
-## Google Colab File Link (A Single Click Solution)
-[`Google Colab File`](https://colab.research.google.com/drive/1cnr9Jjj5Pag5myK6Ny8v5gtHgOqf6uoF?usp=sharing)
-
-## YouTube Video Tutorial Link
-
-[`YouTube Link`](https://www.youtube.com/watch?v=e-uzr2Sm0DA)
-
-## Steps to run Code
-
-- Clone the repository
-```
-git clone https://github.com/MuhammadMoinFaisal/YOLOv8_Segmentation_DeepSORT_Object_Tracking.git
-```
-- Goto the cloned folder.
-```
-cd YOLOv8_Segmentation_DeepSORT_Object_Tracking
-```
-- Install the Dependencies
-```
-pip install -e '.[dev]'
-
-```
-- Setting the Directory.
-```
-cd ultralytics/yolo/v8/segment
-
-```
-- Downloading the DeepSORT Files From The Google Drive
-```
-
-https://drive.google.com/drive/folders/1kna8eWGrSfzaR6DtNJ8_GchGgPMv3VC8?usp=sharing
-```
-- After downloading the DeepSORT Zip file from the drive, unzip it go into the subfolders and place the deep_sort_pytorch folder into the ultralytics/yolo/v8/segment folder
-
-- Downloading a Sample Videos from the Google Drive
-- Demo Video 1
-```
-gdown "https://drive.google.com/uc?id=19P9Cf9UiJ9gU9KHnAfud6hrFOgobETTX"
-```
-
-- Demo Video 2
-```
-gdown "https://drive.google.com/uc?id=1rjBn8Fl1E_9d0EMVtL24S9aNQOJAveR5&confirm=t"
-```
-- Demo Video 3
-```
-gdown "https://drive.google.com/uc?id=1aL0XIoOQlHf9FBvUx3FMfmPbmRu0-rF-&confirm=t"
-```
-- Run the code with mentioned command below.
-
-- For yolov8 segmentation + Tracking
-```
-python predict.py model=yolov8x-seg.pt source="test1.mp4"
-```
-
-### RESULTS
-
-#### Object Segmentation and DeepSORT Tracking (ID + Trails) and Vehicles Counting
-
-
-#### Object Segmentation and DeepSORT Tracking (ID + Trails)
-
-
-
-### Watch the Complete Step by Step Explanation
-
-- Video Tutorial Link [`YouTube Link`](https://www.youtube.com/watch?v=0JIPNk21ivU)
-
-
-[]([https://www.youtube.com/watch?v=0JIPNk21ivU&t=244s](https://www.youtube.com/watch?v=0JIPNk21ivU))
-
-
-## References
-- https://github.com/ultralytics/ultralytics
+coming soon
diff --git a/ultralytics/yolo/v8/segment/predict_cam2.py b/ultralytics/yolo/v8/segment/predict_cam2.py
new file mode 100644
index 0000000..06cea71
--- /dev/null
+++ b/ultralytics/yolo/v8/segment/predict_cam2.py
@@ -0,0 +1,310 @@
+# Ultralytics YOLO 🚀, GPL-3.0 license
+
+import hydra
+import torch
+
+from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
+from ultralytics.yolo.utils.checks import check_imgsz
+from ultralytics.yolo.utils.plotting import colors, save_one_box
+
+from ultralytics.yolo.v8.detect.predict import DetectionPredictor
+from numpy import random
+import math
+
+import cv2
+from deep_sort_pytorch.utils.parser import get_config
+from deep_sort_pytorch.deep_sort import DeepSort
+#Deque is basically a double ended queue in python, we prefer deque over list when we need to perform insertion or pop up operations
+#at the same time
+from collections import deque
+import numpy as np
+from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
+palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
+data_deque = {}
+deepsort = None
+rectangle_top_left_back = (0, 96)
+rectangle_bottom_right_back = (285, 435)
+# rectangle_top_left = (141, 151)
+# rectangle_bottom_right = (1279, 450)
+font_scale = 1
+font_thickness = 2
+text_position = (10, 50)
+font = cv2.FONT_HERSHEY_SIMPLEX
+font_color = (255, 255, 255)
+line_thickness =3
+hide_labels=False, # hide labels
+hide_conf=False,
+
+def init_tracker():
+ global deepsort
+ cfg_deep = get_config()
+ cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
+
+ deepsort= DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
+ max_dist=cfg_deep.DEEPSORT.MAX_DIST, min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
+ nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
+ max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT, nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
+ use_cuda=True)
+##########################################################################################
+def xyxy_to_xywh(*xyxy):
+ """" Calculates the relative bounding box from absolute pixel values. """
+ bbox_left = min([xyxy[0].item(), xyxy[2].item()])
+ bbox_top = min([xyxy[1].item(), xyxy[3].item()])
+ bbox_w = abs(xyxy[0].item() - xyxy[2].item())
+ bbox_h = abs(xyxy[1].item() - xyxy[3].item())
+ x_c = (bbox_left + bbox_w / 2)
+ y_c = (bbox_top + bbox_h / 2)
+ w = bbox_w
+ h = bbox_h
+ return x_c, y_c, w, h
+
+def check_rect_overlap(R1, R2):
+ if (R1[0]>=R2[2]) or (R1[2]<=R2[0]) or (R1[3]<=R2[1]) or (R1[1]>=R2[3]):
+ return False
+ else:
+ return True
+
+def xyxy_to_tlwh(bbox_xyxy):
+ tlwh_bboxs = []
+ for i, box in enumerate(bbox_xyxy):
+ x1, y1, x2, y2 = [int(i) for i in box]
+ top = x1
+ left = y1
+ w = int(x2 - x1)
+ h = int(y2 - y1)
+ tlwh_obj = [top, left, w, h]
+ tlwh_bboxs.append(tlwh_obj)
+ return tlwh_bboxs
+
+def compute_color_for_labels(label):
+ """
+ Simple function that adds fixed color depending on the class
+ """
+ if label == 0: #person
+ color = (85,45,255)
+ elif label == 2: # Car
+ color = (222,82,175)
+ elif label == 3: # Motobike
+ color = (0, 204, 255)
+ elif label == 5: # Bus
+ color = (0, 149, 255)
+ else:
+ color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
+ return tuple(color)
+
+
+
+def UI_box(x, img, color=None, label=None, line_thickness=None):
+ # Plots one bounding box on image img
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
+ color = color or [random.randint(0, 255) for _ in range(3)]
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
+ if label:
+ tf = max(tl - 1, 1) # font thickness
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
+
+ img = cv2.rectangle(img, (c1[0], c1[1] - t_size[1] -3), (c1[0] + t_size[0], c1[1]+3), color,-1, cv2.LINE_AA)
+
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+
+
+
+def draw_boxes(img, bbox, names,object_id, identities=None, offset=(0, 0)):
+ #cv2.line(img, line[0], line[1], (46,162,112), 3)
+ counter =0
+ height, width, _ = img.shape
+ # remove tracked point from buffer if object is lost
+ for key in list(data_deque):
+ if key not in identities:
+ data_deque.pop(key)
+
+ for i, box in enumerate(bbox):
+
+ x1, y1, x2, y2 = [int(i) for i in box]
+ x1 += offset[0]
+ x2 += offset[0]
+ y1 += offset[1]
+ y2 += offset[1]
+
+ # code to find center of bottom edge
+ center = (int((x2+x1)/ 2), int((y2+y2)/2))
+
+ # get ID of object
+ id = int(identities[i]) if identities is not None else 0
+
+ # create new buffer for new object
+ if id not in data_deque:
+ data_deque[id] = deque(maxlen= 64)
+ color = compute_color_for_labels(object_id[i])
+ obj_name = names[object_id[i]]
+ label = '{}{:d}'.format("", id) + ":"+ '%s' % (obj_name)
+
+ # add center to buffer
+ data_deque[id].appendleft(center)
+ UI_box(box, img, label=label, color=color, line_thickness=2)
+ # draw trail
+ for i in range(1, len(data_deque[id])):
+ # check if on buffer value is none
+ if data_deque[id][i - 1] is None or data_deque[id][i] is None:
+ continue
+ # generate dynamic thickness of trails
+ thickness = int(np.sqrt(64 / float(i + i)) * 1.5)
+ # draw trails
+ cv2.line(img, data_deque[id][i - 1], data_deque[id][i], color, thickness)
+ return img
+
+
+class SegmentationPredictor(DetectionPredictor):
+
+ def postprocess(self, preds, img, orig_img):
+ masks = []
+ # TODO: filter by classes
+ p = ops.non_max_suppression(preds[0],
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ nm=32)
+ proto = preds[1][-1]
+ for i, pred in enumerate(p):
+ shape = orig_img[i].shape if self.webcam else orig_img.shape
+ if not len(pred):
+ continue
+ if self.args.retina_masks:
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
+ masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC
+ else:
+ masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
+
+ return (p, masks)
+ def write_results(self, idx, preds, batch):
+ counter =0
+ count = 0
+ min_counter = 0
+ max_counter =0
+ avg_counter = 0
+ total_count = 0
+ frame_count =0
+ p, im, im0 = batch
+ log_string = ""
+ if len(im.shape) == 3:
+ im = im[None] # expand for batch dim
+ self.seen += 1
+ if self.webcam: # batch_size >= 1
+ log_string += f'{idx}: '
+ frame = self.dataset.count
+ else:
+ frame = getattr(self.dataset, 'frame', 0)
+
+ self.data_path = p
+ self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
+ log_string += '%gx%g ' % im.shape[2:] # print string
+ self.annotator = self.get_annotator(im0)
+
+ preds, masks = preds
+ det = preds[idx]
+ if len(det) == 0:
+ return log_string
+ # Segments
+ mask = masks[idx]
+ if self.args.save_txt:
+ segments = [
+ ops.scale_segments(im0.shape if self.args.retina_masks else im.shape[2:], x, im0.shape, normalize=True)
+ for x in reversed(ops.masks2segments(mask))]
+
+ # Print results
+ for c in det[:, 5].unique():
+ n = (det[:, 5] == c).sum()
+ log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
+
+ # for c in det[:, 5].unique():
+ # n = (det[:, 5] == c).sum() # detections per class
+ # s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
+
+ # print(self.model.names)
+ annotator = Annotator(im0, line_width=line_thickness, example=str(self.model.names))
+ # Mask plotting
+ self.annotator.masks(
+ mask,
+ colors=[colors(x, True) for x in det[:, 5]],
+ im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
+ 255 if self.args.retina_masks else im[idx])
+
+ det = reversed(det[:, :6])
+ counter +=len(det)
+ print(counter)
+ self.all_outputs.append([det, mask])
+ xywh_bboxs = []
+ confs = []
+ oids = []
+ outputs = []
+ counters = []
+ # Write results
+
+ for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
+ # seg = segments[j].reshape(-1) # (n,2) to (n*2)
+ line = (cls, conf)
+ x_c, y_c, bbox_w, bbox_h = xyxy_to_xywh(*xyxy)
+ bbow_xyxy = [tensor.item() for tensor in xyxy]
+ xywh_obj = [x_c, y_c, bbox_w, bbox_h]
+ label = f' {self.model.names[int(c)]} {conf:.2f}'
+ xywh_bboxs.append(xywh_obj)
+ confs.append([conf.item()])
+ oids.append(int(cls))
+ cv2.rectangle(im0, rectangle_top_left_back, rectangle_bottom_right_back, (0, 0, 0), -1)
+ if check_rect_overlap(bbow_xyxy, rectangle_top_left_back+rectangle_bottom_right_back) :
+ counter = counter -1
+ else:
+ annotator.box_label(xyxy, label, color=colors(c, True))
+# print(label)
+ counters.append(counter)
+ max_counter = max(counters)
+ min_counter = min(counters)
+ total_count = sum(counters)
+ frame_count += len(counters)
+ if frame_count > 0:
+ avg_count = total_count / frame_count
+ else:
+ avg_count = 0
+
+ print(f"objects detected {count}")
+ print(f"Total objects detected: {counter}")
+ print(f"Max objects detected: {max_counter}")
+ xywhs = torch.Tensor(xywh_bboxs)
+ confss = torch.Tensor(confs)
+ cv2.putText(im0, f'piglets detected: {counter} Max: {max_counter} Avg: {math.ceil(avg_count)} Min: {min_counter}', text_position, font, font_scale, font_color, font_thickness)
+# cv2.putText(im0, f'objects detected {count}', (141, 500), font, font_scale, font_color,font_thickness)
+# cv2.rectangle(im0, rectangle_top_left, rectangle_bottom_right, (255, 255, 255), 2)
+ cv2.rectangle(im0, rectangle_top_left_back, rectangle_bottom_right_back, (0, 0, 0), -1)
+ # outputs = deepsort.update(xywhs, confss, oids, im0)
+ # if len(outputs) >= 0:
+ # if check_rect_overlap(bbow_xyxy, rectangle_top_left_back+rectangle_bottom_right_back) or check_rect_overlap(bbow_xyxy, rectangle_top_left+rectangle_bottom_right) :
+ # counter = counter -1
+ # if check_rect_overlap(bbow_xyxy, rectangle_top_left+rectangle_bottom_right):
+ # count +=1
+ # else:
+ # bbox_xyxy = outputs[:, :4]
+ # identities = outputs[:, -2]
+ # object_id = outputs[:, -1]
+ # cv2.putText(im0, f'piglets detected: {counter} Max: {max_counter} Avg: {math.ceil(avg_count)} Min: {min_counter}', text_position, font, font_scale, font_color, font_thickness)
+ # cv2.rectangle(im0, rectangle_top_left, rectangle_bottom_right, (255, 255, 255), 2)
+ # cv2.rectangle(im0, rectangle_top_left_back, rectangle_bottom_right_back, (0, 0, 0), -1)
+ # draw_boxes(im0, bbox_xyxy, self.model.names, object_id,identities)
+ return log_string
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
+def predict(cfg):
+ init_tracker()
+ cfg.model = cfg.model or "yolov8n-seg.pt"
+ cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
+ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
+
+ predictor = SegmentationPredictor(cfg)
+ predictor()
+
+
+if __name__ == "__main__":
+ predict()
diff --git a/ultralytics/yolo/v8/segment/predict_line_cam1.py b/ultralytics/yolo/v8/segment/predict_line_cam1.py
new file mode 100644
index 0000000..80a11bd
--- /dev/null
+++ b/ultralytics/yolo/v8/segment/predict_line_cam1.py
@@ -0,0 +1,313 @@
+# Ultralytics YOLO 🚀, GPL-3.0 license
+
+import hydra
+import torch
+
+from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
+from ultralytics.yolo.utils.checks import check_imgsz
+from ultralytics.yolo.utils.plotting import colors, save_one_box
+
+from ultralytics.yolo.v8.detect.predict import DetectionPredictor
+from numpy import random
+import math
+
+import cv2
+from deep_sort_pytorch.utils.parser import get_config
+from deep_sort_pytorch.deep_sort import DeepSort
+#Deque is basically a double ended queue in python, we prefer deque over list when we need to perform insertion or pop up operations
+#at the same time
+from collections import deque
+import numpy as np
+from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
+palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
+data_deque = {}
+deepsort = None
+rectangle_top_left_back = (750, 470)
+rectangle_bottom_right_back = (1279, 719)
+rectangle_top_left = (141, 151)
+rectangle_bottom_right = (1279, 450)
+font_scale = 1
+font_thickness = 2
+text_position = (10, 50)
+font = cv2.FONT_HERSHEY_SIMPLEX
+font_color = (255, 255, 255)
+line_thickness =3
+hide_labels=False, # hide labels
+hide_conf=False,
+
+def init_tracker():
+ global deepsort
+ cfg_deep = get_config()
+ cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
+
+ deepsort= DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
+ max_dist=cfg_deep.DEEPSORT.MAX_DIST, min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
+ nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
+ max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT, nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
+ use_cuda=True)
+##########################################################################################
+def xyxy_to_xywh(*xyxy):
+ """" Calculates the relative bounding box from absolute pixel values. """
+ bbox_left = min([xyxy[0].item(), xyxy[2].item()])
+ bbox_top = min([xyxy[1].item(), xyxy[3].item()])
+ bbox_w = abs(xyxy[0].item() - xyxy[2].item())
+ bbox_h = abs(xyxy[1].item() - xyxy[3].item())
+ x_c = (bbox_left + bbox_w / 2)
+ y_c = (bbox_top + bbox_h / 2)
+ w = bbox_w
+ h = bbox_h
+ return x_c, y_c, w, h
+
+def check_rect_overlap(R1, R2):
+ if (R1[0]>=R2[2]) or (R1[2]<=R2[0]) or (R1[3]<=R2[1]) or (R1[1]>=R2[3]):
+ return False
+ else:
+ return True
+
+def xyxy_to_tlwh(bbox_xyxy):
+ tlwh_bboxs = []
+ for i, box in enumerate(bbox_xyxy):
+ x1, y1, x2, y2 = [int(i) for i in box]
+ top = x1
+ left = y1
+ w = int(x2 - x1)
+ h = int(y2 - y1)
+ tlwh_obj = [top, left, w, h]
+ tlwh_bboxs.append(tlwh_obj)
+ return tlwh_bboxs
+
+def compute_color_for_labels(label):
+ """
+ Simple function that adds fixed color depending on the class
+ """
+ if label == 0: #person
+ color = (85,45,255)
+ elif label == 2: # Car
+ color = (222,82,175)
+ elif label == 3: # Motobike
+ color = (0, 204, 255)
+ elif label == 5: # Bus
+ color = (0, 149, 255)
+ else:
+ color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
+ return tuple(color)
+
+
+
+def UI_box(x, img, color=None, label=None, line_thickness=None):
+ # Plots one bounding box on image img
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
+ color = color or [random.randint(0, 255) for _ in range(3)]
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
+ if label:
+ tf = max(tl - 1, 1) # font thickness
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
+
+ img = cv2.rectangle(img, (c1[0], c1[1] - t_size[1] -3), (c1[0] + t_size[0], c1[1]+3), color,-1, cv2.LINE_AA)
+
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+
+
+
+def draw_boxes(img, bbox, names,object_id, identities=None, offset=(0, 0)):
+ #cv2.line(img, line[0], line[1], (46,162,112), 3)
+ counter =0
+ height, width, _ = img.shape
+ # remove tracked point from buffer if object is lost
+ for key in list(data_deque):
+ if key not in identities:
+ data_deque.pop(key)
+
+ for i, box in enumerate(bbox):
+
+ x1, y1, x2, y2 = [int(i) for i in box]
+ x1 += offset[0]
+ x2 += offset[0]
+ y1 += offset[1]
+ y2 += offset[1]
+
+ # code to find center of bottom edge
+ center = (int((x2+x1)/ 2), int((y2+y2)/2))
+
+ # get ID of object
+ id = int(identities[i]) if identities is not None else 0
+
+ # create new buffer for new object
+ if id not in data_deque:
+ data_deque[id] = deque(maxlen= 64)
+ color = compute_color_for_labels(object_id[i])
+ obj_name = names[object_id[i]]
+ label = '{}{:d}'.format("", id) + ":"+ '%s' % (obj_name)
+
+ # add center to buffer
+ data_deque[id].appendleft(center)
+ UI_box(box, img, label=label, color=color, line_thickness=2)
+ # draw trail
+ for i in range(1, len(data_deque[id])):
+ # check if on buffer value is none
+ if data_deque[id][i - 1] is None or data_deque[id][i] is None:
+ continue
+ # generate dynamic thickness of trails
+ thickness = int(np.sqrt(64 / float(i + i)) * 1.5)
+ # draw trails
+ cv2.line(img, data_deque[id][i - 1], data_deque[id][i], color, thickness)
+ return img
+
+
+class SegmentationPredictor(DetectionPredictor):
+
+ def postprocess(self, preds, img, orig_img):
+ masks = []
+ # TODO: filter by classes
+ p = ops.non_max_suppression(preds[0],
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ nm=32)
+ proto = preds[1][-1]
+ for i, pred in enumerate(p):
+ shape = orig_img[i].shape if self.webcam else orig_img.shape
+ if not len(pred):
+ continue
+ if self.args.retina_masks:
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
+ masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC
+ else:
+ masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
+
+ return (p, masks)
+ def write_results(self, idx, preds, batch):
+ counter =0
+ count = 0
+ min_counter = 0
+ max_counter =0
+ avg_counter = 0
+ total_count = 0
+ frame_count =0
+ p, im, im0 = batch
+ log_string = ""
+ if len(im.shape) == 3:
+ im = im[None] # expand for batch dim
+ self.seen += 1
+ if self.webcam: # batch_size >= 1
+ log_string += f'{idx}: '
+ frame = self.dataset.count
+ else:
+ frame = getattr(self.dataset, 'frame', 0)
+
+ self.data_path = p
+ self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
+ log_string += '%gx%g ' % im.shape[2:] # print string
+ self.annotator = self.get_annotator(im0)
+
+ preds, masks = preds
+ det = preds[idx]
+ if len(det) == 0:
+ return log_string
+ # Segments
+ mask = masks[idx]
+ if self.args.save_txt:
+ segments = [
+ ops.scale_segments(im0.shape if self.args.retina_masks else im.shape[2:], x, im0.shape, normalize=True)
+ for x in reversed(ops.masks2segments(mask))]
+
+ # Print results
+ for c in det[:, 5].unique():
+ n = (det[:, 5] == c).sum()
+ log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
+
+ # for c in det[:, 5].unique():
+ # n = (det[:, 5] == c).sum() # detections per class
+ # s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
+
+ # print(self.model.names)
+ annotator = Annotator(im0, line_width=line_thickness, example=str(self.model.names))
+ # Mask plotting
+ self.annotator.masks(
+ mask,
+ colors=[colors(x, True) for x in det[:, 5]],
+ im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
+ 255 if self.args.retina_masks else im[idx])
+
+ det = reversed(det[:, :6])
+ counter +=len(det)
+ print(counter)
+ self.all_outputs.append([det, mask])
+ xywh_bboxs = []
+ confs = []
+ oids = []
+ outputs = []
+ counters = []
+ # Write results
+
+ for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
+ # seg = segments[j].reshape(-1) # (n,2) to (n*2)
+ line = (cls, conf)
+ x_c, y_c, bbox_w, bbox_h = xyxy_to_xywh(*xyxy)
+ bbow_xyxy = [tensor.item() for tensor in xyxy]
+ xywh_obj = [x_c, y_c, bbox_w, bbox_h]
+ label = f' {self.model.names[int(c)]} {conf:.2f}'
+ xywh_bboxs.append(xywh_obj)
+ confs.append([conf.item()])
+ oids.append(int(cls))
+ cv2.rectangle(im0, rectangle_top_left_back, rectangle_bottom_right_back, (0, 0, 0), -1)
+ if check_rect_overlap(bbow_xyxy, rectangle_top_left_back+rectangle_bottom_right_back) or check_rect_overlap(bbow_xyxy, rectangle_top_left+rectangle_bottom_right) :
+ counter = counter -1
+ if check_rect_overlap(bbow_xyxy, rectangle_top_left+rectangle_bottom_right):
+ count +=1
+ else:
+ continue
+ annotator.box_label(xyxy, label, color=colors(c, True))
+ print(label)
+ counters.append(counter)
+ max_counter = max(counters)
+ min_counter = min(counters)
+ total_count = sum(counters)
+ frame_count += len(counters)
+ if frame_count > 0:
+ avg_count = total_count / frame_count
+ else:
+ avg_count = 0
+
+ print(f"objects detected {count}")
+ print(f"Total objects detected: {counter}")
+ print(f"Max objects detected: {max_counter}")
+ xywhs = torch.Tensor(xywh_bboxs)
+ confss = torch.Tensor(confs)
+ cv2.putText(im0, f'piglets detected: {counter} Max: {max_counter} Avg: {math.ceil(avg_count)} Min: {min_counter}', text_position, font, font_scale, font_color, font_thickness)
+ cv2.putText(im0, f'objects detected {count}', (141, 500), font, font_scale, font_color,font_thickness)
+ cv2.rectangle(im0, rectangle_top_left, rectangle_bottom_right, (255, 255, 255), 2)
+ cv2.rectangle(im0, rectangle_top_left_back, rectangle_bottom_right_back, (0, 0, 0), -1)
+ # outputs = deepsort.update(xywhs, confss, oids, im0)
+ # if len(outputs) >= 0:
+ # if check_rect_overlap(bbow_xyxy, rectangle_top_left_back+rectangle_bottom_right_back) or check_rect_overlap(bbow_xyxy, rectangle_top_left+rectangle_bottom_right) :
+ # counter = counter -1
+ # if check_rect_overlap(bbow_xyxy, rectangle_top_left+rectangle_bottom_right):
+ # count +=1
+ # else:
+ # bbox_xyxy = outputs[:, :4]
+ # identities = outputs[:, -2]
+ # object_id = outputs[:, -1]
+ # cv2.putText(im0, f'piglets detected: {counter} Max: {max_counter} Avg: {math.ceil(avg_count)} Min: {min_counter}', text_position, font, font_scale, font_color, font_thickness)
+ # cv2.rectangle(im0, rectangle_top_left, rectangle_bottom_right, (255, 255, 255), 2)
+ # cv2.rectangle(im0, rectangle_top_left_back, rectangle_bottom_right_back, (0, 0, 0), -1)
+ # draw_boxes(im0, bbox_xyxy, self.model.names, object_id,identities)
+ return log_string
+
+
+@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
+def predict(cfg):
+ init_tracker()
+ cfg.model = cfg.model or "yolov8n-seg.pt"
+ cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
+ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
+
+ predictor = SegmentationPredictor(cfg)
+ predictor()
+
+
+if __name__ == "__main__":
+ predict()
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index acfe8b1..50e0815 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -14,7 +14,8 @@
from ultralytics.yolo.utils.tal import make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel
-from ..detect.train import Loss
+from ultralytics.yolo.v8.detect.train import Loss
+
# BaseTrainer python usage