Python 创建一个类,通过自定义方法进行矩阵乘法运算(超详细)

快速解决

直接使用自定义类实现矩阵乘法运算。以下代码展示核心实现方式,适用于任意二维矩阵的乘法计算,无需依赖第三方库。

class Matrix:
    def __init__(self, data):
        self.data = data  # 初始化矩阵数据

    def multiply(self, other):
        if len(self.data[0]) != len(other.data):
            raise ValueError("矩阵维度不匹配")  # 检查相乘条件
        result = [[0] * len(other.data[0]) for _ in range(len(self.data))]  # 初始化结果矩阵
        for i in range(len(self.data)):
            for j in range(len(other.data[0])):
                for k in range(len(other.data)):
                    result[i][j] += self.data[i][k] * other.data[k][j]  # 核心乘法逻辑
        return Matrix(result)

常用方法

方法名称 使用场景 代码示例 说明
__init__ 矩阵初始化 Matrix([[1,2],[3,4]]) 创建矩阵对象
multiply 矩阵相乘 a.multiply(b) 实现矩阵乘法
validate 维度验证 a.validate(b) 检查能否相乘
transpose 矩阵转置 a.transpose() 交换行列位置
__str__ 输出显示 print(a) 友好格式输出
__add__ 矩阵相加 a + b 实现加法运算

详细说明

初始化与数据校验

class Matrix:
    def __init__(self, data):
        if not isinstance(data, list) or not all(isinstance(row, list) for row in data):
            raise TypeError("输入数据必须为二维列表")  # 检查数据结构
        if not data or not data[0]:
            raise ValueError("矩阵不能为空")  # 检查空矩阵
        row_length = len(data[0])
        if not all(len(row) == row_length for row in data):
            raise ValueError("所有行必须长度相同")  # 检查矩阵完整性
        self.data = data

核心乘法实现

    def multiply(self, other):
        if not isinstance(other, Matrix):
            raise TypeError("只能与Matrix对象相乘")  # 类型检查
        if len(self.data[0]) != len(other.data):
            raise ValueError(f"矩阵维度不匹配: {len(self.data[0]} vs {len(other.data)}")  # 维度检查
        result = []
        for i in range(len(self.data)):
            row = []
            for j in range(len(other.data[0])):
                total = 0
                for k in range(len(other.data)):
                    total += self.data[i][k] * other.data[k][j]  # 行列对应相乘求和
                row.append(total)
            result.append(row)
        return Matrix(result)

验证方法补充

    def validate(self, other):
        return len(self.data[0]) == len(other.data)  # 返回布尔值判断能否相乘

高级技巧

1. 性能优化方案

from itertools import product

def optimized_multiply(self, other):
    if not self.validate(other):
        raise ValueError("矩阵维度不匹配")
    # 使用zip(*other.data)实现矩阵转置
    return Matrix([
        [sum(a * b for a, b in zip(self_row, other_col)) 
         for other_col in zip(*other.data)]
        for self_row in self.data
    ])

2. 支持类型转换

    def to_numpy(self):
        import numpy as np  # 动态导入避免依赖
        return np.array(self.data)  # 转换为NumPy数组

    @classmethod
    def from_numpy(cls, array):
        return cls(array.tolist())  # 从NumPy数组创建

3. 稀疏矩阵处理

class SparseMatrix(Matrix):
    def __init__(self, data):
        super().__init__(data)
        self.sparse_data = {  # 转换为稀疏存储
            (i, j): val for i, row in enumerate(data)
            for j, val in enumerate(row) if val != 0
        }

    def multiply(self, other):
        # 优化稀疏矩阵计算逻辑
        result = {}
        for (i1, j1), val1 in self.sparse_data.items():
            for (i2, j2), val2 in other.sparse_data.items():
                if j1 == i2:
                    result[(i1, j2)] = result.get((i1, j2), 0) + val1 * val2
        # 重构结果矩阵
        max_i = max(i for i, j in result.keys())
        max_j = max(j for i, j in result.keys())
        matrix = [[0] * (max_j + 1) for _ in range(max_i + 1)]
        for (i, j), val in result.items():
            matrix[i][j] = val
        return SparseMatrix(matrix)

常见问题

Q1: 如何处理不同矩阵维度相乘?
A: 调用validate()方法前先检查self.data[0]的列数是否等于other.data的行数,若不匹配会抛出异常

Q2: 如何将结果转换为普通列表?
A: 定义to_list()方法返回self.data属性,示例代码:

def to_list(self):
    return self.data  # 获取矩阵数据

Q3: 支持非整数元素吗?
A: 支持任意数字类型(int/float/complex),但需确保元素可进行乘法运算

Q4: 矩阵转置实现方式?
A: 添加transpose()方法,使用zip(*self.data)配合列表推导式实现:

def transpose(self):
    return Matrix([list(row) for row in zip(*self.data)])  # 生成转置矩阵

总结

本文通过自定义Matrix类完整实现了矩阵乘法运算,提供基础实现、性能优化和稀疏矩阵处理方案,帮助开发者快速构建矩阵计算功能。