数据集简介
本数据拥有
训练集:43685张;
验证集:5363张;
测试集:5363张;
总类别数:158类。





部分代码:
定义数据集
class MyDataset(Dataset):
    def __init__(self, mode='train', transform=None):
        super(MyDataset, self).__init__()
        self.data = []
        self.transform = transform
        with open(f'{data_path}{mode}.txt') as f:
            for line in f.readlines():
                info = line.strip().split(' ')
                if len(info) > 0:
                    self.data.append(
                        [data_path+'/'+info[0].strip(), info[1].strip()])
    def __getitem__(self, idx):
        image_file, label = self.data[idx]
        img = Image.open(image_file).convert('RGB')
        img = np.array(img)
        # (Tensor(shape=[3, 227, 227], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
        if self.transform is not None:
            img = self.transform(img)
        label = np.array([label], dtype="int64")
        return img, label
    def __len__(self):定义ResNet网络
resnet50 = paddle.vision.models.resnet50(num_classes=158)取单张测试图片进行可视化展示
import pylab as pl
import matplotlib.font_manager as fm
test_path = '/home/aistudio/Mydata/test1.txt'
myfont = fm.FontProperties(fname=r'/home/aistudio/simkai.ttf') # 设置字体   
jetson_path = '/home/aistudio/Mydata/garbage_classification.json'
with open(jetson_path, 'r') as f1:
    load_dict = json.load(f1)
with open(test_path, 'r') as f2:
    img_path = f2.readline().strip().split(' ')
test_img_path = '/home/aistudio/Mydata/' + f'{img_path[0]}'
print('输入测试图片路径为:')
print(test_img_path)
clas = load_dict[f'{lab1}']#从字典中查找标签0对应的垃圾种类
img = cv2.imread(test_img_path)
plt.imshow(img)
plt.title(f'预测:{clas}', fontproperties = myfont, fontsize=20)


















