用于生成热力图,记录过程,方便之后直接使用。
使用场景:联邦学习中显示客户端数据分布,或者显示数据分布的各类其他场景
文章目录
- 一、代码
- hot.py
- 使用方法
 
- 二、参数解释
- 三、样图
- 关键词
一、代码
写这段代码时主要考虑联邦学习中显示客户端数据分布这一场景
hot.py
import numpy as np
import matplotlib.pyplot as plt
def hot_map(y_train, dataidx_map):
    # CIFAR-10 数据集共有 10 个类别
    num_classes = 10
    # 有 10 个客户端
    num_clients = 10
    #图片中字体大小
    font_size = 32
    # 初始化一个矩阵来存储每个客户端的数据分布
    client_data_distribution = np.zeros((num_clients, num_classes), dtype=int)
    # 统计每个客户端中每个类别的样本数量
    for client_id in range(num_clients):
        indices = dataidx_map[client_id]
        client_labels = y_train[indices]
        unique_labels, label_counts = np.unique(client_labels, return_counts=True)
        for label, count in zip(unique_labels, label_counts):
            client_data_distribution[client_id, label] = count
    # 转置矩阵,这里的转置主要是为了让横坐标是客户端,纵坐标是类标签。如果不转置,横纵坐标会交换
    client_data_distribution = client_data_distribution.T
    # 设置全局字体为新罗马字体
    plt.rcParams["font.family"] = "Times New Roman"
    # 绘制热力图
    plt.figure(figsize=(10, 6))
    plt.imshow(client_data_distribution, cmap='Reds', interpolation='nearest')
    #设置图片标题(上方)
    # plt.title('Clients Data Distribution in CIFAR-10 Dataset')
    # 隐藏坐标轴的边框,更美观
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.xlabel('Client', fontsize=font_size)
    plt.ylabel('Label', fontsize=font_size)
    cbar = plt.colorbar()
        
    # 隐藏颜色条的边框
    cbar.outline.set_visible(False)
    cbar.ax.tick_params(labelsize=font_size)  # 设置颜色条刻度标签的字体大小 
    plt.xticks(np.arange(num_classes), np.arange(num_classes), fontsize =font_size)
    plt.yticks(np.arange(num_clients), np.arange(num_clients), fontsize=font_size)
    
    # 设置坐标(i, j)显示的数值,可直接注释去除
    for i in range(num_clients):
        for j in range(num_classes):
            # text((x, y)=坐标, s=数值, ha=水平对齐, va=垂直对齐, color=颜色)
            plt.text(x=i, y=j, s=client_data_distribution[j][i], ha='center', va='center', color='white')
    plt.tight_layout()
    plt.savefig('Fig.jpg',dpi = 400, bbox_inches='tight')# bbox_inches用于在保存时将图片位于画布中间,保持紧凑;dpi是一个关于图片清晰度的参数,数值越大,图片越高清
    plt.show()
使用方法
首先在需要调用热力图的地方引入文件
from hot import hot_map
接着在需要画图的地方调用,通常是刚对客户端分配好数据或者对数据分布进行处理后的位置
hot_map(y_train,net_dataidx_map)
二、参数解释
y_train:[6 9 9 … 9 1 1],就是训练数据的标签,用列表表示。
net_dataidx_map:{0:[39982, 40086, 49891, 13047, 8170, 94, 4697,],1:[…], …},这是各客户端的数据分配情况,使用字典显示,字典的键表示客户端标记,表示几号客户端;值用列表显示,列表中的各数值表示y_train的下标,举例来说,以0的39982为例,表示0号客户端包含了y_train中第39982个标签,是客户端与数据标签的映射。
三、样图

关键词
热力图; 联邦学习; 数据分布;python


















