原文链接:Einsums in the wild

引言

机器学习和统计学中存在大量的线性组合,许多统计学和机器学习中的算法和模型都可以写成矩阵与向量之间的运算。Einsums 是一种表示向量、矩阵和高维数组之间的线性运算的方法。

这篇文章中列出了使用 Einsums 的示例。我们假设读者熟悉 爱因斯坦求和(Einstein summation) 的基础知识。但是,我们将在下一节中简单介绍 einsum。

在下一节中,我们将简要介绍 einsum 表达式及其在 numpy / jax.numpy 中的用法。

Einsums 简介

令 \(\mathbf{a} \in \mathbb{R}^{M}\) 表示为一个一维向量,\(a_{m}\) 表示向量 \(\mathbf{a}\) 中的第 \(m\) 个元素。假设我们想要表示 \(\mathbf{a}\) 中所有元素的和,可以写成如下形式:

\[ \sum_{m=1}^{M} a_{m} \]

为了引入 einsum 标记,我们注意到这个等式中的和符号(\(\Sigma\))只是表明我们应该考虑 \(\mathbf{a}\) 中的所有元素并将它们相加。如果我们假设:1)向量 \(\mathbf{a}\) 中的维数没有歧义;2)我们对它的所有元素求和,我们可以将一维向量 \(\mathbf{a}\) 中所有元素之和的 einsum 标记定义为:

\[ \sum_{m=1}^N a_m\stackrel{\text{einsum}}{\equiv} \mathbf{a}_m \]

为了保持我们的符号一致,我们将带括号的索引表示为静态维度(static dimensions),静态维度允许我们扩展 einsum 的表达能力。也就是说,我们将 \(\mathbf{a}\) 中的所有元素在 einsum 标记下表示为 \(\mathbf{a}_{(m)}\)。

由于数组的名称对于定义这些表达式不一定有意义,因此我们在 numpy 中通过仅关注索引来定义 einsum 表达式。为了表示哪些维度是静态的哪些应该相加,我们引入了 -> 符号。-> 左侧的元素定义数组的索引集,-> 右侧的元素表示我们不求和的索引。

例如,对向量 \(\mathbf{a}\) 中的所有元素求和可以写成:

\[ \mathbf{a}_m \equiv \texttt{m->} \]

另外,选择向量 \(\mathbf{a}\) 中的所有元素可以写成:

\[ \mathbf{a}_{(m)} \equiv \texttt{m->m} \]

在下面的代码片段中,我们展示了这个符号的作用:

1
2
3
4
5
>>> a = np.array([1, 2, 3, 4])
>>> np.einsum("m->", a)
10
>>> np.einsum("m->m", a)
array([1, 2, 3, 4])

高维数组

令 \(\mathbf{a} \in \mathbb{R}^{M}\) 和 \(\mathbf{b} \in \mathbb{R}^{M}\) 表示两个一维向量,则 \(\mathbf{a}\) 和 \(\mathbf{b}\) 之间的点乘可以写为:

\[ \begin{aligned} \mathbf{a}^{\mathrm{T}} \mathbf{b} &= a_1 b_1 + \ldots + a_M b_M \\ &= \sum_{m=1}^M a_m b_m \end{aligned} \]

按照我们之前的标记,我们看到这个 einsum 表达式在数学和 numpy 形式中的表示是:

\[ \mathbf{a}_m \mathbf{b}_m \equiv \texttt{m,m->} \]

此外,\(\mathbf{a}\) 和 \(\mathbf{b}\) 之间的逐元素乘积的 einsum 标记由下式给出:

\[ \mathbf{a}_{(m)} \mathbf{b}_{(m)} \equiv \texttt{m,m->m} \]

例如,考虑以下一维数组 ab

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
>>> a = np.array([1, 3, 1])
>>> b = np.array([-1, 2, -2])
>>> # dot product
>>> a @ b
3
>>> np.einsum("m,m->", a, b)
3
>>> # Element-wise product
>>> a * b
array([-1,  6, -2])
>>> np.einsum("m,m->m", a, b)
array([-1,  6, -2])

我们可以进一步对上面的想法进行推广。考虑矩阵 \(\mathbf{A} \in \mathbb{R}^{N\times M}\) 和向量 \(\mathbf{x} \in \mathbb{R}^M\) 之间的乘积。通过线性代数,我们可以写成:

\[ \mathbf{A} \mathbf{x} = \begin{bmatrix} \mathbf{a}_1^{\mathrm{T}} \\ \vdots \\ \mathbf{a}_N^{\mathrm{T}} \end{bmatrix} \mathbf{x} = \begin{bmatrix} \mathbf{a}_1^{\mathrm{T}} \mathbf{x} \\ \vdots \\ \mathbf{a}_N^{\mathrm{T}} \mathbf{x} \end{bmatrix} \]

其中,\(\mathbf{a}_n^{\mathrm{T}}\) 表示矩阵 \(\mathbf{A}\) 中的第 \(n\) 行元素。

根据上面的式子,我们注意到 \(\mathbf{A} \mathbf{x} \in \mathbb{R}^{N}\) 可以表示为:

\[ (\mathbf{A} \mathbf{x})_{n} = \sum_{m=1}^M a_{n, m} x_m \]

我们还可以观察到这个表达式中,\(\mathbf{A}\) 的第一个维度是静态的。于是,einsum 标记可以写为:

\[ \mathbf{A}_{(n),m}\mathbf{x}_m \equiv \texttt{nm,m->n} \]

通过最后一个例子的结果,我们可以很容易地表达出两个矩阵相乘后的第 \((i, j)\) 项。

令 \(\mathbf{A} \in \mathbb{R}^{N \times M}\),\(\mathbf{B} \in \mathbb{R}^{M \times K}\),则 \(\mathbf{A}\) 和 \(\mathbf{B}\) 的乘积为:

\[ \begin{aligned} \mathbf{A}\mathbf{B} &= \begin{bmatrix} \mathbf{a}_1^{\mathrm{T}} \\ \vdots \\ \mathbf{a}_N^{\mathrm{T}} \end{bmatrix} \begin{bmatrix} \mathbf{b}_1, \ldots, \mathbf{b}_M \\ \end{bmatrix} \\ &= \begin{bmatrix} \mathbf{a}_1^{\mathrm{T}} \mathbf{b}_1 & \mathbf{a}_1^{\mathrm{T}} \mathbf{b}_2 & \ldots & \mathbf{a}_1^{\mathrm{T}} \mathbf{b}_M \\ \mathbf{a}_2^{\mathrm{T}} \mathbf{b}_1 & \mathbf{a}_2^{\mathrm{T}} \mathbf{b}_2 & \ldots & \mathbf{a}_2^{\mathrm{T}} \mathbf{b}_M\\ \vdots & \vdots & \ddots & \vdots \\ \mathbf{a}_N^{\mathrm{T}} \mathbf{b}_1 & \mathbf{a}_N^{\mathrm{T}} \mathbf{b}_2 & \ldots & \mathbf{a}_N^{\mathrm{T}} \mathbf{b}_M \end{bmatrix} \end{aligned} \]

于是,矩阵乘积 \(\mathbf{A}\mathbf{B}\) 中的第 \((i, j)\) 项为:

\[ \begin{aligned} \mathbf{A}\mathbf{B} &= \mathbf{a}_i^{\mathbf{T}} \mathbf{b}_j \\ &= \sum_{m=1}^M a_{i,m} b_{m, j} \end{aligned} \]

从上面的等式中,我们看到 \(\mathbf{A}\) 的第一个维度和 \(\mathbf{B}\) 的第二个维度是静态的。其 einsum 标记为:

\[ \mathbf{A}_{(i),m} \mathbf{B}_{m, (j)}\equiv \texttt{im,mj->ij} \]
1
2
3
4
5
6
7
8
>>> A = np.array([[1, 2], [-2, 1]])
>>> B = np.array([[0, 1], [1, 0]])
>>> A @ B
array([[ 2,  1],
       [ 1, -2]])
>>> np.einsum("im,mj->ij", A, B)
array([[ 2,  1],
       [ 1, -2]])

更高维的数组

在机器学习中使用 einsum 的优势在于它们在处理高维数组时的具有强大的表达能力。正如我们将看到的,知道矩阵向量乘法运算的 einsum 表示很容易让我们将其推广到多个维度。这是因为当输出中存在静态维度时,可以将枚举视为线性变换的表达式。

为了促进使用 einsums 标记在机器学习中表示线性组合,我们考虑以下示例。

机器学习中的 Einsums

令 \(\mathbf{x} \in \mathbb{R}^M\) 和 \(\mathbf{A} \in \mathbb{R}^{M \times M}\) 分别表示一个一维数组和一个二维数组,则以零为中心,精度矩阵 \(\mathbf{A}\) 的平方马氏距离定义为:

\[ D_{\mathbf{A}}(\mathbf{x}) = \mathbf{x}^{\mathrm{T}} \mathbf{A} \mathbf{x} \]

根据矩阵乘法法则,对于任意给定的 \(\mathbf{x}\) 和一个有效的精度矩阵 \(\mathbf{A}\) 我们都可以计算 \(D_{\mathbf{A}}(\mathbf{x})\)。我们可以得到 \(D_{\mathbf{A}}(\mathbf{x})\) 的 einsum 标注为 i,ij,j->,这是因为:

\[ \begin{aligned} \mathbf{x}^{\mathbf{T}} \mathbf{A} \mathbf{x} &= \sum_{i,j} x_i A_{i,j} x_j \\ &\stackrel{\text{einsum}}{\equiv} \mathbf{x}_i \mathbf{A}_{i,j} \mathbf{x}_j \end{aligned} \]

一个更有趣的场景是:考虑这样一种情况,我们有 \(N\) 个观测值存储在一个二维数组 \(\mathbf{X} \in \mathbb{R}^{N \times M}\) 中。如果我们记 \(\mathbf{x}_{n} \in \mathbb{R}^M\) 为 \(\mathbf{X}\) 中的第 \(n\) 个观测,我们需要计算每个观测值的平方马氏距离,以获得:

\[ \mathbf{x}_n^{\mathbf{T}} \mathbf{A} \mathbf{x}_n \quad \forall n \in \{1, \ldots, N\} \]

获得上面结果的一种方法是通过矩阵计算得到:

\[ \text{Diag}(\mathbf{X}^{\mathbf{T}}{\bf A}\mathbf{X}) \]

其中:\(\text{Diag}(\mathbf{M})_{i} = \mathbf{M}_{i,j}\)。

根据矩阵乘法,我们可以证明上面的结果:

\[ \begin{aligned} (\mathbf{X} {\bf A}\mathbf{X}^{\mathbf{T}}) &= \begin{bmatrix}\mathbf{x}_1^{\mathbf{T}} \\ \vdots \\\mathbf{x}_N^{\mathbf{T}}\end{bmatrix} {\bf A} \begin{bmatrix}{\bf x_1} & \ldots & \mathbf{x}_N\end{bmatrix}\\ &= \begin{bmatrix}\mathbf{x}_1^{\mathbf{T}} {\bf A}\\ \vdots \\\mathbf{x}_N^{\mathbf{T}} {\bf A}\end{bmatrix} \begin{bmatrix}{\bf x_1} & \ldots & \mathbf{x}_N\end{bmatrix}\\ &= \begin{bmatrix} \mathbf{x}_1^{\mathbf{T}} {\bf A} \mathbf{x}_1 & \mathbf{x}_1^{\mathbf{T}} {\bf A} \mathbf{x}_2 & \ldots & \mathbf{x}_1^{\mathbf{T}} {\bf A} \mathbf{x}_N \\ \mathbf{x}_2^{\mathbf{T}} {\bf A} \mathbf{x}_1 & \mathbf{x}_2^{\mathbf{T}} {\bf A} \mathbf{x}_2 & \ldots & \mathbf{x}_2^{\mathbf{T}} {\bf A} \mathbf{x}_N \\ \vdots & \vdots & \ddots & \vdots \\ \mathbf{x}_N^{\mathbf{T}} {\bf A} \mathbf{x}_1 & \mathbf{x}_N^{\mathbf{T}} {\bf A} \mathbf{x}_2 & \ldots & \mathbf{x}_N^{\mathbf{T}} {\bf A} \mathbf{x}_N \end{bmatrix} \end{aligned} \]

但是,上述表达式的计算效率很低,因为我们需要计算 \(N^2\) 项来获得我们想要的大小为 \(N\) 的一维数组。更高效的方式是利用 einsum 标记来计算和表达上述的平方马氏距离问题。正如我们已经看到的,\(D_{\mathbf{A}}(\mathbf{x})\) 的 einsum 标记形式为:

\[ \mathbf{x}_{i}\mathbf{A}_{i,j}\mathbf{x}_j \]

我们可以很简单地将上面的表达式扩展到 \(N\) 个元素,我们只需要注意 \(\mathbf{X}\) 的第一个维度是静态的。我们可以得到:

\[ \mathbf{X}_{(n), i}\mathbf{A}_{i,j}\mathbf{X}_{(n),j} \equiv \texttt{ni,ij,nj->n} \]

这个特定的例子中,使用 einsum 标注可以帮助我们避免计算不在对角线上的元素,与直接利用矩阵相乘方法相比,提高了我们的速度。

1
2
3
4
5
6
7
In [1]: %%timeit -n 10 -r 10
   ...: np.diag(X @ A @ X.T)
10 loops, best of 10: 5.8 s per loop

In [2]: %%timeit -n 10 -r 10
   ...: np.einsum("ni,ij,nj->n", X, A, X, optimize="optimal")
10 loops, best of 10: 598 ms per loop

为了推广这一结果,我们考虑三维数组 \(\mathbf{X} \in \mathbb{R}^{N_{1} \times N_{2} \times M}\)。使用线性代数的基本工具,我们已经没法用代数形式表示在 \(\mathbf{X}\) 的最后一维上进行平方马氏距离的计算。从我们之前的结果中,我们看到为了扩展 \(\mathbf{X}\) 的计算,我们只需要引入一个额外的静态维度:

\[ \mathbf{X}_{(n), (m), i}\mathbf{A}_{i,j}\mathbf{X}_{(n),(m), j} \equiv \texttt{nmi,ij,nmj->nm} \]

最后这些表达式表明,当我们必须计算未使用索引的已知线性变换时,einsums 可以提供很大帮助

如果我们继续增加 \(\mathbf{X}\) 的维度,我们会得到以下结果:

  1. i,ij,j-> 得到标量输出,
  2. ni,ij,nj->n 得到一维数组输出,
  3. nmi,ij,nmj->nm 得到二维数组输出,
  4. nmli,ij,nmlj->nml 得到三维数组是输出,
  5. …i,ij,…j->… 得到 \(d\) 维数组输出。

此外,einsums 表达式在索引块上是可交换的。这意味着 einsum 表达式的结果与数组的定位顺序无关。对于我们之前的示例,以下三个表达式是等价的:

  1. ni,ij,nj->n
  2. ni,nj,ij->n
  3. ni,nj,ij->n

高斯分布的对数密度函数

令 \(\mathbf{x} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma})\),则 \(\mathbf{x}\) 的对数密度函数为:

\[ \log p(\mathbf{x} \vert \boldsymbol{\mu}, \boldsymbol{\Sigma}) = -\frac{1}{2}(\mathbf{x} - \boldsymbol{\mu})^{\mathrm{T}}\boldsymbol{\Sigma}^{-1}(\mathbf{x} - \boldsymbol{\mu}) + \text{const} \]

假设我们想要绘制二元高斯分布的对数密度函数在区域 \(\mathcal{X} \subseteq \mathbb{R}^2\) 中的图像。正如我们之前看到的,表达式 \(\mathbf{x}^{\mathrm{T}} \mathbf{A} \mathbf{x}\) 的 einsum 标注为 i,ij,j->。通过引入静态维度 nm,我们可以计算在 \(\mathcal{X}\) 上的对数密度函数,只需要将 nm 索引加入 einsum 表达式中,并且将它们指定为输出的大小。于是,我们可以得到:inm,ij,jnm->nm

下面,我们展示了一个例子:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
mean_vec = jnp.array([1, 0]) 
cov_matrix = jnp.array([[4, -2], [-2, 4]])
prec_matrix = jnp.linalg.inv(cov_matrix)

step = 0.1
xmin, xmax = -8, 8
ymin, ymax = -10, 10
X_grid = jnp.mgrid[xmin:xmax:step, ymin:ymax:step] + mean_vec[:, None, None]

diff_grid = (X_grid - mean_vec[:, None, None])
log_prob_grid = -jnp.einsum("inm,ij,jnm->nm", diff_grid, prec_matrix, diff_grid) / 2
plt.contour(*X_grid, log_prob_grid, 30)
高斯分布对数密度

高斯分布对数密度

我们将上面的想法扩展到 一组具有恒定均值和多个协方差矩阵的多元高斯 的情况。

回想一下,计算区域 \(\mathcal{X} \subseteq \mathbb{R}^2\) 上二元高斯分布的对数密度的 einsum 表达式由 inm,ij,jnm->nm 给出。假设我们有一组 \(K\) 个高斯分布,对于每个 \(k\),我们有一个精度矩阵 \(\mathbf{S}_{k}\) 和一个同样的均值 \(\boldsymbol{\mu} \in \mathbb{R}^{M}\)。为了计算每个区域的密度函数,我们只需修改之前的表达式以考虑新的静态维度 \(k\)。我们得到:inm,kij,jnm->knm

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
C1 = jnp.array([
    [4, -2],
    [-2, 4]
])

C2 = jnp.array([
    [4, 0],
    [0, 4]
])

C3 = jnp.array([
    [4, 2],
    [2, 4]
])

C4 = jnp.array([
    [1, -2],
    [2, 4]
])

C = jnp.stack([C1, C2, C3, C4], axis=0)
S = jnp.linalg.inv(C) # inversion over the fist dimension

log_prob_grid_multiple = -jnp.einsum("inm,kij,jnm->knm", diff_grid, S, diff_grid) / 2

fig, ax = plt.subplots(2, 2)
ax = ax.ravel()
for axi, log_prob_grid in zip(ax, log_prob_grid_multiple):
    axi.contour(*X_grid, log_prob_grid, 30)
K 组均值相同精度不同的高斯分布

K 组均值相同精度不同的高斯分布

回顾一下,计算的马氏距离的 einsum 表达式由下式给出:

  • i,ij,ij-> \(\mathbf{x}\) 为一维数组(向量);
  • ni,ij,nj->n 对于 \(N\) 组 \(\mathbf{x}\) 构成的观测矩阵;
  • nmi,ij,nmj->nm 对于 \(N \times M\) 组 \(\mathbf{x}\) 构成的观测值网格;
  • nmi,kij,nmj->knm 对于 \(N \times M\) 组 \(\mathbf{x}\) 构成的观测值网格,并且具有不同的精度矩阵。

贝叶斯逻辑回归模型的预测面

只要我们的运算中最里面的操作存在元素的线性组合,我们就可以使用 einsum 表达式。

下一个示例,我们计算高斯先验下的贝叶斯逻辑回归的预测曲面。也就是说,我们要计算:

\[ \begin{aligned} p(\hat{y} = 1 \vert \mathbf{x}) &= \int_{\mathbb{R}^2} \sigma(\mathbf{w}^{\mathrm{T}} \mathbf{x}) p(\mathbf{w} \vert \mathcal{D}) d\mathbf{x}\\ &= \mathbb{E}_{\mathbf{w} \vert \mathcal{D}}\left[\sigma(\mathbf{w}^{\mathrm{T}} \mathbf{x})\right]\\ &\approx \frac{1}{S} \sum_{s=1}^S \sigma\left({\bf w^{(s)}}^{\mathrm{T}} \mathbf{x}\right) \end{aligned} \]

假设我们已经估计了后验分布的参数 \(\hat{\mathbf{w}}\),\(\hat{\boldsymbol{\Sigma}}\)。因此预测分布的后验在计算上是难以处理的,我用使用后验预测分布表面的蒙特卡罗近似。和之前的例子相同,我们想要计算在区域上计算 \(p(\hat{y} = 1 \vert \mathbf{x})\)。

在这种情况下,我们从后验分布 \(p(\mathbf{w} \vert \mathcal{D})\) 中采样了 \(S\) 个权重。对于每个权重 \(s\),我们要计算网格点 \(\mathcal{X}\) 处的值。

回想一下,两个向量之间的点乘可以用 einsum 表示为 m,m->。为了得到由一组 \(S\) 在格网 \(\mathcal{X}\) 上每个点计算的值(结果是一个三维数组),我们可以简单地在点乘中引入一个静态维度 s,表示每一个模拟的权重。令 ij 表示网格上的位置,我们可以得到 einsum 表达式为:sm,mij->sij

在计算了 sij 后,我们对结果使用 logistic 函数,并在 s 维度上对每个元素取平均,得到近似的预测分布。

下面是代码片段:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
# Estimated posterior mean and precision matrix
w_hat = jnp.array([4.29669692, 1.6520908])
S_hat = jnp.array([[2.74809243, 0.76832627],
					[0.76832627, 0.88442754]])
C_hat = jnp.linalg.inv(S_hat)

n_samples = 1_000
boundary, step = 8, 0.1
key = jax.random.PRNGKey(314)
X_grid = jnp.mgrid[-boundary:boundary:step, -boundary:boundary:step]

w_samples = jax.random.multivariate_normal(key, w_hat, S_hat, shape=(n_samples,))

logit_grid = jnp.einsum("sm,mij->sij", w_samples, X_grid)
P_grid = jax.nn.sigmoid(logit_grid).mean(axis=0)

plt.contour(*X_grid, P_grid, 30)
plt.title(r"$p(\hat{y} = 1 \vert \mathbf{x})$", fontsize=15)
贝叶斯逻辑回归模型的预测面

贝叶斯逻辑回归模型的预测面

图像压缩:奇异值分解(SVD)

在学习奇异值分解时,一个典型的应用是 使用奇异值分解来压缩图像。作为一个启发式示例,假设我们想要在多个阈值上比较图像的奇异值分解结果。

我们将图像 \(\mathbf{P}\) 分解为:

\[ \mathbf{P} \stackrel{\text{einsum}}{\equiv} \mathbf{c}_{k}\mathbf{L}_{(n),(m),k} \in\mathbb{R}^{N\times M} \equiv \texttt{c,ijc->ij} \]

这是线性代数的经典结果,我们可以将矩阵 \(\mathbf{P}\) 可以分解为:

\[ \mathbf{P} = \mathbf{U} \boldsymbol{\Sigma} \mathbf{V}^{\mathrm{T}} \]

其中:\(\mathbf{U} \in \mathbb{R}^{M \times M}\),\(\mathbf{V} \in \mathbb{R}^{N \times N}\),以及 \(\boldsymbol{\Sigma} \in \mathbb{R}^{M \times N}\)。\(\boldsymbol{\Sigma}\) 中的对角线元素为 \(\{\sigma_1, \sigma_2, \ldots, \sigma_{\min(n,m)}\}\),其他元素为 \(0\)。

scipy 中,\(\mathbf{P}\) 的 SVD 分解被方便地分解(以 einsum 形式)为:

\[ \mathbf{P} \stackrel{\text{einsum}}{\equiv}\hat{\mathbf{U}}_{(n),k}\hat{\boldsymbol{\sigma}}_{k} \hat{\mathbf{V}}_{k, (m)} \]

其中:\(\hat{\mathbf{U}} \in \mathbb{R}^{M \times R}\),\(\hat{\mathbf{V}} \in \mathbb{R}^{N \times R}\),\(\boldsymbol{\sigma} = \{\sigma_1 \ldots, \sigma_R\}\),并且 \(R = \min(M,N)\)。

作为一个示例,假设我们想要用前 \(K\) 个奇异值分量来近似矩阵 \(\mathbf{P}\)。首先我们观察到矩阵 \(\mathbf{P}\) 的第 \((n,m)\) 个元素可由下式计算:

\[ \mathbf{P}_{n,m} = \sum_{k=1}^R \hat{\mathbf{U}}_{n,k} \hat{\boldsymbol{\sigma}}_{k} \hat{\mathbf{V}}_{k, m} \]

如果我们想考虑 \(\mathbf{P}_{n,m}\) 的前 \(K\) 个分量,我们只需要修改求和中的项数即可得到:

\[ \sum_{k=1}^K \hat{\mathbf{U}}_{n,k} \hat{\boldsymbol{\sigma}}_{k} \hat{\mathbf{V}}_{k, m} \]

但是,上面的表达式 不能用 einsum 标注表示。正如我在开头提到的,每个 einsum 表达式都假设在索引的每个元素上求和

为了绕过这个约束,我们简单的引入大小为 \(R\) 的一维向量 \(\mathbf{1}_{\cdot \leq K}\),其中前 \(K\) 个元素的值为 \(1\),剩下 \(R-K\) 个元素的值为 \(0\)。

因此,使用前 \(K\) 个奇异值分量近似矩阵 \(\mathbf{P}\) 的表达式可以写为:

\[ \sum_{k=1}^R \hat{\mathbf{U}}_{n,k} \hat{\boldsymbol{\sigma}}_{k} \hat{\mathbf{V}}_{k, m} (\mathbf{1}_{\cdot \leq K})_{k} \]

上式可以简单的用 einsum 表达式写为:

\[ \hat{\mathbf{U}}_{(n),k} \hat{\boldsymbol{\sigma}}_{k} \hat{\mathbf{V}}_{k, {m}} (\mathbf{1}_{\cdot \leq K})_{k} \equiv \texttt{nk,k,km,k->nm} \]

我们同样可以考虑 \(K\) 取不同值的情况。我们定义二维数组:

\[ \mathbf{I} = \begin{bmatrix} \mathbf{1}_{\cdot \leq K}\\ \mathbf{1}_{\cdot \leq K_2}\\ \vdots\\ \mathbf{1}_{\cdot \leq K_C}\\ \end{bmatrix} \]

接下来,我们简单地修改我们之前的表达式,以考虑到矩阵 \(\mathbf{I}\) 的附加静态维度。我们可以得到:

\[ \begin{aligned} \hat{\mathbf{U}}_{(n),k}\hat{\boldsymbol{\sigma}}_{k} \hat{\mathbf{V}}_{k, (m)}{\mathbf{I}}_{(c), k} &\equiv \texttt{nk,k,km,ck->nmc}\\ &\equiv \texttt{nk,k,km,ck->cnm} \end{aligned} \]

我们在下一个代码中提供了一个例子:首先,我们加载一张图像存到三维数组 img 中。接下来,我们对 img 进行变换,得到一个二维数组 img_gray。我们在 img_gray 上执行奇异值分解并定义了一个矩阵 indexer 包含不同的阈值。最后,我们利用我们之前定义的表达式来计算在 indexer 中定义的不同阈值的图像的 SVD 近似值。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
FILEPATH = "https://i.imgur.com/p91lldL.png"
img = plt.imread(FILEPATH)

c_weights = jnp.array([0.2989, 0.5870, 0.1140])
img_gray = jnp.einsum("c,ijc->ij", c_weights, img)
U, s, Vh = jax.scipy.linalg.svd(img_gray, full_matrices=False)

indexer = s[:, None] > jnp.array([10, 100, 1_000, 5_000])
img_svd_collection = jnp.einsum("nk,k,km,ck->cnm", U, s, Vh, indexer)

fig, ax = plt.subplots(2, 2, figsize=(5, 6))
ax = ax.ravel()
for axi, img_svd in zip(ax, img_svd_collection):
    axi.imshow(img_svd, cmap="bone")
    axi.axis("off")
plt.tight_layout(w_pad=-1)
SVD 近似图像

SVD 近似图像