Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit ddc5f3c2
authored
Nov 30, 2017
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix wrong arguments on conv
1 parent
58284aa4
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
Dragon/src/utils/op_kernel.cu
Dragon/src/utils/op_kernel.cu
View file @
ddc5f3c
...
...
@@ -3027,22 +3027,22 @@ template <> void Im2Col2d<float, CUDAContext>(const int C, const int H, const in
const int count = (C * col_h * col_w);
_Im2Col2d_NCHW<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
H, W,
col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
col_h, col_w,
im,
col);
} else if (data_format == "NHWC") {
const int count = (col_h * col_w * C);
_Im2Col2d_NHWC<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
C, H, W,
col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
col_h, col_w,
im,
col);
} else LOG(FATAL) << "Unknown data format: " << data_format;
...
...
@@ -3148,22 +3148,22 @@ template <> void Col2Im2d<float, CUDAContext>(const int C, const int H, const in
const int count = (C * H * W);
_Col2Im2d_NCHW<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
H, W,
col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
col_h, col_w,
col,
im);
} else if (data_format == "NHWC") {
const int count = (H * W * C);
_Col2Im2d_NHWC<float> << <GET_BLOCKS(count), CUDA_NUM_THREADS >> >(count,
C, H, W,
col_h, col_w,
kernel_h, kernel_w,
stride_h, stride_w,
pad_h, pad_w,
dilation_h, dilation_w,
col_h, col_w,
col,
im);
} else LOG(FATAL) << "Unknown data format: " << data_format;
...
...
@@ -3979,9 +3979,9 @@ __global__ void _ROIAlign(const int count,
for (T w = wstart + w_stride; w <= wend - w_stride + 0.01; w += max(w_stride, 0.01)) {
x_idx++;
int hlow = min(max(static_cast<int>(floor(h)), 0), height - 1);
int hhigh =
hlow + 1
;
int hhigh =
min(hlow + 1, height - 1)
;
int wleft = min(max(static_cast<int>(floor(w)), 0), width - 1);
int wright =
wleft + 1
;
int wright =
min(wleft + 1, width - 1)
;
int topleft = hlow * width + wleft;
int topright = hlow * width + wright;
int bottomleft = hhigh * width + wleft;
...
...
@@ -4098,9 +4098,9 @@ __global__ void _ROIAlignGrad(const int count,
x_idx++;
if (offset_mask[pool_idx] != x_idx) continue;
int hlow = min(max(static_cast<int>(floor(rh)), 0), height - 1);
int hhigh =
hlow + 1
;
int hhigh =
min(hlow + 1, height - 1)
;
int wleft = min(max(static_cast<int>(floor(rw)), 0), width - 1);
int wright =
wleft + 1
;
int wright =
min(wleft + 1, width - 1)
;
if (h != hlow && h != hhigh && w != wleft && w != wright) continue;
T alpha = (hlow == hhigh) ? static_cast<T>(0.5) : (rh - hlow) / (hhigh - hlow);
T beta = (wleft == wright) ? static_cast<T>(0.5) : (rw - wleft) / (wright - wleft);
...
...
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment