#!/usr/bin/env python # coding: utf-8 # # Map-style computation via SPMD using IPython Parallel Broadcast View # # In order to use BroadcastView efficiently, tasks must be able to be expressed as a single `apply` (or `execute`) call. # This is best done through [SPMD-style](https://en.wikipedia.org/wiki/Single_program,_multiple_data) tasks. # # If you've written code for MPI, this pattern ought to be familiar. # # The main thing that differs SPMD tasks is that each engine gets the same _code_ to execute, # but the results of what the execute depend on the _state_ of the process. # Typically the engine's _rank_ in the cluster and the cluster _size_, though it can get more complex. # # So rather than calling `map(func, inputs)`, you call `apply(map_func)`, where `map_func` computes a _partition_ and calls the map. # Computing the partition becomes part of the task itself. # So the first thing we need is a function to compute the partition, given inputs, rank, and size. # Here is a very simple example partitioning function for one-dimensional sequences (e.g. lists) # In[1]: def get_partition(n_items: int, rank: int, size: int)-> tuple[int, int]: """ Compute the partition Returns (start, end) of partition """ chunk_size = n_items // size if n_items % size: chunk_size += 1 start = rank * chunk_size if rank + 1 == size: end = n_items else: end = start + chunk_size return (start, end) # In[2]: n = 10 for size in (3, 4, 5): for rank in range(size): print(size, rank, get_partition(n, rank, size)) # Now set up our fake workload. # # It is a bunch of random files. # For our purposes, 5 files per engine. # In[3]: import tempfile from pathlib import Path tmp_dir = tempfile.TemporaryDirectory() tmp_path = Path(tmp_dir.name) n_engines = 100 tasks_per_engine = 5 n_items = n_engines * tasks_per_engine for i in range(n_items): with (tmp_path / f"file-{i:03}.txt").open("wb") as f: f.write(os.urandom(1024)) # In[4]: input_files = list(tmp_path.glob("*.txt")) len(input_files) # Here's our task: compute the md5sum of the contents of one file # In[5]: from hashlib import md5 def compute_one(fname): hasher = md5() with open(fname, "rb") as f: hasher.update(f.read()) return hasher.hexdigest() # In[6]: compute_one(input_files[0]) # In[7]: get_ipython().run_line_magic('time', 'local_result = list(map(compute_one, input_files))') # Now we need to define a task that takes as input: # # - tmp_path (same everywhere) # - rank (unique per engine) # - size (same everywhere) # # which will compute the same thing as computing a chunk of `map(compute_one, input_files)` # In[8]: def spmd_task(tmp_path, rank, size): # identify all inputs all_input_files = list(Path(tmp_path).glob("*.txt")) # partition inputs n_items = len(all_input_files) start, end = get_partition(n_items, rank, size) my_input_files = all_input_files[start:end] # compute result return list(map(compute_one, my_input_files)) # In[9]: for rank in range(5): print(spmd_task(tmp_path, rank, n_items // 2)) # In[10]: local_result[:10] # Now it's time to do it in parallel # In[11]: import logging import ipyparallel as ipp try: # stop previous cluster, if we're re-running cells cluster.stop_cluster_sync() except NameError: pass cluster = ipp.Cluster(n=n_engines, log_level=logging.WARNING) rc = cluster.start_and_connect_sync() # In[12]: broadcast_view = rc.broadcast_view() # Distribute rank and size. # This is unnecessary if engines are created with MPI. # In[13]: broadcast_view.scatter("rank", rc.ids, flatten=True) broadcast_view["size"] = size = len(rc) # enable cloudpickle, which handles imports; # we could also explicitly push everything we are going to use. # In[14]: broadcast_view.use_cloudpickle().get(); # We can now send this SPMD task as a _single_ task on all engines, # each of which will compute its own partition as part of the task and do its work: # In[15]: ar = broadcast_view.apply(spmd_task, tmp_path, ipp.Reference("rank"), size) # Finally, we can reconstruct the result: # # Because we called `apply`, the result is a list of lists, when we want a single flat sequence. # `itertools.chain` takes care of that! # In[16]: ar.get()[:2] # In[17]: from itertools import chain parallel_result = list(chain(*ar)) # In[18]: parallel_result[:6], local_result[:6] # In[19]: assert parallel_result == local_result # In[20]: cluster.stop_cluster_sync()