1

Context:

I've implemented Muzero for the game Tic-tac-toe. Unfortunately, the self-play and training is very slow (like 10 hours until it plays quite well). I ran the python profiler to find the parts that take the most time. The result is that most time is spent doing Monte Carlo tree searches, specifically querying the neural networks for the next hidden state and the value, policy predictions.

While running self-play, my GPU is only on 30% load and my CPU is on max load (single process since GIL)

Note: In the muzero paper, they have done some fancy stuff like scaling some gradients, which I haven't implemented yet. This will probably also result in a small speedup

What I'm trying to do:

I want to speed up the self-play by running multiple MCTS's in different threads so that they pause whenever they want to query the neural network until enough other threads have queries for the network. Then I put all the queries into a batch and sent the batch to the network. Once I have the results, I return them to each thread, and they continue until they try to query the network again.

Reason:

Let's say I want to play 100 self-play games. I do 50 simulations per move, the average game length is 7 and the observation shape is (3,3,3). With my current approach, this would result in 100 * 7 * 50 = 35000 network queries of shape (1,3,3,3).

With the approach described above, I could run all 100 games at once and batch the network queries, resulting in 7 * 50 = 350 network queries of shape (100,3,3,3)

I hope that this will result in a significant speedup.

Questions:

  • What are your thoughts on this plan?
  • Any frameworks/PyTorch features that can help me with my plan?
  • How would you implement something like this?
  • How would you tackle problems like threads never waking up because the batch never gets full enough (if there's something better than timeouts), or possible race conditions?

Don't feel obligated to answer all of these questions.

Lynix
  • 33
  • 3

1 Answers1

2

Batching: A Good Idea

You're right, batching is a great way to speed up AlphaZero or MuZero self-play! Your proposed solution of running multiple games in parallel is the easiest way to achieve some batching. There are other solutions, most notably Virtual loss, which allows you to get batches even from a single tree search. You can even combine both approaches.

Implementation Details

Unfortunately, I don't know of any PyTorch features that help with this, and in general Python is not the best language for concurrency and multithreading.

The way I implemented this in kZero (in Rust):

  • There is a single thread responsible for NN execution that listens to a channel/queue of work items. Once it has collected enough items, it evaluates the batch and sends the results back to each work item through a dedicated temporary response channel that was part of the original item.
  • Different games run on different threads or async tasks, send their evaluation requests to the main thread, and wait for the response.

The way to implement this in Python is with either threading and queues or with asyncio and its built-in queues.

See some of the diagrams in the kZero readme for some graphics to illustrate this.

An alternative is to share a single thread between multiple games and write a state machine to switch between them yourself. Effectively, this means you're implementing your own async-like system, this is tricky to get right! I did this at first but later switched to using async proper to simplify the code a lot.

Yet another alternative is to do what mctx does and let JAX "magically" handle all batching and concurrency. I'm not sure how well this actually works in practice, I haven't used JAX myself yet.

Deadlocks and Race Conditions

You can avoid the potential deadlock of batches never filling by just running enough games at once! For example, if you're only running a single NN execution thread with a batch size of 100, spawning at least 100 games will make deadlocks impossible. It's best to spawn at least twice as many games to get some concurrency, that way, by the time the first 100 requests have been evaluated, the next 100 are ready to go, and the NN thread never has to wait long.

The way to avoid deadlocks and race conditions caused by concurrency bugs is to write correct code! This is tricky to get right and basically an entire domain by itself. I've found that only communicating between threads using (concurrency-supporting) queues makes it relatively easy to write concurrency-bug-free code.

KarelPeeters
  • 493
  • 2
  • 8