蒋蒋的学习笔记

Qwen2.5VL-Vision Encode-PatchEmbed

import torch
import torch.nn as nn

input = torch.randn((1024, 3, 2, 14, 14))
conv_3d = nn.Conv3d(3, 1280, (2, 14, 14), (2, 14, 14), bias=False)
print(f"conv_3d.weight.shape: {conv_3d.weight.shape}")
out1 = conv_3d.forward(input)
out1 = out1.squeeze(-1).squeeze(-1).squeeze(-1)
print(f"out1.shape: {out1.shape}")
print(f"out1: {out1}")


conv_3d_weight = conv_3d.weight
input = input.reshape((1024, -1))
conv_3d_weight = conv_3d_weight.reshape((1280, -1))
out2 = input @ conv_3d_weight.T
print(f"out2.shape: {out2.shape}")
print(f"out2: {out2}")

dif_sum = torch.sum(torch.abs(out1-out2)).item()
dif_max = torch.max(torch.abs(out1-out2)).item()
print(f"max dif: {dif_max}, sum dif: {dif_sum}")
conv_3d.weight.shape: torch.Size([1280, 3, 2, 14, 14])
out1.shape: torch.Size([1024, 1280])
out1: tensor([[-0.5432, -0.1391,  0.8107,  ..., -0.3918, -0.0495, -0.0942],
        [-0.3619, -0.1890, -0.1352,  ..., -0.3478, -0.1874,  0.8309],
        [ 0.2297,  0.2138, -0.6129,  ..., -0.0177, -0.2678,  0.0514],
        ...,
        [ 0.3529, -0.6613, -0.9834,  ..., -2.3169,  0.0557,  0.4819],
        [-0.3649, -1.4611,  0.3064,  ..., -0.3660,  0.8573, -0.6109],
        [ 0.3370,  0.4705, -0.3167,  ..., -0.8647,  0.0209, -0.3970]],
       grad_fn=<SqueezeBackward1>)
out2.shape: torch.Size([1024, 1280])
out2: tensor([[-0.5432, -0.1391,  0.8107,  ..., -0.3918, -0.0495, -0.0942],
        [-0.3619, -0.1890, -0.1352,  ..., -0.3478, -0.1874,  0.8309],
        [ 0.2297,  0.2138, -0.6129,  ..., -0.0177, -0.2678,  0.0514],
        ...,
        [ 0.3529, -0.6613, -0.9834,  ..., -2.3169,  0.0557,  0.4819],
        [-0.3649, -1.4611,  0.3064,  ..., -0.3660,  0.8573, -0.6109],
        [ 0.3370,  0.4705, -0.3167,  ..., -0.8647,  0.0209, -0.3970]],
       grad_fn=<MmBackward0>)
max dif: 4.0531158447265625e-06, sum dif: 0.3553368151187897

rust代码

pub struct Qwen2_5VisionPatchEmbed {
    conv3d_weight: Tensor,
}

impl Qwen2_5VisionPatchEmbed {
    pub fn new(cfg: &Qwen2_5VLVisionConfig, vb: VarBuilder) -> Result<Self> {
        let patch_size = cfg.patch_size;
        let temporal_patch_size = cfg.temporal_patch_size;
        let in_channels = cfg.in_channels;
        let embed_dim = cfg.hidden_size;
        // conv3d weight key: visual.patch_embed.proj.weight, value: Tensor[dims 1280, 3, 2, 14, 14; bf16, cuda:0]
        // (1280, 3, 2, 14, 14) -> (1280, 1176) -> (1176, 1280)
        let conv3d_weight = vb.get_with_hints(
            (
                embed_dim,
                in_channels,
                temporal_patch_size,
                patch_size,
                patch_size,
            ),
            "proj.weight",
            Init::Const(1.),
        )?.flatten(1, 4)?.t()?;
        Ok(Self {
            conv3d_weight,
        })
    }

    pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
        // hidden_states shape:  (grid_t*grid_h*grid_w, c*temporal_patch_size*patch_size*patch_size)
        // ((), 1176) matmul (1176, 1280) -> ((), 1280)
        let hidden_states = hidden_states.matmul(&self.conv3d_weight)?;
        Ok(hidden_states)
    }
}