자료구조와 알고리즘
250224 세그먼트 트리(Segment Tree)
juwanseo
2025. 2. 25. 10:45
🔹 세그먼트 트리의 특징
- 구간(Query) 연산을 빠르게 수행
- 구간 합, 최댓값, 최솟값 등의 연산을 **O(log N)**의 시간 복잡도로 해결할 수 있음.
- 배열의 특정 원소를 빠르게 수정
- 배열의 특정 값을 변경할 때도 **O(log N)**의 시간 복잡도로 갱신 가능.
- 메모리 사용량
- 일반적으로 4배 크기의 배열을 사용하여 트리를 구현함.
🔹 세그먼트 트리의 동작 방식
세그먼트 트리는 일반적으로 완전 이진 트리(Complete Binary Tree) 형태로 구현됩니다.
즉, 부모 노드는 두 자식 노드의 정보를 이용해 값을 저장합니다.
- 트리 구성 (Build, O(N))
- 원래 배열을 기반으로 트리를 구축하는 과정.
- 각 노드는 특정 구간을 담당하며, 부모 노드는 두 자식 노드의 값을 이용해 자신의 값을 계산.
- 쿼리 처리 (Query, O(log N))
- 특정 구간의 합이나 최댓값, 최솟값을 구할 때 사용.
- 필요한 구간만 탐색하여 연산을 수행하므로 효율적.
- 값 업데이트 (Update, O(log N))
- 배열의 특정 값을 변경할 때, 관련된 모든 부모 노드도 갱신해야 함.
- 따라서 트리의 높이(log N) 만큼만 연산하면 됨.
🔹 세그먼트 트리 예제
예를 들어, 배열 [1, 3, 5, 7, 9, 11] 이 주어졌을 때,
이 배열의 구간 합을 빠르게 구하는 세그먼트 트리를 구축할 수 있습니다.
📌 구간 [1, 4]의 합을 구하는 경우
- 일반적인 배열 탐색: O(N)
- 세그먼트 트리 사용: O(log N)
🔹 세그먼트 트리 vs 펜윅 트리(Fenwick Tree)
- 세그먼트 트리
- 구간 연산과 특정 값 업데이트 모두 가능
- 하지만 메모리 사용량이 많고 구현이 복잡함
- 펜윅 트리(Fenwick Tree, BIT)
- 구간 합과 업데이트가 가능하지만, 최댓값/최솟값 연산에는 부적합
- 메모리 사용량이 적고 구현이 간단
✅ 세그먼트 트리는 다양한 연산이 가능하므로 더 범용적이지만, 펜윅 트리는 단순한 문제(구간 합)에서는 더 효율적!
🔹 정리
- 세그먼트 트리는 구간 연산을 효율적으로 수행하는 트리 구조의 자료구조
- **O(log N)**의 시간 복잡도로 구간 연산(Query) 및 업데이트(Update) 가능
- 최댓값, 최솟값, 구간 합 등 다양한 연산이 가능
- 구현은 어렵지만, 빠른 연산 속도가 필요할 때 유용
🔹 사용 예시
✔ 누적 합 문제
✔ 최댓값/최솟값 구하기
✔ 게임 랭킹 시스템 (예: 특정 점수 구간의 플레이어 수 계산)
✔ 범위 내 특정 조건 만족 개수 찾기
class SegmentTree:
def __init__(self, arr):
"""세그먼트 트리 초기화"""
self.n = len(arr)
self.tree = [0] * (4 * self.n) # 충분한 크기의 트리 배열 생성
self.build(arr, 0, 0, self.n - 1)
def build(self, arr, node, start, end):
"""트리 초기 구축 (O(N))"""
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
self.build(arr, left_child, start, mid)
self.build(arr, right_child, mid + 1, end)
self.tree[node] = self.tree[left_child] + self.tree[right_child] # 부모는 자식 합
def query(self, node, start, end, left, right):
"""구간 합 쿼리 (O(log N))"""
if right < start or end < left: # 범위 밖
return 0
if left <= start and end <= right: # 범위 안
return self.tree[node]
mid = (start + end) // 2
left_sum = self.query(2 * node + 1, start, mid, left, right)
right_sum = self.query(2 * node + 2, mid + 1, end, left, right)
return left_sum + right_sum
def update(self, node, start, end, index, value):
"""특정 원소 값 변경 (O(log N))"""
if start == end:
self.tree[node] = value
else:
mid = (start + end) // 2
if index <= mid:
self.update(2 * node + 1, start, mid, index, value)
else:
self.update(2 * node + 2, mid + 1, end, index, value)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2] # 부모 값 갱신
def range_sum(self, left, right):
"""사용자가 호출하는 구간 합 함수"""
return self.query(0, 0, self.n - 1, left, right)
def point_update(self, index, value):
"""사용자가 호출하는 업데이트 함수"""
self.update(0, 0, self.n - 1, index, value)
구현코드