Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions models/bevdet/ops/bev_pool/bev_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,45 @@ def backward(ctx, out_grad):

return x_grad, None, None, None, None, None, None

class QuickCumsumMean(torch.autograd.Function):

@staticmethod
def forward(ctx, x, geom_feats, ranks):
x = x.cumsum(0)
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[:-1] = ranks[1:] != ranks[:-1]

interval_starts = torch.where(kept)[0].int()
interval_lengths = torch.zeros_like(interval_starts)
interval_lengths[1:] = interval_starts[1:] - interval_starts[:-1]
interval_lengths[0] = interval_starts[0] + 1

x, geom_feats = x[kept], geom_feats[kept]
x = torch.cat((x[:1], x[1:] - x[:-1]))

interval_lengths = interval_lengths.unsqueeze(-1)
x = x / interval_lengths

# save kept for backward
ctx.save_for_backward(kept)
ctx.save_for_backward(interval_lengths)

# no gradient for geom_feats
ctx.mark_non_differentiable(geom_feats)

return x, geom_feats

@staticmethod
def backward(ctx, gradx, gradgeom):
(kept, interval_lengths) = ctx.saved_tensors
back = torch.cumsum(kept, 0)
back[kept] -= 1

gradx /= interval_lengths

val = gradx[back]

return val, None, None

def bev_pool(feats, coords, B, D, H, W, mean_pool=False):
assert feats.shape[0] == coords.shape[0]
Expand Down
8 changes: 4 additions & 4 deletions models/bevdet/ops/bev_pool/src/bev_pool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ __global__ void bev_pool_kernel(int b, int d, int h, int w, int n, int c, int n_
return;
int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];
const int *cur_geom_feats = geom_feats + interval_start * 4;
const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4;
const float *cur_x = x + interval_start * c + cur_c;
float *cur_out = out + cur_geom_feats[3] * d * h * w * c +
cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c +
Expand Down Expand Up @@ -75,7 +75,7 @@ __global__ void bev_pool_grad_kernel(int b, int d, int h, int w, int n, int c, i
int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];

const int *cur_geom_feats = geom_feats + interval_start * 4;
const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4;
float *cur_x_grad = x_grad + interval_start * c + cur_c;

const float *cur_out_grad = out_grad + cur_geom_feats[3] * d * h * w * c +
Expand Down Expand Up @@ -103,7 +103,7 @@ __global__ void bev_mean_pool_kernel(int b, int d, int h, int w, int n, int c, i
return;
int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];
const int *cur_geom_feats = geom_feats + interval_start * 4;
const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4;
const float *cur_x = x + interval_start * c + cur_c;
float *cur_out = out + cur_geom_feats[3] * d * h * w * c +
cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c +
Expand Down Expand Up @@ -134,7 +134,7 @@ __global__ void bev_mean_pool_grad_kernel(int b, int d, int h, int w, int n, int
int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];

const int *cur_geom_feats = geom_feats + interval_start * 4;
const int *cur_geom_feats = geom_feats + (interval_start + interval_length - 1) * 4;
float *cur_x_grad = x_grad + interval_start * c + cur_c;

const float *cur_out_grad = out_grad + cur_geom_feats[3] * d * h * w * c +
Expand Down