Improve Day 23 perf

This commit is contained in:
Christian
2023-12-28 13:57:09 +01:00
parent ce80483f96
commit 9c4e8e8774

View File

@@ -1,5 +1,8 @@
use hashbrown::HashMap; use anyhow::{Context, Result};
use hashbrown::{HashMap, HashSet};
use itertools::iproduct;
use std::ops::{Index, IndexMut}; use std::ops::{Index, IndexMut};
use std::thread;
#[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash, Debug)]
struct Pos { struct Pos {
@@ -7,34 +10,12 @@ struct Pos {
y: usize, y: usize,
} }
#[rustfmt::skip]
impl Pos { impl Pos {
fn north(&self) -> Self { fn north(&self) -> Self { Pos { x: self.x - 1, y: self.y } }
Pos { fn south(&self) -> Self { Pos { x: self.x + 1, y: self.y } }
x: self.x - 1, fn west(&self) -> Self { Pos { x: self.x, y: self.y - 1 } }
y: self.y, fn east(&self) -> Self { Pos { x: self.x, y: self.y + 1 } }
}
}
fn south(&self) -> Self {
Pos {
x: self.x + 1,
y: self.y,
}
}
fn west(&self) -> Self {
Pos {
x: self.x,
y: self.y - 1,
}
}
fn east(&self) -> Self {
Pos {
x: self.x,
y: self.y + 1,
}
}
} }
struct Grid<T> { struct Grid<T> {
@@ -72,19 +53,31 @@ impl<T: Clone> Grid<T> {
.filter(|(x, y)| (x < &self.height) & (y < &self.width)) .filter(|(x, y)| (x < &self.height) & (y < &self.width))
.map(|(x, y)| Pos { x, y }) .map(|(x, y)| Pos { x, y })
} }
fn adj4p<'a>(
&'a self,
pos: &Pos,
pred: impl Fn(&Pos) -> bool + 'a,
) -> impl Iterator<Item = Pos> + 'a {
self.adj4(pos).filter(pred)
}
fn points(&self) -> impl Iterator<Item = Pos> {
iproduct!(0..self.height, 0..self.width).map(|(x, y)| Pos { x, y })
}
} }
impl<T> Index<Pos> for Grid<T> { impl<T> Index<Pos> for Grid<T> {
type Output = T; type Output = T;
fn index(&self, index: Pos) -> &Self::Output { fn index(&self, index: Pos) -> &Self::Output {
&self.buffer[index.x * &self.width + index.y] &self.buffer[index.x * self.width + index.y]
} }
} }
impl<T> IndexMut<Pos> for Grid<T> { impl<T> IndexMut<Pos> for Grid<T> {
fn index_mut(&mut self, index: Pos) -> &mut Self::Output { fn index_mut(&mut self, index: Pos) -> &mut Self::Output {
&mut self.buffer[index.x * &self.width + index.y] &mut self.buffer[index.x * self.width + index.y]
} }
} }
@@ -117,203 +110,127 @@ fn search(grid: &Grid<u8>, seen: &mut Grid<bool>, pos: Pos, steps: usize) -> usi
}; };
seen[pos] = false; seen[pos] = false;
return result; result
}
fn search_p2(grid: &Grid<u8>, seen: &mut Grid<bool>, pos: Pos, steps: usize) -> usize {
let destination = Pos {
x: grid.height - 1,
y: grid.width - 2,
};
if pos == destination {
return steps;
}
if seen[pos] {
return 0;
}
seen[pos] = true;
let result = match grid[pos] {
b'#' => 0,
_ => grid
.adj4(&pos)
.map(|pos| search(grid, seen, pos, steps + 1))
.fold(0, std::cmp::max),
};
seen[pos] = false;
return result;
} }
type Node = u8; type Node = u8;
type Graph = Vec<Vec<(Node, u32)>>; type Graph = Vec<[(Node, u32); 4]>;
fn find_longest_path( fn constructs_graph(grid: &Grid<u8>) -> (Graph, Node, Node) {
graph: &Graph, let entry = Pos { x: 0, y: 1 };
mut seen: u64, let exit = Pos {
source: Node,
dest: Node,
distance: usize,
) -> usize {
if source == dest {
return distance;
}
seen |= 1 << source;
let mut result = 0;
for (next, dist) in graph[source as usize].iter().copied() {
if seen & (1 << next) == 0 {
result = result.max(find_longest_path(
graph,
seen,
next,
dest,
distance + dist as usize,
));
}
}
return result;
}
fn find_longest_path_no_rec(
graph: &Graph,
source: Node,
dest: Node
) -> usize {
let mut stack = Vec::from([(source, 0u64, 0)]);
stack.reserve(63);
let mut result = 0;
while let Some((source, mut seen, distance)) = stack.pop() {
if source == dest {
result = result.max(distance);
} else {
seen |= 1 << source;
let v = unsafe { graph.get_unchecked(source as usize) };
for (next, dist) in v.iter().copied() {
if seen & (1 << next) == 0 {
stack.push((next, seen, distance + dist as usize));
}
}
}
}
result
}
fn construct_graph(grid: &Grid<u8>) -> (Graph, Node, Node) {
let mut seen = Grid::new(grid.height, grid.width, false);
let mut todo = Vec::new();
let mut next = Vec::new();
let mut graph = HashMap::new();
let destination = Pos {
x: grid.height - 1, x: grid.height - 1,
y: grid.width - 2, y: grid.width - 2,
}; };
let initial = Pos { x: 0, y: 1 }; let is_free = |pos: &Pos| grid[*pos] != b'#';
todo.push((initial, initial, initial, 0)); let is_node = |pos: &Pos| is_free(pos) && grid.adj4p(pos, is_free).count() > 2;
while !todo.is_empty() { let mut nodes: HashSet<_> = grid.points().filter(is_node).collect();
for (pos, prev, mut last_node, mut dist) in todo.iter().copied() { nodes.insert(entry);
let is_node = 2 < grid.adj4(&pos).filter(|p| grid[*p] != b'#').count(); nodes.insert(exit);
let node_ids: HashMap<_, _> = nodes.iter().copied().zip(1..).collect();
if is_node || pos == destination { let mut flat_graph = vec![[(0, 0), (0, 0), (0, 0), (0, 0)]; nodes.len() + 1];
graph
.entry(last_node)
.or_insert(Vec::new())
.push((pos, dist));
graph
.entry(pos)
.or_insert(Vec::new())
.push((last_node, dist));
last_node = pos;
dist = 0;
if !seen[pos] { for node in nodes.iter().copied() {
seen[pos] = true; for (adj_idx, start) in grid.adj4p(&node, is_free).enumerate() {
for neigh in grid.adj4(&pos) { let mut prev = node;
if neigh != prev && grid[neigh] != b'#' { let mut pos = start;
next.push((neigh, pos, last_node, dist + 1)); let mut dist = 1;
} while !nodes.contains(&pos) {
} (prev, pos) = (pos, grid.adj4p(&pos, is_free).find(|&p| p != prev).unwrap());
} dist += 1;
} else {
for neigh in grid.adj4(&pos) {
if neigh != prev && grid[neigh] != b'#' {
next.push((neigh, pos, last_node, dist + 1));
}
}
} }
flat_graph[node_ids[&node]][adj_idx] = (node_ids[&pos] as u8, dist as u32);
} }
todo.clear();
todo.append(&mut next);
} }
for v in graph.values_mut() { (flat_graph, node_ids[&entry] as u8, node_ids[&exit] as u8)
v.sort();
let mut nv = Vec::new();
let mut last = v[0];
for &v in v.iter().skip(1) {
if v.0 != last.0 {
nv.push(last);
}
last = v;
}
nv.push(last);
v.clear();
v.append(&mut nv);
}
let node_idx: HashMap<Pos, u8> = graph.keys().copied().zip(0..).collect();
let start = node_idx[&initial];
let destination = node_idx[&destination];
let mut graph_vec = vec![Vec::new(); node_idx.len()];
graph
.iter()
.map(|(k, v)| {
(
node_idx[k],
v.iter()
.map(|(n, dist)| (node_idx[n], *dist))
.collect::<Vec<(Node, u32)>>(),
)
})
.for_each(|(node, adj)| {
graph_vec[node as usize] = adj;
});
// let m = graph_vec.iter().map(|v| v.len()).reduce(std::cmp::max).unwrap();
// let mut gridify = vec![(63, 0); 63 * graph_vec.len()];
// for (id, adjs) in graph_vec.iter().enumerate() {
// for (aid, adj) in adjs.iter().enumerate() {
// gridify[id * m + aid] = *adj;
// }
// }
// println!("{}", graph_vec.len());
(graph_vec, start, destination)
} }
fn main() { fn find_longest_path(graph: &Graph, source: Node, dest: Node, forbidden: u64) -> u32 {
let filename = std::env::args().nth(1).unwrap(); let mut stack = Vec::from([(source, forbidden, 0)]);
let input = std::fs::read_to_string(filename).unwrap(); stack.reserve(63);
let mut result = 0;
while let Some((source, mut seen, distance)) = stack.pop() {
if source == dest {
result = result.max(distance);
} else {
seen |= 1 << source;
for (next, dist) in graph[source as usize].iter().copied() {
if next == 0 {
break;
}
if seen & (1 << next) == 0 {
stack.push((next, seen, distance + dist));
}
}
}
}
result
}
fn main() -> Result<()> {
let filename = std::env::args()
.nth(1)
.context("./day23 <path to puzzle input>")?;
let input = std::fs::read_to_string(filename)?;
let width = input.lines().next().unwrap().len(); let width = input.lines().next().unwrap().len();
let buffer: Vec<_> = input.lines().flat_map(|line| line.bytes()).collect(); let buffer: Vec<_> = input.lines().flat_map(|line| line.bytes()).collect();
let height = buffer.len() / width; let height = buffer.len() / width;
let grid = Grid::from_buffer(buffer, width); let grid = Grid::from_buffer(buffer, width);
let mut seen = Grid::new(height, width, false); let mut seen = Grid::new(height, width, false);
let longest_path = search(&grid, &mut seen, Pos { x: 0, y: 1 }, 0); let part1 = search(&grid, &mut seen, Pos { x: 0, y: 1 }, 0);
println!("1) {}", longest_path);
let (graph, start, destination) = construct_graph(&grid); let (graph, entry, exit) = constructs_graph(&grid);
//let part2 = find_longest_path(&graph, 0, start, destination, 0);
let part2 = find_longest_path_no_rec(&graph, start, destination); // depth-bounded bfs tp determine starting locations
println!("{}", part2); let collect_starts = |source: u8, depth: usize| {
let mut todo = Vec::from([(source, 1u64 << source, 0)]);
let mut next = Vec::new();
for _ in 0..depth {
for &(pos, seen, dist) in todo.iter() {
for &(next_pos, d) in graph[pos as usize].iter() {
if next_pos == 0 {
break;
}
if seen & (1 << next_pos) == 0 {
next.push((next_pos, seen | 1 << next_pos, dist + d));
}
}
}
todo.clear();
todo.append(&mut next);
}
todo
};
let starts = collect_starts(entry, 4);
let penultimate = graph[exit as usize][0];
let part2 = thread::scope(|scope| {
let mut handles = Vec::new();
for &(start, forbidden, distance) in starts.iter() {
let graph = &graph;
handles.push(scope.spawn(move || {
find_longest_path(graph, start, penultimate.0, forbidden) + distance + penultimate.1
}));
}
handles
.into_iter()
.fold(0, |acc, handle| acc.max(handle.join().unwrap()))
});
println!("1) {}", part1);
println!("2) {}", part2);
Ok(())
} }