论文速读|ParGo: Bridging Vision-Language with Partial and Global Views.AAAI25
论文地址:https://arxiv.org/abs/2408.12928
代码地址:https://github.com/bytedance/ParGo
bib引用:
@misc{wang2025pargobridgingvisionlanguagepartial,
title={ParGo: Bridging Vision-Language with Partial and Global Views},
author={An-Lan Wang and Bin Shan and Wei Shi and Kun-Yu Lin and Xiang Fei and Guozhi Tang and Lei Liao and Jingqun Tang and Can Huang and Wei-Shi Zheng},
year={2025},
eprint={2408.12928},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2408.12928},
}
In Short
提出了一种新颖的 Partial-Global 投影器(ParGo),用于连接视觉和语言模态,以提升多模态大语言模型(MLLMs)的性能,并通过构建新数据集和大量实验验证了其有效性。
- 研究背景:在多模态大语言模型(MLLMs)中,视觉 - 语言投影器至关重要,但现有方法存在缺陷。
- 基于线性层的投影器难以控制提供给 LLM 的视觉标记数量,计算成本高;
- 基于全局注意力的投影器会使生成的标记集中在突出区域而忽略细节。
- ParGo 投影器
- 结构:主要包含 Partial-Global Perception 块和 Cascaded Partial Perception 块。
- 前者使用部分令牌和全局令牌分别提取图像的局部和全局信息,通过特定的交叉注意力掩码与图像特征交互;
- 后者通过掩码自注意力块实现不同局部令牌之间的交互,考虑图像局部区域间的关系。
- 优势:整合全局和局部视图,有效弥合预训练视觉编码器和 LLM 之间的表示差距,缓解对突出区域的过度关注,能更好地为 LLM 提供代表图像的视觉特征。
- 结构:主要包含 Partial-Global Perception 块和 Cascaded Partial Perception 块。
- ParGoCap-1M-PT 数据集
- 构建方法:从 Laion 数据集中随机选大量图像,用强大的闭源 MLLMs(如 GPT4-V 和 Gemini)根据特定提示生成详细描述图像的字幕,包括全局和局部信息。再通过计算图像与字幕相似度进行质量控制,过滤掉低质量数据。
- 作用:现有预训练数据集通常来自互联网,字幕短且缺乏局部细节描述,不利于模态对齐。ParGoCap-1M-PT 提供大量高质量详细字幕样本,有助于模型学习细粒度细节,促进视觉和语言特征空间的对齐。
- 实验
- 设置:在 MME、MMBench、SEED-Bench 和 MM-Vet 四个基准上进行实验,使用预训练的 EVA-02-CLIPL/14 作为视觉编码器、7B Vicuna 作为大语言模型,ParGo 投影器默认六层,部分和全局令牌分别为 288 和 16 个。采用两阶段训练,包括粗 - 细字幕预训练和监督微调,并使用多种现有数据和任务。
- 结果:与基于注意力和线性的投影器方法相比,ParGo 在多个基准测试中表现更优。如在 MMBench 和 MME 基准上分别比 Honeybee(D-Abstractor)有 2.9 和 77.6 的提升,比 InstructBLIP(使用 Q-former)在 MMBench 上高 37.7,比 LLaVA-1.5(使用线性或 MLP 投影器)在 MME 上提升 117.4。消融实验表明,ParGo 的各个组件、令牌数量设置和预训练数据选择均有效,且在需要细节感知能力的任务中优势明显,如在 OCR 和 CNT 任务上比 Q-Former 分别有 32.5 和 16.33 的提升。
- 结论:ParGo 投影器通过整合局部和全局视图,在连接视觉和语言模态方面表现出色,显著优于其他投影器,尤其在需要细节感知的任务中。新构建的数据集也有助于提升模型性能,为 MLLMs 对视觉内容的理解提供了更有效的方法。
摘要
This work presents ParGo, a novel Partial-Global projector designed to connect the vision and language modalities for Multimodal Large Language Models (MLLMs). Unlike previous works that rely on global attention-based projectors, our ParGo bridges the representation gap between the separately pre-trained vision encoders and the LLMs by integrating global and partial views, which alleviates the overemphasis on prominent regions. To facilitate the effective training of ParGo, we collect a large-scale detail-captioned imagetext dataset named ParGoCap-1M-PT, consisting of 1 million images paired with high-quality captions. Extensive experiments on several MLLM benchmarks demonstrate the effectiveness of our ParGo, highlighting its superiority in aligning vision and language modalities. Compared to conventional Q-Former projector, our ParGo achieves an improvement of 259.96 in MME benchmark. Furthermore, our experiments reveal that ParGo significantly outperforms other projectors, particularly in tasks that emphasize detail perception ability.
这项工作介绍了 ParGo,这是一种新颖的 Partial-Global projector,旨在连接多模态大型语言模型 (MLLM) 的视觉和语言模态。与以前依赖基于全局注意力的projector 的工作不同,我们的 ParGo 通过集成全局视图和部分视图来弥合单独预先训练的视觉编码器和 LLM 之间的表示差距,从而缓解了对突出区域的过度强调。为了促进 ParGo 的有效训练,我们收集了一个名为 ParGoCap-1M-PT 的大规模细节字幕图像文本数据集,该数据集由 100 万张图像和高质量字幕配对组成。对几个 MLLM 基准的广泛实验证明了我们的 ParGo 的有效性,突出了它在调整视觉和语言模式方面的优势。与传统的 Q-Former 投影仪相比,我们的 ParGo 在 MME 基准测试中提高了 259.96。此外,我们的实验表明,ParGo 的性能明显优于其他projectors,特别是在强调细节感知能力的任务中。
Introduction
最近多模态大语言在各种任务(e.g.VQA)应用广泛。vision language projector是经常用到的component之一。因为在弥合模态差距之间的关键作用,所以近年来被广泛研究。
【研究现状】
- 两种图像特征投影方法的问题。
- 直接用 MLP(多层感知机)作为投影器,难以控制提供给语言模型的视觉标记数量,例如在处理细粒度特征时会遇到困难,这导致了高计算成本。
- 基于全局注意力的投影器,使用注意力操作将图像特征全局投影到固定数量的视觉标记上。然而,这些基于全局投影的投影器会使生成的标记集中在突出区域,而忽略更精细的细节。(例如,以图 1 中的图像为例,以前的方法往往会聚焦在堡垒上,很容易忽略顶部的两个人。)
所以:①本文的目标是构建一个视觉语言投影仪,它可以为 LLM 提供更好地表示图像的视觉特征,同时使用固定数量的视觉标记。其灵感来自于这样一种观察,即图像可以用两种信息来正确描述,即全局信息呈现对图像的整体理解,而多个部分信息强调微妙的细节,一个例子如图 1 所示。
所以:②提出了一种基于部分全局注意力机制的新型 Partial-Global 投影仪 (ParGo)。通过集成全局视图和部分视图,我们的 ParGo 有效地弥合了单独预先训练的视觉编码器和 LLM 之间的表示差距,从而减轻了对突出区域的过度强调。此外,考虑到图像中不同部分区域之间的关系,ParGo 包含一个级联部分感知块,可实现图像不同部分区域之间的交互。
最后:③为了促进 ParGo 的有效训练,我们收集了一个大规模的细节标题图像文本数据集命名为 ParGoCap-1M-PT 用于预训练。大多数现有的预训练数据集(通常来自互联网)包含的字幕通常很短,强调突出的视觉特征,而缺乏对部分区域的详细描述。在此类数据集上进行训练使模型难以学习精细细节。相比之下,我们的 ParGoCap-1M-PT 包含图像中多个区域的更长、更详细的描述。在这两种标题数据上进行预训练后,我们使用几个公开可用的指令调优数据集将我们的模型转移到多个下游任务中,例如 LLaVA-150k (Liu et al. 2023a)。
相关工作
Vision-language Projector【线性projector、attention based projector】
Vision-language projectors play a crucial role and are widely used components in MLLM. They aim to connect the visual feature space and language feature space, which can be divided into linear-based and attention-based projectors. Linear-based projectors (Liu et al. 2023b,a; Zhu et al. 2023; Chen et al. 2023b; Dong et al. 2024) employ a linear layer to connect the vision encoder seamlessly with the language model (LLM). Despite their straightforward implementation, the linear-based projectors encounter challenges in producing a large number of visual tokens to LLMs, leading to high computational costs. Another line of research (Alayrac et al. 2022; Li et al. 2023b; Bai et al. 2023; Dai et al. 2023; Ye et al. 2023b) explore more flexible projectors (e.g., Qformer (Li et al. 2022) and Perceiver Resampler (Alayrac et al. 2022)) based on attention mechanism. Such attentionbased methods often extract prominent image features, leading to a loss of detail and a drop in the model performance. Similar findings are also mentioned in a recent work, Honeybee (Cha et al. 2023), which proposes a D-Abstractor that uses a Deformable attention (Zhu et al. 2020) to retain the local information and achieve superior performance. To efficiently provide comprehensive information to LLMs using a fixed number of visual tokens, we propose Partial-Global projector, which uses a partial-global projection that simultaneously extracts both partial and global information.
视觉语言投影仪起着至关重要的作用,是 MLLM 中广泛使用的组件。他们的目标是连接视觉特征空间和语言特征空间,从而可以 分为线性投影仪和注意力投影仪。基于线性的投影仪(Liu 等人,2023b,a;Zhu 等人,2023 年;Chen 等人,2023b;Dong et al. 2024) 采用线性层将视觉编码器与语言模型 (LLM) 无缝连接。尽管线性投影仪的实现很简单,但在为 LLM 生成大量视觉标记时遇到了挑战,从而导致高计算成本。另一条研究线(Alayrac 等人,2022 年;Li 等人,2023b;Bai 等人,2023 年;Dai 等人,2023 年;Ye et al. 2023b) 探索基于注意力机制的更灵活的投影仪(例如,Qformer (Li et al. 2022) 和 Perceiver Resampler (Alayrac et al. 2022))。这种基于注意力的方法通常会提取突出的图像特征,从而导致细节丢失和模型性能下降。最近的一项工作 Honeybee (Cha et al. 2023) 中也提到了类似的发现,该研究提出了一种 D-Abstractor,它使用可变形注意力(Zhu et al. 2020)来保留局部信息并实现卓越的性能。为了使用固定数量的视觉标记有效地向 LLM 提供全面信息,我们提出了 Partial-Global projector,它使用部分-全局投影同时提取部分和全局信息。
Multi-modal Pre-training Data【很多用LLM生成数据构成新的合成数据集】
The recent remarkable progress achieved by close-sourced MLLMs has led recent researchers (Chen et al. 2023b; Yu et al. 2024a) to consider using MLLM to synthesize detailcaptioned data, supplementing the limitations of conventional web-crawled datasets. In this work, we further contribute a detail-captioned dataset for pre-training, aimed at enhancing the alignment between the two modalities from a data perspective.
使用网络爬取的大规模图像文本数据集(例如,(Schuhmann 等人,2021 年;Byeon 等人,2022 年;Changpinyo 等人,2021 年;Sharma 等人,2018 年)) 已成为 MLLM 最常见的策略。然而,网络爬虫数据集主要使用嘈杂和简短的标题来呈现图像的主要特征,缺乏详细的描述。为了获得详细的描述,一些作品(Wang et al. 2023, 2025)提供了方框(或蒙版)级别的标题,但受到方框生成(接地)模型的限制(Fang et al. 2024, 2025;Tang 等人,2022a,b)。近源 MLLM 最近取得的显着进展引领了最近的研究人员(Chen 等人,2023b;Yu et al. 2024a) 考虑使用 MLLM 来合成详细标题数据,补充传统网络爬虫数据集的局限性。在这项工作中,我们进一步贡献了一个用于预训练的细节说明数据集,旨在从数据角度增强两种模式之间的一致性。
Notes:这个过程中会涉及生成数据风格一致性这种问题,还有LLM典型的视觉幻觉问题【也有人专门针对视觉幻觉做了新的工作】
图 2. (a).MLLM 的管道,以我们提议的 ParGo 作为视觉语言投影仪。首先,我们使用冻结图像编码器来提取图像特征。为了更好地将预先训练的视觉编码器与 LLM 对齐,我们提出了一个 PartialGlobal 投影仪,使用两种标记 i . e . i.e. i.e. 来投影图像特征,即部分和全局标记。最后,输出的部分和全局视觉标记以及标记化的文本被馈送到 LLM 中,以自动回归的方式生成文本输出。具体来说,每个 Partial-Global 投影仪层都包含一个 Partial-Global Perception 块,该块利用两种令牌来提取图像特征。此外,为了充分考虑图像中不同部分区域之间的关系,我们合并了一个级联部分感知块,以实现部分标记之间以级联方式进行交互。(b).部分全局和级联部分注意力掩码的演示。值得注意的是,Partial-Global Attention 掩码在不同层中保持不变,而 Cascaded Partial Attention 掩码在各个层中保持不变。
结论
In this work, we focus on the vision-language projector in MLLMs, proposing Partial-Global projector (ParGo). ParGo employs partial and global tokens with specially designed attention masks to extract two kinds of information separately, with considering the relation between different partial regions in an image. Moreover, to further facilitate the alignment between the two modalities, we contribute a large-scale detail-captioned dataset ParGoCap-1M-PT for pre-training. Extensive ablations and experiments are conducted, which illustrate the effectiveness of our ParGo. We find that ParGo significantly outperforms other projectors, particularly in tasks that emphasize detail perception. These results highlight ParGo’s potential to enhance MLLMs by providing a more nuanced understanding of visual content through the integration of both partial and global views.
在这项工作中,我们专注于 MLLM 中的视觉语言投影仪,提出了 Partial-Global 投影仪 (ParGo)。ParGo 采用部分和全局标记以及专门设计的注意力掩码,分别提取两种信息,同时考虑图像中不同部分区域之间的关系。此外,为了进一步促进两种模式之间的对齐,我们提供了一个大规模的细节标题数据集 ParGoCap-1M-PT 用于预训练。进行了广泛的消融和实验,这说明了我们的 ParGo 的有效性。我们发现 ParGo 的性能明显优于其他投影仪,尤其是在强调细节感知的任务中。这些结果凸显了 ParGo 通过集成部分视图和全局视图来提供更细致入微的视觉内容理解来增强 MLLM 的潜力。
代码中值得学习的部分👍👍👍
1. get_extended_attention_mask
- https://github.com/bytedance/ParGo/blob/main/pargo/backbone/language/qformer_bert.py
BertModel
中,get_extended_attention_mask
方法用于生成扩展的注意力掩码,以控制模型在注意力计算时忽略某些令牌。
def get_extended_attention_mask(
self,
attention_mask: Tensor,
input_shape: Tuple[int],
device: device,
is_decoder: bool,
has_query: bool = False,
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
# add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
if has_query: # UniLM style attention mask
causal_mask = torch.cat(
[
torch.zeros(
(batch_size, prefix_seq_len, seq_length),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=1,
)
causal_mask = torch.cat(
[
torch.ones(
(batch_size, causal_mask.shape[1], prefix_seq_len),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
这个方法根据输入的 attention_mask
以及模型是否为解码器等条件,生成不同的扩展注意力掩码,用于后续的注意力计算。
2. get_global_local_attention
BertModel
中,get_global_local_attention
生成自注意力掩码和交叉注意力掩码,论文中提到的"通过特定的交叉注意力掩码与图像特征交互"。
def get_global_local_attention(self, batch_size=3, total_query=160, total_image_embeds=1024, local_visible_query_num=128, device=None):
assert (total_image_embeds % local_visible_query_num == 0), "invalid globle query ratio"
global_visible_query_num = total_query - local_visible_query_num
self_attention_mask = torch.ones((total_query, total_query))
self_attention_mask[:local_visible_query_num,:local_visible_query_num] = 0
self_attention_mask[-global_visible_query_num:,-global_visible_query_num:] = 0
self_attention_mask_ret = self_attention_mask.masked_fill(self_attention_mask==1, float("-inf")).to(device)
attention_mask = torch.zeros((total_query, total_image_embeds)) # inititalize the attention_mask
index = torch.arange(0,total_image_embeds).reshape(local_visible_query_num, int(total_image_embeds / local_visible_query_num))
attention_mask = attention_mask.scatter_(dim=1, index=index, value=1) # local attention mask
cross_attention_mask_ret = torch.zeros((total_query, total_image_embeds))
cross_attention_mask_ret[:local_visible_query_num] = \
cross_attention_mask_ret[:local_visible_query_num].masked_fill(attention_mask[:local_visible_query_num]==0, float("-inf")) # fill -inf
cross_attention_mask_ret = cross_attention_mask_ret.unsqueeze(0).unsqueeze(0).repeat(batch_size,1,1,1).to(device) # (2,1,128,1024)
return self_attention_mask_ret, cross_attention_mask_ret
通过生成特定的掩码,控制后面部分令牌和全局令牌与图像特征之间的注意力交互
3. forward
(BertModel)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
):
# ... ...
if is_decoder:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask,
input_ids.shape,
device,
is_decoder,
has_query=(query_embeds is not None),
)
else:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape, device, is_decoder
) # (3,1,1,128)
# ......
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
0
].size()
else:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [
self.invert_attention_mask(mask) for mask in encoder_attention_mask
]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
local_visible_query_num=self.config.local_query_length[0]
if local_visible_query_num == 0:
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask)
else:
extended_attention_mask, encoder_extended_attention_mask = self.get_global_local_attention(
batch_size = query_embeds.shape[0],
total_query = query_embeds.shape[1],
total_image_embeds = encoder_hidden_states.shape[1],
local_visible_query_num=local_visible_query_num,
device = device
)
else:
encoder_extended_attention_mask = None
# ......
encoder_outputs = self.encoder(
embedding_output, # query_embedding (3,128,1024)
attention_mask=extended_attention_mask, # (3,128)
head_mask=head_mask, # [None, None, None, None]
encoder_hidden_states=encoder_hidden_states, # image embedding (3,1024,1408)
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
# ... ...
这里根据不同条件生成相应的注意力掩码,并将其传递给编码器进行计算【Cascaded Partial Perception 块中通过掩码自注意力块实现不同局部令牌之间的交互】
4. 一个简化的代码整合
Aim at:结合 全局-局部注意力机制 和 多模态对齐
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
class GlobalLocalMultimodalAttention(nn.Module):
def __init__(self, text_encoder_type="bert-base-uncased", embed_dim=512, num_heads=8, local_query_length=64):
super(GlobalLocalMultimodalAttention, self).__init__()
# 文本编码器 (BERT)
self.text_encoder = BertModel.from_pretrained(text_encoder_type)
self.text_projection = nn.Linear(self.text_encoder.config.hidden_size, embed_dim) # 映射到嵌入空间
# 图像编码器 (简单的卷积网络用于示例)
self.image_encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, embed_dim, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
self.flatten = nn.Flatten(start_dim=2)
# 全局注意力
self.global_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
# 局部注意力
self.local_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
# 超参数
self.local_query_length = local_query_length
def forward(self, input_ids, attention_mask, images):
"""
Params:
- input_ids: 文本的 token ids (batch_size, seq_len)
- attention_mask: 文本的注意力掩码 (batch_size, seq_len)
- images: 图像输入 (batch_size, 3, height, width)
"""
batch_size, _, height, width = images.size()
# 1. 文本嵌入
text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
text_embedding = self.text_projection(text_output.last_hidden_state) # (batch_size, seq_len, embed_dim)
# 2. 图像嵌入
image_features = self.image_encoder(images) # (batch_size, embed_dim, height//2, width//2)
image_features = self.flatten(image_features) # (batch_size, embed_dim, num_patches)
image_embedding = image_features.permute(0, 2, 1) # (batch_size, num_patches, embed_dim)
# 全局注意力
text_global_embedding, _ = self.global_attention(
query=text_embedding.permute(1, 0, 2), # (seq_len, batch_size, embed_dim)
key=image_embedding.permute(1, 0, 2), # (num_patches, batch_size, embed_dim)
value=image_embedding.permute(1, 0, 2) # (num_patches, batch_size, embed_dim)
)
text_global_embedding = text_global_embedding.permute(1, 0, 2) # (batch_size, seq_len, embed_dim)
# 局部注意力
# 获取局部的查询和图像块
local_query = text_embedding[:, :self.local_query_length, :] # (batch_size, local_query_length, embed_dim)
local_image = image_embedding[:, :self.local_query_length, :] # (batch_size, local_query_length, embed_dim)
text_local_embedding, _ = self.local_attention(
query=local_query.permute(1, 0, 2), # (local_query_length, batch_size, embed_dim)
key=local_image.permute(1, 0, 2), # (local_query_length, batch_size, embed_dim)
value=local_image.permute(1, 0, 2) # (local_query_length, batch_size, embed_dim)
)
text_local_embedding = text_local_embedding.permute(1, 0, 2) # (batch_size, local_query_length, embed_dim)
# 全局和局部融合
fused_embedding = torch.cat([text_global_embedding, text_local_embedding], dim=1) # (batch_size, seq_len + local_query_length, embed_dim)
return fused_embedding
代码解析:
- 文本处理:使用预训练的 BERT 提取文本嵌入。【为了匹配图像嵌入的维度,使用 self.text_projection 将文本嵌入投影到相同的嵌入空间。】
- 图像处理:使用一个简单的卷积网络对图像进行特征提取。【将图像划分为 局部块(patches),并将其展平为二维表示。】
- 全局注意力:在整个文本和图像之间计算全局注意力。【text_embedding 作为查询,image_embedding 作为键和值。】
- 局部注意力:只选择文本和图像的部分片段(局部查询)参与注意力计算。【e.g.,local_query_length 决定了文本和图像中参与局部注意力计算的片段长度。】
- 全局和局部融合:将全局注意力和局部注意力的输出拼接在一起,形成最终的融合嵌入。
原代码更注重动态调整注意力范围(通过 query_embeds 和 local_visible_query_num 控制)。
上面的例子直接通过全局注意力和局部注意力的划分显式处理。
原文地址:https://blog.csdn.net/Romaga/article/details/145231753
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!