线段树的定义

线段树(segment tree)是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。 使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。

线段树的作用

有一列数,初始值全为0,每次可以进行以下三种操作中的一种: (1)给指定区间每个数加上一个值 (2)将指定区间所有数置成一个统一的值 (3)询问一个区间上的最小值、最大值、所有数的和

线段树的存储

struct Node
{
	int left,right;
	int min,max,sum;
	Node *leftchild,*rightchild;
}

线段树的构造

void build(Node *cur,int l,int r)
{
	cur->left=l;
	cur->right=r;
	if(l!=r)
	{
		cur->leftchild=new Node;
		cur->rightchild=new Node;
		build(cur->leftchild,l,(l+r)/2);
		build(cur->rightchild,(l+r)/2+1,r);
	}
	else
		cur->leftchild=cur->rightchild=NULL;
}

在线段树上对元素进行修改

将线段树上位置为x的元素修改成num

void insert(Node *cur,int x,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->left==cur->right)
		cur->min=cur->max=cur->sum=num;
	else
	{
		if(x<=(cur->left+cur->right)/2)
			insert(lc,x,num);
		if(x>(cur->left+cur->right)/2)
			insert(rc,x,num);
		cur->sum=lc->sum+rc->sum;
		cur->max=max(lc->max,rc->max);
		cur->min=min(lc->min,rc->min);
	}
}

对线段树的某一个区间进行询问

询问区间[l,r]所有元素和

int querysum(Node *cur,int l,int r)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	int ret=0;
	if((l<=cur->left)&&(cur->right<=r))
		ret+=cur->sum;
	else
	{
		if(l<=(cur->left+cur->right)/2)
			ret+=querysum(lc,l,r);
		if(r>(cur->left+cur->right)/2)
			ret+=querysum(rc,l,r);
	}
	return ret;
}

那么对于整个区间进行操作呢?

是不是需要对每一个元素都进行insert()操作呢? 显然不行,那样的话每一次操作都要$O(nlogn)$,失去了线段树的优势。

lazy-tag技术:

lazy-tag的思想就是在对区间进行更新的时候,在分解出来的区间上打上作用于整个子树的标记 当在对这个区间进行维护或者查询的时候便将这个标记进行分解,并将其传递到它的两个子树上。

对一个区间整体加上一个数

将区间[l,r]整体加上delta

void update(Node *cur,int l,int r,int delta)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if((l<=cur->left)&&(cur->right<=r))
		cur->delta+=delta;
	else
	{
		if(l<=(cur->left+cur->right)/2)
			update(lc,l,r,delta);
		if(r>(cur->left+cur->right)/2)
			update(rc,l,r,delta);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
}

使用lazy-tag技术对元素进行修改

将位置为x的元素修改为num

void insert(Node *cur,int x,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->left==cur->right)
	{
		cur->sum=num;
		cur->delta=0;
	}
	else
	{
		lc->delta+=cur->delta;
		rc->delta+=cur->delta;
		cur->delta=0;
		if(x<=(cur->left+cur->right)/2)
			insert(lc,x,num);
		if(x>(cur->left+cur->right)/2)
			insert(rc,x,num);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
}

使用lazy-tag技术对区间进行查询

询问区间[l,r]上所有元素的和

int querysum(Node *cur,int l,int r)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	int ret=0;
	if((l<=cur->left)&&(cur->right<=r))
		ret+=cur->sum+(cur->right-cur->left+1)*cur->delta;
	else
	{
		lc->delta+=cur->delta;
		rc->delta+=cur->delta;
		cur->delta=0;
		if(l<=(cur->left+cur->right)/2)
			ret+=querysum(lc,l,r);
		if(r>(cur->left+cur->right)/2)
			ret+=querysum(rc,l,r);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
	return ret;
}

将一个区间整体置为一个数

额外维护两个域en表示一个区间是否为统一的数,若en有效,则data域表示区间统一的数的数值。 同样,我们可以用lazy-tag来维护这种操作

求一个节点表示区间上所有元素的和

inline int clacsum(Node *cur)
{
	if(cur->en)
		return (cur->right-cur->left+1)*cur->data;
	return cur->sum;
}

修改一个节点的值

将位置为x的元素修改为num

void insert(Node *cur,int x,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->left==cur->right)
	{
		cur->sum=num;
		cur->en=false;
	}
	else
	{
		if(cur->en)
		{
			lc->data=rc->data=cur->data;
			lc->en=rc->en=true;
			cur->en=false;
		}
		if(x<=(cur->left+cur->right)/2)
			insert(lc,x,num);
		if(x>(cur->left+cur->right)/2)
			insert(rc,x,num);
		cur->sum=clacsum(lc)+clacsum(rc);
	}
}

修改整个区间的值

将区间[l,r]上所有元素的值置为num

void update(Node *cur,int l,int r,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->en)
	{
		if(lc!=NULL)
		{
			lc->data=num;
			lc->en=true;
		}
		if(rc!=NULL)
		{
			rc->data=num;
			rc->en=true;
		}
	}
	if((l<=cur->left)&&(cur->right<=r))
	{
		cur->en=true;
		cur->data=num;
	}
	else
	{
		if(l<=(cur->left+cur->right)/2)
			update(lc,l,r,num);
		if(r>(cur->left+cur->right))
			update(rc,l,r,num);
		cur->sum=calcsum(lc)+calcsum(rc);
	}
}

查询区间的和

询问区间[l,r]上所有元素的和

int querysum(Node *cur,int l,int r)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	int ret=0;
	if((l<=cur->left)&&(cur->right<=r))\
		ret+=clacsum(cur);
	else
	{
		if(cur->en)
		{
			lc->data=cur->data;
			lc->en=true;
			rc->data=cur->data;
			rc->en=true;
			cur->en=false;
		}
		if(l<=(cur->left+cur->right)/2)
			ret+=querysum(lc,l,r);
		if(r>(cur->left+cur->right)/2)
			ret+=querysum(rc,l,r);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
	return ret;
}

总结

到这里就差不多写完了,把所有操作的代码汇个总:

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;

struct Node
{
	int left,right;
	int min,max,sum,delta;
	bool en;
	Node *leftchild,*rightchild;
};

void build(Node *cur,int l,int r)
{
	cur->left=l;
	cur->right=r;
	if(l!=r)
	{
		cur->leftchild=new Node;
		cur->rightchild=new Node;
		build(cur->leftchild,l,(l+r)/2);
		build(cur->rightchild,(l+r)/2+1,r);
	}
	else
		cur->leftchild=cur->rightchild=NULL;
}

void insert(Node *cur,int x,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->left==cur->right)
		cur->min=cur->max=cur->sum=num;
	else
	{
		if(x<=(cur->left+cur->right)/2)
			insert(lc,x,num);
		if(x>(cur->left+cur->right)/2)
			insert(rc,x,num);
		cur->sum=lc->sum+rc->sum;
		cur->max=max(lc->max,rc->max);
		cur->min=min(lc->min,rc->min);
	}
}

int querysum(Node *cur,int l,int r)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	int ret=0;
	if((l<=cur->left)&&(cur->right<=r))
		ret+=cur->sum;
	else
	{
		if(l<=(cur->left+cur->right)/2)
			ret+=querysum(lc,l,r);
		if(r>(cur->left+cur->right)/2)
			ret+=querysum(rc,l,r);
	}
	return ret;
}

//range
void update(Node *cur,int l,int r,int delta)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if((l<=cur->left)&&(cur->right<=r))
		cur->delta+=delta;
	else
	{
		if(l<=(cur->left+cur->right)/2)
			update(lc,l,r,delta);
		if(r>(cur->left+cur->right)/2)
			update(rc,l,r,delta);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
}

void insert(Node *cur,int x,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->left==cur->right)
	{
		cur->sum=num;
		cur->delta=0;
	}
	else
	{
		lc->delta+=cur->delta;
		rc->delta+=cur->delta;
		cur->delta=0;
		if(x<=(cur->left+cur->right)/2)
			insert(lc,x,num);
		if(x>(cur->left+cur->right)/2)
			insert(rc,x,num);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
}

int querysum(Node *cur,int l,int r)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	int ret=0;
	if((l<=cur->left)&&(cur->right<=r))
		ret+=cur->sum+(cur->right-cur->left+1)*cur->delta;
	else
	{
		lc->delta+=cur->delta;
		rc->delta+=cur->delta;
		cur->delta=0;
		if(l<=(cur->left+cur->right)/2)
			ret+=querysum(lc,l,r);
		if(r>(cur->left+cur->right)/2)
			ret+=querysum(rc,l,r);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
	return ret;
}

//all range
inline int clacsum(Node *cur)
{
	if(cur->en)
		return (cur->right-cur->left+1)*cur->data;
	return cur->sum;
}

void insert(Node *cur,int x,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->left==cur->right)
	{
		cur->sum=num;
		cur->en=false;
	}
	else
	{
		if(cur->en)
		{
			lc->data=rc->data=cur->data;
			lc->en=rc->en=true;
			cur->en=false;
		}
		if(x<=(cur->left+cur->right)/2)
			insert(lc,x,num);
		if(x>(cur->left+cur->right)/2)
			insert(rc,x,num);
		cur->sum=clacsum(lc)+clacsum(rc);
	}
}

void update(Node *cur,int l,int r,int num)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	if(cur->en)
	{
		if(lc!=NULL)
		{
			lc->data=num;
			lc->en=true;
		}
		if(rc!=NULL)
		{
			rc->data=num;
			rc->en=true;
		}
	}
	if((l<=cur->left)&&(cur->right<=r))
	{
		cur->en=true;
		cur->data=num;
	}
	else
	{
		if(l<=(cur->left+cur->right)/2)
			update(lc,l,r,num);
		if(r>(cur->left+cur->right))
			update(rc,l,r,num);
		cur->sum=calcsum(lc)+calcsum(rc);
	}
}

int querysum(Node *cur,int l,int r)
{
	Node *lc=cur->leftchild,*rc=cur->rightchild;
	int ret=0;
	if((l<=cur->left)&&(cur->right<=r))\
		ret+=clacsum(cur);
	else
	{
		if(cur->en)
		{
			lc->data=cur->data;
			lc->en=true;
			rc->data=cur->data;
			rc->en=true;
			cur->en=false;
		}
		if(l<=(cur->left+cur->right)/2)
			ret+=querysum(lc,l,r);
		if(r>(cur->left+cur->right)/2)
			ret+=querysum(rc,l,r);
		cur->sum=lc->sum+lc->delta*(lc->right-lc->left+1);
		cur->sum+=rc->sum+rc->delta*(rc->right-rc->left+1);
	}
	return ret;
}