Optimize iterator chains better
The iterator adaptor .chain() and similar code uses a state flag that changes mid-loop; the compiler should recognize this and try to split it into two consecutive loops if possible.
Reproduction in code (playground)
pub fn visit_both(a: &[u32], b: &[u32], f: fn(u32)) {
for &elt in chain(a, b) {
f(elt)
}
}
/// this is a simplified version of std::iter::Chain
pub struct Chain<I> { front: bool, a: I, b: I }
fn chain<I>(a: I, b: I) -> Chain<I::IntoIter> where I: IntoIterator {
Chain { front: true, a: a.into_iter(), b: b.into_iter() }
}
impl<I> Iterator for Chain<I> where I: Iterator {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
if self.front {
if let elt @ Some(_) = self.a.next() {
return elt;
}
self.front = false;
}
self.b.next()
}
}
Actual Behaviour
There is only one call instruction in visit_both's generated code, indicating the loop was not split.
Expected Behaviour
Generated code should be more similar to two consecutive loops.
Hi, I'm new to rustc and would like to help! This issue seems like a nice one to start, do you mind me taking this one?
Liang Ying-Ruei at 2016-11-29 17:26:59
@TheKK I'm sure you can take this. This is not my area, but there will be others that can help you along the way.
bluss at 2016-11-29 20:01:21
cc @pcwalton
Unsure how/if this can be approached.
bluss at 2016-11-29 21:54:45
It can be done with inverting conditions inside out - first condition must check a.next() and return, and only else do something more complicated.
/// this is a simplified version of std::iter::Chain pub struct Chain2<I> { a: I, b: Option<I> } fn chain2<I>(a: I, b: I) -> Chain2<I> where I: Iterator { Chain2 { a: a, b: Some(b) } } impl<I> Iterator for Chain2<I> where I: Iterator { type Item = I::Item; fn next(&mut self) -> Option<Self::Item> { if let Some(item) = self.a.next() { return Some(item); } if let Some(b) = self.b.take() { self.a = b; } self.a.next() } } fn main() { let v0 = vec![0, 1]; let v1 = vec![0, 1, 2]; let mut v3 = vec![]; for &x in chain2(v0.iter(), v1.iter()) { v3.push(x); } assert_eq!(v3, [0, 1, 0, 1, 2]); }basyg at 2017-10-06 19:08:25
Benchmark: test chained ... bench: 3,138 ns/iter (+/- 209) test chained2 ... bench: 696 ns/iter (+/- 7) test unrolled ... bench: 733 ns/iter (+/- 27)
#![feature(test)] extern crate test; /// this is a simplified version of std::iter::Chain pub struct Chain2<I> { a: I, b: Option<I> } fn chain2<I>(a: I, b: I) -> Chain2<I> where I: Iterator { Chain2 { a: a, b: Some(b) } } impl<I> Iterator for Chain2<I> where I: Iterator { type Item = I::Item; fn next(&mut self) -> Option<Self::Item> { if let Some(item) = self.a.next() { return Some(item); } if let Some(b) = self.b.take() { self.a = b; } self.a.next() } } fn init() -> [u8; 1024] { let mut x: [u8; 1024] = unsafe { std::mem::zeroed() }; for i in 0..x.len() { x[i] = i as u8; } x } #[bench] fn unrolled(b: &mut test::Bencher) { let x = init(); let y = init(); b.iter(|| { for &v in x.iter() { test::black_box(v); } for &v in y.iter() { test::black_box(v); } }); } #[bench] fn chained(b: &mut test::Bencher) { let x = init(); let y = init(); b.iter(|| { for &v in x.iter().chain(y.iter()) { test::black_box(v); } }); } #[bench] fn chained2(b: &mut test::Bencher) { let x = init(); let y = init(); b.iter(|| { for &v in chain2(x.iter(), y.iter()) { test::black_box(v); } }); }basyg at 2017-10-06 19:09:25
@basyg Nice!! Would it have to have an inferior DoubleEndedIterator implementation?
I can see that this seems to work -- but without the same nice performance.
impl<I> DoubleEndedIterator for Chain2<I> where I: DoubleEndedIterator { fn next_back(&mut self) -> Option<Self::Item> { if let Some(ref mut b) = self.b { if let elt @ Some(_) = b.next() { return elt; } } self.b.take(); self.a.next() } }bluss at 2017-10-06 20:10:30
I think this fix is needed to behave correctly w.r.t fusing (not calling .next() on a subiterator that has already returned None) -- but it still compiles well:
fn next(&mut self) -> Option<Self::Item> { if let Some(item) = self.a.next() { return Some(item); } if let Some(b) = self.b.take() { self.a = b; self.a.next() } else { None } }bluss at 2017-10-06 20:18:27
@bluss I couldn't get good perfomance for both Iterator traits at once. Doing something little bit different returns to poor perfomance. I tried to write some analogue code in C++ (gcc and clang) and it is all the same: adding some code (e.g std::swap() or array indexing) causes big degradation of iteration speed. So it isn't only rust compiler issue.
Best universal solution is a compiler patching.
basyg at 2017-10-07 21:26:34
Here's one variant, it's completely symmetric in a and b and it looks pretty good.
The benchmark results with your benchmark look like this:
test chained ... bench: 599 ns/iter (+/- 12) test chained_rev ... bench: 1,144 ns/iter (+/- 35) test unrolled ... bench: 612 ns/iter (+/- 20)So rev is still slower, for some reason.
Anyway, I switched the benchmark to use something even more cut-throat. Unfair! We use a summation instead, then the compiler can really show how many loop optimizations it can realize for a given iteration.
The results are as follows:
test chained ... bench: 885 ns/iter (+/- 6) test chained2 ... bench: 877 ns/iter (+/- 5) test chained_rev ... bench: 750 ns/iter (+/- 21) test std_chain_sum ... bench: 30 ns/iter (+/- 0) test unrolled ... bench: 31 ns/iter (+/- 1)Code: https://gist.github.com/2dfbf131280050cdf76982ea54a363f2
std_chain_sumshows the regular .chain()'s special casefoldimplementation paying off. And the summation tests shows that even if it manages to make it into two consecutive loops, it can't go on an turn it into the same function asunrolled.bluss at 2017-10-07 22:03:36
Here's one variant, it's completely symmetric in a and b and it looks pretty good.
Yeah! It's very nice. Something better without looping can't be expected.
And the summation tests shows that even if it manages to make it into two consecutive loops, it can't go on an turn it into the same function as unrolled.
What are you mean? Results of std_chain_sum and unrolled are same - std_chain_sum was completely optimized.
basyg at 2017-10-07 23:16:35
std_chain_sum goes through a different path --
foldwhich doesn't pass throughiterator::nextforstd::iter::Chain. So it just shows a completely different code path for computing the same result.bluss at 2017-10-08 00:48:41