回顾一下卷积计算
图片来源:飞桨深度学习基础篇卷积算子
import torch
import torch.nn as nn
input = torch.randn((1024, 3, 2, 14, 14))
conv_3d = nn.Conv3d(3, 1280, (2, 14, 14), (2, 14, 14), bias=False)
print(f"conv_3d.weight.shape: {conv_3d.weight.shape}")
out1 = conv_3d.forward(input)
out1 = out1.squeeze(-1).squeeze(-1).squeeze(-1)
print(f"out1.shape: {out1.shape}")
print(f"out1: {out1}")
conv_3d_weight = conv_3d.weight
input = input.reshape((1024, -1))
conv_3d_weight = conv_3d_weight.reshape((1280, -1))
out2 = input @ conv_3d_weight.T
print(f"out2.shape: {out2.shape}")
print(f"out2: {out2}")
dif_sum = torch.sum(torch.abs(out1-out2)).item()
dif_max = torch.max(torch.abs(out1-out2)).item()
print(f"max dif: {dif_max}, sum dif: {dif_sum}")
conv_3d.weight.shape: torch.Size([1280, 3, 2, 14, 14])
out1.shape: torch.Size([1024, 1280])
out1: tensor([[-0.5432, -0.1391, 0.8107, ..., -0.3918, -0.0495, -0.0942],
[-0.3619, -0.1890, -0.1352, ..., -0.3478, -0.1874, 0.8309],
[ 0.2297, 0.2138, -0.6129, ..., -0.0177, -0.2678, 0.0514],
...,
[ 0.3529, -0.6613, -0.9834, ..., -2.3169, 0.0557, 0.4819],
[-0.3649, -1.4611, 0.3064, ..., -0.3660, 0.8573, -0.6109],
[ 0.3370, 0.4705, -0.3167, ..., -0.8647, 0.0209, -0.3970]],
grad_fn=<SqueezeBackward1>)
out2.shape: torch.Size([1024, 1280])
out2: tensor([[-0.5432, -0.1391, 0.8107, ..., -0.3918, -0.0495, -0.0942],
[-0.3619, -0.1890, -0.1352, ..., -0.3478, -0.1874, 0.8309],
[ 0.2297, 0.2138, -0.6129, ..., -0.0177, -0.2678, 0.0514],
...,
[ 0.3529, -0.6613, -0.9834, ..., -2.3169, 0.0557, 0.4819],
[-0.3649, -1.4611, 0.3064, ..., -0.3660, 0.8573, -0.6109],
[ 0.3370, 0.4705, -0.3167, ..., -0.8647, 0.0209, -0.3970]],
grad_fn=<MmBackward0>)
max dif: 4.0531158447265625e-06, sum dif: 0.3553368151187897
pub struct Qwen2_5VisionPatchEmbed {
conv3d_weight: Tensor,
}
impl Qwen2_5VisionPatchEmbed {
pub fn new(cfg: &Qwen2_5VLVisionConfig, vb: VarBuilder) -> Result<Self> {
let patch_size = cfg.patch_size;
let temporal_patch_size = cfg.temporal_patch_size;
let in_channels = cfg.in_channels;
let embed_dim = cfg.hidden_size;
// conv3d weight key: visual.patch_embed.proj.weight, value: Tensor[dims 1280, 3, 2, 14, 14; bf16, cuda:0]
// (1280, 3, 2, 14, 14) -> (1280, 1176) -> (1176, 1280)
let conv3d_weight = vb.get_with_hints(
(
embed_dim,
in_channels,
temporal_patch_size,
patch_size,
patch_size,
),
"proj.weight",
Init::Const(1.),
)?.flatten(1, 4)?.t()?;
Ok(Self {
conv3d_weight,
})
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
// hidden_states shape: (grid_t*grid_h*grid_w, c*temporal_patch_size*patch_size*patch_size)
// ((), 1176) matmul (1176, 1280) -> ((), 1280)
let hidden_states = hidden_states.matmul(&self.conv3d_weight)?;
Ok(hidden_states)
}
}