如何在 JAX 中正确计算批量矩阵指数(expm)

如何在 JAX 中正确计算批量矩阵指数(expm)

本文详解 jax 中 `jax.scipy.linalg.expm` 批量计算失败的常见原因与解决方案,涵盖新版原生支持、旧版兼容写法及关键形状调试技巧。

在使用 JAX 计算矩阵指数(如量子线路中的参数化幺正演化 $ e^{iA} $)时,一个典型错误是:

ValueError: expected A to be a square matrix

尽管你确认最后两维是方阵(如 (4, 4)),但报错仍发生——这往往源于 输入张量的维度结构不符合 expm 的隐式批处理规则

? 根本原因:expm 对输入形状有严格要求

jax.scipy.linalg.expm 自 JAX v0.4.7 起原生支持批量输入,但前提是:
✅ 输入数组的最后两个轴必须构成方阵(如 (…, n, n));
❌ 其余前导维度将被自动视为 batch 维度;
❌ 若中间存在非 batch 的冗余维度(如你的 A.shape = (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)),它仍能工作;
⚠️ 但若 A 的最后两维不满足 n == n(例如 (4, 5)),或 A.ndim

在你的代码中,问题出在 pauli_matrix(num_qubits) 的构造逻辑:

def pauli_matrix(num_qubits):
    _pauli_matrices = jnp.array(
        [[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
    )
    # ❌ 错误:对 _pauli_matrices 重复 kronecker 积,却未指定作用于哪一组 qubit
    # 且 [1:] 切片导致维度混乱,最终使 tensordot 结果 A 的 shape 不符合预期
    return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]

该函数实际生成的是 (15, 4**num_qubits, 4**num_qubits) 形状的 Pauli 基(对 2-qubit 应为 (15, 4, 4)),但 reduce(jnp.kron, …) 在 num_qubits=2 时会生成 (4^2, 4^2) = (16, 16) 矩阵,再 [1:] 切片得 (15, 16, 16) —— 而你 theta 是 (15, 2,2,2,2,2,2,2,2),tensordot 后 A 实际为 (2,2,2,2,2,2,2,2, 16, 16),并非你误以为的 (2,…,2,4,4)。因此 expm 接收的不是 (N, 4, 4),而是高维张量,但只要末两维是方阵,新版 JAX 就能处理。

✅ 正确做法:确保 A 的 shape 为 (…, d, d),其中 d = 2**num_qubits。

Copy Leaks

Copy Leaks

AI内容检测和分级,帮助创建和保护原创内容

下载

✅ 解决方案一:升级 JAX 并规范输入(推荐)

确保使用 JAX ≥ 0.4.7:

pip install --upgrade jax jaxlib

然后修正 pauli_matrix 和 SpecialUnitary:

import jax.numpy as jnp
import jax.scipy.linalg as linalg
from functools import reduce

def pauli_basis_1q():
    return jnp.array([
        [[1., 0.], [0., 1.]],   # I
        [[0., 1.], [1., 0.]],   # X
        [[0., -1j], [1j, 0.]],  # Y
        [[1., 0.], [0., -1.]],  # Z
    ])

def pauli_matrix(num_qubits):
    """返回 (4**num_qubits - 1) 个 traceless n-qubit Pauli 算符,shape (15, 4, 4) for n=2"""
    basis = pauli_basis_1q()
    # 构造所有非恒等的 n-qubit Pauli 张量积:共 4^n - 1 个
    from itertools import product
    ops = []
    for indices in product(range(4), repeat=num_qubits):
        if all(i == 0 for i in indices):  # skip identity
            continue
        op = basis[indices[0]]
        for i in indices[1:]:
            op = jnp.kron(op, basis[i])
        ops.append(op)
    return jnp.stack(ops)  # shape: (15, 4, 4) for num_qubits=2

num_qubits = 2
d = 2 ** num_qubits  # 4
theta = jnp.pi * jnp.random.uniform(shape=(15,))  # 简化:单组参数,shape (15,)

A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]])  # -> (4, 4)
U = linalg.expm(1j * A / 2)  # ✅ works: (4, 4)

# 批量示例:theta shape (8, 15) → A shape (8, 4, 4) → U shape (8, 4, 4)
theta_batch = jnp.pi * jnp.random.uniform(shape=(8, 15))
A_batch = jnp.einsum('bi,ij->bjk', theta_batch, pauli_matrix(num_qubits))  # (8, 4, 4)
U_batch = linalg.expm(1j * A_batch / 2)  # ✅ native batch support
print(U_batch.shape)  # (8, 4, 4)

⚙️ 解决方案二:旧版 JAX 兼容写法(jnp.vectorize)

若受限于旧版 JAX(

expm_vec = jnp.vectorize(linalg.expm, signature='(n,n)->(n,n)')

# A_batch shape: (B, d, d)
U_batch = expm_vec(1j * A_batch / 2)  # returns (B, d, d)

⚠️ 注意:vectorize 在 JIT 下可能不如原生批量高效,仅作兼容之用。

? 关键检查清单

  • ✅ 使用 A.shape[-2] == A.shape[-1] 验证末两维是否为方阵;
  • ✅ 避免在 tensordot 或 einsum 中引入意外维度(如你的原始 theta 有 9 维,极易出错);
  • ✅ 优先用 einsum 替代嵌套 tensordot 提升可读性;
  • ✅ 调试时打印 A.shape 和 A.dtype,确认无 float64(JAX 默认 float32,expm 要求浮点)。

掌握这些要点,你就能稳健地在 JAX 中实现量子态演化、李群指数映射等核心计算。

https://www.php.cn/faq/1986954.html

发表回复

Your email address will not be published. Required fields are marked *