question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Communication timeout on DGX-2 with UCX+NVLink

See original GitHub issue

We’ve been experiencing a communication timeout on a DGX-2 with UCX+NVlink for a particular problem. The code is below, but I can’t share data due to its size (210GB):

import time
import sys
import dask_cudf
import cudf
import cupy
import numpy as np
import cupy as cp
from distributed import wait

sys.path.append("../../tools/")
from readers import build_reader

from dask.distributed import Client
from dask_cuda import DGX
from dask_cuda.initialize import initialize

from dask_cuda import LocalCUDACluster

# ON/OFF settings for various devices
enable_tcp_over_ucx = True
enable_nvlink = True
enable_infiniband = False
interface="enp134s0f1"
protocol="ucx"


if __name__ == "__main__":
    # initialize client with the same settings as workers
    initialize(create_cuda_context=True,
               enable_tcp_over_ucx=enable_tcp_over_ucx,
               enable_infiniband=enable_infiniband,
               enable_nvlink=enable_nvlink)

    if protocol == "tcp":
        cluster = LocalCUDACluster(protocol='tcp',
                                   silence_logs=False,
                                   interface=interface,
                                   CUDA_VISIBLE_DEVICES=list(range(16)))
    elif protocol == "ucx":
        cluster = LocalCUDACluster(protocol='ucx',
                                   silence_logs=False,
                                   enable_tcp_over_ucx=enable_tcp_over_ucx,
                                   enable_infiniband=enable_infiniband,
                                   enable_nvlink=enable_nvlink,
                                   interface=interface,
                                   ucx_net_devices=interface,
                                   CUDA_VISIBLE_DEVICES=list(range(16)))

    client = Client(cluster)
    client.run(cudf.set_allocator, "default", pool=True)

    data_dir = '/data/parquet/'
    q30_session_timeout_inSec = 3600
    q30_limit = 1000
    file_format = "parquet"

    from dask.distributed import performance_report
    with performance_report(filename="dask_report.html"):

        t = time.time()
        def vec_arange(end_sr):
            """
            Returns flattented arange output with start=0, and for end<end_sr[i] in series
            """
            ar = cp.arange(end_sr.max())

            ### get flag matrix for values that are < end
            m = ar < end_sr.values[:, None]

            # ranges in a flattened 1d array
            ranges = (ar * m)[m]
            return ranges


        def get_session_id_from_session_boundry(session_change_df, last_session_len):
            """
                This function returns session starts given a session change df
            """

            user_val_counts = session_change_df.wcs_user_sk.value_counts(sort=False)
            user_val_counts = user_val_counts.reset_index(drop=False)
            user_val_counts = user_val_counts.rename(
                {"index": "wcs_user_sk", "wcs_user_sk": "user_count"}
            )

            ### sort again by user_sk because we want our starts to be aligned
            user_val_counts = user_val_counts.sort_values(by="wcs_user_sk").reset_index(
                drop=True
            )

            end_range = user_val_counts["user_count"]

            user_session_ids = vec_arange(end_range)

            ### up shift the session length df
            session_len = session_change_df["t_index"].diff().reset_index(drop=True)
            session_len = session_len.shift(-1)
            session_len.iloc[-1] = last_session_len

            session_id_final_series = (
                cudf.Series(user_session_ids).repeat(session_len).reset_index(drop=True)
            )
            return session_id_final_series


        def get_session_id(df, time_out):
            """
                This function creates a session id column for each click
                The session id grows in incremeant for each user's susbequent session
                Session boundry is defined by the time_out 
            """

            df["user_change_flag"] = df["wcs_user_sk"].diff(periods=1) != 0
            df["time_delta"] = df["tstamp_inSec"].diff(periods=1)
            df["session_timeout_flag"] = df["tstamp_inSec"].diff(periods=1) > time_out

            df["session_change_flag"] = df["session_timeout_flag"] | df["user_change_flag"]

            # print(f"Total session change = {df['session_change_flag'].sum():,}")

            cols_keep = ["wcs_user_sk", "i_category_id", "session_change_flag"]
            df = df[cols_keep]

            df = df.reset_index(drop=True)
            df["t_index"] = cudf.utils.cudautils.arange(start=0, stop=len(df), dtype=np.int32)

            session_change_df = df[df["session_change_flag"]]
            last_session_len = len(df) - session_change_df["t_index"].iloc[-1]

            session_ids = get_session_id_from_session_boundry(
                session_change_df, last_session_len
            )

            assert len(session_ids) == len(df)
            return session_ids

        table_reader = build_reader(file_format, basepath=data_dir)

        wcs_cols = ["wcs_user_sk", "wcs_item_sk", "wcs_click_date_sk", "wcs_click_time_sk"]
        wcs_df = table_reader.read("web_clickstreams", relevant_cols=wcs_cols)

        item_cols = ["i_category_id", "i_item_sk"]
        item_df = table_reader.read("item", relevant_cols=item_cols)

        f_wcs_df = wcs_df[wcs_df["wcs_user_sk"].notnull()]
        f_item_df = item_df[item_df["i_category_id"].notnull()]

        merged_df = f_wcs_df.merge(f_item_df, left_on=["wcs_item_sk"], right_on=["i_item_sk"])

        merged_df["tstamp_inSec"] = (
            merged_df["wcs_click_date_sk"] * 24 * 60 * 60 + merged_df["wcs_click_time_sk"]
        )
        cols_keep = ["wcs_user_sk", "tstamp_inSec", "i_category_id"]
        merged_df = merged_df[cols_keep]

        ### that the click for each user ends up at the same partition
        merged_df = merged_df.set_index("wcs_user_sk")
        merged_df = merged_df.reset_index(drop=False)

        def get_sessions(df):
            df = df.sort_values(by=["wcs_user_sk", "tstamp_inSec"]).reset_index(drop=True)
            df["session_id"] = get_session_id(df, q30_session_timeout_inSec)
            df = df[["wcs_user_sk", "i_category_id", "session_id"]]
            return df

        session_df = merged_df.map_partitions(get_sessions)
        del merged_df


        def get_distinct_sessions(df):
            df = df.drop_duplicates().reset_index(drop=True)
            return df

        distinct_session_df = session_df.map_partitions(get_distinct_sessions)

        ### get_pair_helper
        def get_pairs(
            df,
            merge_col=["session_id", "wcs_user_sk"],
            pair_col="i_category_id",
            output_col_1="category_id_1",
            output_col_2="category_id_2",
        ):
            """
                Gets pair after doing a inner merge
            """
            pair_df = df.merge(df, on=merge_col, suffixes=["_t1", "_t2"], how="inner")
            pair_df = pair_df[[f"{pair_col}_t1", f"{pair_col}_t2"]]
            pair_df = pair_df[pair_df[f"{pair_col}_t1"] < pair_df[f"{pair_col}_t2"]]
            pair_df = pair_df.rename(
                columns={f"{pair_col}_t1": output_col_1, f"{pair_col}_t2": output_col_2}
            )
            return pair_df

        pair_df = distinct_session_df.map_partitions(get_pairs)
        del distinct_session_df

        print("Time:", time.time() - t)

The problem with this code is that communication times out, but that can be circumvented by increasing distributed.comm.timeouts.connect, still I believe this is just a consequence of a problem with how data transfer behaves. By watching the Dask taskstream I noticed that data transfer takes longer as time passes, the first transfers are of hundreds of milliseconds, then gradually start increasing to few seconds, tens and eventually hundreds of seconds. I’ve uploaded a Dask performance report of the code above, it was extracted on a DGX-2 running with 16 GPUs and UCX+NVLink. It’s worth noting this doesn’t happen with TCP communication.

It seems that this issue is created by a combination of factors, certainly small data transfers over NVLink are not efficient without caching CUDA IPC handles, and that probably contributes. I believe there’s some sort of issue with a blocking task somewhere that eventually causes transfer times to be computed as if they were longer than they actually are, when they are actually waiting for some blocking task to finish.

Any ideas here @mrocklin @quasiben @jakirkham @madsbk ?

cc @randerzander @beckernick @VibhuJawa

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:28 (25 by maintainers)

github_iconTop GitHub Comments

1reaction
jakirkhamcommented, May 5, 2020

Was going to add the option to fallback to CuPy in Distributed as CuPy also has a memory pool and is likely in use by users working with Python on a GPU.

IOW the order would be RMM -> CuPy -> Numba

1reaction
pentschevcommented, May 5, 2020

I’m somewhat convinced that this problem will remain when users can’t use an RMM pool or disable the IPC cache. In such situations, all transfers would require mapping the memory handle which has a cost in the range of 100 ms. For that reason, I think we should encourage users to always use an RMM pool and turn on IPC cache (in UCX-Py we’ve enabled it back by default), in principle it would work without those but will offer very little benefit (if any), therefore I suggest we keep things as is and tell users to rely on RMM pool+IPC cache as there isn’t much we’ll be able to do otherwise. Any objections to this proposal?

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found