question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Using HinSAGE with a different generator to the one used to construct model gives non-obvious reshaping errors

See original GitHub issue

Describe the bug

When training HinSAGE with DeepGraphInfomax, it’s possible to hit InvalidArgumentError: Input to reshape is a tensor with 1000 values, but the requested shape has 2250.

Approximate reproducer:

single_generator = sg.mapper.HinSAGENodeGenerator(hetero_single_graph, batch_size=100, num_samples=[5, 5], head_node_type="atom")
corrupted_generator = sg.mapper.CorruptedGenerator(single_generator)
train_corr_gen = corrupted_generator.flow(train_single_nodes)
test_corr_gen = corrupted_generator.flow(test_single_nodes)
base_model = sg.layer.HinSAGE([64, 64], activations=["relu", "relu"], generator=generator)
dgi_model = sg.layer.DeepGraphInfomax(base_model, corrupted_generator)
import tensorflow as tf

x_in, x_out = dgi_model.in_out_tensors()
model = tf.keras.Model(inputs=x_in, outputs=x_out)
model.compile(loss=tf.nn.sigmoid_cross_entropy_with_logits, optimizer="Adam")

small_gen = corrupted_generator.flow(train_single_nodes[:10])
history = model.fit(small_gen, epochs=30)
StellarGraph: Undirected multigraph
 Nodes: 17868, Edges: 52012

 Node types:
  atom: [17868]
    Features: float32 vector, length 4
    Edge types: atom-1JHC->atom, atom-1JHN->atom, atom-2JHC->atom, atom-2JHH->atom, atom-2JHN->atom, ... (3 more)

 Edge types:
    atom-3JHC->atom: [16642]
        Weights: all 1 (default)
    atom-2JHC->atom: [12584]
        Weights: all 1 (default)
    atom-1JHC->atom: [7966]
        Weights: all 1 (default)
    atom-3JHH->atom: [6543]
        Weights: all 1 (default)
    atom-2JHH->atom: [4340]
        Weights: all 1 (default)
    atom-3JHN->atom: [1989]
        Weights: all 1 (default)
    atom-2JHN->atom: [1423]
        Weights: all 1 (default)
    atom-1JHN->atom: [525]
        Weights: all 1 (default)
Full stack trace
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-166-d83cc857889d> in <module>
----> 1 history = model.fit(small_gen, epochs=30)
      2 sg.utils.plot_history(history)

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    340                 mode=ModeKeys.TRAIN,
    341                 training_context=training_context,
--> 342                 total_epochs=epochs)
    343             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
    344 

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
    126         step=step, mode=mode, size=current_batch_size) as batch_logs:
    127       try:
--> 128         batch_outs = execution_function(iterator)
    129       except (StopIteration, errors.OutOfRangeError):
    130         # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)
     96     # `numpy` translates Tensors to values in Eager mode.
     97     return nest.map_structure(_non_none_constant_value,
---> 98                               distributed_function(input_fn))
     99 
    100   return execution_function

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    566         xla_context.Exit()
    567     else:
--> 568       result = self._call(*args, **kwds)
    569 
    570     if tracing_count == self._get_tracing_count():

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    630         # Lifting succeeded, so variables are initialized and we can run the
    631         # stateless function.
--> 632         return self._stateless_fn(*args, **kwds)
    633     else:
    634       canon_args, canon_kwds = \

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2361     with self._lock:
   2362       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2363     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2364 
   2365   @property

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in _filtered_call(self, args, kwargs)
   1609          if isinstance(t, (ops.Tensor,
   1610                            resource_variable_ops.BaseResourceVariable))),
-> 1611         self.captured_inputs)
   1612 
   1613   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1690       # No tape is watching; skip to running the function.
   1691       return self._build_call_outputs(self._inference_function.call(
-> 1692           ctx, args, cancellation_manager=cancellation_manager))
   1693     forward_backward = self._select_forward_and_backward_functions(
   1694         args,

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    543               inputs=args,
    544               attrs=("executor_type", executor_type, "config_proto", config),
--> 545               ctx=ctx)
    546         else:
    547           outputs = execute.execute_with_cancellation(

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     65     else:
     66       message = e.message
---> 67     six.raise_from(core._status_to_exception(e.code, message), None)
     68   except TypeError as e:
     69     keras_symbolic_tensors = [

~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError:  Input to reshape is a tensor with 1000 values, but the requested shape has 2250
	 [[node model_4/reshape_191/Reshape (defined at <ipython-input-166-d83cc857889d>:1) ]] [Op:__inference_distributed_function_54767]

Function call stack:
distributed_function

(StellarGraph 1.0.0rc1 dogfooding.)

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:10 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
huonwcommented, Sep 21, 2020

UPDATE: I opened my issue in a new one (#1802) given that my issue only has one generator and so it may be different (and may not be a bug).

Thanks, I replied there 👍

1reaction
kieranricardocommented, Apr 24, 2020

Yep, I got the same results with your above graph

Read more comments on GitHub >

github_iconTop Results From Across the Web

Link prediction with Heterogeneous GraphSAGE (HinSAGE)
In this example, we use our generalisation of the GraphSAGE algorithm to heterogeneous graphs (which we call HinSAGE) to build a model that...
Read more >
Text Representation Enrichment Utilizing Graph based ... - arXiv
After constructing the graph, node embeddings are trained in an unsupervised manner, and we update the graph nodes with new representations.
Read more >
GraphSAGE for Classification in Python | Well Enough
So, in this blog I'll cover GraphSAGE - an inductive deep learning model for graphs that can handle the addition of new nodes...
Read more >
hinsage-link-prediction.ipynb - Colaboratory - Google Colab
In this example, we use our generalisation of the GraphSAGE algorithm to heterogeneous graphs (which we call HinSAGE) to build a model that...
Read more >
Linkprediction using Hinsage/Graphsage in StellarGraph ...
If there is any node containing missing data, the thing will just produce NAs. Especially dangerous if you create your graph by joining...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found