侧边栏壁纸
  • 累计撰写 7 篇文章
  • 累计创建 10 个标签
  • 累计收到 2 条评论

目 录CONTENT

文章目录

Rust线段树模板

Novice
2024-03-04 / 0 评论 / 0 点赞 / 61 阅读 / 10431 字

普通线段树

impl ArgTrait<i32> for i32 {}

impl NodeTrait<i32> for i32 {}

trait NodeTrait<N>: Copy + Default + std::ops::Add<Output=N> {}

trait ArgTrait<N>: Into<N> + Copy {}

#[derive(Debug)]
struct SegTree<N> {
    tree: Vec<N>,
    n: usize,
}

impl<N: NodeTrait<N>> SegTree<N> {
    fn from<T: ArgTrait<N>>(raw: &[T]) -> SegTree<N> {
        let n = raw.len();
        let mut tree = vec![N::default(); 4 * n];
        fn build<T: ArgTrait<N>, N: NodeTrait<N>>(raw: &[T], tree: &mut [N], s: usize, t: usize, p: usize) {
            if s == t {
                tree[p] = raw[s].into();
                return;
            }
            let mid = (s + t) >> 1;
            build(raw, tree, s, mid, p << 1);
            build(raw, tree, mid + 1, t, (p << 1) + 1);
            tree[p] = tree[p * 2] + tree[p * 2 + 1];
        }
        build(raw, &mut tree, 0, n - 1, 1);
        return SegTree { tree, n };
    }

    fn query(&self, left: usize, right: usize) -> N {
        // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
        fn query_node<N: NodeTrait<N>>(tree: &[N], l: usize, r: usize, s: usize, t: usize, p: usize) -> N {
            if l <= s && t <= r {
                return tree[p];
            }
            let m = (s + t) >> 1;
            return if l <= m && r > m {
                let left = query_node(tree, l, r, s, m, p * 2);
                let right = query_node(tree, l, r, m + 1, t, p * 2 + 1);
                return left + right;
            } else if l <= m {
                query_node(tree, l, r, s, m, p * 2)
            } else {
                query_node(tree, l, r, m + 1, t, p * 2 + 1)
            };
        }
        return query_node(&self.tree, left, right, 0, self.n - 1, 1);
    }

    fn update<T: ArgTrait<N>>(&mut self, index: usize, value: T) {
        fn update_node<N: NodeTrait<N>, T: ArgTrait<N>>(tree: &mut [N], s: usize, t: usize, p: usize, index: usize, value: T, ) {
            if s == t {
                tree[p] = value.into();
                 return;
            }
            let mid = (s + t) >> 1;
            if index <= mid {
                update_node(tree, s, mid, p << 1, index, value);
            } else {
                update_node(tree, mid + 1, t, (p << 1) + 1, index, value);
            }
            tree[p] = tree[p * 2] + tree[p * 2 + 1];
        }

        update_node(&mut self.tree, 0, self.n - 1, 1, index, value);
    }
}

懒更新线段树

struct SegmentTree {
    tree: Vec<i32>,
    lazy: Vec<i32>,
    n: usize,
}

impl SegmentTree {
    fn build(buf: &Vec<i32>) -> Self {
        let n = buf.len();
        let mut tree = vec![0; n * 4];
        let mut lazy = vec![0; n * 4];
        Self::inner_build(buf, &mut tree, 0, n - 1, 1);
        return SegmentTree { tree, lazy, n };
    }
    fn inner_build(buf: &Vec<i32>, tree: &mut Vec<i32>, s: usize, t: usize, p: usize) {
        // 对 [s,t] 区间建立线段树,当前根的编号为 p
        if s == t {
            tree[p] = buf[s];
            return;
        }
        let mid = (s + t) >> 1;
        // 递归对左右区间建树
        Self::inner_build(buf, tree, s, mid, 2 * p);
        Self::inner_build(buf, tree, mid + 1, t, 2 * p + 1);

        // 合并
        tree[p] = tree[p * 2] + tree[p * 2 + 1];
    }

    fn update(&mut self, l: usize, r: usize, val: i32) {
        Self::inner_update(&mut self.tree, &mut self.lazy, l, r, 0, self.n - 1, 1, val);
    }
    // [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间, p为当前节点的编号
    fn inner_update(tree: &mut Vec<i32>, lazy: &mut Vec<i32>, l: usize, r: usize, s: usize, t: usize, p: usize, val: i32) {
        // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改
        if l <= s && r >= t {
            // 更新当前节点的值
            tree[p] = (t + 1 - s) as i32 * val;
            lazy[p] += val;
            return;
        }
        let m = s + ((t - s) >> 1);
        // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
        if lazy[p] != 0 && s != t {
            // 更新左儿子的值
            tree[p * 2] += lazy[p] * (m + 1 - s) as i32;
            // 更新左儿子的lazy
            lazy[p * 2] += lazy[p];

            // 更新右儿子的值
            tree[p * 2 + 1] += lazy[p] * (t - m) as i32;
            // 更新右儿子的lazy
            lazy[p * 2 + 1] += lazy[p];

            // 清空当前节点的标记
            lazy[p] = 0;
        }
        if l <= m {
            Self::inner_update(tree, lazy, l, r, s, m, p * 2, val);
        }
        if r > m {
            Self::inner_update(tree, lazy, l, r, m + 1, t, p * 2 + 1, val);
        }
        tree[p] = tree[p * 2] + tree[p * 2 + 1];
    }

    fn query(&mut self, l: usize, r: usize) -> i32 {
        return Self::inner_query(&mut self.tree, &mut self.lazy, l, r, 0, self.n - 1, 1);
    }

    // [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
    fn inner_query(tree: &mut Vec<i32>, lazy: &mut Vec<i32>, l: usize, r: usize, s: usize, t: usize, p: usize) -> i32 {
        // 当前区间为询问区间的子集时直接返回当前区间的值
        if l <= s && r >= t {
            return tree[p];
        }
        let m = s + ((t - s) >> 1);
        // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
        if lazy[p] != 0 {
            tree[p * 2] += lazy[p] * (m + 1 - s) as i32;
            lazy[p * 2] += lazy[p];

            tree[p * 2 + 1] += lazy[p] * (t - m) as i32;
            lazy[p * 2 + 1] += lazy[p];
        }
        lazy[p] = 0;
        let mut ans = 0;
        if l <= m {
            ans += Self::inner_query(tree, lazy, l, r, s, m, p * 2);
        }
        if r > m {
            ans += Self::inner_query(tree, lazy, l, r, m + 1, t, p * 2 + 1);
        }
        return ans;
    }
}

动态开点线段树

pub struct SegmentTree {
    l: i32,
    r: i32,
    value: i32,
    left: Option<Box<SegmentTree>>,
    right: Option<Box<SegmentTree>>,
}

impl SegmentTree {
    fn value_default() -> i32 {
        return 0;
    }
    fn accept(&mut self, val: i32) {
        self.value += val;
    }
    fn merge(left: Option<i32>, right: Option<i32>) -> i32 {
        return left.unwrap_or(Self::value_default()) + right.unwrap_or(Self::value_default());
    }
    pub fn new(l: i32, r: i32) -> Self {
        return SegmentTree { l, r, value: Self::value_default(), left: None, right: None };
    }
    pub fn update(&mut self, i: i32, val: i32) {
        if self.l == self.r {
            self.accept(val);
            return;
        }
        let mid = self.l + (self.r - self.l) / 2;
        if i <= mid {
            self.left.get_or_insert(Box::new(SegmentTree::new(self.l, mid))).update(i, val);
        } else {
            self.right.get_or_insert(Box::new(SegmentTree::new(mid + 1, self.r))).update(i, val);
        }
        self.value = Self::merge(self.left.as_ref().map(|x| x.value), self.right.as_ref().map(|x| x.value));
    }
    pub fn query(&self, l: i32, r: i32) -> i32 {
        if l <= self.l && r >= self.r {
            return self.value;
        }
        let mid = self.l + (self.r - self.l) / 2;
        let left = (l <= mid).then(|| self.left.as_ref().map_or(0, |left| left.query(l, r)));
        let right = (r > mid).then(|| self.right.as_ref().map_or(0, |right| right.query(l, r)));
        return Self::merge(left, right);
    }
}

懒更新动态开点线段树

pub struct SegmentTree {
    l: i32,
    r: i32,
    value: i32,
    lazy: Option<i32>,
    left: Option<Box<SegmentTree>>,
    right: Option<Box<SegmentTree>>,
}

impl SegmentTree {
    fn value_default() -> i32 {
        return 0;
    }
    fn merge(left: Option<i32>, right: Option<i32>) -> i32 {
        return std::cmp::max(left.unwrap_or(Self::value_default()), right.unwrap_or(Self::value_default()));
    }
    fn accept(&mut self, val: i32) {
        self.lazy = Some(val);
        self.value = val;
    }

    pub fn new(l: i32, r: i32) -> Self {
        return SegmentTree { l, r, value: Self::value_default(), lazy: None, left: None, right: None };
    }
    pub fn update(&mut self, l: i32, r: i32, val: i32) {
        if self.l >= l && self.r <= r {
            self.accept(val);
            return;
        }
        self.push_down();
        let mid = self.l + (self.r - self.l) / 2;
        (l <= mid).then(|| self.left.get_or_insert(Box::new(SegmentTree::new(self.l, mid))).update(l, r, val));
        (r > mid).then(|| self.right.get_or_insert(Box::new(SegmentTree::new(mid + 1, self.r))).update(l, r, val));
        self.value = Self::merge(self.left.as_ref().map(|x| x.value), self.right.as_ref().map(|x| x.value));
    }

    pub fn query(&mut self, l: i32, r: i32) -> i32 {
        if self.l >= l && self.r <= r {
            return self.value;
        }
        self.push_down();
        let mid = self.l + (self.r - self.l) / 2;
        let left = (l <= mid).then(|| self.left.as_mut().map_or(Self::value_default(), |left| left.query(l, r)));
        let right = (r > mid).then(|| self.right.as_mut().map_or(Self::value_default(), |right| right.query(l, r)));
        return Self::merge(left, right);
    }
    fn push_down(&mut self) {
        if let Some(val) = self.lazy.take() {
            let mid = self.l + (self.r - self.l) / 2;
            let left = self.left.get_or_insert(Box::new(SegmentTree::new(self.l, mid)));
            left.accept(val);
            let right = self.right.get_or_insert(Box::new(SegmentTree::new(mid + 1, self.r)));
            right.accept(val);
        }
    }
}

0

评论区