diff --git a/models/bevdet/ops/bev_pool/bev_pool.py b/models/bevdet/ops/bev_pool/bev_pool.py index 6e6d012..f54e21b 100644 --- a/models/bevdet/ops/bev_pool/bev_pool.py +++ b/models/bevdet/ops/bev_pool/bev_pool.py @@ -126,6 +126,45 @@ def backward(ctx, out_grad): return x_grad, None, None, None, None, None, None +class QuickCumsumMean(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, geom_feats, ranks): + x = x.cumsum(0) + kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) + kept[:-1] = ranks[1:] != ranks[:-1] + + interval_starts = torch.where(kept)[0].int() + interval_lengths = torch.zeros_like(interval_starts) + interval_lengths[1:] = interval_starts[1:] - interval_starts[:-1] + interval_lengths[0] = interval_starts[0] + 1 + + x, geom_feats = x[kept], geom_feats[kept] + x = torch.cat((x[:1], x[1:] - x[:-1])) + + interval_lengths = interval_lengths.unsqueeze(-1) + x = x / interval_lengths + + # save kept for backward + ctx.save_for_backward(kept) + ctx.save_for_backward(interval_lengths) + + # no gradient for geom_feats + ctx.mark_non_differentiable(geom_feats) + + return x, geom_feats + + @staticmethod + def backward(ctx, gradx, gradgeom): + (kept, interval_lengths) = ctx.saved_tensors + back = torch.cumsum(kept, 0) + back[kept] -= 1 + + gradx /= interval_lengths + + val = gradx[back] + + return val, None, None def bev_pool(feats, coords, B, D, H, W, mean_pool=False): assert feats.shape[0] == coords.shape[0] diff --git a/models/bevdet/ops/bev_pool/src/bev_pool_cuda.cu b/models/bevdet/ops/bev_pool/src/bev_pool_cuda.cu index 884ea76..2932116 100644 --- a/models/bevdet/ops/bev_pool/src/bev_pool_cuda.cu +++ b/models/bevdet/ops/bev_pool/src/bev_pool_cuda.cu @@ -31,7 +31,7 @@ __global__ void bev_pool_kernel(int b, int d, int h, int w, int n, int c, int n_ return; int interval_start = interval_starts[index]; int interval_length = interval_lengths[index]; - const int *cur_geom_feats = geom_feats + interval_start * 4; + const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4; const float *cur_x = x + interval_start * c + cur_c; float *cur_out = out + cur_geom_feats[3] * d * h * w * c + cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c + @@ -75,7 +75,7 @@ __global__ void bev_pool_grad_kernel(int b, int d, int h, int w, int n, int c, i int interval_start = interval_starts[index]; int interval_length = interval_lengths[index]; - const int *cur_geom_feats = geom_feats + interval_start * 4; + const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4; float *cur_x_grad = x_grad + interval_start * c + cur_c; const float *cur_out_grad = out_grad + cur_geom_feats[3] * d * h * w * c + @@ -103,7 +103,7 @@ __global__ void bev_mean_pool_kernel(int b, int d, int h, int w, int n, int c, i return; int interval_start = interval_starts[index]; int interval_length = interval_lengths[index]; - const int *cur_geom_feats = geom_feats + interval_start * 4; + const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4; const float *cur_x = x + interval_start * c + cur_c; float *cur_out = out + cur_geom_feats[3] * d * h * w * c + cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c + @@ -134,7 +134,7 @@ __global__ void bev_mean_pool_grad_kernel(int b, int d, int h, int w, int n, int int interval_start = interval_starts[index]; int interval_length = interval_lengths[index]; - const int *cur_geom_feats = geom_feats + interval_start * 4; + const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4; float *cur_x_grad = x_grad + interval_start * c + cur_c; const float *cur_out_grad = out_grad + cur_geom_feats[3] * d * h * w * c +