蒋蒋的学习笔记

插值

什么是插值

一种通过已知的、离散的数据点来估算未知位置数据点的过程或方法。

应用

aha库中插值的使用

Deepseek-OCR

Paddle-OCR / Hunyuan-OCR

RMBG2.0

理解插值

问题

假设我们有一个一维张量包含10个数据点:

目标:通过插值得到5个数据点

坐标映射

插值的本质是回答一个问题:”输出的第i个数据,与输入数据的对应关系”

输出(5个点): □────□────□────□────□ 0 1 2 3 4 –坐标 0.5 1.5 2.5 3.5 4.5 –中心点 总长度 = 5.0

缩放比例 scale = 10/5 = 2.0 输出点在输入中的位置:i_in = (i_out + 0.5) * scale - 0.5

* 角点对齐:align_corners=true    
```text
把数据点看作"点",只关心数据中心点的位置
输入(10个点):
●────●────●────●────●────●────●────●────●────●
0    1    2    3    4    5    6    7    8    9
总长度 = 9.0

输出(5个点):
●────●────●────●────●
0    1    2    3    4
总长度 = 4.0

缩放比例 = 9/4 = 2.25
输出点在输入中的位置:i_in = i_out * scale
  1. 计算缩放因子
    fn compute_scale(input_size: usize, output_size: usize, align_corners: bool) -> f32 {
     if align_corners && output_size > 1 {
         (input_size - 1) as f32 / (output_size - 1) as f32
     } else {
         input_size as f32 / output_size as f32
     }
    }
    
  2. 计算输出点在输入中的位置
    pub fn compute_1d_coords(
     input_size: usize,
     output_size: usize,
     align_corner: Option<bool>,
    ) -> Result<Vec<f32>> {
     if input_size == 0 {
         return Err(anyhow!("input_size must be > 0"));
     }
     if output_size == 0 {
         return Err(anyhow!("output_size must be > 0"));
     }
     if input_size == 1 {
         return Ok(vec![0f32; output_size]);
     }
     let align_corners = align_corner.unwrap_or(false);
     let scale = compute_scale(input_size, output_size, align_corners);
     if align_corners {
         Ok((0..output_size).map(|i| i as f32 * scale).collect())
     } else {
         Ok((0..output_size)
             .map(|i| {
                 let coord = (i as f32 + 0.5) * scale - 0.5;
                 coord.clamp(0.0, (input_size - 1) as f32)
             })
             .collect())
     }
    }
    

最近邻插值

// 2d: (bs, c, h, w) -> (bs*c, h, w) for c in 0..dim0 { for i in 0..target_h { // 计算高度方向的最近邻坐标 let coord_h = if target_h == 1 { (orig_h - 1) as f32 / 2.0 } else { (i as f32 + 0.5) * (orig_h as f32 / target_h as f32) - 0.5 }; let nearest_h = coord_h.round() as usize; let clamped_h = nearest_h.clamp(0, orig_h - 1);

    for j in 0..target_w {
        // 计算宽度方向的最近邻坐标
        let coord_w = if target_w == 1 {
            (orig_w - 1) as f32 / 2.0
        } else {
            (j as f32 + 0.5) * (orig_w as f32 / target_w as f32) - 0.5
        };
        let nearest_w = coord_w.round() as usize;
        let clamped_w = nearest_w.clamp(0, orig_w - 1);

        output_data[c][i][j] = input_data[c][clamped_h][clamped_w];
    }
} } ``` ### 线性插值 * align_corner可选 * 考虑左右两个点,使用权重来组合两个点的值 * 权重通过距离计算 ```text 当align_corner=false时 输入10个点 □────□────□────□────□────□────□────□────□────□ 0    1    2    3    4    5    6    7    8    9 --坐标 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5  9.5 --中心点

输出(4个点): □────□────□────□ 0 1 2 3 –坐标 0.5 1.5 2.5 3.5 –中心点

* scale = 10 / 4 = 2.5
* 当 i_out=1时,i_in = (i_out + 0.5) * scale - 0.5 = 3.25
* 计算左右邻点:
    * left_index = floor(3.25) = 3
    * right_index = left_index + 1 = 4
* 计算权重
    * dis_to_left = i_in - left_index = 0.25
    * dis_to_right = right_index - i_in = 0.75
    * 左邻点权重: weight_left = 1.0 - dis_to_left = 0.75
    * 右邻点权重: weight_right = 1.0 - dis_to_right = 0.25
* 插值计算
    * num_out = left_num * weight_left + right_num * weight_right

```rust
// 1d:(bs, c, dim)
let coords = compute_1d_coords(orig_size, target_size, align_corner)?;
for b in 0..bs {
    for c in 0..channels {
        for (i, &coord) in coords.iter().enumerate() {
            let coord = coord.clamp(0.0, (orig_size - 1) as f32);
            let x0 = coord.floor() as usize;
            let x1 = (x0 + 1).min(orig_size - 1);
            let weight = coord - x0 as f32;
            let value0 = input_data[b][c][x0];
            let value1 = input_data[b][c][x1];

            output_data[b][c][i] = value0 * (1.0 - weight) + value1 * weight;
        }
    }
}

双线性插值

双三次插值

// 三次卷积函数2 fn cubic_convolution2(x: f64, a: f64) -> f64 { (((x - 5.0) * x + 8.0) * x - 4.0) * a }

fn get_cubic_coefficients(t: f64, a: f64) -> [f64; 4] { let coeff0 = cubic_convolution2(t + 1.0, a); let coeff1 = cubic_convolution1(t, a); let coeff2 = cubic_convolution1(1.0 - t, a); let coeff3 = cubic_convolution2(1.0 - t + 1.0, a);

[coeff0, coeff1, coeff2, coeff3] }

fn cubic_interp1d(x0: f32, x1: f32, x2: f32, x3: f32, t: f64, a: f64) -> f32 { let coeffs = get_cubic_coefficients(t, a); x0 * coeffs[0] as f32 + x1 * coeffs[1] as f32 + x2 * coeffs[2] as f32 + x3 * coeffs[3] as f32 } for c in 0..dim0 { for out_y in 0..target_height { let center_y = if align_corners { out_y as f32 * scale_h } else { (out_y as f32 + 0.5) * scale_h - 0.5 } .clamp(0.0, (input_height - 1) as f32); let in_y = center_y.floor() as isize; let t_y = center_y - in_y as f32; for out_x in 0..target_width { let center_x = if align_corners { out_x as f32 * scale_w } else { (out_x as f32 + 0.5) * scale_w - 0.5 } .clamp(0.0, (input_width - 1) as f32); let in_x: isize = center_x.floor() as isize; let t_x = center_x - in_x as f32; let mut coefficients = [0.0; 4]; for k in 0..4 { let row = (in_y - 1 + k as isize).clamp(0, input_height as isize - 1) as usize; let x_minus_1 = input_data[c][row][(in_x - 1).clamp(0, input_width as isize - 1) as usize]; let x_plus_0 = input_data[c][row][in_x.clamp(0, input_width as isize - 1) as usize]; let x_plus_1 = input_data[c][row][(in_x + 1).clamp(0, input_width as isize - 1) as usize]; let x_plus_2 = input_data[c][row][(in_x + 2).clamp(0, input_width as isize - 1) as usize];

            coefficients[k] =
                cubic_interp1d(x_minus_1, x_plus_0, x_plus_1, x_plus_2, t_x as f64, -0.75);
        }
        output_data[c][out_y][out_x] = cubic_interp1d(
            coefficients[0],
            coefficients[1],
            coefficients[2],
            coefficients[3],
            t_y as f64,
            -0.75,
        );
    }
} } ``` * 抗锯齿设置 antialias = true且下采样时,a = -0.5 * 根据scale, 扩大考虑范围 ```rust fn bicubic_filter(x: f32, a: f32) -> f32 {
let x = x.abs();
if x < 1.0 {
    ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0
} else if x < 2.0 {
    (((x - 5.0) * x + 8.0) * x - 4.0) * a
} else {
    0.0
} }

let scale = scale_h.max(scale_w); let support_size = if scale >= 1.0 { (2.0 * scale).ceil() } else { 2.0 }; for c in 0..dim0 { for out_y in 0..target_height { let center_y = (out_y as f32 + 0.5) * scale_h - 0.5; let start_y = (center_y - support_size).ceil() as isize; let end_y = (center_y + support_size).floor() as isize; for out_x in 0..target_width { let center_x = (out_x as f32 + 0.5) * scale_w - 0.5; let start_x = (center_x - support_size).ceil() as isize; let end_x = (center_x + support_size).floor() as isize; let mut sum = 0.0; let mut weight_sum = 0.0; for iy in start_y..end_y + 1 { for ix in start_x..end_x + 1 { if iy >= 0 && iy < input_height as isize && ix >= 0 && ix < input_width as isize { let dx = (ix as f32 - center_x).abs(); let dy = (iy as f32 - center_y).abs(); let wx = bicubic_filter(dx / scale_w.max(1.0), -0.5); let wy = bicubic_filter(dy / scale_h.max(1.0), -0.5); let weight = wx * wy; sum += input_data[c][iy as usize][ix as usize] * weight; weight_sum += weight; } } } if weight_sum > 0.0 { output_data[c][out_y][out_x] = sum / weight_sum; } else { let y = center_y.round().clamp(0.0, (input_height - 1) as f32) as usize; let x = center_x.round().clamp(0.0, (input_width - 1) as f32) as usize; output_data[c][out_y][out_x] = input_data[c][y][x]; } } } } ```

完整代码地址: https://github.com/jhqxxx/aha/blob/main/src/utils/interpolate.rs