笔记|多模态融合(一):从特征拼接到注意力融合——多模态学习基础
系列说明:本文是多模态融合系列的第一篇,从最基本的概念出发,建立三级融合的数学框架,为后续 CLIP、BLIP-2、LLaVA 等模型打下理论基础。 前置知识:线性代数基础、Transformer 注意力机制(可参考本站第十一篇 UIT/DiT 详解)。
0. 从一个图文检索任务说起
假设你手上有一个小型数据集:5 张图片和 5 段文字描述。
| 编号 | 图片内容 | 文字描述 |
|---|---|---|
| 1 | 一只橘猫趴在窗台上 | "an orange cat lying on a windowsill" |
| 2 | 一辆红色跑车停在路边 | "a red sports car parked on the street" |
| 3 | 一碗拉面冒着热气 | "a steaming bowl of ramen noodles" |
| 4 | 一片雪山风景 | "a snowy mountain landscape" |
| 5 | 一个小女孩在画画 | "a little girl painting on a canvas" |
你的任务是:给定一张新的图片(比如另一只猫),从这 5 段文字中找到最匹配的描述。
这个问题的核心挑战是——图片和文字是完全不同的信号。图片是一个三维张量 \(\mathbf{I} \in \mathbb{R}^{H \times W \times 3}\),文字是一个变长离散序列 \(\mathbf{w} = (w_1, w_2, \ldots, w_L)\),\(w_i \in \mathcal{V}\)(词汇表)。如何让它们"对话"?
这就是多模态融合(Multimodal Fusion)要解决的核心问题。
1. 什么是"模态"?
模态(Modality)指信息的来源或表现形式。在机器学习中,常见模态包括:
| 模态 | 原始形式 | 典型编码器 | 编码后维度 |
|---|---|---|---|
| 视觉(图像) | \(\mathbf{I} \in \mathbb{R}^{H \times W \times 3}\) | ViT, ResNet | \(\mathbf{v} \in \mathbb{R}^{N_v \times d_v}\) |
| 语言(文本) | \((w_1, \ldots, w_L), w_i \in \mathcal{V}\) | BERT, GPT | \(\mathbf{t} \in \mathbb{R}^{N_t \times d_t}\) |
| 音频 | \(\mathbf{a} \in \mathbb{R}^{T \times F}\)(频谱) | Whisper, HuBERT | \(\mathbf{a} \in \mathbb{R}^{N_a \times d_a}\) |
| 3D 点云 | \(\mathbf{P} \in \mathbb{R}^{N \times 3}\) | PointNet, PCT | \(\mathbf{p} \in \mathbb{R}^{N_p \times d_p}\) |
关键观察:不同模态经过各自的编码器后,都变成了序列化的向量表示。模态之间的差异体现在两个维度上:
- 序列长度不同:\(N_v \neq N_t\)(ViT-B/16 对 224×224 图片产生 196 个 patch token,而一段文字可能只有 20 个 token)
- 特征维度不同:\(d_v \neq d_t\)(ViT-B 的 \(d_v = 768\),而某些文本编码器 \(d_t = 512\))
多模态融合的目标:将来自不同模态的表示整合为统一的联合表示,使下游任务(分类、检索、生成)能够同时利用多个模态的信息。
数学化地说,我们要找到一个融合函数:
\[ \mathcal{F}: \underbrace{\mathbb{R}^{N_v \times d_v}}_{\text{视觉}} \times \underbrace{\mathbb{R}^{N_t \times d_t}}_{\text{语言}} \longrightarrow \underbrace{\mathbb{R}^{N_z \times d_z}}_{\text{联合表示}} \]
不同的 \(\mathcal{F}\) 设计,构成了多模态融合的不同流派。
2. 三级融合框架
按照融合发生的时机,可以将多模态融合分为三个层级。用一个直观的类比:
| 层级 | 类比 | 信息交换时机 |
|---|---|---|
| 早期融合 | 把中文和英文混在一起写 | 原始特征层 |
| 中期融合 | 两个翻译官反复交流后各自总结 | 中间表征层 |
| 晚期融合 | 各自写完报告再合并结论 | 决策/输出层 |
2.1 早期融合(Early Fusion)
核心思想:在编码的最早阶段就将不同模态的特征合并,然后用统一的网络处理。
最简单的形式——拼接:
给定视觉特征 \(\mathbf{v} \in \mathbb{R}^{d_v}\) 和文本特征 \(\mathbf{t} \in \mathbb{R}^{d_t}\)(这里先考虑单向量的简化场景),拼接融合为:
\[ \mathbf{z}_{\text{concat}} = [\mathbf{v}; \mathbf{t}] \in \mathbb{R}^{d_v + d_t} \]
然后通过 MLP 映射到任务空间:\(\hat{y} = \text{MLP}(\mathbf{z}_{\text{concat}})\)。
问题:拼接是线性组合,无法建模模态之间的交互。"猫"和"窗台"的组合含义不等于"猫"的含义加上"窗台"的含义。
改进——双线性池化(Bilinear Pooling):
双线性池化通过外积捕获二阶交互:
\[ \mathbf{z}_{\text{bilinear}} = \mathbf{v}^\top \mathbf{W} \mathbf{t} \]
其中 \(\mathbf{W} \in \mathbb{R}^{d_v \times d_t}\) 是可学习的权重矩阵。完整的双线性池化输出维度为 \(d_v \times d_t\):
\[ z_{ij} = \sum_k v_k W_{ki} t_j = \mathbf{v}^\top \mathbf{W}_{\cdot, j} \cdot t_j \]
用外积的形式表达更清晰——令 \(\tilde{\mathbf{z}} = \mathbf{v} \otimes \mathbf{t} \in \mathbb{R}^{d_v \times d_t}\),则 \(\tilde{z}_{ij} = v_i \cdot t_j\)。这个 \(d_v \times d_t\) 维的向量能表达所有二阶交互项。
定理(双线性模型的表达能力):设 \(\phi(\mathbf{v}, \mathbf{t}) = \mathbf{w}^\top (\mathbf{v} \otimes \mathbf{t})\) 为双线性分类器,则其等价于在 \(\mathbf{v}\) 和 \(\mathbf{t}\) 的所有分量对之间施加线性权重。形式化地,\(\phi(\mathbf{v}, \mathbf{t}) = \sum_{i,j} W_{ij} v_i t_j\),即一个关于输入的二次函数。(参考:Tenenbaum & Freeman, 2000)
问题:当 \(d_v = d_t = 768\) 时,\(\tilde{\mathbf{z}}\) 的维度高达 \(768^2 \approx 59\) 万,计算和存储开销巨大。
解决方案——紧凑双线性池化(Compact Bilinear Pooling):
利用 Count Sketch 投影将外积压缩到低维:
\[ \mathbf{z}_{\text{compact}} = \text{CountSketch}(\mathbf{v}) \circledast \text{CountSketch}(\mathbf{t}) \in \mathbb{R}^{d_z} \]
其中 \(\circledast\) 表示卷积(等价于 FFT 域的逐元素乘法),\(d_z \ll d_v \cdot d_t\)。这利用了如下性质:
定理(Count Sketch 的卷积性质):若 \(\psi_h^s(\mathbf{x})\) 为 Count Sketch 投影(由哈希函数 \(h\) 和符号函数 \(s\) 定义),则 \(\psi_{h_1}^{s_1}(\mathbf{v}) \circledast \psi_{h_2}^{s_2}(\mathbf{t})\) 是 \(\mathbf{v} \otimes \mathbf{t}\) 的 Count Sketch 投影的无偏估计。(Pham & Pagh, 2013)
实际计算通过 FFT 加速:\(\mathbf{z}_{\text{compact}} = \text{FFT}^{-1}(\text{FFT}(\psi(\mathbf{v})) \odot \text{FFT}(\psi(\mathbf{t})))\)。
PyTorch 实现:
1 | import torch |
用前面的例子验证:
1 | # 模拟 5 个图文对 |
2.2 晚期融合(Late Fusion)
核心思想:每个模态用独立的编码器处理到最终表示,然后在决策层合并。
设视觉编码器 \(f_v: \mathbb{R}^{H \times W \times 3} \to \mathbb{R}^{d}\) 和文本编码器 \(f_t: \mathcal{V}^* \to \mathbb{R}^{d}\) 分别将输入映射到共享的 \(d\) 维空间。
Score-level 融合(对检索任务):
\[ \text{sim}(\mathbf{I}, \mathbf{w}) = \frac{f_v(\mathbf{I})^\top f_t(\mathbf{w})}{\|f_v(\mathbf{I})\| \cdot \|f_t(\mathbf{w})\|} \]
这正是 CLIP 采用的方式(下一篇详细介绍)。
Decision-level 融合(对分类任务):
\[ \hat{y} = \lambda \cdot f_v^{\text{cls}}(\mathbf{I}) + (1 - \lambda) \cdot f_t^{\text{cls}}(\mathbf{w}) \]
其中 \(f_v^{\text{cls}}\) 和 \(f_t^{\text{cls}}\) 分别输出各模态的 logits,\(\lambda\) 为混合权重。
优点:
- 各模态编码器可以独立预训练
- 推理时可以单模态部署(图像检索只需图像编码器)
- 计算效率高——编码后只需一次点积
缺点:
- 模态之间的交互仅限于最终的相似度计算
- 无法捕获细粒度的跨模态对应关系(如"红色"对应图片中车的颜色)
PyTorch 实现:
1 | class LateFusion(nn.Module): |
2.3 中期融合(Mid Fusion / Cross-Modal Attention)
核心思想:在编码过程的中间层引入跨模态信息交换,让不同模态在处理过程中互相看到彼此。
这是现代多模态模型最常用的策略,核心工具是交叉注意力(Cross-Attention)。
交叉注意力的数学推导
回顾标准的自注意力(Self-Attention)。给定输入序列 \(\mathbf{X} \in \mathbb{R}^{N \times d}\):
\[ \text{SelfAttn}(\mathbf{X}) = \text{softmax}\!\left(\frac{\mathbf{X}\mathbf{W}_Q (\mathbf{X}\mathbf{W}_K)^\top}{\sqrt{d_k}}\right) \mathbf{X}\mathbf{W}_V \]
其中 \(\mathbf{W}_Q, \mathbf{W}_K \in \mathbb{R}^{d \times d_k}\),\(\mathbf{W}_V \in \mathbb{R}^{d \times d_v}\)。
交叉注意力的关键修改:Query 来自一个模态,Key 和 Value 来自另一个模态。
设视觉 token 序列 \(\mathbf{V} \in \mathbb{R}^{N_v \times d}\),文本 token 序列 \(\mathbf{T} \in \mathbb{R}^{N_t \times d}\),则"以视觉查询文本"的交叉注意力为:
\[ \text{CrossAttn}(\mathbf{V}, \mathbf{T}) = \text{softmax}\!\left(\frac{(\mathbf{V}\mathbf{W}_Q)(\mathbf{T}\mathbf{W}_K)^\top}{\sqrt{d_k}}\right) \mathbf{T}\mathbf{W}_V \]
拆解每个视觉 token \(\mathbf{v}_i\) 的输出:
\[ \text{out}_i = \sum_{j=1}^{N_t} \alpha_{ij} \cdot (\mathbf{t}_j \mathbf{W}_V) \]
其中注意力权重为:
\[ \alpha_{ij} = \frac{\exp\!\left(\frac{(\mathbf{v}_i \mathbf{W}_Q)(\mathbf{t}_j \mathbf{W}_K)^\top}{\sqrt{d_k}}\right)}{\sum_{k=1}^{N_t} \exp\!\left(\frac{(\mathbf{v}_i \mathbf{W}_Q)(\mathbf{t}_k \mathbf{W}_K)^\top}{\sqrt{d_k}}\right)} \]
几何直觉:视觉 token \(\mathbf{v}_i\) 通过 \(\mathbf{W}_Q\) 生成一个"问题"向量,文本 token \(\mathbf{t}_j\) 通过 \(\mathbf{W}_K\) 生成一个"钥匙"向量。点积 \((\mathbf{v}_i \mathbf{W}_Q)(\mathbf{t}_j \mathbf{W}_K)^\top\) 衡量这个"问题"和"钥匙"的匹配程度。匹配度高的文本 token 贡献更多的信息(通过 \(\mathbf{W}_V\) 投影后加权求和)。
定理(多层交叉注意力的 Bayes 最优性):考虑线性化的多模态 in-context learning 设置,单层线性自注意力不能对所有任务分布一致地达到 Bayes 最优预测。但一个包含交叉注意力 + 自注意力 + 残差连接的多层模型,在梯度流优化下可以收敛到 Bayes 最优的 in-context 预测器。(引自:Akyürek et al., 2025, arXiv:2602.04872)
这个定理表明,交叉注意力不仅是一种工程直觉,它在理论上也是多模态信息融合的最优选择(至少在线性化设置下)。
双向交叉注意力
在实践中,通常采用双向交叉注意力——视觉看文本,文本也看视觉:
\[ \begin{aligned} \mathbf{V}' &= \mathbf{V} + \text{CrossAttn}(\mathbf{V}, \mathbf{T}) \\ \mathbf{T}' &= \mathbf{T} + \text{CrossAttn}(\mathbf{T}, \mathbf{V}) \end{aligned} \]
这种双向设计让两个模态的信息对称流动。ViLBERT(2019)是最早采用这种设计的模型之一。
计算复杂度分析
| 操作 | 复杂度 |
|---|---|
| 自注意力(视觉) | \(O(N_v^2 \cdot d)\) |
| 自注意力(文本) | \(O(N_t^2 \cdot d)\) |
| 交叉注意力(V→T) | \(O(N_v \cdot N_t \cdot d)\) |
| 拼接后自注意力 | \(O((N_v + N_t)^2 \cdot d)\) |
当 \(N_v \gg N_t\) 时(ViT 产生 196 个 token,文本只有 20 个),交叉注意力 \(O(N_v \cdot N_t \cdot d)\) 比拼接后自注意力 \(O((N_v + N_t)^2 \cdot d)\) 更高效。
PyTorch 实现:
1 | class CrossAttention(nn.Module): |
3. 三种融合的对比实验
回到第 0 节的图文检索任务。用一个简化实验来对比三种融合方式:
1 | import torch |
直观对比:
| 维度 | 早期融合 | 中期融合 | 晚期融合 |
|---|---|---|---|
| 交互深度 | 一阶/二阶 | 注意力级别 | 仅相似度 |
| 计算开销 | \(O(d_v \cdot d_t)\) | \(O(N_v \cdot N_t \cdot d)\) | \(O(d)\) |
| 预训练独立性 | ✗ 需要联合训练 | 部分独立 | ✓ 完全独立 |
| 细粒度对齐 | ✗ | ✓ token 级 | ✗ |
| 检索效率 | ✗ 需要联合编码 | ✗ | ✓ 可预计算 |
关键洞察:没有"最好的"融合策略,选择取决于任务需求。
- 检索任务(需要在百万数据库中快速搜索)→ 晚期融合(可预计算嵌入,线下建索引)
- VQA 任务(需要理解"图片左边的红色物体是什么")→ 中期融合(需要 token 级交互)
- 简单分类任务(已知模态组合固定)→ 早期融合(参数量最少)
4. 从融合框架到模型演进
上面介绍的三级融合框架,恰好对应了多模态大模型的三条演进路线:
1 | 晚期融合 ─────────> CLIP (2021) |
| 时间线 | 模型 | 融合策略 | 核心特点 |
|---|---|---|---|
| 2019 | ViLBERT / LXMERT | 中期(双流交叉注意力) | 首次将 BERT 扩展到视觉-语言 |
| 2021 | CLIP | 晚期(对比学习) | 4 亿图文对, 零样本迁移 |
| 2022 | Flamingo | 中期(Gated Cross-Attention) | 少样本学习, 冻结预训练模型 |
| 2023 | BLIP-2 | 中期(Q-Former) | 冻结 ViT + 冻结 LLM, 轻量桥接 |
| 2023 | LLaVA | 中期(MLP 投影) | 视觉指令微调, 简洁架构 |
| 2024 | SD3 / Flux (MMDiT) | 中期(Joint Attention) | 扩散模型中的多模态融合 |
| 2024 | Chameleon | 早期(统一 token) | 图像/文本共享 next-token-prediction |
| 2025 | InternVL 2.5 | 中期(ViT-MLP-LLM) | MMMU 70.1%, 开源 SOTA |
| 2025 | Qwen2.5-VL | 中期 + 早期混合 | 原生视频理解, 动态分辨率 |
值得特别一提的是 MMDiT(Multimodal Diffusion Transformer),它是 Stable Diffusion 3 和 Flux 的核心架构,将多模态融合直接嵌入了扩散模型的去噪过程中(详见本站第十四篇 SD3 架构解析和第十五篇 Flux 架构解析)。
MMDiT 的融合策略结合了双流和单流设计:
- Double Stream(SD3 / Flux 前 19 层):文本和图像各自拥有独立的 QKV 投影和 MLP,但在注意力计算时拼接 KV 做 Joint Attention——让文本 token 能关注图像 token,反之亦然
- Single Stream(Flux 后 38 层):文本和图像完全合并为一个序列,共享所有参数
这种"先分后合"的混合设计,正是中期融合在生成式 AI 中的典型落地形式:网络前期保留模态特性(各自的 MLP 处理不同的特征分布),后期充分融合(共享参数减少冗余)。
这个演进并非单向的——早期融合回归。最初人们用早期融合但效果差(因为缺乏大规模预训练);CLIP 证明了晚期融合 + 大数据的威力;BLIP-2/LLaVA 用中期融合实现了更细粒度的理解;而 Chameleon 又回到了早期融合——但这次有了 VQ-VAE 图像 token 化和海量数据,效果远超当年。
5. Tensor Fusion Network:一个完整的数学例子
为了展示如何从数学出发设计融合方法,这里详解 Tensor Fusion Network(TFN, Zadeh et al., 2017),它是理解更高阶交互的经典框架。
5.1 问题设定
三模态情感分析任务:给定视频中某一时刻的视觉表情 \(\mathbf{v} \in \mathbb{R}^{d_v}\)、语音语调 \(\mathbf{a} \in \mathbb{R}^{d_a}\)、文字内容 \(\mathbf{t} \in \mathbb{R}^{d_t}\),预测情感极性 \(y \in \{-1, +1\}\)。
5.2 张量融合
TFN 的核心是三阶张量积。首先,为每个模态增加一个常数维度(对应偏置项):
\[ \tilde{\mathbf{v}} = [\mathbf{v}; 1] \in \mathbb{R}^{d_v + 1}, \quad \tilde{\mathbf{a}} = [\mathbf{a}; 1] \in \mathbb{R}^{d_a + 1}, \quad \tilde{\mathbf{t}} = [\mathbf{t}; 1] \in \mathbb{R}^{d_t + 1} \]
张量融合计算三者的外积:
\[ \mathbf{Z} = \tilde{\mathbf{v}} \otimes \tilde{\mathbf{a}} \otimes \tilde{\mathbf{t}} \in \mathbb{R}^{(d_v+1) \times (d_a+1) \times (d_t+1)} \]
展开后,\(\mathbf{Z}\) 包含了所有可能的交互项:
\[ Z_{ijk} = \tilde{v}_i \cdot \tilde{a}_j \cdot \tilde{t}_k \]
由于加了常数 1,\(\mathbf{Z}\) 实际上包含:
- 一阶项:\(v_i, a_j, t_k\)(当其他两个取到常数 1 时)
- 二阶项:\(v_i a_j, v_i t_k, a_j t_k\)
- 三阶项:\(v_i a_j t_k\)
命题(TFN 的完备性):张量融合 \(\mathbf{Z} = \tilde{\mathbf{v}} \otimes \tilde{\mathbf{a}} \otimes \tilde{\mathbf{t}}\) 包含了三个模态之间所有一阶、二阶和三阶多项式交互项。
这个命题的证明是直接的:展开 \((d_v+1)(d_a+1)(d_t+1)\) 维的张量积,按常数 1 出现的位置分类即可得到上述三类项。
PyTorch 实现:
1 | class TensorFusionNetwork(nn.Module): |
注意当 \(d_v = d_a = d_t = 768\) 时,融合维度为 \((769)^3 \approx 4.5 \times 10^8\)——完全不可行。这就是为什么实际应用中需要低秩近似(Low-Rank Tensor Fusion, LMF)或直接转向注意力机制。
6. 总结与下一篇预告
| 概念 | 数学表达 | 适用场景 |
|---|---|---|
| 拼接融合 | \(\mathbf{z} = [\mathbf{v}; \mathbf{t}]\) | 简单分类, 维度小 |
| 双线性池化 | \(z_{ij} = v_i t_j\) | 二阶交互, 中等维度 |
| 紧凑双线性 | FFT + Count Sketch | 高维二阶交互 |
| 张量融合 | \(\mathbf{Z} = \tilde{\mathbf{v}} \otimes \tilde{\mathbf{a}} \otimes \tilde{\mathbf{t}}\) | 多模态高阶交互 |
| 晚期融合 | \(\text{sim} = \mathbf{v}^\top \mathbf{t} / \|\mathbf{v}\|\|\mathbf{t}\|\) | 检索, 零样本迁移 |
| 交叉注意力 | \(\text{softmax}(QK^\top / \sqrt{d_k})V\) | 细粒度对齐, VQA |
本文建立了多模态融合的基本数学框架。但一个关键问题还没回答:不同模态的编码器从哪里来? 如果两个编码器各自独立训练,它们输出的特征空间毫无对齐——"猫"的视觉特征和"cat"的文本特征可能方向完全不同。
下一篇我们将看到 CLIP 如何通过对比学习,在 4 亿图文对上训练出天然对齐的视觉和语言编码器,从而让晚期融合的余弦相似度真正有意义。
参考文献
- Tenenbaum, J. B., & Freeman, W. T. (2000). Separating style and content with bilinear models. Neural Computation.
- Gao, Y., et al. (2016). Compact Bilinear Pooling. CVPR 2016.
- Zadeh, A., et al. (2017). Tensor Fusion Network for Multimodal Sentiment Analysis. EMNLP 2017.
- Lu, J., et al. (2019). ViLBERT: Pretraining Task-Agnostic Visiolinguistic Representations. NeurIPS 2019.
- Akyürek, E., et al. (2025). Multi-layer Cross-Attention is Provably Optimal for Multi-modal In-context Learning. arXiv:2602.04872.