Using HinSAGE with a different generator to the one used to construct model gives non-obvious reshaping errors
See original GitHub issueDescribe the bug
When training HinSAGE with DeepGraphInfomax, it’s possible to hit InvalidArgumentError: Input to reshape is a tensor with 1000 values, but the requested shape has 2250
.
Approximate reproducer:
single_generator = sg.mapper.HinSAGENodeGenerator(hetero_single_graph, batch_size=100, num_samples=[5, 5], head_node_type="atom")
corrupted_generator = sg.mapper.CorruptedGenerator(single_generator)
train_corr_gen = corrupted_generator.flow(train_single_nodes)
test_corr_gen = corrupted_generator.flow(test_single_nodes)
base_model = sg.layer.HinSAGE([64, 64], activations=["relu", "relu"], generator=generator)
dgi_model = sg.layer.DeepGraphInfomax(base_model, corrupted_generator)
import tensorflow as tf
x_in, x_out = dgi_model.in_out_tensors()
model = tf.keras.Model(inputs=x_in, outputs=x_out)
model.compile(loss=tf.nn.sigmoid_cross_entropy_with_logits, optimizer="Adam")
small_gen = corrupted_generator.flow(train_single_nodes[:10])
history = model.fit(small_gen, epochs=30)
StellarGraph: Undirected multigraph
Nodes: 17868, Edges: 52012
Node types:
atom: [17868]
Features: float32 vector, length 4
Edge types: atom-1JHC->atom, atom-1JHN->atom, atom-2JHC->atom, atom-2JHH->atom, atom-2JHN->atom, ... (3 more)
Edge types:
atom-3JHC->atom: [16642]
Weights: all 1 (default)
atom-2JHC->atom: [12584]
Weights: all 1 (default)
atom-1JHC->atom: [7966]
Weights: all 1 (default)
atom-3JHH->atom: [6543]
Weights: all 1 (default)
atom-2JHH->atom: [4340]
Weights: all 1 (default)
atom-3JHN->atom: [1989]
Weights: all 1 (default)
atom-2JHN->atom: [1423]
Weights: all 1 (default)
atom-1JHN->atom: [525]
Weights: all 1 (default)
Full stack trace
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-166-d83cc857889d> in <module>
----> 1 history = model.fit(small_gen, epochs=30)
2 sg.utils.plot_history(history)
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
817 max_queue_size=max_queue_size,
818 workers=workers,
--> 819 use_multiprocessing=use_multiprocessing)
820
821 def evaluate(self,
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
340 mode=ModeKeys.TRAIN,
341 training_context=training_context,
--> 342 total_epochs=epochs)
343 cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
344
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
126 step=step, mode=mode, size=current_batch_size) as batch_logs:
127 try:
--> 128 batch_outs = execution_function(iterator)
129 except (StopIteration, errors.OutOfRangeError):
130 # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)
96 # `numpy` translates Tensors to values in Eager mode.
97 return nest.map_structure(_non_none_constant_value,
---> 98 distributed_function(input_fn))
99
100 return execution_function
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
566 xla_context.Exit()
567 else:
--> 568 result = self._call(*args, **kwds)
569
570 if tracing_count == self._get_tracing_count():
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
630 # Lifting succeeded, so variables are initialized and we can run the
631 # stateless function.
--> 632 return self._stateless_fn(*args, **kwds)
633 else:
634 canon_args, canon_kwds = \
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
2361 with self._lock:
2362 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2363 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2364
2365 @property
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in _filtered_call(self, args, kwargs)
1609 if isinstance(t, (ops.Tensor,
1610 resource_variable_ops.BaseResourceVariable))),
-> 1611 self.captured_inputs)
1612
1613 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1690 # No tape is watching; skip to running the function.
1691 return self._build_call_outputs(self._inference_function.call(
-> 1692 ctx, args, cancellation_manager=cancellation_manager))
1693 forward_backward = self._select_forward_and_backward_functions(
1694 args,
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
543 inputs=args,
544 attrs=("executor_type", executor_type, "config_proto", config),
--> 545 ctx=ctx)
546 else:
547 outputs = execute.execute_with_cancellation(
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
~/.pyenv/versions/sg-1.0.0rc1/lib/python3.6/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Input to reshape is a tensor with 1000 values, but the requested shape has 2250
[[node model_4/reshape_191/Reshape (defined at <ipython-input-166-d83cc857889d>:1) ]] [Op:__inference_distributed_function_54767]
Function call stack:
distributed_function
(StellarGraph 1.0.0rc1 dogfooding.)
Issue Analytics
- State:
- Created 3 years ago
- Comments:10 (9 by maintainers)
Top Results From Across the Web
Link prediction with Heterogeneous GraphSAGE (HinSAGE)
In this example, we use our generalisation of the GraphSAGE algorithm to heterogeneous graphs (which we call HinSAGE) to build a model that...
Read more >Text Representation Enrichment Utilizing Graph based ... - arXiv
After constructing the graph, node embeddings are trained in an unsupervised manner, and we update the graph nodes with new representations.
Read more >GraphSAGE for Classification in Python | Well Enough
So, in this blog I'll cover GraphSAGE - an inductive deep learning model for graphs that can handle the addition of new nodes...
Read more >hinsage-link-prediction.ipynb - Colaboratory - Google Colab
In this example, we use our generalisation of the GraphSAGE algorithm to heterogeneous graphs (which we call HinSAGE) to build a model that...
Read more >Linkprediction using Hinsage/Graphsage in StellarGraph ...
If there is any node containing missing data, the thing will just produce NAs. Especially dangerous if you create your graph by joining...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thanks, I replied there 👍
Yep, I got the same results with your above graph