.. _ReferenceMPMD:

Multiple Program Multiple Data (MPMD)
=====================================


The AMReX-MPMD interface is used to send data to another program or application. In order to enable this feature, the executable
has to be built with ``-DMFIX_MPMD = yes``.

:ref:`Input parameters controlling the AMReX-MPMD interface<InputsMPMD>` are defined in the User Guide run-time inputs section.


Sample Python Program
----------------------

A sample python script that gathers and plots velocity statistics for the case of fluid flow through a pipe can be found
in ``tutorials/mpmd/main.py``. The script can be divided into the following sections:

Initialize
~~~~~~~~~~

* Initialize AMReX::MPMD and leverage MPI from ``mpi4py`` to perform communication split.

    .. code-block:: python

       amr.MPMD_Initialize_without_split([])
       app_comm = MPI.COMM_WORLD.Split(amr.MPMD_AppNum(), amr.MPMD_MyProc())
       app_world_size = app_comm.Get_size()
       app_rank = app_comm.Get_rank()
       amr.initialize_when_MPMD([], app_comm)

* Determine the C++ app's root process

    .. code-block:: python

       if app_rank == 0:
           if amr.MPMD_MyProc() == app_rank: # first program
               other_root = app_comm.Get_size()
           print(f'other_root = {other_root}')

* Create an MPMD::Copier object that gets the BoxArray information from the C++ app.

    .. code-block:: python

       copr = amr.MPMD_Copier(True)


Receive Once
~~~~~~~~~~~~

* Receive the ``Header`` information as a json string on the python root from the C++ root
  and broadcast to all python ranks.

    .. code-block:: python

       header_json = ""

       if app_rank == 0:
           buf = bytearray(10000)  # Create a buffer to receive the message
           MPI.COMM_WORLD.Recv([buf, MPI.CHAR], source=other_root)
           header_json = buf.decode().strip('\x00')  # Decode and strip null characters

       header_json = app_comm.bcast(header_json, root=0)

* Receive all the static ``Multifab`` data.

    .. code-block:: python

       my_static_data = MyData()
       for mf in header["data"]["static_mfs"]:
           my_static_data.define_mf(copr, mf["n"], mf["c"])
           my_static_data.copy_mf(copr, mf["n"], mf["c"])


Receive Until *End*
~~~~~~~~~~~~~~~~~~~

* Receive ``End Flag`` on the python root from the C++ root and broadcast to all python
  ranks. If the flag is ``1``, break out of the loop.

    .. code-block:: python

       if app_rank == 0:
           int_flags = np.empty(len(header["data"]["int_flags_root"]), dtype='i')
           MPI.COMM_WORLD.Recv(int_flags, source=other_root)
           print(f"app_rank = {app_rank}, int_flags = {int_flags})")
           end = int_flags[0]

       end = app_comm.bcast(end, root=0)

       if end == 1:
           break

* Receive ``Reals`` on the python root from the C++ root and broadcast to all python
  ranks. Save ``time`` to an array on the python root for plotting.

    .. code-block:: python

       if app_rank == 0:
           reals = np.empty(len(header["data"]["reals_root"]), dtype=np.double)
           MPI.COMM_WORLD.Recv(reals, source=other_root)
           print(f"app_rank = {app_rank}, reals = {reals})")
           time = reals[0]
           time_arr.append(time)

       time = app_comm.bcast(time, root=0)

* Receive ``MultiFab`` data and store to arrays on the python root as necessary.
  In this example, the ``centerline`` u-velocity and the data needed to compute the
  mean and variance of u-velocity on the central ``y-plane`` are stored as an array
  in time.

    .. code-block:: python

       for mf in header["data"]["mfs"]:
           my_data.copy_mf(copr, mf["n"], mf["c"])

       for mfi in my_data.mfs["vel_g"]:
           bx = mfi.validbox()
           y_intrst_exists = True
           z_intrst_exists = True

           if (j_intrst < bx.small_end[1] or j_intrst > bx.big_end[1]):
               y_intrst_exists = False

           if (k_intrst < bx.small_end[2] or k_intrst > bx.big_end[2]):
               z_intrst_exists = False

           if (not((y_intrst_exists or z_intrst_exists))):
               continue

           ###..............

           if ( y_intrst_exists and z_intrst_exists ):
               np_array = np.array(vel_g_array[0,k_intrst,j_intrst,:])
               u_centerline[bx.small_end[0]:bx.small_end[0] + np_array.size] = np_array

           if (y_intrst_exists):

           ###..............

               y_pl_npts += np.sum(y_volfrac_array)
               y_pl_u_mn += np.sum(y_vel_g_array[0,:,:])
               y_pl_u2_mn += np.sum(y_vel_g_array[0,:,:]*y_vel_g_array[0,:,:])

       # Reduce from all python ranks
       y_pl_npts = app_comm.reduce(y_pl_npts,op=MPI.SUM,root=0)
       y_pl_u_mn = app_comm.reduce(y_pl_u_mn,op=MPI.SUM,root=0)
       y_pl_u2_mn = app_comm.reduce(y_pl_u2_mn,op=MPI.SUM,root=0)
       u_centerline = app_comm.reduce(u_centerline,op=MPI.SUM,root=0)

       #...................

       if app_rank == 0:
           y_pl_u_mn /= y_pl_npts
           y_pl_u2_mn /= y_pl_npts

           y_pl_u_mn_arr.append(y_pl_u_mn)
           y_pl_u_var_arr.append(y_pl_u2_mn-y_pl_u_mn*y_pl_u_mn)
           u_centerline_arr.append(u_centerline)

       app_comm.barrier()


Plot
~~~~

* Plot figures on the python root using the collected arrays.

    .. code-block:: python

       if app_rank == 0:
           fig, (ax1, ax2, ax3) = plt.subplots(1,3)

           ax1.plot(time_arr, y_pl_u_mn_arr)
           ax1.set_title('mean')
           ax1.set_xlabel('Time (s)')

           ax2.plot(time_arr, y_pl_u_var_arr)
           ax2.set_title('var')
           ax2.set_xlabel('Time (s)')

           ax3.plot(range(xlen), np.array(u_centerline_arr).mean(axis=0))
           ax3.set_title('centerline U (m/s)')
           ax3.set_xlabel('i')

           plt.savefig('my_plot.png')


Finalize
~~~~~~~~

* Finalize AMReX and AMReX::MPMD.

    .. code-block:: python

       amr.finalize()
       amr.MPMD_Finalize()