···11+import fileinput
22+from collections import deque, defaultdict
33+44+from utils import Point, N, S, E, W, DIRS
55+66+77+SLOPES = {
88+ '^': S,
99+ 'v': N,
1010+ '<': W,
1111+ '>': E,
1212+}
1313+1414+1515+def get_neighbours(graph, node, part_2=False):
1616+ """In part 2, we ignore slopes."""
1717+ if graph.get(node) == '#':
1818+ return
1919+2020+ if not part_2:
2121+ if graph.get(node) in SLOPES:
2222+ yield node + SLOPES[graph.get(node)]
2323+ return
2424+2525+ for d in DIRS:
2626+ np = node + d
2727+ if np not in graph:
2828+ continue
2929+3030+ neighbour = graph.get(np)
3131+3232+ if neighbour == '#':
3333+ continue
3434+3535+ if not part_2 and neighbour in SLOPES and d != SLOPES[neighbour]:
3636+ continue
3737+3838+ yield np
3939+4040+4141+def longest_path(graph, start, end, part_2=False):
4242+ horizon = [(start, 0, set())]
4343+ best = 0
4444+4545+ while horizon:
4646+ curr, dist, seen = horizon.pop()
4747+4848+ if curr == end:
4949+ best = max(best, dist)
5050+ continue
5151+5252+ if curr in seen:
5353+ continue
5454+5555+ for neighbour, weight in graph[curr]:
5656+ horizon.append((neighbour, dist + weight, seen | set([curr])))
5757+5858+ return best
5959+6060+6161+def compress_graph(graph, part_2=False):
6262+ # Sort all points in graph by their degree.
6363+ degrees = defaultdict(set)
6464+ for node in graph:
6565+ degrees[len(list(get_neighbours(graph, node, part_2)))].add(node)
6666+6767+ key_points = degrees[1] | degrees[3] | degrees[4]
6868+6969+7070+ # Find the distance from node to all other "key points" it can reach.
7171+ def bfs(start):
7272+ horizon = deque([(start, 0)])
7373+ seen = set()
7474+7575+ while horizon:
7676+ curr, dist = horizon.pop()
7777+7878+ if curr != start and curr in key_points:
7979+ yield curr, dist
8080+ continue
8181+8282+ if curr in seen:
8383+ continue
8484+8585+ seen.add(curr)
8686+8787+ for neighbour in get_neighbours(graph, curr, part_2):
8888+ horizon.appendleft((neighbour, dist + 1))
8989+9090+9191+ # Create the compressed weighted graph.
9292+ compressed = defaultdict(list)
9393+9494+ for node in key_points:
9595+ for neighbour, weight in bfs(node):
9696+ compressed[node].append((neighbour, weight))
9797+9898+ return compressed
9999+100100+101101+# Parse problem input.
102102+GRAPH = {}
103103+START = None
104104+END = None
105105+for y, line in enumerate(fileinput.input()):
106106+ for x, c in enumerate(line.strip()):
107107+ p = Point(x, y)
108108+ if c == '.':
109109+ if y == 0:
110110+ START = p
111111+112112+ END = p
113113+114114+ GRAPH[p] = c
115115+116116+print("Part 1:", longest_path(compress_graph(GRAPH), START, END))
117117+print("Part 2:", longest_path(compress_graph(GRAPH, part_2=True), START, END, part_2=True))
118118+