Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" | |
| import torch | |
| # ---------------------------------------------------------------------------- | |
| def fma(a, b, c): # => a * b + c | |
| return _FusedMultiplyAdd.apply(a, b, c) | |
| # ---------------------------------------------------------------------------- | |
| class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c | |
| def forward(ctx, a, b, c): # pylint: disable=arguments-differ | |
| out = torch.addcmul(c, a, b) | |
| ctx.save_for_backward(a, b) | |
| ctx.c_shape = c.shape | |
| return out | |
| def backward(ctx, dout): # pylint: disable=arguments-differ | |
| a, b = ctx.saved_tensors | |
| c_shape = ctx.c_shape | |
| da = None | |
| db = None | |
| dc = None | |
| if ctx.needs_input_grad[0]: | |
| da = _unbroadcast(dout * b, a.shape) | |
| if ctx.needs_input_grad[1]: | |
| db = _unbroadcast(dout * a, b.shape) | |
| if ctx.needs_input_grad[2]: | |
| dc = _unbroadcast(dout, c_shape) | |
| return da, db, dc | |
| # ---------------------------------------------------------------------------- | |
| def _unbroadcast(x, shape): | |
| extra_dims = x.ndim - len(shape) | |
| assert extra_dims >= 0 | |
| dim = [i for i in range(x.ndim) if x.shape[i] > 1 and ( | |
| i < extra_dims or shape[i - extra_dims] == 1)] | |
| if len(dim): | |
| x = x.sum(dim=dim, keepdim=True) | |
| if extra_dims: | |
| x = x.reshape(-1, *x.shape[extra_dims+1:]) | |
| assert x.shape == shape | |
| return x | |
| # ---------------------------------------------------------------------------- | |