Transferring Boxes to torch device changes their dtype to float32.
See original GitHub issueI am trying to support mixed precision training in Detectron2 for my use-case, using NVIDIA Apex, and I am unable to make drop-in replacements as the library suggests due to some internal wiring. I think they should be fixed regardless, because their behavior is non-intuitive.
Instructions To Reproduce the Issue:
>>> f16_boxes = Boxes(torch.tensor([[1.0, 2.0, 3.0, 4.0]]).to(torch.float16))
>>> f16_boxes = f16_boxes.to(torch.device("cuda:0"))
>>> f16_boxes.tensor.dtype # Expected: torch.float16
torch.float32
Location of Issue:
The problem lies in this method: https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/boxes.py#L154
This definition calls the constructor:
def to(self, device: str) -> "Boxes":
return Boxes(self.tensor.to(device))
and the constructor sets everything to torch.float32
by force.
def __init__(self, tensor: torch.Tensor):
device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) # This line!
...
Proposed Solution:
The method body of to
should look like:
def to(self, device: str) -> "Boxes":
boxes_new = self.clone()
boxes_new.tensor = self.tensor.to(device)
return boxes_new
This method body will retain current behavior in pure FP32 settings, and provide ease with mixed precision training, as well as looks more intuitive.
[I trimmed down other parts from issue template because they aren’t needed here.]
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
It seems the original issue is no longer valid? Feel free to reopen if it’s not the case.
I don’t think it’s necessary to use float16 for boxes because I couldn’t think of any place where this would improve speed, and it would cause issues for box-related operations which might not support fp16.
It might be better to just keep using float32 and apply appropriate casting when needed (e.g. from predicted fp16 deltas to fp32 deltas).
(unrelated to your issue, but Apex’s claim of “drop-in replacement” is just not true for any sufficiently complicated model. Hacks might be required to make it work