Skip to content

Speeding up a multiarmed bandit by 30x

Published: at 12:00 AM

Slow Code

Step 1 of becoming a 30X bandit developer- write a very slow implementation of an epsilon multiarmed bandit during your walkthrough of Sutton & Barto’s Reinforcement Learning book.

To be clear, slow code can still very much be useful code:

Slow is also relative. A single run of 1000 steps was timing in at 250 ms which seemed fine at first glance.

Too Slow Code

So why are we still here as I sing praises of poorly optimized code?

My song quickly strikes a different tune when our slow code is approaching 10 minutes per single grid point!

Multiprocessing and Amdahl’s Law

Luckily, bandit runs are independent and lend themselves well to a multiprocessing paradigm. I ended up using the ProcessPoolExecutor1 class from the python standard concurrent.futures library to further abstract away some of the setup.

Less fortunately, hardware limitations start to rear their ugly head. Probably (hopefully) not a concern in a work setting, where you’ll have access to the full range of compute by your preferred, friendly neighborhood cloud vendor (or even a beefy on-prem rack2) but in the world of personal, side projects and a single, trusty but aging desktop, we start running into diminishing returns. Including Amdahl’s Law.

A quick screen grab of the wandb dashboard for the sweep below (where n_jobs parameterized the number of parallel processes to spin up):

python run.py -m experiment.upload=true experiment.tag=cpu_profile run.steps=1000 run.n_runs=2000 run.n_jobs=1,2,3,4,5,6,7,8`  

and we can immediately observe:

  1. Scaling from one to two processes nearly halves our original run time!
  2. We quickly hit diminishing returns for each additional process, until there’s really no further incremental gains at all.

Looks like we’ve hit Amdahl’s Law:

the overall performance improvement gained by optimizing a single part of a system is limited by the fraction of time that the improved part is actually used

Parallel programming is hard.

Profilers and Flamegraphs

Although multiprocessing ran its course, I thought there was probably quite a bit of head room for optimization and broke out the excellent austin profiler3 with output piped to flamegraph (Note the -C flag to enable profiling child processes):

austin -C -i 1ms python run.py run.steps=1000 run.n_runs=20 | sed '/^#/d' | ~/projects/FlameGraph/flamegraph.pl --countname=μs > fg_v1_original.svg`  

We can pretty easily pick out the 8 different process pool stacks running our bandits in parallel on the right side of the flamegraph.

A closer zoom into one of the individual process pools reveals something puzzling- our code is spending a lot of time (ie the length of the xaxis on a flamegraph will correlate to how long a particular portion of code runs based on how often it was sampled by the profiler) interacting with the omegaconf dictionary objects I used for configuring my experiment runs via yaml files and hydra.

I didn’t realize taking advantage of the hydra class instantiation feature by default actually ports the omegaconf object rather than a pure python dict. Apparently this came with a significant amount of overhead as a simple 2 line change to convert to a simple python dict resulted in a large performance gain- what was taking over 100 seconds was now clocking in at less than 10:

Beware diligent fstrings

The bandit runs that were originally taking up the majority of the processing time were now comfortably taking a back seat to the actual setup (e.g. wandb) and package imports seen on the left hand side of the flamegraph.

Zooming in again onto the top of the stack for a single process, it appeared that my home spun implementation of argmax (with its head scratching assortment of type conversion) was the next bottleneck to tackle.

	def argmax(self, Q):  
    	"""Return max estimate Q, if tie between actions, choose at random between tied actions"""  
    	Q\_array \= np.array(list(self.Q.values()))  
    	At \= np.argwhere(Q\_array \== np.max(Q\_array)).flatten().tolist()

    	if len(At) \> 1:  
        	At \= np.random.choice(At)  
    	else:  
        	At \= At\[0\]

    	return list(Q.keys())\[At\]  

Swapping it out for the numpy implementation of argmax took some rearranging to consistently handle numpy arrays rather than an eclectic mix of lists, dicts, and pandas dataframes sprinkled throughout the codebase. The refactor was well worth it after seeing the final result of 3 time slower runs.

Slower wasn’t quite the direction I was aiming for here. Taking the profiler out again, I noticed some interesting references to printing arrays for some reason:

It turns out fstrings in debug logging statements are not lazy. I had added some additional logging while refactoring in the numpy array arrays which were apparently always evaluating even when not running the logger in DEBUG mode:

30x Bandit

One last look through the profiler and our bandits are now humming along so fast that the process stacks are no longer even showing up as distinct stacks! Most of what the flamegraph has sampled now consists of import statements rather than the actual bandit runs themselves, which seems like a good stopping point for our performance refactor as we’ve probably picked the low hanging fruit (and I want an excuse to go back to implementing rather than writing!).

Footnotes

  1. When to use a pool vs thread? I like the rule of thumb cited in this stackoverflow thread, with the gist being “use threads for when I/O is the bottleneck, use pools when most of the processing is cpu bound. Even external libraries like Numpy that nominally run multithreaded can sometimes lock the GIL”.

  2. Beowulf clusters are more a relic of my shoestring academic budget past than anything you’d find in the office these days. Even back then, we were starting to hear about how “gaming cards” could be used instead for certain computational tasks and that we should all check out something called CUDA.

  3. What is profiling? Which profiler should you use? The author of the austin profiler has a great post here.