Python 使用类实现一个计算矩阵加法的程序(一文讲透)

使用 Python 类实现矩阵加法的完整方案

快速解决

class Matrix:
    def __init__(self, data):
        self.data = data
        self.rows = len(data)
        self.cols = len(data[0]) if data else 0

    def __add__(self, other):
        if self.rows != other.rows or self.cols != other.cols:
            raise ValueError("矩阵维度必须一致")
        return Matrix([[a + b for a, b in zip(x, y)] for x, y in zip(self.data, other.data)])

常用方法

方法名称 使用频率 功能说明
__add__ ⭐⭐⭐⭐⭐ 重载加法运算符实现矩阵相加
validate_dimensions ⭐⭐⭐⭐ 验证两个矩阵维度是否匹配
element_wise_add ⭐⭐⭐ 执行元素级相加操作
__str__ ⭐⭐⭐ 矩阵格式化输出
from_list ⭐⭐ 类方法创建矩阵实例
add_multiple 支持多个矩阵相加

详细说明

验证矩阵维度

def validate_dimensions(self, other):
    # 检查矩阵行数是否匹配
    if self.rows != other.rows:
        raise ValueError("矩阵行数不匹配:{} vs {}".format(self.rows, other.rows))
    # 检查矩阵列数是否匹配
    if self.cols != other.cols:
        raise ValueError("矩阵列数不匹配:{} vs {}".format(self.cols, other.cols))

重载加法运算符

def __add__(self, other):
    # 调用维度验证方法
    self.validate_dimensions(other)
    # 使用zip函数配对行并逐元素相加
    result = [
        [x[i] + y[i] for i in range(self.cols)] 
        for x, y in zip(self.data, other.data)
    ]
    return Matrix(result)

元素级相加

def element_wise_add(self, other_matrix):
    # 初始化结果矩阵
    result = []
    # 遍历每一行
    for i in range(self.rows):
        row = []
        # 遍历每一列
        for j in range(self.cols):
            # 将对应位置的元素相加
            row.append(self.data[i][j] + other_matrix.data[i][j])
        result.append(row)
    return Matrix(result)

高级技巧

多异常处理

def add_matrix(self, other):
    try:
        if self.rows != other.rows or self.cols != other.cols:
            raise ValueError("矩阵维度不一致")
        return Matrix([[a + b for a, b in zip(x, y)] for x, y in zip(self.data, other.data)])
    except TypeError:
        raise TypeError("输入必须为Matrix实例")
    except ValueError as ve:
        raise ve

支持多种输入格式

@classmethod
def from_list(cls, input_list):
    if not all(isinstance(row, list) for row in input_list):
        raise ValueError("输入必须为二维列表")
    if not input_list:
        return cls([])
    cols = len(input_list[0])
    if not all(len(row) == cols for row in input_list):
        raise ValueError("所有行长度必须一致")
    return cls(input_list)

使用 NumPy 加速

import numpy as np

class NumpyMatrix:
    def __init__(self, data):
        self.data = np.array(data, dtype=object)
    
    def add(self, other):
        if self.data.shape != other.data.shape:
            raise ValueError("矩阵形状不匹配")
        return NumpyMatrix(self.data + other.data)

常见问题

Q1: 如何处理不同维度的矩阵相加?

A: 通过validate_dimensions方法强制检查行数和列数是否一致。若不一致会抛出ValueError异常,避免无效计算。

Q2: 可以用其他方式实现矩阵相加吗?

A: 可以使用map+lambda组合:

def add(self, other):
    return Matrix(list(map(
        lambda row1, row2: list(map(lambda a, b: a + b, row1, row2)),
        self.data, other.data
    )))

Q3: 如何测试矩阵加法程序?

A: 可通过以下测试用例验证:

m1 = Matrix([[1, 2], [3, 4]])
m2 = Matrix([[5, 6], [7, 8]])
m3 = m1 + m2
assert m3.data == [[6, 8], [10, 12]]

总结

通过 Python 类封装实现矩阵加法,既能保证代码复用性,又能通过运算符重载提升可读性。