[KED-1458] Versioning extremely slow on DBFS
See original GitHub issueDescription
Our pipeline running on Azure Databricks has gotten progressively slower. Somebody noticed that running on a fresh set of paths (without so many versions) was significantly faster. Further investigation yielded that it wasn’t because of the data itself; instead, it’s finding the existing versions that’s prohibitively slow. Specifically, underlying functions used by iglob
(like os.scandir
) are much slower than their DBFS-native counterparts (e.g. dbutils.fs.ls
).
Context
Pipelines that should take less than 2 hours are taking 3-5 times that.
Steps to Reproduce
On a Databricks cluster:
- Save a versioned
SparkDataSet
. - Load it back.
- Save the same dataset 500 more times.
- Load it back.
Expected Result
Time taken for Step 4 should remain quite similar to that for Step 2.
Actual Result
Time explodes. 🌋
Possible Implementation
See DBFSDirEntry
, _get_dbutils
, _dbfs_scandir
, and patches below:
# Copyright 2020 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.
"""``AbstractDataSet`` implementation to access Spark data frames using
``pyspark``
"""
from contextlib import contextmanager
from copy import deepcopy
from fnmatch import fnmatch
from pathlib import Path, PurePosixPath
from typing import Any, Dict, List, Tuple
from unittest.mock import patch
from warnings import warn
import IPython
from hdfs import HdfsError, InsecureClient
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.utils import AnalysisException
from s3fs import S3FileSystem
from kedro.contrib.io import DefaultArgumentsMixIn # isort:skip
from kedro.io import AbstractVersionedDataSet, Version # isort:skip
def _parse_glob_pattern(pattern: str) -> str:
special = ("*", "?", "[")
clean = []
for part in pattern.split("/"):
if any(char in part for char in special):
break
clean.append(part)
return "/".join(clean)
def _split_filepath(filepath: str) -> Tuple[str, str]:
split_ = filepath.split("://", 1)
if len(split_) == 2:
return split_[0] + "://", split_[1]
return "", split_[0]
def _strip_dbfs_prefix(path: str) -> str:
return path[len("/dbfs") :] if path.startswith("/dbfs") else path
class DBFSDirEntry:
"""Mock ``DirEntry`` object yielded by DBUtils-based ``scandir``."""
def __init__(self, file_info):
self.name = file_info.name
self.path = file_info.path
self._is_dir = file_info.isDir()
def is_dir(self):
"""Return True if the entry is a directory."""
return self._is_dir
def _get_dbutils():
return IPython.get_ipython().user_ns["dbutils"]
@contextmanager
def _dbfs_scandir(path):
yield map(DBFSDirEntry, _get_dbutils().fs.ls(_strip_dbfs_prefix(path)))
class KedroHdfsInsecureClient(InsecureClient):
"""Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists``
and ``hdfs_glob`` methods required by ``SparkDataSet``"""
def hdfs_exists(self, hdfs_path: str) -> bool:
"""Determines whether given ``hdfs_path`` exists in HDFS.
Args:
hdfs_path: Path to check.
Returns:
True if ``hdfs_path`` exists in HDFS, False otherwise.
"""
return bool(self.status(hdfs_path, strict=False))
def hdfs_glob(self, pattern: str) -> List[str]:
"""Perform a glob search in HDFS using the provided pattern.
Args:
pattern: Glob pattern to search for.
Returns:
List of HDFS paths that satisfy the glob pattern.
"""
prefix = _parse_glob_pattern(pattern) or "/"
matched = set()
try:
for dpath, _, fnames in self.walk(prefix):
if fnmatch(dpath, pattern):
matched.add(dpath)
matched |= set(
"{}/{}".format(dpath, fname)
for fname in fnames
if fnmatch("{}/{}".format(dpath, fname), pattern)
)
except HdfsError: # pragma: no cover
# HdfsError is raised by `self.walk()` if prefix does not exist in HDFS.
# Ignore and return an empty list.
pass
return sorted(matched)
class SparkDataSet(DefaultArgumentsMixIn, AbstractVersionedDataSet):
"""``SparkDataSet`` loads and saves Spark data frames.
Example:
::
>>> from pyspark.sql import SparkSession
>>> from pyspark.sql.types import (StructField, StringType,
>>> IntegerType, StructType)
>>>
>>> from kedro.contrib.io.pyspark import SparkDataSet
>>>
>>> schema = StructType([StructField("name", StringType(), True),
>>> StructField("age", IntegerType(), True)])
>>>
>>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)]
>>>
>>> spark_df = SparkSession.builder.getOrCreate()\
>>> .createDataFrame(data, schema)
>>>
>>> data_set = SparkDataSet(filepath="test_data")
>>> data_set.save(spark_df)
>>> reloaded = data_set.load()
>>>
>>> reloaded.take(4)
"""
def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._fs_prefix + str(self._filepath),
file_format=self._file_format,
load_args=self._load_args,
save_args=self._save_args,
version=self._version,
)
def __init__( # pylint: disable=too-many-arguments
self,
filepath: str,
file_format: str = "parquet",
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
credentials: Dict[str, Any] = None,
) -> None:
"""Creates a new instance of ``SparkDataSet``.
Args:
filepath: path to a Spark data frame. When using Databricks
and working with data written to mount path points,
specify ``filepath``s for (versioned) ``SparkDataSet``s
starting with ``/dbfs/mnt``.
file_format: file format used during load and save
operations. These are formats supported by the running
SparkContext include parquet, csv. For a list of supported
formats please refer to Apache Spark documentation at
https://spark.apache.org/docs/latest/sql-programming-guide.html
load_args: Load args passed to Spark DataFrameReader load method.
It is dependent on the selected file format. You can find
a list of read options for each supported format
in Spark DataFrame read documentation:
https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame
save_args: Save args passed to Spark DataFrame write options.
Similar to load_args this is dependent on the selected file
format. You can pass ``mode`` and ``partitionBy`` to specify
your overwrite mode and partitioning respectively. You can find
a list of options for each format in Spark DataFrame
write documentation:
https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
credentials: Credentials to access the S3 bucket, such as
``aws_access_key_id``, ``aws_secret_access_key``, if ``filepath``
prefix is ``s3a://`` or ``s3n://``. Optional keyword arguments passed to
``hdfs.client.InsecureClient`` if ``filepath`` prefix is ``hdfs://``.
Ignored otherwise.
"""
credentials = deepcopy(credentials) or {}
fs_prefix, filepath = _split_filepath(filepath)
if fs_prefix in ("s3a://", "s3n://"):
if fs_prefix == "s3n://":
warn(
"`s3n` filesystem has now been deprecated by Spark, "
"please consider switching to `s3a`",
DeprecationWarning,
)
_s3 = S3FileSystem(client_kwargs=credentials)
exists_function = _s3.exists
glob_function = _s3.glob
path = PurePosixPath(filepath)
elif fs_prefix == "hdfs://" and version:
warn(
"HDFS filesystem support for versioned {} is in beta and uses "
"`hdfs.client.InsecureClient`, please use with caution".format(
self.__class__.__name__
)
)
# default namenode address
credentials.setdefault("url", "http://localhost:9870")
credentials.setdefault("user", "hadoop")
_hdfs_client = KedroHdfsInsecureClient(**credentials)
exists_function = _hdfs_client.hdfs_exists
glob_function = _hdfs_client.hdfs_glob
path = PurePosixPath(filepath)
else:
exists_function = glob_function = None # type: ignore
path = Path(filepath) # type: ignore
super().__init__(
load_args=load_args,
save_args=save_args,
filepath=path,
version=version,
exists_function=exists_function,
glob_function=glob_function,
)
self._file_format = file_format
self._fs_prefix = fs_prefix
@staticmethod
def _get_spark():
return SparkSession.builder.getOrCreate()
@patch("os.scandir", _dbfs_scandir)
def _load(self) -> DataFrame:
with patch("os.path.lexists", return_value=True):
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))
self._logger.info('`load_path = "%s"`', load_path)
return self._get_spark().read.load(
load_path, self._file_format, **self._load_args
)
@patch("os.scandir", _dbfs_scandir)
def _save(self, data: DataFrame) -> None:
with patch("os.path.lexists", return_value=True):
save_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path()))
self._logger.info('`save_path = "%s"`', save_path)
data.write.save(save_path, self._file_format, **self._save_args)
@patch("os.scandir", _dbfs_scandir)
def _exists(self) -> bool:
with patch("os.path.lexists", return_value=True):
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))
self._logger.info('`load_path = "%s"`', load_path)
try:
self._get_spark().read.load(load_path, self._file_format)
except AnalysisException as exception:
if exception.desc.startswith("Path does not exist:"):
return False
raise
return True
I believe it’s fine to have DBFS-specific code, as we already do. However, a few necessary enhancements:
- Detect DBFS (e.g. look for
DATABRICKS_RUNTIME_VERSION
environment variable) and apply patches conditionally. - Apply to any dataset, not just
SparkDataSet
, when the filesystem is DBFS (versioning should be the same). => Actually implement this code as part of the versioning mix-in?
Possible Alternatives
- Define a DBFS glob function from scratch (DBUtils doesn’t provide one AFAIK).
- Actually manage (archive?) versions instead of keeping hundreds. ¯\(ツ)/¯
Your Environment
Include as many relevant details about the environment in which you experienced the bug:
- Kedro version used (
pip show kedro
orkedro -V
):0.15.5
- Python version used (
python -V
): Python 3.7.3 - Operating system and version: Databricks Runtime Version 5.5 Conda Beta (includes Apache Spark 2.4.3, Scala 2.11)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top GitHub Comments
This is a great writeup, thanks for such a detailed issue 👌
Fixed in merge commit: 6bf1066