Commit ddc5f3c2 by Ting PAN

Fix wrong arguments on conv

1 parent 58284aa4
Showing with 8 additions and 8 deletions
...@@ -3027,22 +3027,22 @@ template <> void Im2Col2d<float, CUDAContext>(const int C, const int H, const in ...@@ -3027,22 +3027,22 @@ template <> void Im2Col2d<float, CUDAContext>(const int C, const int H, const in
const int count = (C * col_h * col_w); const int count = (C * col_h * col_w);
_Im2Col2d_NCHW<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _Im2Col2d_NCHW<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
H, W, H, W,
col_h, col_w,
kernel_h, kernel_w, kernel_h, kernel_w,
stride_h, stride_w, stride_h, stride_w,
pad_h, pad_w, pad_h, pad_w,
dilation_h, dilation_w, dilation_h, dilation_w,
col_h, col_w,
im, im,
col); col);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
const int count = (col_h * col_w * C); const int count = (col_h * col_w * C);
_Im2Col2d_NHWC<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _Im2Col2d_NHWC<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
C, H, W, C, H, W,
col_h, col_w,
kernel_h, kernel_w, kernel_h, kernel_w,
stride_h, stride_w, stride_h, stride_w,
pad_h, pad_w, pad_h, pad_w,
dilation_h, dilation_w, dilation_h, dilation_w,
col_h, col_w,
im, im,
col); col);
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
...@@ -3148,22 +3148,22 @@ template <> void Col2Im2d<float, CUDAContext>(const int C, const int H, const in ...@@ -3148,22 +3148,22 @@ template <> void Col2Im2d<float, CUDAContext>(const int C, const int H, const in
const int count = (C * H * W); const int count = (C * H * W);
_Col2Im2d_NCHW<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _Col2Im2d_NCHW<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
H, W, H, W,
col_h, col_w,
kernel_h, kernel_w, kernel_h, kernel_w,
stride_h, stride_w, stride_h, stride_w,
pad_h, pad_w, pad_h, pad_w,
dilation_h, dilation_w, dilation_h, dilation_w,
col_h, col_w,
col, col,
im); im);
} else if (data_format == "NHWC") { } else if (data_format == "NHWC") {
const int count = (H * W * C); const int count = (H * W * C);
_Col2Im2d_NHWC<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count, _Col2Im2d_NHWC<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
C, H, W, C, H, W,
col_h, col_w,
kernel_h, kernel_w, kernel_h, kernel_w,
stride_h, stride_w, stride_h, stride_w,
pad_h, pad_w, pad_h, pad_w,
dilation_h, dilation_w, dilation_h, dilation_w,
col_h, col_w,
col, col,
im); im);
} else LOG(FATAL) << "Unknown data format: " << data_format; } else LOG(FATAL) << "Unknown data format: " << data_format;
...@@ -3979,9 +3979,9 @@ __global__ void _ROIAlign(const int count, ...@@ -3979,9 +3979,9 @@ __global__ void _ROIAlign(const int count,
for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) { for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) {
x_idx++; x_idx++;
int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1); int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1);
int hhigh = hlow + 1; int hhigh = min(hlow + 1, height - 1);
int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1); int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1);
int wright = wleft + 1; int wright = min(wleft + 1, width - 1);
int topleft = hlow * width + wleft; int topleft = hlow * width + wleft;
int topright = hlow * width + wright; int topright = hlow * width + wright;
int bottomleft = hhigh * width + wleft; int bottomleft = hhigh * width + wleft;
...@@ -4098,9 +4098,9 @@ __global__ void _ROIAlignGrad(const int count, ...@@ -4098,9 +4098,9 @@ __global__ void _ROIAlignGrad(const int count,
x_idx++; x_idx++;
if (offset_mask[pool_idx] != x_idx) continue; if (offset_mask[pool_idx] != x_idx) continue;
int hlow = min(max(static_cast<int>(floor(rh)), 0), height - 1); int hlow = min(max(static_cast<int>(floor(rh)), 0), height - 1);
int hhigh = hlow + 1; int hhigh = min(hlow + 1, height - 1);
int wleft = min(max(static_cast<int>(floor(rw)), 0), width - 1); int wleft = min(max(static_cast<int>(floor(rw)), 0), width - 1);
int wright = wleft + 1; int wright = min(wleft + 1, width - 1);
if (h != hlow && h != hhigh && w != wleft && w != wright) continue; if (h != hlow && h != hhigh && w != wleft && w != wright) continue;
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (rh - hlow) / (hhigh - hlow); T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (rh - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (rw - wleft) / (wright - wleft); T beta = (wleft == wright) ? static_cast<T>(0.5) : (rw - wleft) / (wright - wleft);
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!