蒋蒋的学习笔记

Qwen2.5VL-Vision Encode-2D_RoPE

位置编码参考视频

1D_RoPE参考视频

1D_RoPE参考博客

2D_RoPE参考博客

简单复习位置编码

简单复习RoPE

rope

RoPE扩展-2D

rust

pub fn rot_pos_emb(&self, grid_thw: &Tensor) -> Result<Tensor> {
        let mut pos_ids = Vec::new();
        for i in 0..grid_thw.dim(0)? {
            let [t, h, w] = grid_thw.i(i)?.to_vec1::<u32>()?[..] else {
                return Err(Error::Msg(format!("grid_thw Expected exactly 3 elements")));
            };
            // hpos_ids shape (h, w)
            let hpos_ids = Tensor::arange(0, h, grid_thw.device())?
                .unsqueeze(1)?
                .expand((h as usize, w as usize))?;
            let hpos_ids = hpos_ids.reshape((
                h as usize / self.spatial_merge_size,
                self.spatial_merge_size,
                w as usize / self.spatial_merge_size,
                self.spatial_merge_size,
            ))?;
            let hpos_ids = hpos_ids.permute((0, 2, 1, 3))?.flatten_all()?;
            let wpos_ids = Tensor::arange(0, w, grid_thw.device())?
                .unsqueeze(0)?
                .expand((h as usize, w as usize))?;
            let wpos_ids = wpos_ids.reshape((
                h as usize / self.spatial_merge_size,
                self.spatial_merge_size,
                w as usize / self.spatial_merge_size,
                self.spatial_merge_size,
            ))?;
            let wpos_ids = wpos_ids.permute((0, 2, 1, 3))?.flatten_all()?;
            // thw_pos_ids shape (h*w, 2)
            let thw_pos_ids =
                Tensor::stack(&[&hpos_ids, &wpos_ids], D::Minus1)?.repeat((t as usize, 1))?;
            pos_ids.push(thw_pos_ids);
        }
        let pos_ids = Tensor::cat(&pos_ids, 0)?.contiguous()?;
        let max_grid_size = grid_thw.i((.., 1..))?.max_all()?.to_scalar::<u32>()?;
        let rotary_pos_emb_full = self
            .rotary_pos_emb
            .forward(max_grid_size as usize, grid_thw.device())?;

        // contiguous()一定要加!!!很重要!!!!,不然index_select出来的是错的
        // 找错找了半天,都是泪啊,做维度索引操作后contiguous顺手写上总没错
        // 第一列是h维度的索引
        let pos_ids_h = pos_ids.i((.., 0))?.contiguous()?;
        // 第二列是w维度的索引
        let pos_ids_w = pos_ids.i((.., 1))?.contiguous()?;
        let rotary_pos_emb_h = rotary_pos_emb_full.index_select(&pos_ids_h, 0)?;
        let rotary_pos_emb_w = rotary_pos_emb_full.index_select(&pos_ids_w, 0)?;
        // 每个patch融合h索引和w索引两个的位置编码信息
        let rotary_pos_emb = Tensor::cat(&[rotary_pos_emb_h, rotary_pos_emb_w], 1)?.contiguous()?;
        Ok(rotary_pos_emb)
    }