Multiprocessing in Python and PyTorch
This is the first part of a 3-part series covering multiprocessing, distributed communication, and distributed training in PyTorch.
In this article, we will cover the basics of multiprocessing in Python first, then move on to PyTorch; so even if you don’t use PyTorch, you may still find helpful resources here :)
Multiprocessing is process-level parallelism, in the sense that each spawned process is allocated seperate memory and resources. In Python, in many cases, multiprocessing is used to bypass the infamous GIL, which is a global lock that prevents the interpreter from running any other code while it is executing a single thread.
multiprocessing
In native Python, multiprocessing is achieved by using the multiprocessing
module.
The official documentation of multiprocessing
is here, and it is great! In this tutorial we will only cover
some of the most important and relevant features of the module; for more details, please refer to the official documentation.
Process
To spawn a new process, we create a Process
object, and then we call the start()
method.
import multiprocessing as mp
import time
def foo(x):
print(f"foo({x})")
time.sleep(2)
return x
p = mp.Process(target=foo, args=(1,))
p.start()
p.join()
foo(1)
---
Runtime: 2.0 seconds
Very straightforward! The join()
method blocks the main process until the spawned process finishes. Without it, the main process would exit immediately at the end without waiting for foo(1)
to complete.
Cross-process communication
For communication between two processes, we can use a Pipe
object. For communication between two or more processes, we can use a Queue
object.
import multiprocessing as mp
import time
def foo(x, q):
# Producer
print(f"Putting {x} in queue")
q.put(x)
def bar(q):
# Consumer
print(f"Got {q.get()} from queue")
queue = mp.Queue()
p1 = mp.Process(target=foo, args=(1, queue))
p2 = mp.Process(target=bar, args=(queue,))
p1.start()
p2.start()
p1.join()
p2.join()
Putting 1 in queue
Got 1 from queue
---
Runtime: 0.0 seconds
Pool
multiprocessing.Pool
creates a pool of processes, each of which is allocated a separate memory space. It is a context manager, so it can be used in a with
statement.
import multiprocessing as mp
with mp.Pool(processes=4) as pool:
# do something
Otherwise, be sure to close the pool when you are done with it.
pool = mp.Pool(processes=4)
# do something
pool.close()
pool.join()
Once pool.close()
is invoked, no more tasks can be submitted to the pool. Once all tasks are completed, the worker processes will exit (gracefully). On the other hand, if you want to terminate the pool immediately, you can use pool.terminate()
.
If you want to wait for all tasks to finish, you can use pool.join()
. One must call close()
or terminate()
before using join()
.
To do some actual work, in most cases you would want to use either apply
, map
, starmap
or their _async
variants.
apply
We can submit a function to the pool to be executed in a worker process by using pool.apply
.
import time
import multiprocessing as mp
def foo(x, y):
time.sleep(3)
return x + y
with mp.Pool(processes=4) as pool:
a = pool.apply(foo, (1, 2))
b = pool.apply(foo, (3, 4))
print(a, b)
3 7
---
Runtime: 6.0 seconds
We create a pool with 4 worker processes, and then submit two tasks to the pool to run. Since apply
is a blocking call, the main process will wait until the first task is completed before submitting the second task. This is mostly useless, because no parallelism is achieved here. In this case, if we want to run multiple tasks in parallel, we should use apply_async
like this
with mp.Pool(processes=4) as pool:
handle1 = pool.apply_async(foo, (1, 2))
handle2 = pool.apply_async(foo, (3, 4))
a = handle1.get()
b = handle2.get()
print(a, b)
3 7
---
Runtime: 3.0 seconds
apply_async
is non-blocking and returns a AsyncResult
object immediately. We can then use get
to get the result of the task.
Note that get
will block until the task is completed; apply(fn, args, kwargs)
is equivalent to apply_async(fn, args, kwargs).get()
.
On the other hand, we can add a callback to apply_async
to be executed when the task is completed:
def callback(result):
print(f"Got result: {result}")
with mp.Pool(processes=4) as pool:
handle1 = pool.apply_async(foo, (1, 2), callback=callback)
handle2 = pool.apply_async(foo, (3, 4), callback=callback)
Got result: 3
Got result: 7
---
Runtime: 3.0 seconds
In rare cases the second function may be completed before the first one, and correspondingly, the callback for the second function will be called before the callback for the first function.
Because the number of worker processes is limited, if all workers are busy when a new task is submitted, the task will be queued and executed later.
with mp.Pool(processes=2) as pool:
for _ in range(3):
pool.apply_async(foo, (1, 2))
---
Runtime: 6.0 seconds
In the example above, the first and second foo
calls are executed in the 2 workers, but the third has to wait until a worker becomes available.
map
and starmap
map
divides the input iterable into chunks and submits each chunk to the pool as a separate task. The results of the tasks are then gathered and returned as a list.
import multiprocessing as mp
import time
def foo(x):
print(f"Starting foo({x})")
time.sleep(2)
return x
with mp.Pool(processes=2) as pool:
result = pool.map(foo, range(10), chunksize=None)
print(result)
Starting foo(0)
Starting foo(2)
Starting foo(1)
Starting foo(3)
Starting foo(4)
Starting foo(6)
Starting foo(5)
Starting foo(7)
Starting foo(8)
Starting foo(9)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
---
Runtime: 12.0 seconds
In the example above, chunksize
is set to its default value None
. I’m not sure how the chunk size is determined, but it seems to scale with the length of the iterable argument. In this case, the chunksize is automatically calculated to be 2. This means the iterable is divided into 5 chunks of size 2: [0, 1], [2, 3], [4, 5], [6, 7], [8, 9]
. At first, the first two chunks are submitted to the 2 workers, and then the next two chunks are submitted. Finally, the last chunk [8, 9]
is submitted to the either worker: at that point only one worker would process this chunk. This is why the runtime is 12 seconds, which is sub-optimal. In this case, if we explicitly set the chunksize to 1 or 5, the runtime will be 10 seconds, which is as good as it gets.
map
is a blocking call, so it will wait until all tasks are completed before returning. Similar to apply
, we can use map_async
to submit tasks to the pool and get the results asynchronously.
with mp.Pool(processes=2) as pool:
handle = pool.map_async(foo, range(10), chunksize=None)
# do something else
result = handle.get()
print(result)
The limitation of map
is it simply passes the elements of the iterable to a function. Thus if we want to apply a multi-argument function, we either have to pass in a list and unpack it inside the function, which is ugly, or use starmap
. For each element of the iterable, starmap
will unpack it into the arguments of the function.
def bar(x, y):
print(f"Starting bar({x}, {y})")
time.sleep(2)
return x + y
with mp.Pool(processes=2) as pool:
pool.starmap(bar, [(1, 2), (3, 4), (5, 6)])
Starting bar(1, 2)
Starting bar(3, 4)
Starting bar(5, 6)
---
Runtime: 6.0 seconds
starmap
blocks. The async variant starmap_async
is also available and do the exact thing that you would expect.
torch.multiprocessing
The official documentation for torch.multiprocessing
is here. Also checkout the best practices documentation.
torch.multiprocessing
is a wrapper of multiprocessing
with extra functionalities, which API is fully compatible with the original module, so we can use it as a drop-in replacement. Let’s try running an example from the previous section, but using torch.multiprocessing
:
import torch.multiprocessing as mp
import time
def foo(x):
print(f"Starting foo({x})")
time.sleep(2)
return x
with mp.Pool(processes=2) as pool:
result = pool.map(foo, range(10), chunksize=None)
print(result)
Starting foo(0)
Starting foo(2)
Starting foo(1)
Starting foo(3)
Starting foo(4)
Starting foo(6)
Starting foo(5)
Starting foo(7)
Starting foo(8)
Starting foo(9)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
---
Runtime: 12.0 seconds
No difference from the previous example!
multiprocessing
supports 3 process start methods: fork (default on Unix), spawn (default on Windows and MacOS), and forkserver. To use CUDA in subprocesses, one must use either forkserver or spawn. The start method should be set once by using set_start_method()
in the if __name__ == '__main__'
clause of the main module:
import torch.multiprocessing as mp
if __name__ == '__main__':
mp.set_start_method('forkserver')
...
Sharing tensors
torch
makes use of 2 sharing strategies for CPU tensors: file descriptor (default) and file system.
It is recommended to use the queue strategy above to share tensors between processes. The tensors must be in shared memory, and they will be automatically moved to shared memory once Queue.put(tensor)
is called if they are not already. Queue.get()
returns a handle to the tensor in shared memory.
To manually move a tensor to shared memory, we can use Tensor.share_memory_()
. This is a no-op if the tensor is already in shared memory, or if the tensor is a CUDA tensor. For nn.Module
, we can move the module to shared memory by calling .share_memory()
.
To check if a tensor is in shared memory, we can use Tensor.is_shared()
.
import torch.multiprocessing as mp
import time
mat = torch.randn((200, 200))
print(mat.is_shared())
queue = mp.Queue()
q.put(a)
print(a.is_shared())
False
True
---
Runtime: 0.0 seconds
Again, when we put a tensor into the queue, it is automatically moved to shared memory, that is why the second check returns True
.
Note that if Tensor.grad
is not None
, it is also shared.
If the provider process exits while its tensor is still in a shared queue, attempts to get the tensor will raise an exception.
import torch
import torch.multiprocessing as mp
import time
def foo(q):
q.put(torch.randn(20, 20))
q.put(torch.randn(10, 10))
time.sleep(3)
def bar(q):
t1 = q.get()
print(f"Received {t1.size()}")
time.sleep(4)
t2 = q.get()
print(f"Received {t2.size()}")
if __name__ == "__main__":
mp.set_start_method('spawn')
queue = mp.Queue()
p1 = mp.Process(target=foo, args=(queue,))
p2 = mp.Process(target=bar, args=(queue,))
p1.start()
p2.start()
p1.join()
p2.join()
Received torch.Size([20, 20])
Process Process-2:
Traceback (most recent call last):
...
File "/home/term1nal/miniconda3/envs/ML/lib/python3.9/multiprocessing/connection.py", line 635, in SocketClient
s.connect(address)
ConnectionRefusedError: [Errno 111] Connection refused
---
Runtime: 4.0 seconds
The first tensor is consumed while the provider process (running foo
) is still alive. The second tensor is consumed when the provider process already exited, thus raising an error.
So make sure you consume the tensors in queue before the provider process exits, or employ some waiting mechanism on the provider side.
Sharing CUDA tensors
It is basically the same as above, but must be handled with a bit more care.
CUDA tensors always use the CUDA API, and that is the only mechanism through which CUDA tensors can be shared. Tensor.share_memory_()
is a no-op for CUDA tensors.
Unlike CPU tensors, it is required to keep the provider running as long as any consumer processes have references to a CUDA tensor. Once the consumer is done with the tensor, it should explicitly call del
to release the memory. The following example is a bad practice:
import torch
import torch.multiprocessing as mp
import time
def foo(q):
q.put(torch.randn(20, 20).cuda())
time.sleep(2)
def bar(q):
tensor = q.get()
time.sleep(2) # delibrately sleep to make sure that foo is done
print(f"Received {tensor.size()}")
if __name__ == "__main__":
mp.set_start_method('spawn')
queue = mp.Queue()
p1 = mp.Process(target=foo, args=(queue,))
p2 = mp.Process(target=bar, args=(queue,))
p1.start()
p2.start()
p1.join()
p2.join()
[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
Received torch.Size([20, 20])
spawn
Creating multiple processes is hideous. If we want to start multiple processes running a function, we can do it like this:
import torch.multiprocessing as mp
def foo():
pass
if __name__ == "__main__":
num_proc = 4
processes = [mp.Process(target=foo)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
The problem lies in the join part. If the first process does not terminate, the termination of others will go unnoticed; and there are no facilities for error propagation.
spawn
takes care of error propagation, out of order termination, and will actively terminate processes upon detecting an error in one of them.
import torch.multiprocessing as mp
def foo(idx):
pass
if __name__ == "__main__":
mp.spawn(foo, args=(), nprocs=4, join=True)
The function fn
passed to spawn
(foo
in this case) will be called as fn(idx, *args)
, where idx
is the index of the process.
Closing remarks
The knowledge covered in this article should familiarize you with basic multiprocessing in Python/PyTorch. Checkout the next article of the series where we will discuss distributed communication in PyTorch.