Tuesday, February 4, 2020

USACO Wormsort / Union-Find Algorithm

wormsort3

USACO 2020 January Contest, Silver

Problem 3. Wormhole Sort

In [1]:
# SAMPLE INPUT:
#
# 4 4
# 3 2 1 4
# 1 2 9
# 1 3 7
# 2 3 10
# 2 4 3
In [2]:
fin = open ('wormsort.in', 'r')
fout = open ('wormsort.out', 'w')
N,M  = map(int, fin.readline().split())
#print(N,M)
cows = [int(x)-1 for x in fin.readline().split()]
print("cows =",cows)
edges = []
for i in range(M):  
    j,k,w = map(int, fin.readline().split())
    e = (w,j-1,k-1)
    edges.append(e)

edges.sort()
print("sorted edges =",edges)
cows = [2, 1, 0, 3]
sorted edges = [(3, 1, 3), (7, 0, 2), (9, 0, 1), (10, 1, 2)]
In [3]:
# cut-and-paste Union-Find code

PARENT = [i for i in range(N)]
def find(x):
    while x!=PARENT[x]:
        x=PARENT[x]
    return x

# replace calls to find() with this call if you want to do path compression
def find_with_path_compression(x):
    if x!=PARENT[x]:
        PARENT[x]=find_with_path_compression(PARENT[x])
    return PARENT[x]

def union(x,y):
    a = find_with_path_compression(x)
    b = find_with_path_compression(y)
    PARENT[a]=b
    return
In [4]:
# function to see if cows are sortable for the current 
# set of connected components defined in PARENT
# FAIL_INDEX is a shortcut so that the search for an unsortable cow
# begins with the index of an unsortable cow found earlier
FAIL_INDEX = 0
def cows_sortable():
    global FAIL_INDEX   
    for i in range(FAIL_INDEX,N):
        if find(i)!=find(cows[i]):
            # if there is an assignment statement in the function for 
            # a global variable, you must declare it global
            FAIL_INDEX = i   
            return False
    return True
In [5]:
last_weight = -1
while len(edges)>0:
    w,a,b = edges.pop()
    if last_weight == w:
        union(a,b)
    else:
        if cows_sortable():
            break
        else:
            union(a,b)
            last_weight = w

print(last_weight)
9