Cannot find result of scattered array?
See original GitHub issueHello everyone,
I am currently trying to run a simple script in Jupyter that takes a 4 x 2 array and scatters it to all available cores (there are 4). Afterwards, each element in the scattered 1 x 2 array is multiplied by 2. Here’s the script.
%%file test2.py
from time import time
from mpi4py import MPI
import jax
import numpy as np
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
time_init = time()
print("hello from rank", rank)
@jax.jit
def foo():
if rank == 0:
arr = jnp.arange(1,9).reshape((4, 2))
print(arr)
arr_sum, data = mpi4jax.scatter(arr, root=0, comm=comm)
print(arr_sum)
print("This took", time() - time_init, "s")
foo()
arr_sum *= 2
And to execute, I write the following:
!mpirun -np 4 -x PMIX_MCA_gds=^ds12 python test2.py | tee finding_scatter.txt
the -x PMIX_MCA_gds=^ds12
is used to suppress the output that is sometimes seen
However, whenever I try calling the scattered data after scattering, it doesn’t show up - and I get this error:
hello from rank 0
hello from rank 1
hello from rank 3
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "test2.py", line 25, in <module>
arr_sum *= 2
NameError: name 'arr_sum' is not defined
Traceback (most recent call last):
File "test2.py", line 25, in <module>
arr_sum *= 2
NameError: name 'arr_sum' is not defined
Traced<ShapedArray(int32[4,2])>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=0/1)>
This took 0.009289026260375977 s
hello from rank 2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "test2.py", line 25, in <module>
arr_sum *= 2
NameError: name 'arr_sum' is not defined
Traceback (most recent call last):
File "test2.py", line 25, in <module>
arr_sum *= 2
NameError: name 'arr_sum' is not defined
--------------------------------------------------------------------------
Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:
Process name: [[30243,1],1]
Exit code: 1
What’s going on? From the error itself, I get that it’s trying to search for arr_sum
but it cannot find it because it’s apparently not defined. But how can that be, since I’ve presumably scattered a fragment to each core? Shouldn’t assigning arr_sum
to each scattered array fragment be enough to define it, thereby allowing each core to automatically identify their respective scattered fragment and perform calculations with them? I am trying to run mpi4jax as analogously as I would I run regular MPI4Py code, but it doesn’t seem to be working. I’ve tried putting the scatter code anywhere else but in the if rank == 0
statement, and that doesn’t work either. My guess is that there is probably some syntax or formatting issue that I cannot find in my test2.py
file.
Issue Analytics
- State:
- Created 2 years ago
- Comments:10
Ok, I see. Well, if none of you mind, I’ll close the issue. I think I’ve solved my biggest issues here. Again, thank you to both of you for helping.
Yes we are saying the same thing Il 21 mag 2021, 22:46 +0200, ISDementyev @.***>, ha scritto: