소개
문제 제시
어떠한 배열이 있다. 이 배열의 길이
세그먼트 트리
세그먼트 트리는 갱신이 많이 일어나는 배열의 어떤 구간에 대한 연산 결과를 빠르게 구할 수 있는 자료구조이다. 세그먼트 트리가 지원하는 연산의 종류는 다음과 같다.
- 배열의 임의의 어떤 구간에 대해 연산(합/곱, 최소/최댓값 등)한다.
- 배열의 임의의 인덱스의 값을 수정한다.
세그먼트 트리의 원리는 전체 구간을 분할하여 여러 구간들에 대한 연산 결과를 미리 트리 형태로 저장해 사용하는 것이다. 일종의 분할 정복 매커니즘을 사용하여 공간을 더 사용하는 대신 시간을 줄이는 것이다.
원리
길이가
현재 블로그 개편의 일환으로 다이어그램을 그리는 방식을 최적화하고 있습니다. 개편이 완료될 동안 다이어그램은 표시되지 않습니다.
먼저 우리가 전 범위
구간 쿼리
이렇게 트리를 구성하면 임의의 구간의 합을 특별한 방식으로 구할 수 있다. 이해를 돕기 위해
현재 블로그 개편의 일환으로 다이어그램을 그리는 방식을 최적화하고 있습니다. 개편이 완료될 동안 다이어그램은 표시되지 않습니다.
만약 1 ~ 5에 해당하는 구간의 합을 구하려면 어떤 정점을 조사하면 될까? 위 그림의 3개의 빨간색 정점들의 합을 조사하면 것이다. 빨간색 정점에 도달하기 위해서는 트리의 루트에서부터 정점이 나타내는 구간을 바탕으로 DFS를 수행하면 된다. 사실 어떠한 구간을 잡더라도 우리가 조사해야 할 정점의 개수는
원소 갱신
이제 배열에서 원소 하나를 갱신하는 상황을 살펴보자. 만약 3번 인덱스의 값을 수정한다면 3번 인덱스를 범위에 포함하는 정점들도 같이 갱신해주어야 할 것이다.
현재 블로그 개편의 일환으로 다이어그램을 그리는 방식을 최적화하고 있습니다. 개편이 완료될 동안 다이어그램은 표시되지 않습니다.
위 그림의 초록색 정점들이 수정 대상이다. 이들은 루트 정점부터 3번 인덱스를 표현하는 리프 정점까지의 경로와 같다. 이들을 갱신하는데는 트리의 높이만큼의 시간이 소요되므로 세그먼트 트리에서 임의의 인덱스의 값을 수정하는 연산의 시간복잡도는
구현
세그먼트 트리를 구현해보자. 이 글에서는 구간 합을 계산하는 세그먼트 트리를 구현할 것이며, 구간 곱이나 구간 최댓값과 같은 연산을 지원하는 세그먼트 트리 역시 얼마든지 응용하여 구현할 수 있다. 우리가 구현해야 할 것은 세 가지 함수이다.
- 트리 구축 생성자
- 구간 쿼리 함수
- 인덱스 갱신 함수
세 가지 기능 모두 재귀 함수를 이용한 DFS를 실행하는 함수이다.
트리 구축
먼저 주어진 배열을 가지고 초기 세그먼트 트리를 구축해야 한다. 이는 재귀 함수를 이용해 루트 정점부터 각 정점의 구간을 정해주면서 전개한다. 아래 코드에서는 해당 정점이 나타내는 구간 left
와 right
에 새로운 정점을 만들어 할당해주고 있다. 자식 정점의 생성자를 호출할 때 해당 자식 정점이 나타내는 범위를 인자로 주고 있다.
class Node:
def __init__(self, array: list[int], l: int, r: int):
if r - l > 1:
m = (l + r) // 2
self.left = Node(array, l, m)
self.right = Node(array, m, r)
self.value = self.left.value + self.right.value
else:
self.value = array[l]
구간 쿼리
구간 쿼리를 처리하는 함수도 마찬가지로 연산 결과를 구해야 하는 구간 정보와 함께 해당 정점이 나타내는 범위를 인자로 물려주면서 재귀한다. 연산의 대상이 되는 구간
구간이 구간 안에 포함되는 경우 구간이 구간에 포함되진 않지만 걸쳐있는 경우 구간과 구간이 겹치지 않는 경우
1의 경우 정점 자기 자신의 값을 돌려주면 되며, 2의 경우 재귀를 통해 자신의 구간의 연산 값을 구해 돌려주어야 하며, 3의 경우 자기 자신을 연산에 포함시키지 않으면 된다.
# class Node:
def get_sum(self, i: int, j: int, l: int, r: int) -> int:
if i <= l and r <= j:
return self.value
elif not (r <= i or j <= l):
m = (l + r) // 2
return self.left.get_sum(i, j, l, m) + self.right.get_sum(i, j, m, r)
else:
return 0
원소 갱신
특정 원소를 갱신하는 함수도 마찬가지로 인덱스 번호, 수정할 값과 함께 해당 정점이 나타내는 범위를 인자로 물려주면서 재귀한다. 양 쪽 자식 정점 중 인덱스 번호를 포함하고 있는 방향으로 재귀 호출을 해준다. 리프 정점에서 값을 교체한 후 재귀를 반환하면서 정점들의 값을 갱신해준다.
# class Node:
def modify(self, index: int, value: int, l: int, r: int):
if r - l > 1:
m = (l + r) // 2
if index < m:
self.left.modify(index, value, l, m)
else:
self.right.modify(index, value, m, r)
self.value = self.left.value + self.right.value
else:
self.value = value
전체 구현
아래 코드는 전체 코드이다. 위에서 구현한 Node
클래스를 외부 클래스로 한 번 더 감싼 것이다.
class SegmentTree:
class Node:
def __init__(self, array: list[int], l: int, r: int):
if r - l > 1:
m = (l + r) // 2
self.left = SegmentTree.Node(array, l, m)
self.right = SegmentTree.Node(array, m, r)
self.value = self.left.value + self.right.value
else:
self.value = array[l]
def get_sum(self, i: int, j: int, l: int, r: int) -> int:
if i <= l and r <= j:
return self.value
elif not (r <= i or j <= l):
m = (l + r) // 2
return self.left.get_sum(i, j, l, m) + self.right.get_sum(i, j, m, r)
else:
return 0
def modify(self, index: int, value: int, l: int, r: int):
if r - l > 1:
m = (l + r) // 2
if index < m:
self.left.modify(index, value, l, m)
else:
self.right.modify(index, value, m, r)
self.value = self.left.value + self.right.value
else:
self.value = value
# 생성자
def __init__(self, array: list[int]):
self.size = len(array)
self.root = SegmentTree.Node(array, 0, len(array))
# 구간 쿼리
def get_sum(self, i: int, j: int) -> int:
return self.root.get_sum(i, j, 0, self.size)
# 원소 갱신
def modify(self, index: int, value: int):
self.root.modify(index, value, 0, self.size)