1. 思路小结
要优化你提供的稀疏矩阵乘法代码,我们可以引入CSR(压缩稀疏行)格式来避免遍历零元素,从而提高效率。CSR格式通过仅存储非零元素以及它们的行和列索引,可以有效减少稀疏矩阵计算时的时间复杂度。下面是对代码的优化版本,采用CSR格式进行稀疏矩阵的乘法:
优化步骤:
将稀疏矩阵转换为CSR格式,存储非零元素的位置和对应的值。
在矩阵乘法过程中,仅对非零元素进行计算,从而跳过零值。
对每一行的非零元素,在相应的列上执行乘法操作。
1.1 优化思路
进行的是两个稀疏矩阵的乘法。稀疏矩阵通常具有大量的零元素,因此直接使用常规矩阵乘法会导致大量的无效计算。为了提高效率,常用的优化方法是只对非零元素进行计算,而跳过零值。为此,我们采用**CSR(压缩稀疏行,Compressed Sparse Row)**格式进行稀疏矩阵存储和乘法计算。
1.1.1 核心步骤如下:
-  矩阵的稀疏表示: - 原矩阵A和B可能有大量的零元素,因此我们采用CSR格式来存储这些矩阵。
- CSR格式由以下三个部分组成: 
    - values[]: 存储所有非零元素的值。
- colIndex[]: 存储每个非零元素所在的列索引。
- rowPtr[]: 记录每行的非零元素在- values[]中的起始位置。
 
 
-  矩阵的稀疏乘法: - 对于矩阵A的每一行,我们找到其所有非零元素的位置及其值。
- 对于每一个非零元素,我们在矩阵B的相应列中查找与之匹配的非零元素。
- 最后将这些匹配的非零元素相乘,并累加到结果矩阵的对应位置。
 
-  优化: - 通过CSR格式,避免了遍历和处理零元素,从而减少了不必要的计算。
- 我们直接对非零元素进行乘法运算,结果累积到结果矩阵C的对应位置。
 
1.2 算法复杂度分析

1.2.1 常规矩阵乘法的复杂度:
对于两个大小分别为 m x n 和 n x p 的矩阵,常规的矩阵乘法复杂度为O(m * n * p)。因为对于每一个 m x p 的结果元素,我们需要计算 n 次乘法操作。
1.2.2 稀疏矩阵乘法的复杂度:
由于稀疏矩阵大部分元素为零,我们只需要处理非零元素。假设矩阵A和矩阵B的非零元素分别为 nnzA 和 nnzB,稀疏矩阵乘法的复杂度可以近似表示为:
- 对于每个非零元素 A[i][k],我们只需遍历矩阵B的第k列的非零元素进行乘法。因此稀疏矩阵乘法的复杂度大约为 O(nnzA * nnzB),其中nnzA和nnzB是矩阵A和矩阵B的非零元素数量。
这相比于常规矩阵乘法的复杂度有了显著的提升,尤其是当矩阵非常稀疏时(即大部分元素为0),非零元素的数量远小于矩阵的总大小。
1.2.3 空间复杂度:
使用CSR格式的空间复杂度为:
- O(nnz):用于存储所有非零元素及其列索引。
- O(m):用于存储每一行的起始位置。
 总体空间复杂度为 O(nnz + m),其中- nnz是矩阵的非零元素数量,- m是矩阵的行数。
1.2.4 总结
通过使用CSR格式存储稀疏矩阵,我们能够有效避免对零元素的计算,显著提升了稀疏矩阵乘法的计算效率。时间复杂度从常规的O(m * n * p)降低到接近于非零元素的数量 O(nnzA * nnzB),特别适合处理大规模稀疏矩阵的场景。
2. 优化后代码及其复杂度为
代码解析:
 toCSR 函数:将普通的二维稀疏矩阵转换为CSR格式。values数组存储非零元素,colIndex存储每个非零元素的列索引,rowPtr则记录每行的非零元素在 values 数组中的起始位置。
multiplySparseMatricesCSR 函数:使用CSR格式进行矩阵乘法。通过 rowPtr 和 colIndex 来快速定位非零元素,避免了对零值的无效计算。
优化效果:
 通过CSR格式存储非零元素,并跳过零元素的乘法操作,能够显著减少计算时间。
 避免遍历零值,提高了计算效率,尤其在大规模稀疏矩阵的场景下。
#include <iostream>
#include <vector>
using namespace std;
// CSR格式的稀疏矩阵
struct CSRMatrix {
    vector<int> values;      // 存储非零元素的值
    vector<int> colIndex;    // 存储非零元素的列索引
    vector<int> rowPtr;      // 每一行的开始位置
};
// 将稀疏矩阵转换为CSR格式
CSRMatrix toCSR(const vector<vector<int>>& matrix) {
    CSRMatrix csr;
    int row = matrix.size();
    int col = matrix[0].size();
    
    csr.rowPtr.push_back(0);  // 第一行的开始位置是0
    
    // 遍历矩阵,收集非零元素的信息
    for (int i = 0; i < row; i++) {
        for (int j = 0; j < col; j++) {
            if (matrix[i][j] != 0) {
                csr.values.push_back(matrix[i][j]);
                csr.colIndex.push_back(j);
            }
        }
        csr.rowPtr.push_back(csr.values.size());  // 记录下一行的开始位置
    }
    
    return csr;
}
// 使用CSR格式进行稀疏矩阵乘法
vector<vector<int>> multiplySparseMatricesCSR(const CSRMatrix& A, const CSRMatrix& B, int colB) {
    int rowA = A.rowPtr.size() - 1;
    vector<vector<int>> C(rowA, vector<int>(colB, 0));  // 初始化结果矩阵
    
    // 遍历A的每一行
    for (int i = 0; i < rowA; i++) {
        // A的第i行的非零元素从A.rowPtr[i]到A.rowPtr[i+1]-1
        for (int aPos = A.rowPtr[i]; aPos < A.rowPtr[i+1]; aPos++) {
            int colA = A.colIndex[aPos];  // 该非零元素所在的列
            int aValue = A.values[aPos];  // 非零元素的值
            
            // 对应B的第colA行
            for (int j = B.rowPtr[colA]; j < B.rowPtr[colA+1]; j++) {
                int colBIndex = B.colIndex[j];
                int bValue = B.values[j];
                C[i][colBIndex] += aValue * bValue;
            }
        }
    }
    
    return C;
}
int main() {
    // 定义稀疏矩阵A
    vector<vector<int>> A = {
        {1, 0, 0},
        {-1, 0, 3}
    };
    // 定义稀疏矩阵B
    vector<vector<int>> B = {
        {7, 0, 0},
        {0, 0, 0},
        {0, 0, 1}
    };
    // 将矩阵A和B转换为CSR格式
    CSRMatrix csrA = toCSR(A);
    CSRMatrix csrB = toCSR(B);
    // 计算A和B的乘积
    vector<vector<int>> C = multiplySparseMatricesCSR(csrA, csrB, B[0].size());
    // 输出结果矩阵
    cout << "Result of A * B:" << endl;
    for (const auto& row : C) {
        for (int elem : row) {
            cout << elem << " ";
        }
        cout << endl;
    }
    return 0;
}
3. 优化前原始代码及其复杂度为O(m * n * p),这里是最朴素的思路,没有利用稀疏特性做任何优化
#include <iostream>
#include <vector>
using namespace std;
// 定义稀疏矩阵乘法函数
vector<vector<int>> multiplySparseMatrices(vector<vector<int>>& A, vector<vector<int>>& B) {
    int rowA = A.size();
    int colA = A[0].size();
    int rowB = B.size();
    int colB = B[0].size();
    // 初始化结果矩阵,大小为rowA * colB
    vector<vector<int>> C(rowA, vector<int>(colB, 0));
    // 遍历矩阵A的每一行
    for (int i = 0; i < rowA; i++) {
        // 遍历矩阵A的每个列,寻找非零元素
        for (int k = 0; k < colA; k++) {
            if (A[i][k] != 0) {
                // 当A的某个位置非零时,计算该元素和矩阵B的第k行
                for (int j = 0; j < colB; j++) {
                    if (B[k][j] != 0) {
                        C[i][j] += A[i][k] * B[k][j];
                    }
                }
            }
        }
    }
    
    return C;
}
int main() {
    // 定义稀疏矩阵A
    vector<vector<int>> A = {
        {1, 0, 0},
        {-1, 0, 3}
    };
    // 定义稀疏矩阵B
    vector<vector<int>> B = {
        {7, 0, 0},
        {0, 0, 0},
        {0, 0, 1}
    };
    // 计算A和B的乘积
    vector<vector<int>> C = multiplySparseMatrices(A, B);
    // 输出结果矩阵
    cout << "Result of A * B:" << endl;
    for (const auto& row : C) {
        for (int elem : row) {
            cout << elem << " ";
        }
        cout << endl;
    }
    return 0;
}




![[WEBPWN]BaseCTF week1 题解(新手友好教程版)](https://i-blog.csdnimg.cn/direct/db366a8c70c54a7995f3fb63ca12dcea.png)














