Improve Day 23 perf

This commit is contained in:
Christian
2023-12-28 13:57:09 +01:00
parent ce80483f96
commit 9c4e8e8774

View File

@@ -1,5 +1,8 @@
use hashbrown::HashMap; use anyhow::{Context, Result};
use hashbrown::{HashMap, HashSet};
use itertools::iproduct;
use std::ops::{Index, IndexMut}; use std::ops::{Index, IndexMut};
use std::thread;
#[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash, Debug)]
struct Pos { struct Pos {
@@ -7,34 +10,12 @@ struct Pos {
y: usize, y: usize,
} }
#[rustfmt::skip]
impl Pos { impl Pos {
fn north(&self) -> Self { fn north(&self) -> Self { Pos { x: self.x - 1, y: self.y } }
Pos { fn south(&self) -> Self { Pos { x: self.x + 1, y: self.y } }
x: self.x - 1, fn west(&self) -> Self { Pos { x: self.x, y: self.y - 1 } }
y: self.y, fn east(&self) -> Self { Pos { x: self.x, y: self.y + 1 } }
}
}
fn south(&self) -> Self {
Pos {
x: self.x + 1,
y: self.y,
}
}
fn west(&self) -> Self {
Pos {
x: self.x,
y: self.y - 1,
}
}
fn east(&self) -> Self {
Pos {
x: self.x,
y: self.y + 1,
}
}
} }
struct Grid<T> { struct Grid<T> {
@@ -72,19 +53,31 @@ impl<T: Clone> Grid<T> {
.filter(|(x, y)| (x < &self.height) & (y < &self.width)) .filter(|(x, y)| (x < &self.height) & (y < &self.width))
.map(|(x, y)| Pos { x, y }) .map(|(x, y)| Pos { x, y })
} }
fn adj4p<'a>(
&'a self,
pos: &Pos,
pred: impl Fn(&Pos) -> bool + 'a,
) -> impl Iterator<Item = Pos> + 'a {
self.adj4(pos).filter(pred)
}
fn points(&self) -> impl Iterator<Item = Pos> {
iproduct!(0..self.height, 0..self.width).map(|(x, y)| Pos { x, y })
}
} }
impl<T> Index<Pos> for Grid<T> { impl<T> Index<Pos> for Grid<T> {
type Output = T; type Output = T;
fn index(&self, index: Pos) -> &Self::Output { fn index(&self, index: Pos) -> &Self::Output {
&self.buffer[index.x * &self.width + index.y] &self.buffer[index.x * self.width + index.y]
} }
} }
impl<T> IndexMut<Pos> for Grid<T> { impl<T> IndexMut<Pos> for Grid<T> {
fn index_mut(&mut self, index: Pos) -> &mut Self::Output { fn index_mut(&mut self, index: Pos) -> &mut Self::Output {
&mut self.buffer[index.x * &self.width + index.y] &mut self.buffer[index.x * self.width + index.y]
} }
} }
@@ -117,73 +110,46 @@ fn search(grid: &Grid<u8>, seen: &mut Grid<bool>, pos: Pos, steps: usize) -> usi
}; };
seen[pos] = false; seen[pos] = false;
return result; result
} }
fn search_p2(grid: &Grid<u8>, seen: &mut Grid<bool>, pos: Pos, steps: usize) -> usize { type Node = u8;
let destination = Pos { type Graph = Vec<[(Node, u32); 4]>;
fn constructs_graph(grid: &Grid<u8>) -> (Graph, Node, Node) {
let entry = Pos { x: 0, y: 1 };
let exit = Pos {
x: grid.height - 1, x: grid.height - 1,
y: grid.width - 2, y: grid.width - 2,
}; };
if pos == destination { let is_free = |pos: &Pos| grid[*pos] != b'#';
return steps; let is_node = |pos: &Pos| is_free(pos) && grid.adj4p(pos, is_free).count() > 2;
let mut nodes: HashSet<_> = grid.points().filter(is_node).collect();
nodes.insert(entry);
nodes.insert(exit);
let node_ids: HashMap<_, _> = nodes.iter().copied().zip(1..).collect();
let mut flat_graph = vec![[(0, 0), (0, 0), (0, 0), (0, 0)]; nodes.len() + 1];
for node in nodes.iter().copied() {
for (adj_idx, start) in grid.adj4p(&node, is_free).enumerate() {
let mut prev = node;
let mut pos = start;
let mut dist = 1;
while !nodes.contains(&pos) {
(prev, pos) = (pos, grid.adj4p(&pos, is_free).find(|&p| p != prev).unwrap());
dist += 1;
} }
flat_graph[node_ids[&node]][adj_idx] = (node_ids[&pos] as u8, dist as u32);
if seen[pos] {
return 0;
}
seen[pos] = true;
let result = match grid[pos] {
b'#' => 0,
_ => grid
.adj4(&pos)
.map(|pos| search(grid, seen, pos, steps + 1))
.fold(0, std::cmp::max),
};
seen[pos] = false;
return result;
}
type Node = u8;
type Graph = Vec<Vec<(Node, u32)>>;
fn find_longest_path(
graph: &Graph,
mut seen: u64,
source: Node,
dest: Node,
distance: usize,
) -> usize {
if source == dest {
return distance;
}
seen |= 1 << source;
let mut result = 0;
for (next, dist) in graph[source as usize].iter().copied() {
if seen & (1 << next) == 0 {
result = result.max(find_longest_path(
graph,
seen,
next,
dest,
distance + dist as usize,
));
} }
} }
return result; (flat_graph, node_ids[&entry] as u8, node_ids[&exit] as u8)
} }
fn find_longest_path_no_rec( fn find_longest_path(graph: &Graph, source: Node, dest: Node, forbidden: u64) -> u32 {
graph: &Graph, let mut stack = Vec::from([(source, forbidden, 0)]);
source: Node,
dest: Node
) -> usize {
let mut stack = Vec::from([(source, 0u64, 0)]);
stack.reserve(63); stack.reserve(63);
let mut result = 0; let mut result = 0;
while let Some((source, mut seen, distance)) = stack.pop() { while let Some((source, mut seen, distance)) = stack.pop() {
@@ -191,10 +157,13 @@ fn find_longest_path_no_rec(
result = result.max(distance); result = result.max(distance);
} else { } else {
seen |= 1 << source; seen |= 1 << source;
let v = unsafe { graph.get_unchecked(source as usize) }; for (next, dist) in graph[source as usize].iter().copied() {
for (next, dist) in v.iter().copied() { if next == 0 {
break;
}
if seen & (1 << next) == 0 { if seen & (1 << next) == 0 {
stack.push((next, seen, distance + dist as usize)); stack.push((next, seen, distance + dist));
} }
} }
} }
@@ -203,48 +172,34 @@ fn find_longest_path_no_rec(
result result
} }
fn construct_graph(grid: &Grid<u8>) -> (Graph, Node, Node) { fn main() -> Result<()> {
let mut seen = Grid::new(grid.height, grid.width, false); let filename = std::env::args()
let mut todo = Vec::new(); .nth(1)
.context("./day23 <path to puzzle input>")?;
let input = std::fs::read_to_string(filename)?;
let width = input.lines().next().unwrap().len();
let buffer: Vec<_> = input.lines().flat_map(|line| line.bytes()).collect();
let height = buffer.len() / width;
let grid = Grid::from_buffer(buffer, width);
let mut seen = Grid::new(height, width, false);
let part1 = search(&grid, &mut seen, Pos { x: 0, y: 1 }, 0);
let (graph, entry, exit) = constructs_graph(&grid);
// depth-bounded bfs tp determine starting locations
let collect_starts = |source: u8, depth: usize| {
let mut todo = Vec::from([(source, 1u64 << source, 0)]);
let mut next = Vec::new(); let mut next = Vec::new();
let mut graph = HashMap::new(); for _ in 0..depth {
for &(pos, seen, dist) in todo.iter() {
let destination = Pos { for &(next_pos, d) in graph[pos as usize].iter() {
x: grid.height - 1, if next_pos == 0 {
y: grid.width - 2, break;
};
let initial = Pos { x: 0, y: 1 };
todo.push((initial, initial, initial, 0));
while !todo.is_empty() {
for (pos, prev, mut last_node, mut dist) in todo.iter().copied() {
let is_node = 2 < grid.adj4(&pos).filter(|p| grid[*p] != b'#').count();
if is_node || pos == destination {
graph
.entry(last_node)
.or_insert(Vec::new())
.push((pos, dist));
graph
.entry(pos)
.or_insert(Vec::new())
.push((last_node, dist));
last_node = pos;
dist = 0;
if !seen[pos] {
seen[pos] = true;
for neigh in grid.adj4(&pos) {
if neigh != prev && grid[neigh] != b'#' {
next.push((neigh, pos, last_node, dist + 1));
}
}
}
} else {
for neigh in grid.adj4(&pos) {
if neigh != prev && grid[neigh] != b'#' {
next.push((neigh, pos, last_node, dist + 1));
} }
if seen & (1 << next_pos) == 0 {
next.push((next_pos, seen | 1 << next_pos, dist + d));
} }
} }
} }
@@ -253,67 +208,29 @@ fn construct_graph(grid: &Grid<u8>) -> (Graph, Node, Node) {
todo.append(&mut next); todo.append(&mut next);
} }
for v in graph.values_mut() { todo
v.sort(); };
let mut nv = Vec::new();
let mut last = v[0]; let starts = collect_starts(entry, 4);
for &v in v.iter().skip(1) { let penultimate = graph[exit as usize][0];
if v.0 != last.0 {
nv.push(last); let part2 = thread::scope(|scope| {
} let mut handles = Vec::new();
last = v;
} for &(start, forbidden, distance) in starts.iter() {
nv.push(last); let graph = &graph;
v.clear(); handles.push(scope.spawn(move || {
v.append(&mut nv); find_longest_path(graph, start, penultimate.0, forbidden) + distance + penultimate.1
}));
} }
let node_idx: HashMap<Pos, u8> = graph.keys().copied().zip(0..).collect(); handles
let start = node_idx[&initial]; .into_iter()
let destination = node_idx[&destination]; .fold(0, |acc, handle| acc.max(handle.join().unwrap()))
let mut graph_vec = vec![Vec::new(); node_idx.len()];
graph
.iter()
.map(|(k, v)| {
(
node_idx[k],
v.iter()
.map(|(n, dist)| (node_idx[n], *dist))
.collect::<Vec<(Node, u32)>>(),
)
})
.for_each(|(node, adj)| {
graph_vec[node as usize] = adj;
}); });
// let m = graph_vec.iter().map(|v| v.len()).reduce(std::cmp::max).unwrap(); println!("1) {}", part1);
println!("2) {}", part2);
// let mut gridify = vec![(63, 0); 63 * graph_vec.len()]; Ok(())
// for (id, adjs) in graph_vec.iter().enumerate() {
// for (aid, adj) in adjs.iter().enumerate() {
// gridify[id * m + aid] = *adj;
// }
// }
// println!("{}", graph_vec.len());
(graph_vec, start, destination)
}
fn main() {
let filename = std::env::args().nth(1).unwrap();
let input = std::fs::read_to_string(filename).unwrap();
let width = input.lines().next().unwrap().len();
let buffer: Vec<_> = input.lines().flat_map(|line| line.bytes()).collect();
let height = buffer.len() / width;
let grid = Grid::from_buffer(buffer, width);
let mut seen = Grid::new(height, width, false);
let longest_path = search(&grid, &mut seen, Pos { x: 0, y: 1 }, 0);
println!("1) {}", longest_path);
let (graph, start, destination) = construct_graph(&grid);
//let part2 = find_longest_path(&graph, 0, start, destination, 0);
let part2 = find_longest_path_no_rec(&graph, start, destination);
println!("{}", part2);
} }