정올/KOI기출 문제

트리와 쿼리

juwanseo 2025. 5. 11. 14:07
#트리 입력 받기 (u, v 연결정보)
N = int(input())

arr = [[] for _ in range(N+1)]

for _ in range(N-1):
    u, v = map(int, input().split())
    arr[u].append(v)
    arr[v].append(u)

#쿼리 입력 받기 (S 집합)
Q = int(input())
for _ in range(Q):
    K, *S = list(map(int, input().split()))

    not_in_S = [True] (N + 1)
    for s in S:
        not_in_S[s] = False

    ans = 0

    for i in range(K-1):
        for j in range(i+1, K):
            visited = not_in_S[:] #얕은 복사 : visited리스트를 여러번 초기화해줘야 함
            # visited 리스트를 수정해도 not_in_S 리스트에 영향을 받지 않게 된다.
            ans += dfs(S[i], S[j])
    print(ans)

#dfs 함수 만들기 -> 부분문제 #2 10점 / N과 Q가 각각 50개 이하일 때에만 동작
#-> 방문을 하는데 너무 시간이 많이 걸림 -> 매 번 두 노드가 연결되어 있는지 확인해야되기 때문
#So, dfs로 방문 가능한 노드를 구한 다음에 해당 노드로 연결 가능한 개수를 찾는 것이 더 빠르다
def dfs(s, e):
    #해당 노드 방문 -> visited -> True
    #s와 e가 같을 경우 두 노드가 연결되어 있다 0> return 1
    #아직 e에 도달하지 못했다면 -> visited 리스트 확인해서 방문을 계속하면서 -> 두 노드가 만날 때까지 탐색
    #e노드에 도달하지 못하면 -> return 0
    visited[s] = True
    cnt = 0
    for a in arr[s]:
        if a == e:
            return 1
        if visited[a] == False:
            cnt += dfs(a, e)
    return cnt

#------------------------
#수학적 로직 추가하기 nCr
#반복문이 줄었음
#두 노드가 연결되어 있는지 확인하는 것이 아니라 dfs를 통해 노드의 갯수를 구해준다.
#구한 갯수를 수학적으로 계산해서 result를 구한다. -> 이후 dfs 함수 구현 ㄱㄱ
Q = int(input())
for _ in range(Q):
    K, *S = list(map(int, input().split()))

    visited = [True] * (N + 1)
    for s in S:
        visited[s] = False
    
    result = 0

    for s in S:
        if visited[s] == False:
            ans = dfs(s)
            result += ans * (ans-1) // 2
        print(result)

#------------------------
#bfs함수 구현
#노드들을 방문할때마다 cnt를 하나씩 늘려나가는 로직으로 변경
#위 dfs와 다른점 : 시작과 끝을 통해 노드가 연결되어 있는것이 아니라 노드를 만날 때마다 cnt를 늘려준다
#부분문제 #2 21점 N ≤​ 2,500, Q ≤​2,500
import sys
input = sys.stdin.readline

N = int(input())

tree = [0] * (N+1)

def find(n):
    if parent[n] != n:
        parent[n] = find(parent[n])
        
    return parent[n]

def union(a, b):
    pa = find(a)
    pb = find(b)
    if pa == pb:
        return

    cnt[pa] += cnt[pb]
    parent[pb] = pa

def dfs(s):
    visited[s] = True
    rr = [s]
    
    while rr:
        node = rr.pop()
        for a in arr[node]:
            if visited[a]:
                continue
            tree[a] = node
            visited[a] = True
            rr.append(a)

arr = [[] for _ in range(N+1)]

for _ in range(N-1):
    u, v = map(int, input().split())
    arr[u].append(v)
    arr[v].append(u)

visited = [False] * (N+1)
dfs(1)

Q = int(input())

not_in_S = [False] * (N+1)
cnt = [1] * (N+1)
parent = list(range(N+1))

for _ in range(Q):
    K, *S = list(map(int, input().split()))    

    for s in S:
        not_in_S[s] = True
    
    for s in S:
        if not_in_S[tree[s]]:
            union(s, tree[s])
    
    result = 0
    for s in S:
        if find(s) == s:
            result += cnt[s] * (cnt[s] - 1) // 2

        not_in_S[s] = False
        parent[s] = s
        cnt[s] = 1
    
    print(result)