别再只盯着YOLO了!用ByteTrack在Python里实现一个简易的车辆跟踪器(附完整代码)
用PythonByteTrack打造高精度车辆追踪系统从原理到实战在智能交通和视频监控领域目标追踪技术正发挥着越来越重要的作用。当我们需要分析交通流量、统计车辆类型或监测异常行为时仅仅依靠目标检测是远远不够的——我们还需要知道同一个目标在不同帧之间的对应关系。这就是多目标追踪(Multi-Object Tracking, MOT)技术的用武之地。1. 为什么选择ByteTrack进行车辆追踪1.1 目标追踪技术的演进目标追踪算法大致可以分为两类基于检测的追踪(Detection-Based Tracking)和联合检测追踪(Joint Detection and Tracking)。前者如经典的SORT和DeepSORT后者如FairMOT和CenterTrack。ByteTrack属于基于检测的追踪方法但它在处理低分检测框方面做出了创新。传统方法通常会忽略低置信度的检测框认为这些可能是误检。但ByteTrack的作者发现这些低分框中其实包含了很多被遮挡或模糊的真实目标。通过巧妙地利用这些低分框ByteTrack在保持高追踪精度的同时显著降低了ID切换(Identity Switch)的次数。1.2 ByteTrack的核心优势高低分框协同利用不像SORT只使用高分检测框ByteTrack分两阶段利用所有检测结果简单高效无需外观特征提取仅依赖运动信息(IoU)进行关联强鲁棒性对遮挡、模糊等挑战性场景有更好的适应性易于部署算法复杂度低适合实时应用场景下表对比了几种常见追踪算法的特点算法需要外观特征利用低分框典型FPSMOTASORT否否600.59DeepSORT是否400.61FairMOT是部分300.64ByteTrack否是500.662. 环境搭建与依赖安装2.1 创建Python虚拟环境为了避免依赖冲突我们首先创建一个干净的Python环境python -m venv byte-tracker source byte-tracker/bin/activate # Linux/Mac # 或者 byte-tracker\Scripts\activate # Windows2.2 安装必要库ByteTrack的核心依赖包括pip install numpy opencv-python scipy lap loguru pip install torch torchvision # 如果使用GPU请安装对应版本的torch对于目标检测部分我们将使用YOLOv5git clone https://github.com/ultralytics/yolov5 cd yolov5 pip install -r requirements.txt提示建议使用Python 3.8或更高版本某些库在旧版本中可能存在兼容性问题3. 实现基础车辆追踪器3.1 构建ByteTrack核心类我们先实现ByteTrack的核心逻辑。创建一个byte_tracker.py文件import numpy as np from scipy.spatial.distance import cdist from collections import deque import lap class STrack: # 单个目标的追踪状态管理 def __init__(self, tlwh, score): self._tlwh np.asarray(tlwh, dtypenp.float32) self.kalman_filter None self.mean, self.covariance None, None self.is_activated False self.score score self.tracklet_len 0 self.state TrackState.New def predict(self): # Kalman滤波预测 if self.state ! TrackState.Tracked: return self.mean, self.covariance self.kalman_filter.predict( self.mean, self.covariance) staticmethod def multi_predict(stracks): # 批量预测 if len(stracks) 0: return multi_mean np.asarray([st.mean for st in stracks]) multi_covariance np.asarray([st.covariance for st in stracks]) for i, st in enumerate(stracks): if st.state ! TrackState.Tracked: multi_mean[i] st.mean multi_mean, multi_covariance stracks[0].kalman_filter.multi_predict( multi_mean, multi_covariance) for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): stracks[i].mean mean stracks[i].covariance cov class ByteTracker: def __init__(self, track_thresh0.5, match_thresh0.8, frame_rate30): self.tracked_stracks [] # 已确认的轨迹 self.lost_stracks [] # 丢失的轨迹 self.removed_stracks [] # 移除的轨迹 self.frame_id 0 self.track_thresh track_thresh self.match_thresh match_thresh self.buffer_size int(frame_rate / 30.0 * 30) self.max_time_lost self.buffer_size def update(self, detections): self.frame_id 1 activated_stracks [] refind_stracks [] lost_stracks [] removed_stracks [] # 将检测结果分为高分和低分 scores np.array([d.score for d in detections]) remain_inds scores self.track_thresh inds_low scores 0.1 inds_high scores self.track_thresh inds_second np.logical_and(inds_low, inds_high) dets [detections[i] for i in range(len(detections)) if remain_inds[i]] dets_second [detections[i] for i in range(len(detections)) if inds_second[i]] # 第一步与高分检测框匹配 strack_pool joint_stracks(self.tracked_stracks, self.lost_stracks) STrack.multi_predict(strack_pool) dists matching.iou_distance(strack_pool, dets) matches, u_track, u_detection matching.linear_assignment(dists, threshself.match_thresh) # 更新匹配成功的轨迹 for itracked, idet in matches: track strack_pool[itracked] det dets[idet] if track.state TrackState.Tracked: track.update(det, self.frame_id) activated_stracks.append(track) else: track.re_activate(det, self.frame_id, new_idFalse) refind_stracks.append(track) # 第二步与低分检测框匹配 if len(dets_second) 0: r_tracked_stracks [strack_pool[i] for i in u_track if strack_pool[i].state TrackState.Tracked] dists matching.iou_distance(r_tracked_stracks, dets_second) matches, u_track, u_detection_second matching.linear_assignment(dists, thresh0.5) for itracked, idet in matches: track r_tracked_stracks[itracked] det dets_second[idet] if track.state TrackState.Tracked: track.update(det, self.frame_id) activated_stracks.append(track) else: track.re_activate(det, self.frame_id, new_idFalse) refind_stracks.append(track) # 处理未匹配的轨迹 for it in u_track: track r_tracked_stracks[it] if not track.state TrackState.Lost: track.mark_lost() lost_stracks.append(track) # 更新轨迹状态 self.tracked_stracks [t for t in self.tracked_stracks if t.state TrackState.Tracked] self.tracked_stracks joint_stracks(self.tracked_stracks, activated_stracks) self.tracked_stracks joint_stracks(self.tracked_stracks, refind_stracks) self.lost_stracks sub_stracks(self.lost_stracks, self.tracked_stracks) self.lost_stracks.extend(lost_stracks) self.lost_stracks sub_stracks(self.lost_stracks, self.removed_stracks) self.removed_stracks.extend(removed_stracks) self.tracked_stracks, self.lost_stracks remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) output_stracks [track for track in self.tracked_stracks if track.is_activated] return output_stracks3.2 集成YOLOv5检测器创建一个detector.py文件来封装YOLOv5检测器import torch from yolov5.models.experimental import attempt_load from yolov5.utils.general import non_max_suppression class YOLOv5Detector: def __init__(self, weights_pathyolov5s.pt, devicecuda:0): self.device torch.device(device) self.model attempt_load(weights_path, map_locationself.device) self.names self.model.module.names if hasattr(self.model, module) else self.model.names self.img_size 640 def detect(self, img, conf_thres0.25, iou_thres0.45): # 预处理 img torch.from_numpy(img).to(self.device) img img.float() / 255.0 if img.ndimension() 3: img img.unsqueeze(0) # 推理 pred self.model(img, augmentFalse)[0] # NMS pred non_max_suppression(pred, conf_thres, iou_thres, classesNone, agnosticFalse) # 后处理 detections [] for i, det in enumerate(pred): if det is not None and len(det): for *xyxy, conf, cls in det: x1, y1, x2, y2 [x.item() for x in xyxy] detections.append({ bbox: [x1, y1, x2, y2], score: conf.item(), class_id: int(cls), class_name: self.names[int(cls)] }) return detections4. 完整视频处理流程实现4.1 主程序框架创建一个main.py文件作为入口import cv2 import numpy as np from detector import YOLOv5Detector from byte_tracker import ByteTracker def process_video(input_path, output_path): # 初始化 detector YOLOv5Detector() tracker ByteTracker(track_thresh0.5, match_thresh0.8) cap cv2.VideoCapture(input_path) width int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps cap.get(cv2.CAP_PROP_FPS) fourcc cv2.VideoWriter_fourcc(*mp4v) out cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_count 0 while cap.isOpened(): ret, frame cap.read() if not ret: break # 检测 detections detector.detect(frame) # 过滤只保留车辆类 (COCO数据集中车辆类ID为2,5,7) vehicle_dets [d for d in detections if d[class_id] in [2,5,7]] # 追踪 tracks tracker.update(vehicle_dets) # 可视化 for track in tracks: bbox track.tlwh cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[0]bbox[2]), int(bbox[1]bbox[3])), (0,255,0), 2) cv2.putText(frame, fID:{track.track_id}, (int(bbox[0]), int(bbox[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2) out.write(frame) frame_count 1 print(fProcessed frame {frame_count}) cap.release() out.release() if __name__ __main__: process_video(input.mp4, output.mp4)4.2 参数调优指南ByteTrack的性能很大程度上依赖于几个关键参数检测阈值(track_thresh)默认0.5值越高误检越少但可能漏检交通场景建议0.4-0.6之间匹配阈值(match_thresh)默认0.8控制轨迹与检测框的匹配严格程度对于高速移动车辆可降低到0.7丢失帧数(buffer_size)默认30帧(1秒30fps)对于遮挡频繁的场景可适当增加注意这些参数需要根据具体场景进行调整建议先用小段视频测试不同参数组合的效果5. 高级功能扩展5.1 轨迹分析与统计我们可以扩展追踪器来收集交通统计数据class TrafficAnalyzer: def __init__(self): self.vehicle_count 0 self.vehicle_types {} self.speed_estimator {} def update(self, tracks, frame_id, fps): for track in tracks: if track.track_id not in self.speed_estimator: self.vehicle_count 1 self.speed_estimator[track.track_id] { prev_pos: track.tlwh[:2], prev_frame: frame_id, speed: 0 } else: # 计算速度 (像素/秒) curr_pos track.tlwh[:2] prev_pos self.speed_estimator[track.track_id][prev_pos] frames_passed frame_id - self.speed_estimator[track.track_id][prev_frame] if frames_passed 0: distance np.sqrt((curr_pos[0]-prev_pos[0])**2 (curr_pos[1]-prev_pos[1])**2) self.speed_estimator[track.track_id][speed] distance * fps / frames_passed self.speed_estimator[track.track_id][prev_pos] curr_pos self.speed_estimator[track.track_id][prev_frame] frame_id5.2 跨摄像头追踪要实现跨摄像头的车辆追踪我们需要考虑外观特征提取虽然ByteTrack本身不依赖外观特征但可以额外添加时空约束不同摄像头之间的车辆出现时间和位置关系轨迹匹配当车辆从一个摄像头视野进入另一个时进行ID传递class MultiCameraTracker: def __init__(self): self.camera_trackers {} # 每个摄像头独立的追踪器 self.global_tracks {} # 全局轨迹ID映射 def add_camera(self, cam_id, calib_paramsNone): self.camera_trackers[cam_id] { tracker: ByteTracker(), calib: calib_params, last_global_id: 0 } def update(self, cam_id, detections): local_tracks self.camera_trackers[cam_id][tracker].update(detections) # 将本地ID映射到全局ID global_tracks [] for track in local_tracks: if track.track_id not in self.camera_trackers[cam_id][local_to_global]: self.camera_trackers[cam_id][last_global_id] 1 global_id self.camera_trackers[cam_id][last_global_id] self.camera_trackers[cam_id][local_to_global][track.track_id] global_id self.global_tracks[global_id] { last_cam: cam_id, last_seen: time.time() } global_id self.camera_trackers[cam_id][local_to_global][track.track_id] global_tracks.append({ global_id: global_id, bbox: track.tlwh, features: self.extract_features(track) }) return global_tracks在实际交通监控项目中ByteTrack展现出了令人印象深刻的性能。特别是在处理车辆相互遮挡的场景时其高低分框协同匹配的策略显著减少了ID切换。一个实用的建议是对于固定摄像头场景可以预先标定场景中的消失点和尺度信息这将有助于更准确地估计车辆速度和行为分析。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2495777.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!