정올/KOI기출 문제
트리와 쿼리
juwanseo
2025. 5. 11. 14:07
#트리 입력 받기 (u, v 연결정보)
N = int(input())
arr = [[] for _ in range(N+1)]
for _ in range(N-1):
u, v = map(int, input().split())
arr[u].append(v)
arr[v].append(u)
#쿼리 입력 받기 (S 집합)
Q = int(input())
for _ in range(Q):
K, *S = list(map(int, input().split()))
not_in_S = [True] (N + 1)
for s in S:
not_in_S[s] = False
ans = 0
for i in range(K-1):
for j in range(i+1, K):
visited = not_in_S[:] #얕은 복사 : visited리스트를 여러번 초기화해줘야 함
# visited 리스트를 수정해도 not_in_S 리스트에 영향을 받지 않게 된다.
ans += dfs(S[i], S[j])
print(ans)
#dfs 함수 만들기 -> 부분문제 #2 10점 / N과 Q가 각각 50개 이하일 때에만 동작
#-> 방문을 하는데 너무 시간이 많이 걸림 -> 매 번 두 노드가 연결되어 있는지 확인해야되기 때문
#So, dfs로 방문 가능한 노드를 구한 다음에 해당 노드로 연결 가능한 개수를 찾는 것이 더 빠르다
def dfs(s, e):
#해당 노드 방문 -> visited -> True
#s와 e가 같을 경우 두 노드가 연결되어 있다 0> return 1
#아직 e에 도달하지 못했다면 -> visited 리스트 확인해서 방문을 계속하면서 -> 두 노드가 만날 때까지 탐색
#e노드에 도달하지 못하면 -> return 0
visited[s] = True
cnt = 0
for a in arr[s]:
if a == e:
return 1
if visited[a] == False:
cnt += dfs(a, e)
return cnt
#------------------------
#수학적 로직 추가하기 nCr
#반복문이 줄었음
#두 노드가 연결되어 있는지 확인하는 것이 아니라 dfs를 통해 노드의 갯수를 구해준다.
#구한 갯수를 수학적으로 계산해서 result를 구한다. -> 이후 dfs 함수 구현 ㄱㄱ
Q = int(input())
for _ in range(Q):
K, *S = list(map(int, input().split()))
visited = [True] * (N + 1)
for s in S:
visited[s] = False
result = 0
for s in S:
if visited[s] == False:
ans = dfs(s)
result += ans * (ans-1) // 2
print(result)
#------------------------
#bfs함수 구현
#노드들을 방문할때마다 cnt를 하나씩 늘려나가는 로직으로 변경
#위 dfs와 다른점 : 시작과 끝을 통해 노드가 연결되어 있는것이 아니라 노드를 만날 때마다 cnt를 늘려준다
#부분문제 #2 21점 N ≤ 2,500, Q ≤2,500
import sys
input = sys.stdin.readline
N = int(input())
tree = [0] * (N+1)
def find(n):
if parent[n] != n:
parent[n] = find(parent[n])
return parent[n]
def union(a, b):
pa = find(a)
pb = find(b)
if pa == pb:
return
cnt[pa] += cnt[pb]
parent[pb] = pa
def dfs(s):
visited[s] = True
rr = [s]
while rr:
node = rr.pop()
for a in arr[node]:
if visited[a]:
continue
tree[a] = node
visited[a] = True
rr.append(a)
arr = [[] for _ in range(N+1)]
for _ in range(N-1):
u, v = map(int, input().split())
arr[u].append(v)
arr[v].append(u)
visited = [False] * (N+1)
dfs(1)
Q = int(input())
not_in_S = [False] * (N+1)
cnt = [1] * (N+1)
parent = list(range(N+1))
for _ in range(Q):
K, *S = list(map(int, input().split()))
for s in S:
not_in_S[s] = True
for s in S:
if not_in_S[tree[s]]:
union(s, tree[s])
result = 0
for s in S:
if find(s) == s:
result += cnt[s] * (cnt[s] - 1) // 2
not_in_S[s] = False
parent[s] = s
cnt[s] = 1
print(result)