深度学习中评估指标计算库TorchMetrics的使用
TorchMetrics是一个包含100多个PyTorch指标实现的集合(如分类、检测、分割、回归等)并提供易于使用的API来创建自定义指标。可以将TorchMetrics与任何PyTorch模型或PyTorch Lightning结合使用。源码地址https://github.com/Lightning-AI/torchmetrics最新发布版本为v1.9.0license为Apache-2.0。安装完YOLO环境后执行以下命令评估指标任务不同安装命令也不同pip install torchmetrics pip install torchmetrics[detection]YOLOv8/YOLO11/YOLO26有自己内置的评估逻辑一般不建议直接在训练循环内部强行替换其指标计算。这里训练完YOLOv8后评估只是为了演示TorchMetrics的使用。Classify主要测试代码如下def _parse_label_file(label_file): idx_to_class {} class_to_idx {} with open(label_file, moder, encodingutf-8) as f: for line in f: line line.strip() if not line: continue idx, name line.split() idx int(idx) idx_to_class[idx] name class_to_idx[name] idx return idx_to_class, class_to_idx def _get_images(images_path): image_files list(Path(images_path).rglob(*.*)) image_files [p for p in image_files if p.suffix.lower() in [.jpg, .jpeg, .png, .bmp, .webp]] if len(image_files) 0: raise RuntimeError(colorama.Fore.RED fno images found: {images_path}) return image_files def test_classify(model_name, images_path, label_file): if model_name is None or not model_name or not Path(model_name).is_file(): raise FileNotFoundError(colorama.Fore.RED f{model_name} is not a file) if images_path is None or not images_path or not Path(images_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED f{images_path} is not a directory) if label_file is None or not label_file or not Path(label_file).is_file(): raise FileNotFoundError(colorama.Fore.RED f{label_file} is not a file) _, class_to_idx _parse_label_file(label_file) print(fclass to idx: {class_to_idx}) num_classes len(class_to_idx) acc_metric MulticlassAccuracy(num_classesnum_classes) f1_metric MulticlassF1Score(num_classesnum_classes) acc_metric.reset() f1_metric.reset() image_files _get_images(images_path) model YOLO(model_name) model.eval() with torch.no_grad(): for img_path in image_files: class_name img_path.parent.name if class_name not in class_to_idx: print(colorama.Fore.YELLOW finvalid image file: {img_path}) continue gt_label class_to_idx[class_name] results model(str(img_path), verboseFalse) probs results[0].probs.data pred_label int(torch.argmax(probs).item()) pred_tensor torch.tensor([pred_label]) gt_tensor torch.tensor([gt_label]) acc_metric.update(pred_tensor, gt_tensor) f1_metric.update(pred_tensor, gt_tensor) acc acc_metric.compute().item() f1 f1_metric.compute().item() print(colorama.Fore.GREEN fAccuracy: {acc:.4f}\nF1 Score: {f1:.4f})执行结果如下图所示Detect主要测试代码如下def test_detect(model_name, images_path, txts_path): if model_name is None or not model_name or not Path(model_name).is_file(): raise FileNotFoundError(colorama.Fore.RED f{model_name} is not a file) if images_path is None or not images_path or not Path(images_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED f{images_path} is not a directory) if txts_path is None or not txts_path or not Path(txts_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED f{txts_path} is not a directory) image_files _get_images(images_path) preds_all [] targets_all [] model YOLO(model_name) model.eval() with torch.no_grad(): for img_path in image_files: txt_path txts_path / img_path.stem .txt if not Path(txt_path).exists(): raise FileNotFoundError(colorama.Fore.RED f{txt_path} does not exist) img cv2.imread(str(img_path)) if img is None: raise FileNotFoundError(colorama.Fore.RED funable to load image file: {img_path}) h, w img.shape[:2] gt_boxes [] gt_labels [] with open(txt_path, moder, encodingutf-8) as f: for line in f: parts line.strip().split() if len(parts) ! 5: raise RuntimeError(colorama.Fore.RED f{txt_path}: file content is incorrect) cls int(parts[0]) cx, cy, bw, bh map(float, parts[1:]) x1 (cx - bw / 2) * w y1 (cy - bh / 2) * h x2 (cx bw / 2) * w y2 (cy bh / 2) * h gt_boxes.append([x1, y1, x2, y2]) gt_labels.append(cls) if len(gt_boxes) 0: gt_boxes torch.zeros((0, 4)) gt_labels torch.zeros((0,), dtypetorch.int64) else: gt_boxes torch.tensor(gt_boxes, dtypetorch.float32) gt_labels torch.tensor(gt_labels, dtypetorch.int64) results model(str(img_path), verboseFalse)[0] if results.boxes is None or len(results.boxes) 0: pred_boxes torch.zeros((0, 4)) pred_scores torch.zeros((0,)) pred_labels torch.zeros((0,), dtypetorch.int64) else: pred_boxes results.boxes.xyxy.cpu() pred_scores results.boxes.conf.cpu() pred_labels results.boxes.cls.cpu().to(torch.int64) preds_all.append({boxes: pred_boxes, scores: pred_scores, labels: pred_labels}) targets_all.append({boxes: gt_boxes, labels: gt_labels}) print(ftotal samples: {len(preds_all)}) metric MeanAveragePrecision(iou_typebbox, class_metricsTrue) metric.update(preds_all, targets_all) result metric.compute() print(fmetrics result: {result}) map50 result[map_50].item() map5095 result[map].item() print(colorama.Fore.GREEN fmAP50: {map50:.4f}\nmAP50-95: {map5095:.4f})执行结果如下图所示Segment主要测试代码如下def _polygon_to_mask(polygons, h, w): mask np.zeros((h, w), dtypenp.uint8) for poly in polygons: pts np.array(poly, dtypenp.int32).reshape(-1, 2) cv2.fillPoly(mask, [pts], 1) return mask def test_segment(model_name, images_path, txts_path): if model_name is None or not model_name or not Path(model_name).is_file(): raise FileNotFoundError(colorama.Fore.RED f{model_name} is not a file) if images_path is None or not images_path or not Path(images_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED f{images_path} is not a directory) if txts_path is None or not txts_path or not Path(txts_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED f{txts_path} is not a directory) image_files _get_images(images_path) model YOLO(model_name) num_classes len(model.names) 1 # 0:background metric MeanIoU(num_classesnum_classes, per_classTrue, input_formatindex) metric.reset() total 0 target_size (480, 480) model.eval() with torch.no_grad(): for img_path in image_files: txt_path txts_path / img_path.stem .txt if not Path(txt_path).exists(): raise FileNotFoundError(colorama.Fore.RED f{txt_path} does not exist) img cv2.imread(str(img_path)) if img is None: raise FileNotFoundError(colorama.Fore.RED funable to load image file: {img_path}) h, w img.shape[:2] gt_mask np.zeros((h, w), dtypenp.uint8) pred_mask np.zeros((h, w), dtypenp.uint8) with open(txt_path, moder, encodingutf-8) as f: for line in f: parts list(map(float, line.strip().split())) cls int(parts[0]) coords parts[1:] pts [] for i in range(0, len(coords), 2): x coords[i] * w y coords[i 1] * h pts.append([x, y]) mask _polygon_to_mask([pts], h, w) gt_mask[mask 1] cls 1 results model(str(img_path), verboseFalse)[0] if results.masks is not None: masks results.masks.data.cpu().numpy() classes results.boxes.cls.cpu().numpy().astype(int) for i in range(len(masks)): m masks[i] cls classes[i] m (m 0.5).astype(np.uint8) m cv2.resize(m, (w, h), interpolationcv2.INTER_NEAREST) pred_mask[m 1] cls 1 pred_tensor torch.tensor(cv2.resize(pred_mask, target_size, interpolationcv2.INTER_NEAREST)).long() gt_tensor torch.tensor(cv2.resize(gt_mask, target_size, interpolationcv2.INTER_NEAREST)).long() metric.update(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0)) total 1 miou_per_class metric.compute() print(fmetrics result(per class): {miou_per_class}) miou miou_per_class[1:].mean().item() # remove backgroud print(colorama.Fore.GREEN ftotal samples: {total}\nmIoU: {miou:.4f})执行结果如下图所示GitHubhttps://github.com/fengbingchun/NN_Test
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2558395.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!