Simon Fell > Its just code > It Depends, episode 2

Saturday, January 30, 2021

We're continuing with the array search covered in episode 1. I'd recommend reading that before this one.

Lets start with an easy one. One implementation we didn't cover in episode 1 is to use bytes.IndexByte. IndexByte does end up calling a hand craft assembly version so there's some hope this might be decent.

func (n *nodeLoop) getIndexByte(k byte) interface{} {
	idx := bytes.IndexByte(n.key[:n.count], k)
	if idx < 0 {
		return nil
	}
	return n.val[idx]
}
Benchmark_IndexByte-8                           	11620299	       102 ns/op
Benchmark_Loop-8                                	11181962	       102 ns/op
Benchmark_LoopOneBoundsCheck-8                  	13485289	        87.1 ns/op
Benchmark_LoopRev-8                             	10736464	       111 ns/op
Benchmark_LoopRevOneBoundsCheck-8               	14661828	        81.0 ns/op
Benchmark_BinarySearch-8                        	 4895956	       246 ns/op
Benchmark_BinarySearchOneBoundsCheck-8          	 4806734	       249 ns/op
Benchmark_BinarySearchInlined-8                 	10140198	       116 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-8   	11359298	       106 ns/op
PASS

That's disappointing. I generated a cpu profile, and it seems like the overhead of having deal with an arbitrary sized slice out weigh's any gains from the more efficient comparison eventually done. A variation that always passes 16 bytes to bytes.IndexByte and checks the result against n.count performs almost the same. We'll skip the steps for profiling and disassembly, checkout episode 1 if you need the details on how to do that.

Loop Unrolling is a common optimization. From the assembly we looked at previously we know the compiler isn't doing this. We can manually write a loop unrolled version. Lets see how this does.

func (n *nodeLoop) getUnrolledLoop(k byte) interface{} {
	switch n.count {
	case 16:
		if n.key[15] == k {
			return n.val[15]
		}
		fallthrough
	case 15:
		if n.key[14] == k {
			return n.val[14]
		}
		fallthrough
	case 14:
		if n.key[13] == k {
			return n.val[13]
		}
		fallthrough
	case 13:
		if n.key[12] == k {
			return n.val[12]
		}
        fallthrough
    ...
Benchmark_IndexByte-8                           	11525317	       104 ns/op
Benchmark_UnrolledLoop-8                        	13063185	        93.0 ns/op
Benchmark_Loop-8                                	11788951	       101 ns/op
Benchmark_LoopOneBoundsCheck-8                  	13612688	        87.5 ns/op
Benchmark_LoopRev-8                             	10756114	       112 ns/op
Benchmark_LoopRevOneBoundsCheck-8               	14641873	        82.3 ns/op
Benchmark_BinarySearch-8                        	 4833198	       248 ns/op
Benchmark_BinarySearchOneBoundsCheck-8          	 4789767	       250 ns/op
Benchmark_BinarySearchInlined-8                 	10150819	       118 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-8   	11339012	       107 ns/op

There's still a lot of branches in the unrolled version, so not totally surprised with the outcome. Reviewing the assembly for this one, go doesn't use jump tables for the switch statement. This means its doing more comparisons than expected as well.

One thing that stands out from the loops versions assembly is that we have a 64 bit CPU which we're then forcing to do things one byte at time. Can we work 64 bits at a time instead and not have to resort to hand crafted assembly? What if we pack 8 keys into a uint64 and do bitwise operations to compare against the target key? Worst case is we'd have to do this on 2 uint64's instead of 16 bytes. Its doable, but at a cost of being significantly harder to understand than the loop version.

Starting with put we need to work out which uint64 to update, and then shift the key value into the relevant byte of that uint64.

type nodeMasks struct {
	keys1 uint64
	keys2 uint64
	vals  [16]interface{}
	count int
}
func (n *nodeMasks) put(k byte, v interface{}) {
	m := &n.keys1
	c := n.count
	if n.count >= 8 {
		m = &n.keys2
		c = c - 8
	}
	*m = *m | (uint64(k) << (c * 8))
	n.vals[n.count] = v
	n.count++
}

get is more work. First off we need a uint64 that has the value k in each byte. This can be constructed with a bunch of shifts, but turns out to be a significant percentage of the overall work. There are only 256 possible values for this though, so we can calculate them once at startup, and grab that when needed. The lookup table of masks takes 2k bytes (256 * 8) a more than reasonable tradeoff.

func (n *nodeMasks) get(k byte) interface{} {
	if n.count == 0 {
		return nil
	}
	// mask is a uint64 with k at each byte position
	mask := masks[k]
	// act has bytes with value FF in positions we don't want to consider
	act := active[n.count-1]
	// XOR the mask and the keys, for any bytes that match the result will be 0
	// for ones that don't match the result will be not 0.
	// OR the result with act so that any key positions we shouldn't consider get
	// set to FF
	r := (mask ^ n.keys1) | act
	// now check each byte in the result for a zero
	if (r & b1) == 0 {
		return n.vals[0]
	}
	if (r & b2) == 0 {
		return n.vals[1]
	}
	if (r & b3) == 0 {
		return n.vals[2]
	}
	if (r & b4) == 0 {
		return n.vals[3]
	}
	if (r & b5) == 0 {
		return n.vals[4]
	}
	if (r & b6) == 0 {
		return n.vals[5]
	}
	if (r & b7) == 0 {
		return n.vals[6]
	}
	if (r & b8) == 0 {
		return n.vals[7]
	}
	if n.count < 9 {
		return nil
	}
	// same again for the upper 8 keys
	r = (mask ^ n.keys2) | active[n.count-9]
	if (r & b1) == 0 {
		return n.vals[8]
	}
	if (r & b2) == 0 {
		return n.vals[9]
	}
	if (r & b3) == 0 {
		return n.vals[10]
	}
	if (r & b4) == 0 {
		return n.vals[11]
	}
	if (r & b5) == 0 {
		return n.vals[12]
	}
	if (r & b6) == 0 {
		return n.vals[13]
	}
	if (r & b7) == 0 {
		return n.vals[14]
	}
	if (r & b8) == 0 {
		return n.vals[15]
	}
	return nil
}

Once we've finished with the bit operations, we're left with a uint64. Where the key matches the target value, the byte in the uint64 will be zero otherwise its none zero. We could write a loop to then check each byte, the above code is the unrolled equivalent of that loop, and provides a decent gain over the loop, 70.7ns vs 103ns.

Benchmark_IndexByte-8                           	11373612	       104 ns/op
Benchmark_UnrolledLoop-8                        	12968982	        92.5 ns/op
Benchmark_Masks-8                               	16908472	        70.7 ns/op
Benchmark_MasksWithFinalLoop-8                  	11663716	       103 ns/op
Benchmark_Loop-8                                	11187504	       105 ns/op
Benchmark_LoopOneBoundsCheck-8                  	13619496	        88.3 ns/op
Benchmark_LoopRev-8                             	10784320	       112 ns/op
Benchmark_LoopRevOneBoundsCheck-8               	14605256	        82.5 ns/op
Benchmark_BinarySearch-8                        	 4791508	       254 ns/op
Benchmark_BinarySearchOneBoundsCheck-8          	 4614759	       252 ns/op
Benchmark_BinarySearchInlined-8                 	10562584	       114 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-8   	11154210	       107 ns/op

But we're back to byte at a time, can we do more bit twiddling to decode the index out of the uin64? Turns out we can. With some additional bit fiddling, we can get the result to have just the high bit of the byte set where there was a match, and zero otherwise. At that point we can count the number of trailing zeros to determine which high bit was set.

func (n *nodeMasks) getMoreBitTwiddling(k byte) interface{} {
	if n.count == 0 {
		return nil
	}
	// This follows the same approach as get. But uses additional
	// bit twiddling to determine if any of the bytes are zero.
	mask := masks[k]
	r := (mask ^ n.keys1) | active[n.count-1]

	// see https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
	x := (r - 0x0101010101010101) & ^(r) & 0x8080808080808080
	idx := bits.TrailingZeros64(x) / 8
	if idx < 8 {
		return n.vals[idx]
	}
	if n.count < 9 {
		return nil
	}
	r = (mask ^ n.keys2) | active[n.count-9]
	x = (r - 0x0101010101010101) & ^(r) & 0x8080808080808080
	idx = bits.TrailingZeros64(x) / 8
	if idx < 8 {
		return n.vals[idx+8]
	}
	return nil
}

As an aside, despite the fact that bits.TrailingZeros64 appears to be a pure go implementation, the compiler will replace it with the relevant assembly instruction.

Benchmark_IndexByte-8                           	10000000	       103 ns/op
Benchmark_UnrolledLoop-8                        	10000000	        92.9 ns/op
Benchmark_Masks-8                               	10000000	        71.0 ns/op
Benchmark_MasksWithFinalLoop-8                  	10000000	       105 ns/op
Benchmark_MasksWithBitTwiddling-8               	10000000	        78.1 ns/op
Benchmark_Loop-8                                	10000000	       105 ns/op
Benchmark_LoopOneBoundsCheck-8                  	10000000	        88.2 ns/op
Benchmark_LoopRev-8                             	10000000	       113 ns/op
Benchmark_LoopRevOneBoundsCheck-8               	10000000	        82.2 ns/op
Benchmark_BinarySearch-8                        	10000000	       250 ns/op
Benchmark_BinarySearchOneBoundsCheck-8          	10000000	       254 ns/op
Benchmark_BinarySearchInlined-8                 	10000000	       115 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-8   	10000000	       107 ns/op

Unexpectedly the bit twiddling approach is slightly slower than the version that used an unrolled loop to decode the return index. Wasn't expecting that, perhaps as all the checks for zero byte don't depend on each other, they can all get speculatively executed, while the bit twiddling approach depends on each previous instruction. Concrete explanations for this welcome. Notice though that either approach is faster than anything else so far.

Our final version is the one described in the art paper, which uses SIMD instructions to compare all 16 keys at once. There's no SIMD intrinsics for go, so this'll have to be in straight up assembly. I used avo to help deal with some of the drudgery involved. (see asm/asm.go in the repo for the full code). XMM registers are 16 bytes wide, and the VPBROADCASTB & PCMPEQB instructions are the primary SIMD instructions involved.

TEXT("Lookup", NOSPLIT, "func(k byte, x *[16]byte) int32")
Pragma("noescape")
Doc("Lookup returns the index into the array 'x' of the value 'k', or -1 if its not there." +
    " If k appears at multiple locations, you'll get one of them as the return value, it may not be the first one.")
x := Load(Param("x"), GP64())
k := Load(Param("k"), GP32())

xKey := XMM()
MOVD(k, xKey)
VPBROADCASTB(xKey, xKey) // xmm register now contains the value k in all 16 bytes

xArr := XMM()
MOVUPD(Mem{Base: x}, xArr) // xmm register now contains the 16 bytes of the array x

// Compare bytes for equality between the 2 xmm registers.
// xArr is updated with the result. Where they're equal the byte is set to FF
// otherwise its set to 0
PCMPEQB(xKey, xArr)

rv := GP64()
rOffset := GP64()
XORQ(rOffset, rOffset)       // resOffset = 0
MOVQ(xArr, rv)               // get the lower 8 bytes from the xmm register into rv
TESTQ(rv, rv)                // is rv 0? if not, at least one byte was equal
JNZ(LabelRef("returnCount")) // jump to converting that back to a index

MOVHLPS(xArr, xArr) // move top 64 bits to lower 64 bits in xmm register
MOVQ(xArr, rv)      // move lower 8 bytes into rv
TESTQ(rv, rv)
JZ(LabelRef("notFound")) // is rv 0? if so there's no matches, so return -1
// the match was found in the top 8 bytes, so we need
// to offset the final calculated index by 8.
MOVQ(U64(8), rOffset)

Label("returnCount") // return tailing zeros / 8 + offset
idx := GP64()
TZCNTQ(rv, idx)    // set idx to the number of trailing zeros in rv
SHRQ(Imm(3), idx)  // divide idx by 8 to get from bit position to byte posn.
ADDQ(rOffset, idx) // add the result offset in.

Store(idx.As32(), ReturnIndex(0)) // return the final index as the result.
RET()

Label("notFound")
rMiss := GP32()
MOVL(U32(0xFFFFFFFF), rMiss)
Store(rMiss, ReturnIndex(0)) // return -1
RET()

The VPBROADCASTB and PCMPEQB instructions do the heavy lifting, but otherwise it bares a strongly resemblance to the bit mask approach we wrote earlier.

Benchmark_IndexByte-8                           	11691013	       103 ns/op
Benchmark_UnrolledLoop-8                        	12977037	        91.9 ns/op
Benchmark_GetLookupAsm-8                        	23960334	        48.9 ns/op
Benchmark_Masks-8                               	16845604	        70.4 ns/op
Benchmark_MasksWithFinalLoop-8                  	11721877	       102 ns/op
Benchmark_MasksWithBitTwiddling-8               	15233120	        77.3 ns/op
Benchmark_Loop-8                                	11848963	       101 ns/op
Benchmark_LoopOneBoundsCheck-8                  	13606868	        87.6 ns/op
Benchmark_LoopRev-8                             	10635825	       113 ns/op
Benchmark_LoopRevOneBoundsCheck-8               	13355061	        88.9 ns/op
Benchmark_BinarySearch-8                        	 4826577	       253 ns/op
Benchmark_BinarySearchOneBoundsCheck-8          	 4798706	       250 ns/op
Benchmark_BinarySearchInlined-8                 	10516956	       114 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-8   	11263238	       107 ns/op

A clear winner, that's half the time our original loop took, and 30% faster than the next best version.

But which one would i really use? Despite the giant trip down the performance rabbit hole, I'd pick the one that performs the best out of the ones that are easily understandable. That's be the reverse loop with the tweak to remove all but one of the bounds checks, LoopRevOneBoundsCheck. If it later turned out that this particular part of the system was a bottleneck that needed improving then i'd probably go with the assembly version. However I'd need to do a lot more studying first to validate its correctness, not its logical correctness, but its correctness as a assembly function inside a go execution environment.

As a final exercise, I'll leave you to work out at what size does binary search actually win? Let me know what results you get.

Remember how i said it might depend on OS, CPU? Here's some runs from other machines, all using go 1.15.7

Windows 10 with an Intel 9900KS. Faster than the iMac runs, not surprising given its a 3 generations newer processor. But the relative performance between the approaches is pretty similar.

goos: windows
goarch: amd64
pkg: github.com/superfell/loopVsBinarySearch
Benchmark_IndexByte-16                                  14651924                81.8 ns/op
Benchmark_UnrolledLoop-16                               16365361                72.8 ns/op
Benchmark_GetLookupAsm-16                               31375419                37.9 ns/op
Benchmark_Masks-16                                      21862701                55.0 ns/op
Benchmark_MasksWithFinalLoop-16                         14606589                81.3 ns/op
Benchmark_MasksWithBitTwiddling-16                      18442527                63.3 ns/op
Benchmark_Loop-16                                       15587856                77.0 ns/op
Benchmark_LoopOneBoundsCheck-16                         17430231                68.1 ns/op
Benchmark_LoopRev-16                                    14259435                85.0 ns/op
Benchmark_LoopRevOneBoundsCheck-16                      19040486                64.0 ns/op
Benchmark_BinarySearch-16                                6129547               195 ns/op
Benchmark_BinarySearchOneBoundsCheck-16                  6129829               196 ns/op
Benchmark_BinarySearchInlined-16                        13593127                86.6 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-16          14756770                80.6 ns/op
PASS

And finally ARM/linux running on a Raspberry Pi 4 (this doesn't include a version of the SIMD solution). Different relative performance between the approaches now. The UnrolledLoop does better than any others, and the inlined binary search is not far behind it.

goos: linux
goarch: arm
pkg: github.com/superfell/loopVsBinarySearch
Benchmark_IndexByte-4                           	 1443998	       830 ns/op
Benchmark_UnrolledLoop-4                        	 2522744	       477 ns/op
Benchmark_Masks-4                               	 2023891	       594 ns/op
Benchmark_MasksWithFinalLoop-4                  	 1422764	       857 ns/op
Benchmark_MasksWithBitTwiddling-4               	 1911872	       629 ns/op
Benchmark_Loop-4                                	 1708760	       700 ns/op
Benchmark_LoopOneBoundsCheck-4                  	 1619178	       679 ns/op
Benchmark_LoopRev-4                             	 1728384	       737 ns/op
Benchmark_LoopRevOneBoundsCheck-4               	 1907144	       610 ns/op
Benchmark_BinarySearch-4                        	 1000000	      1309 ns/op
Benchmark_BinarySearchOneBoundsCheck-4          	  990722	      1219 ns/op
Benchmark_BinarySearchInlined-4                 	 2160622	       556 ns/op
Benchmark_BinarySearchInlinedOneBoundsCheck-4   	 2185908	       549 ns/op