1.背景:
在DETR中backbone中,resnet50 的构建继承了backbonebase的类,backbonebase的前向过程如下,这里引入了NestedTensor类。
# 前向中输入的是NestedTensor这个类的实例,实质就是将图像张量与对应的mask封装到一起。 def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) out: Dict[str, NestedTensor] = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None # 将mask插值到与输出特征图一致 mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out
NestedTensor,包括tensor和mask两个成员,tensor就是输入的图像。mask跟tensor同高宽但是单通道。
DETR把resnet作为backbone套到了另一个子网络里,这个子网络主要是把tensor list送进resnet网络,然后逐个提取出来其中的节点(也就是里面的Tensor),把每个节点的“mask”提出来做一次采样,然后再打包进自定义的“NestedTensor”中,按照“名称”:Tensor的方式存入输出的out。(这个NestedTensor一个Tensor里打包存了两个变量:x和mask)。
2. DETR网络下NestedTensor的前世今生示例:
2.1 输入
假如我们输入的是如下两张图片,也就说batch为2:
img1 = torch.rand(3, 200, 200), img2 = torch.rand(3, 200, 250)
x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)])
这里会转成nested_tensor, 为什么要转为nested_tensor呢?
这个nestd_tensor的类型简单说就是把{tensor, mask}打包在一起, tensor就是我们的图片的值,那么mask是什么呢?
当一个batch中的图片大小不一样的时候,我们要把它们处理的整齐,简单说就是把图片都padding成最大的尺寸,padding的方式就是补零,那么batch中的每一张图都有一个mask矩阵,所以mask大小为[2, 200,250], 在img有值的地方是1,补零的地方是0,tensor大小为[2,3,200,250]是经过padding后的。
2.2 提取特征
DETR 提取特征,是把NestedTensor中的tensor, 也就是图片输入到特征提取器中。这里使用的是残差网络resnet-50,tensor经过backbone后的结果就是[2,2048,7,8],下面是残差网络最后一层的结构
(2): Bottleneck( (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): FrozenBatchNorm2d() (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): FrozenBatchNorm2d() (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): FrozenBatchNorm2d() (relu): ReLU(inplace=True)
另外,关于NestedTensor中的mask, mask采用的方式F.interpolate,最后得到的结果是[2,7,8],backboneBase的前向过程如下:
class BackboneBase(nn.Module): def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): super().__init__() for name, parameter in backbone.named_parameters(): if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: parameter.requires_grad_(False) if return_interm_layers: return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} else: return_layers = {'layer4': "0"} self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.num_channels = num_channels def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) out: Dict[str, NestedTensor] = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out
Featuremap的位置编码,position_embedding 的前向如下:
利用三角函数的方式获取position_embedding,输入是上面的NestedTensor={tensor,mask}, 输出最终pos的size为[1,2,256,7,8]
def forward(self, tensor_list: NestedTensor): #tensor_list的类型是NestedTensor,内部自动附加了mask,用于表示动态shape,是pytorch中tensor新特性 x = tensor_list.tensors mask = tensor_list.mask assert mask is not None not_mask = ~mask #因为图像是2d的,所以位置编码也分为x,y方向 # 1 1 1 1 .. 2 2 2 2... 3 3 3... y_embed = not_mask.cumsum(1, dtype=torch.float32) # 1 2 3 4 ... 1 2 3 4... x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale #num_pos_feats = 128 ## 0~127 self.num_pos_feats=128,因为前面输入向量是256,编码是一半sin,一半cos dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) ## 输出shape=b,h,w,128 pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # 每个特征图的xy位置都编码成256的向量,其中前128是y方向编码,而128是x方向编码 return pos ## b,n=256,h,w
backbone+ position_embedding 中的NestedTensor流程如下:最终输出为
NestedTensor{tensor,mask},和pos。
tensor=[ 2, 2048,7,8],mask=[2,7,8], pos=[1,2,256,7,8]
class Joiner(nn.Sequential): def __init__(self, backbone, position_embedding): super().__init__(backbone, position_embedding) def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): out.append(x) # position encoding pos.append(self[1](x).to(x.tensors.dtype)) return out, pos
以上是DETR中关于NestedTensor输入在resnet50 backbone以及position_embedding中的历程。
3.COCO格式数据集下的NestedTensor的来龙去脉(一):
我们找到输入数据封装为NesteTensor类型的最初:
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
在Dataloder中涉及两个重要的参数,Sample()和collate_fn()。
3.1 Dataloder:数据预处理DataLoader及各参数详解
pytorch关于数据处理的功能模块均在torch.utils.data 中,pytorch输入数据PipeLine一般遵循一个“三步走”的策略,操作顺序是这样的:
① 继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__返回数据集样本的数量,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。在实现自定义类时,一般需要对图像数据做增强处理,和标签处理,__getitem__返回图像和对应label,图像增强的方法可以使用pytorch自带的torchvision.transforms内模块,也可以使用自定义或者其他第三方增强库。
② 导入 DataLoader类,传入参数(上面自定义类的对象) 创建一个DataLoader对象。
③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练
dataset = MyDataset() # 第一步:构造Dataset对象 dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象 num_epoches = 100 for epoch in range(num_epoches):# 第三步:逐步迭代数据 for img, label in dataloader: # 训练代码
pytorch内部默认的数据处理类有如下:
class Dataset(object): class IterableDataset(Dataset): class TensorDataset(Dataset): # 封装成tensor的数据集,每一个样本都通过索引张量来获得。 class ConcatDataset(Dataset): # 连接不同的数据集以构成更大的新数据集 class Subset(Dataset): # 获取指定一个索引序列对应的子数据集 class ChainDataset(IterableDataset):
DataLoader类详解,数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代。
class DataLoader(object): Arguments: dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.三步走第一步创建的对象 batch_size (int, optional): 每一个batch加载多少个样本,即指定batch_size,默认是 1 shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False ------------------------------------------------------------------------------------ sampler (Sampler, optional): 自定义从数据集中抽取样本的策略,如果指定这个参数,那么shuffle必须为False batch_sampler (Sampler, optional): 此参数很少使用,与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥) ------------------------------------------------------------------------------------ num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0) collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数 pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.默认是False ------------------------------------------------------------------------------------ drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了,如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。 ------------------------------------------------------------------------------------
以上使用的是dataset, batch_size, shuffe, sampler, num_workers, collate_fn, pin_memory,这几个参数。
3.2 sampler参数
sampler参数其实就是一个“采样器”,表示从样本中究竟如何取样,pytorch采样器有如下几个
class Sampler(object): class SequentialSampler(Sampler):# 顺序采样样本,始终按照同一个顺序。 class RandomSampler(Sampler): # 无放回地随机采样样本元素。 class SubsetRandomSampler(Sampler): # 无放回地按照给定的索引列表采样样本元素 class WeightedRandomSampler(Sampler): # 按照给定的概率来采样样本。 class BatchSampler(Sampler): # 在一个batch中封装一个其他的采样器。 # torch.utils.data.distributed.DistributedSampler class DistributedSampler(Sampler): # 采样器可以约束数据加载进数据集的子集。
默认是采用的采样器如下:
if batch_sampler is None: # 没有手动传入batch_sampler参数时 if sampler is None: # 没有手动传入sampler参数时 if shuffle: sampler = RandomSampler(dataset) # 随机采样 else: sampler = SequentialSampler(dataset) # 顺序采样 batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True
3.3 collate_fn 参数
当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。
default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。
默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。
当我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。
例:目标检测时的自定义collate_fn(),给每个图像添加索引
def collate_fn(self, batch): paths, imgs, targets = list(zip(*batch)) # Remove empty placeholder targets # 有可能__getitem__返回的图像是None, 所以需要过滤掉 targets = [boxes for boxes in targets if boxes is not None] # Add sample index to targets # boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。 for i, boxes in enumerate(targets): boxes[:, 0] = i targets = torch.cat(targets, 0) # Selects new image size every tenth batch if self.multiscale and self.batch_count % 10 == 0: self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32)) # Resize images to input shape # 每个图像大小不同呢,所以resize到统一大小 imgs = torch.stack([resize(img, self.img_size) for img in imgs]) self.batch_count += 1 return paths, imgs, targets
其实也可结合使用默认的default_collate
from torch.utils.data.dataloader import default_collate # 导入这个函数 def collate_fn(batch): """ params: batch :是一个列表,列表的长度是 batch_size 列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y 大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)] returns: 整理之后的新的batch """ # 这一部分是对 batch 进行重新 “校对、整理”的代码 return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate 进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。
tip: 在使用pytorch时,当加载数据训练for i, batch in enumerate(train_loader):时,可能会出现TypeError: ‘NoneType’ object is not callable这个错误,若遇到更换pytorch版本即可.
4.COCO格式数据集下的NestedTensor的来龙去脉(二):
collate_fn 方法来重新组装一个batch的数据:
它的作用是将一个batch的数据重新组装为自定义的形式,输入参数batch就是原始的一个batch数据,通常在Pytorch中的Dataloader中,会将一个batch的数据组装为((data1, label1), (data2, label2), ...)这样的形式,于是第一行代码的作用就是将其变为[(data1, data2, data3, ...),(label1, label2, label3,...)]这样的形式,然后取出batch[0]即一个batch的图像输入到nested_tensor _from_tensor_list()方法中进行处理,最后将返回结果替代原始的这一个batch图像数据。
def collate_fn(batch): batch = list(zip(*batch)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch)
为了能够统一batch中所有图像的尺寸,以便形成一个batch,我们需要得到其中的最大尺度(在所有维度上),然后对尺度较小的图像进行填充(padding),同时设置mask以指示哪些部分是padding得来的,以便后续模型能够在有效区域内去学习目标,相当于加入了一部分先验知识。
nested_tensor_from_tensor_list(tensor_list: List[Tensor])实现如下:
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general #得到一个batch中所有图像张量每个维度的最大尺寸 if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) #指示图像中哪些位置是padding部分 mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) #原始图像中有效部分设为false,以区分padding m[: img.shape[1], :img.shape[2]] = False else: raise ValueError('not supported') return NestedTensor(tensor, mask)
如何得到batch中每张图像在每个维度上的最大值。
def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes
5. DETR网络结构一览:
The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
class DETR(nn.Module): """ This is the DETR module that performs object detection """ def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes num_queries: number of object queries, ie detection slot. This is the maximal number of objects DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self.num_queries = num_queries self.transformer = transformer hidden_dim = transformer.d_model self.class_embed = nn.Linear(hidden_dim, num_classes + 1) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) self.query_embed = nn.Embedding(num_queries, hidden_dim) self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) self.backbone = backbone self.aux_loss = aux_loss def forward(self, samples: NestedTensor): """ The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x (num_classes + 1)] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, height, width). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ if not isinstance(samples, NestedTensor): samples = nested_tensor_from_tensor_list(samples) features, pos = self.backbone(samples) # backbone是一个CNN用于特征提取 src, mask = features[-1].decompose() #?? assert mask is not None hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # 这里是吧features的其中一部分信息作为src传进Transformer,input_proj是一个卷积层,用来收缩输入的维度,把维度控制到d_model的尺寸(model dimension) outputs_class = self.class_embed(hs) # 为了把Transformer应用于目标检测问题上,作者引入了“类别嵌入网络”和“框嵌入网络” outputs_coord = self.bbox_embed(hs).sigmoid() # 在框嵌入后加入一层sigmoid输出框坐标(原论文中提到是四点坐标,但是要考虑到原图片的尺寸) out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} if self.aux_loss: out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) return out @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. return [{'pred_logits': a, 'pred_boxes': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
6.NestedTensor 在transformer中的摸爬滚打:
transformer encoder
接上第2节,输入NestdTensor 经backbone和positionembedding变为[tensor,mask,pos]后的故事:目前我们拥有src=[ 2, 2048,7,8],mask=[2,7,8], pos=[1,2,256,7,8]
hs = transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # transformer 的输出是元组,分别为Decoder 和Encoder 的输出,因此这里取第一个代表的是Decoder的输出
input_proj:一个卷积层,卷积核为1*1,将压缩通道的作用,将2048压缩到256,所以传入transformer的维度是压缩后的[2,256,7,8]。
self.input_proj = nn.Conv2d(num_channels, hidden_dim, kernel_size=1) # input_proj是将CNN提取的特征维度映射到Transformer隐层的维度, src, mask = features[-1].decompose() #取backbone最后一层featuremap, 然后将特征图映射为序列形式
看DETR的前向过程:
# 前向输入是一个NestedTensor类的对象 def forward(self, samples, postprocessors=None, targets=None, criterion=None): # 首先将样本转换为NestedTensor 对象 if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) #第一部分如下所示,先利用CNN提取特征,然后将特征图映射为序列形式,最后输入Transformer进行编、解码得到输出结果。 # ********************************************************************************* # 输入到cnn提取特征 features, pos = self.backbone(samples) #todo list num = self.args.layer1_num src, mask = features[num].decompose() # 然后将特征图映射为序列形式 assert mask is not None # transformer 的输出是元组,分别为Decoder 和Encoder 的输出,因此这里取第一个代表的是Decoder的输出[0] hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[num])[0] # 将query_embedding 的权重作为参数输入到Transformer的前向过程,使用时与position encoding的方式相同,直接相加。 # 第二部分对输出的维度进行转化,与分类和回归任务所要求的相对应 # 生成分类与回归的预测结果 outputs_class = self.class_embed(hs) outputs_coord = self.lines_embed(hs).sigmoid() # 由于hs包含了Transformer中Decoder每层的输出,因此索引为-1 代表去最后一层的输出 out = {'pred_logits': outputs_class[-1], 'pred_lines': outputs_coord[-1]} # 若指定要计算Decoder 每层预测输出对应的loss,则记录对应的输出结果 if self.aux_loss: out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) return out 接着看Transformer: class Transformer(nn.Module): def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=False): super().__init__() # encode # 单层 encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) encoder_norm = nn.LayerNorm(d_model) if normalize_before else None # 由6个单层组成整个encoder self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) #decode decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec)
在进行encoder之前先还有个处理:
bs, c, h, w = src.shape# 这个和我们上面说的一样[2,256,7,8] src = src.flatten(2).permute(2, 0, 1) # src转为[56,2,256] pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# pos_embed 转为[56,2,256] mask = mask.flatten(1) #mask 转为[2,56]
encoder的输入为:src, mask, pos_embed
q = k = self.with_pos_embed(src, pos)# pos + src src2 = self.self_attn(q, k, value=src, key_padding_mask=mask)[0] #做self_attention,这个不懂的需要补一下transfomer的知识 src = src + self.dropout1(src2)# 类似于残差网络的加法 src = self.norm1(src)# norm,这个不是batchnorm,很简单不在详述 src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))#两个ffn src = src + self.dropout2(src2)# 同上残差加法 src = self.norm2(src)# norm return src
单层的输出依然为src[56, 2, 256],第二个单层的输入依然是:src, mask, pos_embed。循环往复6次结束encoder,得到输出memory, memory的size依然为[56, 2, 256].
Decoder的输入:
tgt = torch.zeros_like(query_embed) hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
query_embed其实是一个varible,size=[100,2,256],由训练得到,结束后就固定下来了。
class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None): # tgt 是query embedding,shape是(num_queries,b,hidden_dim) # query_pos 是对应tgt的位置编码,shanpe是和tgt一致 # memory是Encode的输出,shape是(h×w,b,hidden_dim) # memory_key_padding_mask 对应encoder的src_key_padding_mask,也是EncoderLayer的key_padding_mask,shape是(b,h×w) # pos对应输入到Encoder的位置编码,这里代表memory的位置编码,shape和memory一致。 output = tgt intermediate = [] # intermediate = []中记录的是每一层输出后的归一化结果,而每一层的输入是前一层输出(没有归一化)的结果 for layer in self.layers: output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos) if self.return_intermediate: intermediate.append(self.norm(output)) # self.norm 是通过初始化时传进来的参数norm(默认none)设置的,那么self.norm就有可能是none,故以下对此作了判断。 if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0)