Skip to content

NumPy — 数值计算基石

NumPy 是整个 Python 数据科学生态的基础。理解它的内存模型和向量化思维,是高效数据处理的关键。

为什么 NumPy 快

NumPy 的核心是 ndarray(N 维数组),底层用 C 实现:

Python list:  [obj_ptr, obj_ptr, obj_ptr, ...]  — 指针数组,元素分散在内存
NumPy array:  [1.0, 2.0, 3.0, ...]              — 连续内存块,类型统一
  • 连续内存:CPU 缓存友好,减少 cache miss
  • SIMD 指令:向量化运算,一条指令处理多个数据
  • 无 Python 循环:避免解释器开销
python
import numpy as np
import time

n = 10_000_000

# Python 列表
lst = list(range(n))
start = time.perf_counter()
result = [x * 2 for x in lst]
print(f"Python list: {time.perf_counter() - start:.3f}s")

# NumPy
arr = np.arange(n)
start = time.perf_counter()
result = arr * 2
print(f"NumPy: {time.perf_counter() - start:.3f}s")
# NumPy 通常快 50-100 倍

ndarray 内存模型

python
arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)

print(arr.shape)    # (2, 3) — 形状
print(arr.dtype)    # float64 — 数据类型
print(arr.ndim)     # 2 — 维度数
print(arr.size)     # 6 — 元素总数
print(arr.itemsize) # 8 — 每个元素字节数
print(arr.nbytes)   # 48 — 总字节数
print(arr.strides)  # (24, 8) — 步长(字节)

步长(strides) 是理解 NumPy 的关键:

python
# strides = (24, 8) 意味着:
# 行方向移动 1 步 = 跳过 24 字节(3个float64)
# 列方向移动 1 步 = 跳过 8 字节(1个float64)

# 转置不复制数据,只改变 strides
arr_T = arr.T
print(arr_T.strides)  # (8, 24) — strides 互换
print(np.shares_memory(arr, arr_T))  # True

创建数组

python
# 从数据创建
a = np.array([1, 2, 3, 4, 5])
b = np.array([[1, 2], [3, 4]], dtype=np.float32)

# 特殊数组
zeros = np.zeros((3, 4))
ones = np.ones((2, 3), dtype=int)
eye = np.eye(3)                    # 单位矩阵
empty = np.empty((2, 2))           # 未初始化(快)
full = np.full((3, 3), 7.0)

# 序列
arange = np.arange(0, 10, 2)      # [0, 2, 4, 6, 8]
linspace = np.linspace(0, 1, 5)   # [0, 0.25, 0.5, 0.75, 1.0]
logspace = np.logspace(0, 3, 4)   # [1, 10, 100, 1000]

# 随机
rng = np.random.default_rng(42)   # 推荐:新式随机生成器
rand = rng.random((3, 3))
normal = rng.normal(0, 1, (100,))
integers = rng.integers(0, 10, size=5)

索引与切片

python
arr = np.arange(12).reshape(3, 4)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]]

# 基础索引
print(arr[1, 2])      # 6
print(arr[0])         # [0 1 2 3]
print(arr[:, 1])      # [1 5 9] — 第 1 列

# 切片(返回视图,不复制)
sub = arr[0:2, 1:3]
print(sub)
# [[1 2]
#  [5 6]]

# 布尔索引
mask = arr > 5
print(arr[mask])      # [ 6  7  8  9 10 11]

# 花式索引(返回副本)
rows = np.array([0, 2])
cols = np.array([1, 3])
print(arr[rows, cols])  # [1, 11] — arr[0,1] 和 arr[2,3]

广播机制 Broadcasting

广播是 NumPy 最强大也最容易混淆的特性:

python
# 规则:从右对齐维度,1 可以扩展到任意大小
a = np.array([[1], [2], [3]])   # shape (3, 1)
b = np.array([10, 20, 30])      # shape (3,) → (1, 3)

print(a + b)
# [[11 21 31]
#  [12 22 32]
#  [13 23 33]]

# 实战:归一化每行
data = np.random.rand(100, 5)
row_mean = data.mean(axis=1, keepdims=True)  # shape (100, 1)
row_std = data.std(axis=1, keepdims=True)    # shape (100, 1)
normalized = (data - row_mean) / row_std     # 广播

广播规则:

  1. 维度数不同时,在左边补 1
  2. 某维度为 1 时,可以扩展到另一个数组的对应维度
  3. 维度不兼容(非 1 且不相等)则报错

通用函数 ufunc

python
arr = np.array([1.0, 4.0, 9.0, 16.0])

# 数学函数
print(np.sqrt(arr))    # [1. 2. 3. 4.]
print(np.exp(arr))
print(np.log(arr))
print(np.sin(arr))

# 聚合
print(arr.sum())       # 30.0
print(arr.mean())      # 7.5
print(arr.std())       # 标准差
print(arr.min(), arr.max())
print(arr.argmin(), arr.argmax())  # 最小/大值的索引

# 多维聚合
matrix = np.arange(12).reshape(3, 4)
print(matrix.sum(axis=0))  # 按列求和 [12 15 18 21]
print(matrix.sum(axis=1))  # 按行求和 [ 6 22 38]

线性代数

python
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

# 矩阵乘法
print(A @ B)           # 矩阵乘法(推荐)
print(np.dot(A, B))    # 等价

# 线性代数操作
print(np.linalg.det(A))          # 行列式
print(np.linalg.inv(A))          # 逆矩阵
eigenvalues, eigenvectors = np.linalg.eig(A)  # 特征值/向量
U, S, Vt = np.linalg.svd(A)     # SVD 分解
x = np.linalg.solve(A, [1, 2])  # 解线性方程组 Ax = b

性能优化技巧

避免循环,使用向量化

python
# 慢:Python 循环
def slow_distance(points):
    n = len(points)
    distances = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            diff = points[i] - points[j]
            distances[i, j] = np.sqrt(np.sum(diff**2))
    return distances

# 快:向量化
def fast_distance(points):
    # points shape: (n, d)
    diff = points[:, np.newaxis, :] - points[np.newaxis, :, :]
    return np.sqrt((diff**2).sum(axis=-1))

points = np.random.rand(100, 3)
# fast_distance 比 slow_distance 快 100x+

内存布局选择

python
# C 顺序(行优先,默认)— 按行遍历快
c_arr = np.zeros((1000, 1000), order='C')

# Fortran 顺序(列优先)— 按列遍历快,适合线性代数
f_arr = np.zeros((1000, 1000), order='F')

# 检查是否连续
print(c_arr.flags['C_CONTIGUOUS'])  # True
print(f_arr.flags['F_CONTIGUOUS'])  # True

视图 vs 副本

python
arr = np.arange(10)

# 视图(共享内存,修改会影响原数组)
view = arr[2:5]
view[0] = 99
print(arr)  # [ 0  1 99  3  4  5  6  7  8  9]

# 副本(独立内存)
copy = arr[2:5].copy()
copy[0] = 0
print(arr)  # 不变

# 检查
print(np.shares_memory(arr, view))  # True
print(np.shares_memory(arr, copy))  # False

实战:图像处理

python
# 图像本质上是 NumPy 数组
from PIL import Image
import numpy as np

# 加载图像
img = np.array(Image.open("photo.jpg"))  # shape: (H, W, 3)
print(img.shape, img.dtype)  # (480, 640, 3) uint8

# 灰度化
gray = img.mean(axis=2).astype(np.uint8)

# 亮度调整
brighter = np.clip(img * 1.5, 0, 255).astype(np.uint8)

# 裁剪
cropped = img[100:300, 200:400]

# 水平翻转
flipped = img[:, ::-1]

核心思维转变

从"循环处理每个元素"转变为"对整个数组做操作"。每当你想写 for 循环处理数组,先想想能否用 NumPy 的向量化操作替代。

本站内容由 褚成志 整理编写,仅供学习参考