IGA implementation
See original GitHub issueThank you for the great work! Your code is highly functional and I hope it will become the go-to open-source software for domain generalization with fair comparisons.
My question is related to the current implementation of the IGA algorithm, in https://github.com/facebookresearch/DomainBed/blob/54c2f8c614a96067dee2d961f0b34575753c9df0/domainbed/algorithms.py#L986
for i, (x, y) in enumerate(minibatches):
[....]
grads.append( autograd.grad(env_loss, self.network.parameters(), retain_graph=True) )
mean_loss = total_loss / len(minibatches)
mean_grad = autograd.grad(mean_loss, self.network.parameters(), retain_graph=True)
[....]
penalty_value += (g - mean_g).pow(2).sum()
[....]
(mean_loss + self.hparams['penalty'] * penalty_value).backward()
I believe there is a small error: a “create_graph” is missing when calling autograd.grad.
In details, with the current implementation the penalty_value is useless in the training. As proof, if we replace the last “backward” by
(self.hparams['penalty'] * penalty_value).backward()
we obtain the error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Another proof is the lack of impact of the hyper-parameter self.hparams['penalty']
The fix is simple: just add create_graph in autograd.grad, to explicitly say to do further operations on gradients, to have a backpropable graph for those gradients
for i, (x, y) in enumerate(minibatches):
[....]
grads.append( autograd.grad(env_loss, self.network.parameters(), retain_graph=True, create_graph=True) )
mean_loss = total_loss / len(minibatches)
mean_grad = autograd.grad(mean_loss, self.network.parameters(), retain_graph=True, create_graph=True)
Overall, I am still (currently) unable to reproduce the results from the paper, but I believe this is a step forward.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (5 by maintainers)
Left a quick review, almost ready to merge @jc-audet .
Hello @alexrame,
Wow! Good catch! I was also having issue reproducing the results from the paper and the lack of open code from the paper made it hard to check my implementation. When I was coding it, I somehow managed to get ~60% test accuracy on CMNIST with the right timing of penalty activation and number of steps of training so I figured it was alright.
We should add this fix asap. I will create a PR with
create_graph
along with penalty annealing I have in my personal fork for IGA.Sorry for the wait