Streaming connections interrupted for extremely long time series
See original GitHub issueDescription Using streaming inference with Triton tends to lead to broken connections after ~millions of inferences that interrupt the service and aren’t easy to reconnect. How many inferences before it happens, and even whether it happens at all, is somewhat inconsistent, though by rough eyeball I see it ~50% of the time. The issues as reported on the client side range, but the most common I’ve seen in production are
[StatusCode.UNAVAILABLE] Connection reset by peer
[StatusCode.UNAVAILABLE] Socket closed
and additionally in the repro I’m providing
inference request for sequence 1001 to model 'mlp' must specify the START flag on the first request of the sequence
even though the same sequence has been going for millions of inferences.
No logs or issues get reported by the server whenever this happens.
Triton Information v2.5.0 container build Using more recent versions is difficult because of instabilities in the corresponding versions of TensorRT, but if this is a known issue that’s been fixed in more recent versions it’s not necessarily out of the question.
To Reproduce The model used doesn’t necessarily matter, but for repro purposes the following code should suffice to export a model the exhibits these issues (even though it’s not truly stateful):
import argparse
import os
import tensorflow as tf
from tritonclient.grpc import model_config_pb2 as model_config
def main(
repo_dir: str,
model_name: str,
model_version: int = 1,
input_dim: int = 1024
):
# create the repo if it doesn't exist
output_dir = os.path.join(repo_dir, model_name)
if not os.path.exists(os.path.join(output_dir, str(model_version))):
os.makedirs(output_dir)
# build a generic linear MLP model
input = tf.keras.Input(
name="input", shape=(input_dim,), dtype="float32", batch_size=1
)
x = input
for dim in [256, 64, 1]:
x = tf.keras.layers.Dense(dim)(x)
model = tf.keras.Model(inputs=input, outputs=x)
model.save(os.path.join(output_dir, str(model_version), "model.savedmodel"))
config = model_config.ModelConfig(
name=model_name,
platform="tensorflow_savedmodel",
input=[
model_config.ModelInput(
name="input",
dims=[1, input_dim],
data_type=model_config.DataType.TYPE_FP32
)
],
output=[
model_config.ModelOutput(
name=x.name.split("/")[0],
dims=[1, 1],
data_type=model_config.DataType.TYPE_FP32
)
],
sequence_batching=model_config.ModelSequenceBatching(
max_sequence_idle_microseconds=10000000,
direct=model_config.ModelSequenceBatching.StrategyDirect(),
),
instance_group=[model_config.ModelInstanceGroup(
gpus=[0],
count=4
)]
)
with open(os.path.join(output_dir, "config.pbtxt"), "w") as f:
f.write(str(config))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--repo-dir", type=str, required=True)
parser.add_argument("--model-name", type=str, required=True)
parser.add_argument("--model-version", type=int, default=1)
parser.add_argument("--input-dim", type=int, default=1024)
flags = parser.parse_args()
main(**vars(flags))
Then start the server and run the following client code
import argparse
import time
from threading import Event
import numpy as np
import tritonclient.grpc as triton
class Callback:
def __init__(self, stop_event):
self.stop_event = stop_event
self.start_time = time.time()
self.total_requests = 0
def __call__(self, result, error=None):
if error is not None:
print("Error {} got raised after {} s and {} requests".format(
str(error), time.time() - self.start_time, self.total_requests
))
self.stop_event.set()
raise error
self.total_requests += 1
if self.total_requests % 100000 == 0:
print("Completed {} requests after {} s".format(
self.total_requests, time.time() - self.start_time
))
def main(
url: str,
model_name: str,
model_version: int = 1,
request_rate: float = 1000.
):
client = triton.InferenceServerClient(url)
model_metadata = client.get_model_metadata(model_name)
input = triton.InferInput(
name=model_metadata.inputs[0].name,
shape=model_metadata.inputs[0].shape,
datatype=model_metadata.inputs[0].datatype
)
stop_event = Event()
with client:
client.start_stream(callback=Callback(stop_event))
last_request_time = time.time()
sequence_start = True
while not stop_event.is_set():
x = np.random.randn(*input.shape()).astype("float32")
input.set_data_from_numpy(x)
# do some throttling to avoid overloading the server
while (time.time() - last_request_time) < 1 / request_rate - 5e-4:
time.sleep(1e-6)
# make the request
client.async_stream_infer(
model_name,
model_version=str(model_version),
sequence_id=1001,
inputs=[input],
sequence_start=sequence_start
)
sequence_start = False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, required=True)
parser.add_argument("--model-name", type=str, required=True)
parser.add_argument("--model-version", type=int, default=1)
parser.add_argument("--request-rate", type=float, default=1000)
flags = parser.parse_args()
main(**vars(flags))
Expected behavior
Ideally the connection should never break, but at the very least advice on how to catch this issue and quickly reconnect to not interrupt service. The issue is that any attempt to exit the current client context has to wait until all outstanding requests are completed, which could be substantial. There can also be other issues where e.g. an attempted reconnect leads to a Too many pings
complaint from the server.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (4 by maintainers)
Top GitHub Comments
FWIW I’ve been unable to reproduce this error again, and have been able to run sequences for > 5M requests for ~1 hour a few times without issue. It’s possible that this arises more frequently when I have multiple clients connecting to a single server, but will keep this closed until if/when I encounter that. Thanks for your help!
Closing issue for now. Please re-open when you have the additional information.