question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

not receiving frames when running long image processing

See original GitHub issue

I’ve modified the edge detection sample code by replacing edge detection with a deep learning network that performs semantic segmentation (DeepLapv3 from Google).

This model can process around 10 frames per second, but most of the time the client doesn’t display back the processed frames. I wonder if this has to do with the fact my code takes too long to process each incoming frame?

I’m pasting the code below… Run with python server.py and then open your browser at http://0.0.0.0:8080

import argparse
import asyncio
import json
import logging
import os

import cv2
from aiohttp import web
from av import VideoFrame

from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder

ROOT = os.path.dirname(__file__)

import sys
import os
import glob
import time
from io import BytesIO
import tarfile
from six.moves import urllib

import numpy as np
from PIL import Image
import cv2
#from skvideo.io import FFmpegWriter
import tensorflow as tf

#from Yulius
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1, allow_growth=True)
tf_config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)

class DeepLabModel(object):
  """Class to load deeplab model and run inference."""

  INPUT_TENSOR_NAME = 'ImageTensor:0'
  OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
  INPUT_SIZE = 513
  FROZEN_GRAPH_NAME = 'frozen_inference_graph'

  def __init__(self, tarball_path):
    """Creates and loads pretrained deeplab model."""
    self.graph = tf.Graph()

    graph_def = None
    # Extract frozen graph from tar archive.
    tar_file = tarfile.open(tarball_path)
    for tar_info in tar_file.getmembers():
      if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
        file_handle = tar_file.extractfile(tar_info)
        graph_def = tf.GraphDef.FromString(file_handle.read())
        break

    tar_file.close()

    if graph_def is None:
      raise RuntimeError('Cannot find inference graph in tar archive.')

    with self.graph.as_default():
      tf.import_graph_def(graph_def, name='')

    self.sess = tf.Session(graph=self.graph,config=tf_config)

  def run(self, image):
    """Runs inference on a single image.
    Args:
      image: A PIL.Image object, raw input image.
    Returns:
      resized_image: RGB image resized from original input image.
      seg_map: Segmentation map of `resized_image`.
    """
    width, height = image.size
    resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
    target_size = (int(resize_ratio * width), int(resize_ratio * height))
    resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
    batch_seg_map = self.sess.run(
        self.OUTPUT_TENSOR_NAME,
        feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
    seg_map = batch_seg_map[0]
    #return resized_image, seg_map
    return seg_map


def create_pascal_label_colormap():
  """Creates a label colormap used in PASCAL VOC segmentation benchmark.
  Returns:
    A Colormap for visualizing segmentation results.
  """
  colormap = np.zeros((256, 3), dtype=int)
  ind = np.arange(256, dtype=int)

  for shift in reversed(range(8)):
    for channel in range(3):
      colormap[:, channel] |= ((ind >> channel) & 1) << shift
    ind >>= 3

  for i in range(0,len(colormap)):
    colormap[i] = [0,0,0];
  colormap[15] = [255,255,255]
  return colormap

colormap = create_pascal_label_colormap();

print('colormap.ndim=',colormap.ndim,',shape=',colormap.shape)

def label_to_color_image(label):
  """Adds color defined by the dataset colormap to the label.
  Args:
    label: A 2D array with integer type, storing the segmentation label.
  Returns:
    result: A 2D array with floating type. The element of the array
      is the color indexed by the corresponding element in the input label
      to the PASCAL color map.
  Raises:
    ValueError: If label is not of rank 2 or its value is larger than color
      map maximum entry.
  """
  if label.ndim != 2:
    raise ValueError('Expect 2-D input label')

  #colormap = create_pascal_label_colormap()

  if np.max(label) >= len(colormap):
    raise ValueError('label value too large.')
  #print(label.shape)
  return colormap[label]

# LABEL_NAMES = np.asarray([
#     'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
#     'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
#     'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
# ])
# 
#FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
#FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

MODEL_NAME = 'mobilenetv2_coco_voctrainaug'  # @param ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']

_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {
    'mobilenetv2_coco_voctrainaug':
        'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
    'mobilenetv2_coco_voctrainval':
        'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
    'xception_coco_voctrainaug':
        'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
    'xception_coco_voctrainval':
        'deeplabv3_pascal_trainval_2018_01_04.tar.gz',
}
_TARBALL_NAME = 'deeplab_model.tar.gz'


model_dir = '../models'
tf.gfile.MakeDirs(model_dir)

download_path = os.path.join(model_dir, _TARBALL_NAME)
if os.path.isfile(download_path):
  print('model on disk already',download_path)
else:
  print('downloading model, this might take a while...',download_path)
  urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME],download_path)
  print('download completed! loading DeepLab model...')

model = DeepLabModel(download_path)
print('model loaded successfully!')

def apply_mask(image, mask):
    image[:, :, 0] = np.where(
        mask == 0,
        0,
        image[:, :, 0]
    )
    image[:, :, 1] = np.where(
        mask == 0,
        255,
        image[:, :, 1]
    )
    image[:, :, 2] = np.where(
        mask == 0,
        0,
        image[:, :, 2]
    )
    return image


class VideoTransformTrack(VideoStreamTrack):
    def __init__(self, track, transform):
        super().__init__()  # don't forget this!
        self.counter = 0
        self.track = track
        self.transform = transform

    async def recv(self):
        frame = await self.track.recv()
        self.counter += 1

        if self.transform == 'edges':
            # perform edge detection
            img = frame.to_ndarray(format='bgr24')
            pil_im = Image.fromarray(img)
            shape = img.shape
            #print(img.shape)
            seg_map = model.run(pil_im)
            width = seg_map.shape[1]
            height = seg_map.shape[0]
            #print(width,height)
            img = cv2.resize(img,(width,height))
            apply_mask(img,seg_map)
            img = cv2.resize(img,(shape[1],shape[0]))
            #print('after')
            #img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)

            # rebuild a VideoFrame, preserving timing information
            new_frame = VideoFrame.from_ndarray(img, format='bgr24')
            new_frame.pts = frame.pts
            new_frame.time_base = frame.time_base
            #print('new_frame=',new_frame)
            return new_frame
        elif self.transform == 'rotate':
            # rotate image
            img = frame.to_ndarray(format='bgr24')
            rows, cols, _ = img.shape
            M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
            img = cv2.warpAffine(img, M, (cols, rows))

            # rebuild a VideoFrame, preserving timing information
            new_frame = VideoFrame.from_ndarray(img, format='bgr24')
            new_frame.pts = frame.pts
            new_frame.time_base = frame.time_base
            return new_frame
        else:
            return frame


async def index(request):
    content = open(os.path.join(ROOT, 'index.html'), 'r').read()
    return web.Response(content_type='text/html', text=content)


async def javascript(request):
    content = open(os.path.join(ROOT, 'client.js'), 'r').read()
    return web.Response(content_type='application/javascript', text=content)


async def offer(request):
    params = await request.json()
    offer = RTCSessionDescription(
        sdp=params['sdp'],
        type=params['type'])

    pc = RTCPeerConnection()
    pcs.add(pc)

    # prepare local media
    player = MediaPlayer(os.path.join(ROOT, 'demo-instruct.wav'))
    if args.write_audio:
        recorder = MediaRecorder(args.write_audio)
    else:
        recorder = MediaBlackhole()

    @pc.on('datachannel')
    def on_datachannel(channel):
        @channel.on('message')
        def on_message(message):
            channel.send('pong')

    @pc.on('iceconnectionstatechange')
    async def on_iceconnectionstatechange():
        print('ICE connection state is %s' % pc.iceConnectionState)
        if pc.iceConnectionState == 'failed':
            await pc.close()
            pcs.discard(pc)

    @pc.on('track')
    def on_track(track):
        print('Track %s received' % track.kind)

        if track.kind == 'audio':
            #print('not adding audio track')
            pc.addTrack(player.audio)
            recorder.addTrack(track)
        elif track.kind == 'video':
            local_video = VideoTransformTrack(track, transform=params['video_transform'])
            pc.addTrack(local_video)

        @track.on('ended')
        async def on_ended():
            print('Track %s ended' % track.kind)
            await recorder.stop()

    # handle offer
    await pc.setRemoteDescription(offer)
    await recorder.start()

    # send answer
    answer = await pc.createAnswer()
    await pc.setLocalDescription(answer)

    return web.Response(
        content_type='application/json',
        text=json.dumps({
            'sdp': pc.localDescription.sdp,
            'type': pc.localDescription.type
        }))


pcs = set()


async def on_shutdown(app):
    # close peer connections
    coros = [pc.close() for pc in pcs]
    await asyncio.gather(*coros)
    pcs.clear()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='WebRTC audio / video / data-channels demo')
    parser.add_argument('--port', type=int, default=8080,
                        help='Port for HTTP server (default: 8080)')
    parser.add_argument('--verbose', '-v', action='count')
    parser.add_argument('--write-audio', help='Write received audio to a file')
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    app = web.Application()
    app.on_shutdown.append(on_shutdown)
    app.router.add_get('/', index)
    app.router.add_get('/client.js', javascript)
    app.router.add_post('/offer', offer)
    web.run_app(app, port=args.port)

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:5

github_iconTop GitHub Comments

4reactions
fake-namecommented, Mar 29, 2019

This is a old issue, but I’m doing something similar to OP (doing expensive processing on the video stream that requires dropping frames to not run out of memory).

The way I figured out to do this is to look at the queue in the incoming track:


	async def recv(self):
		frame = await self.track.recv()

		# Consume all available frames.
		# If we don't do this, we'll bloat indefinitely.
		while not self.track._queue.empty():
			frame = await self.track.recv()

This basically just drops every frame but the most recent.

Even with a intentionally short lifetime on the datachannel (Set to unordered, 100 ms lifetime), I still had queue bloat issues. The above hack resolved them.

One thing that’d be nice to consider in the future would be a lossy transport of some sort. I dunno if that’d be something you’d want to integrate into the library.

1reaction
jlainecommented, Nov 9, 2018

Hi Laurent,

This is a pretty cool example! Once we get it tidied up, it might make a nice addition to aiortc’s examples.

I think the first thing you want to address is moving the processing off the main thread as it seems very expensive and is blocking the main asyncio loop. Two symptoms of this:

  • if you enable the datachannel you’ll notice you’re not getting a “pong” for each “ping”.
  • the audio is breaking up

The easiest way to do this is to move your background removal code to a remove_background function that takes a VideoFrame and returns a VideoFrame. You can then do this in VideoStreamTrack.recv:

        if self.transform == 'edges':
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(None, remove_background, frame)

This however won’t fix a more fundamental issue : the algorithm is so heavy that you won’t ever be able to match the incoming frame rate (when I run it on my laptop I’m getting closer to 1 fps!). This means you’ll need to drop frames and not submit them all to the processing algorithm. A super naive hack which decimates the frames by a factor of 10:

    async def recv(self):
        for i in range(10):
            frame = await self.track.recv()

Finally I noticed there is a lot of conversion / resizing going on in the code, which can be optimized. You don’t need to use PIL at all.

Modify DeepLabModel.run like this:

    def run(self, image):
        """Runs inference on a single image."""
        width = image.shape[1]
        height = image.shape[0]
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = cv2.resize(image, target_size)
        batch_seg_map = self.sess.run(
            self.OUTPUT_TENSOR_NAME,
            feed_dict={self.INPUT_TENSOR_NAME: [resized_image]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map

And your remove_background function looks like this:

def remove_background(frame):
    # remove background
    img = frame.to_ndarray(format='bgr24')
    original_size = (img.shape[1], img.shape[0])
    img, seg_map = model.run(img)
    apply_mask(img, seg_map)
    img = cv2.resize(img, original_size)

    # rebuild a VideoFrame, preserving timing information
    new_frame = VideoFrame.from_ndarray(img, format='bgr24')
    new_frame.pts = frame.pts
    new_frame.time_base = frame.time_base
    return new_frame
Read more comments on GitHub >

github_iconTop Results From Across the Web

openCV: How can I increase performance on image ...
The best thing to do is to run asynchronously. Use a separate thread to capture the images and store in queue and another...
Read more >
How to Fix Your Low Frame Rate - Intel
Easy Fixes for Low FPS​​ If you're seeing low FPS in-game, there are a few steps you can take immediately. Close background processes....
Read more >
Faster video file FPS with cv2.VideoCapture and OpenCV
Have you ever worked with a video file via OpenCV's cv2.VideoCapture function and found that reading frames just felt slow and sluggish?
Read more >
Slow rendering - Android Developers
To fix jank, inspect which frames aren't completing in 16.7ms, and look for what is going wrong. Is Record View#draw taking abnormally long...
Read more >
My app or frame doesn't show all photos or videos
Using the Aura app, you can upload videos, photos, and entire albums to your frame. We carefully designed the upload experience to make......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found