@@ -64,13 +64,12 @@ void StridedCopyKernel(const Context& dev_ctx,
6464#if defined(PADDLE_WITH_CUDA)
6565// not support Windows
6666#if !defined(_WIN32)
67- if (FLAGS_use_stride_kernel && FLAGS_use_stride_compute_kernel &&
67+ if (FLAGS_use_stride_kernel &&
6868 input.place ().GetType () == phi::AllocationType::CPU &&
6969 out->place ().GetType () == phi::AllocationType::GPU &&
70- input.dtype () == out->dtype () && !input.meta ().is_contiguous ()) {
70+ input.dtype () == out->dtype () &&
71+ (!input.meta ().is_contiguous () || !out->meta ().is_contiguous ())) {
7172 phi::DenseTensor dst_gpu;
72- phi::DenseTensor src_cpu;
73-
7473 if (out->meta ().is_contiguous ()) {
7574 dst_gpu = *out;
7675 } else {
@@ -81,176 +80,191 @@ void StridedCopyKernel(const Context& dev_ctx,
8180 dev_ctx.Alloc (&dst_gpu, input.dtype ());
8281 }
8382
84- phi::DenseTensor cpu_input = input;
85- phi::DenseTensor* cpu_out = &src_cpu;
86- void * cpu_output_data;
83+ auto src_cpu_place = input.place ();
84+ auto dst_gpu_place = out->place ();
85+ auto & pool = phi::DeviceContextPool::Instance ();
86+ auto * gpu_dev_ctx = static_cast <phi::GPUContext*>(pool.Get (out->place ()));
87+ auto stream = gpu_dev_ctx->stream ();
88+
89+ if (input.meta ().is_contiguous ()) {
90+ auto src_cpu_place = input.place ();
91+ auto dst_gpu_place = out->place ();
92+ auto size = phi::SizeOf (input.dtype ()) * input.numel ();
93+ void * dst_ptr = gpu_dev_ctx->Alloc (
94+ &dst_gpu,
95+ dst_gpu.dtype (),
96+ 0 ,
97+ dst_gpu_place.GetType () == AllocationType::GPUPINNED);
98+
99+ phi::memory_utils::Copy (
100+ dst_gpu_place, dst_ptr, src_cpu_place, input.data <T>(), size, stream);
101+
102+ } else {
103+ phi::DenseTensor src_cpu;
104+ phi::DenseTensor cpu_input = input;
105+ phi::DenseTensor* cpu_out = &src_cpu;
106+ void * cpu_output_data;
87107
88- phi::DenseTensorMeta cpu_meta = cpu_input.meta ();
89- cpu_meta.strides = cpu_meta.calc_strides (cpu_meta.dims );
90- cpu_meta.offset = 0 ;
91- cpu_out->set_meta (cpu_meta);
108+ phi::DenseTensorMeta cpu_meta = cpu_input.meta ();
109+ cpu_meta.strides = cpu_meta.calc_strides (cpu_meta.dims );
110+ cpu_meta.offset = 0 ;
111+ cpu_out->set_meta (cpu_meta);
92112
93113#if defined(PADDLE_WITH_OPENMP)
94- dev_ctx.HostAlloc (cpu_out, cpu_out->dtype ());
114+ dev_ctx.HostAlloc (cpu_out, cpu_out->dtype ());
95115#endif
96- const void * cpu_input_data = cpu_input.data ();
97- cpu_output_data = malloc (phi::SizeOf (cpu_input.dtype ()) * cpu_out->numel ());
116+ const void * cpu_input_data = cpu_input.data ();
117+ cpu_output_data =
118+ malloc (phi::SizeOf (cpu_input.dtype ()) * cpu_out->numel ());
98119
99- if (FastTransposeCopyValid (*cpu_out, cpu_input)) {
100- constexpr int64_t TRANS_NUMEL = 60 ;
101- void * trans_buffer =
102- malloc (phi::SizeOf (input.dtype ()) * TRANS_NUMEL * TRANS_NUMEL);
120+ if (FastTransposeCopyValid (*cpu_out, cpu_input)) {
121+ constexpr int64_t TRANS_NUMEL = 60 ;
122+ void * trans_buffer =
123+ malloc (phi::SizeOf (input.dtype ()) * TRANS_NUMEL * TRANS_NUMEL);
103124
104- const T* tmp_src_ptr = reinterpret_cast <const T*>(cpu_input_data);
125+ const T* tmp_src_ptr = reinterpret_cast <const T*>(cpu_input_data);
105126#if defined(PADDLE_WITH_OPENMP)
106- T* tmp_out_ptr = reinterpret_cast <T*>(cpu_output_data);
127+ T* tmp_out_ptr = reinterpret_cast <T*>(cpu_output_data);
107128#else
108- T* tmp_out_ptr = cpu_out->data <T>();
129+ T* tmp_out_ptr = cpu_out->data <T>();
109130#endif
110- T* tmp_buf_ptr = reinterpret_cast <T*>(trans_buffer);
131+ T* tmp_buf_ptr = reinterpret_cast <T*>(trans_buffer);
111132
112- int64_t dim0 = cpu_out->dims ()[0 ];
113- int64_t dim1 = cpu_out->dims ()[1 ];
133+ int64_t dim0 = cpu_out->dims ()[0 ];
134+ int64_t dim1 = cpu_out->dims ()[1 ];
114135
115- for (int64_t d0 = 0 ; d0 < dim0; d0 += TRANS_NUMEL) {
116- for (int64_t d1 = 0 ; d1 < dim1; d1 += TRANS_NUMEL) {
117- const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0;
118- T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1;
136+ for (int64_t d0 = 0 ; d0 < dim0; d0 += TRANS_NUMEL) {
137+ for (int64_t d1 = 0 ; d1 < dim1; d1 += TRANS_NUMEL) {
138+ const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0;
139+ T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1;
119140
120- int nr = std::min (dim0 - d0, TRANS_NUMEL);
121- int nc = std::min (dim1 - d1, TRANS_NUMEL);
141+ int nr = std::min (dim0 - d0, TRANS_NUMEL);
142+ int nc = std::min (dim1 - d1, TRANS_NUMEL);
122143
123- for (int c = 0 ; c < nc; c++) {
124- memcpy (tmp_buf_ptr + c * TRANS_NUMEL,
125- src_ptr_inter + c * dim0,
126- nr * sizeof (T));
127- }
144+ for (int c = 0 ; c < nc; c++) {
145+ memcpy (tmp_buf_ptr + c * TRANS_NUMEL,
146+ src_ptr_inter + c * dim0,
147+ nr * sizeof (T));
148+ }
128149
129- int rc_max = std::max (nr, nc);
130- int rc_min = std::min (nr, nc);
131- for (int r = 0 ; r < rc_max; r++) {
132- int end = std::min (r, rc_min);
133- for (int c = 0 ; c < end; c++) {
134- T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c];
135- tmp_buf_ptr[r + TRANS_NUMEL * c] =
136- tmp_buf_ptr[r * TRANS_NUMEL + c];
137- tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp;
150+ int rc_max = std::max (nr, nc);
151+ int rc_min = std::min (nr, nc);
152+ for (int r = 0 ; r < rc_max; r++) {
153+ int end = std::min (r, rc_min);
154+ for (int c = 0 ; c < end; c++) {
155+ T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c];
156+ tmp_buf_ptr[r + TRANS_NUMEL * c] =
157+ tmp_buf_ptr[r * TRANS_NUMEL + c];
158+ tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp;
159+ }
138160 }
139- }
140161
141- for (int r = 0 ; r < nr; r++) {
142- memcpy (out_ptr_inter + r * dim1,
143- tmp_buf_ptr + r * TRANS_NUMEL,
144- nc * sizeof (T));
162+ for (int r = 0 ; r < nr; r++) {
163+ memcpy (out_ptr_inter + r * dim1,
164+ tmp_buf_ptr + r * TRANS_NUMEL,
165+ nc * sizeof (T));
166+ }
145167 }
146168 }
147- }
148- free (trans_buffer);
149- } else {
169+ free (trans_buffer);
170+ } else {
150171#if defined(PADDLE_WITH_OPENMP)
151- phi::DenseTensorIteratorConfig config;
152- config.add_output (*cpu_out);
153- config.add_const_input (cpu_input);
154- config.is_alloc_out_ = true ;
155- phi::DenseTensorIterator iter = config.build ();
156-
157- std::vector<int64_t > tmp_strides (
158- iter.ntensors () * static_cast <size_t >(std::max (iter.ndim (), 2 )));
172+ phi::DenseTensorIteratorConfig config;
173+ config.add_output (*cpu_out);
174+ config.add_const_input (cpu_input);
175+ config.is_alloc_out_ = true ;
176+ phi::DenseTensorIterator iter = config.build ();
159177
160- DealWithStride (iter, tmp_strides.data ());
178+ std::vector<int64_t > tmp_strides (
179+ iter.ntensors () * static_cast <size_t >(std::max (iter.ndim (), 2 )));
161180
162- std::vector<int64_t > out_stride (tmp_strides.begin () + iter.ntensors (),
163- tmp_strides.end ());
181+ DealWithStride (iter, tmp_strides.data ());
164182
165- std::vector<int64_t > output_stride = iter.strides ( 0 );
166- std::vector< int64_t > input_stride = iter. strides ( 1 );
183+ std::vector<int64_t > out_stride (tmp_strides. begin () + iter.ntensors (),
184+ tmp_strides. end () );
167185
168- const int64_t & numel = iter.numel ();
186+ std::vector<int64_t > output_stride = iter.strides (0 );
187+ std::vector<int64_t > input_stride = iter.strides (1 );
169188
170- const char * in_ptr = reinterpret_cast <const char *>(cpu_input_data);
171- char * out_ptr = reinterpret_cast <char *>(cpu_output_data);
189+ const int64_t & numel = iter.numel ();
172190
173- int64_t end = numel;
174- int64_t begin = 0 ;
175- int64_t grain_size = 32768 ;
191+ const char * in_ptr = reinterpret_cast <const char *>(cpu_input_data);
192+ char * out_ptr = reinterpret_cast <char *>(cpu_output_data);
176193
177- int64_t * whole_stride = tmp_strides.data ();
194+ int64_t end = numel;
195+ int64_t begin = 0 ;
196+ int64_t grain_size = 32768 ;
178197
179- omp_set_num_threads ( std::thread::hardware_concurrency () );
198+ int64_t * whole_stride = tmp_strides. data ( );
180199
181200#pragma omp parallel
182- {
183- int64_t num_threads = omp_get_num_threads ();
201+ {
202+ int64_t num_threads = omp_get_num_threads ();
184203
185- if (grain_size > 0 ) {
186- num_threads = std::min (num_threads, DivUp ((end - begin), grain_size));
187- }
204+ if (grain_size > 0 ) {
205+ num_threads =
206+ std::min (num_threads, DivUp ((end - begin), grain_size));
207+ }
188208
189- int64_t tid = omp_get_thread_num ();
190- int64_t chunk_size = DivUp ((end - begin), num_threads);
191- int64_t begin_tid = begin + tid * chunk_size;
192-
193- if (begin_tid < end) {
194- int64_t range_start = begin_tid;
195- int64_t range_end = std::min (end, chunk_size + begin_tid);
196-
197- auto dimiter = DimIter (iter.shape (), range_start, range_end);
198- while (!dimiter.iter_to_end ()) {
199- const auto v_ndim = dimiter.values .size ();
200- const char * tmp_in_data = in_ptr;
201- char * tmp_out_data = out_ptr;
202- for (size_t dim = 0 ; dim < v_ndim; dim++) {
203- int64_t value = dimiter.values [dim];
204- tmp_out_data += value * whole_stride[dim * iter.ntensors () + 0 ];
205- tmp_in_data += value * whole_stride[dim * iter.ntensors () + 1 ];
206- }
209+ int64_t tid = omp_get_thread_num ();
210+ int64_t chunk_size = DivUp ((end - begin), num_threads);
211+ int64_t begin_tid = begin + tid * chunk_size;
212+
213+ if (begin_tid < end) {
214+ int64_t range_start = begin_tid;
215+ int64_t range_end = std::min (end, chunk_size + begin_tid);
216+
217+ auto dimiter = DimIter (iter.shape (), range_start, range_end);
218+ while (!dimiter.iter_to_end ()) {
219+ const auto v_ndim = dimiter.values .size ();
220+ const char * tmp_in_data = in_ptr;
221+ char * tmp_out_data = out_ptr;
222+ for (size_t dim = 0 ; dim < v_ndim; dim++) {
223+ int64_t value = dimiter.values [dim];
224+ tmp_out_data += value * whole_stride[dim * iter.ntensors () + 0 ];
225+ tmp_in_data += value * whole_stride[dim * iter.ntensors () + 1 ];
226+ }
207227
208- auto step = dimiter.iter_for_step ();
228+ auto step = dimiter.iter_for_step ();
209229
210- for (int64_t i = 0 ; i < step[1 ]; i++) {
211- for (int64_t j = 0 ; j < step[0 ]; j++) {
212- const char * real_in_ptr = tmp_in_data + j * whole_stride[1 ];
213- char * real_out_ptr = tmp_out_data + j * whole_stride[0 ];
230+ for (int64_t i = 0 ; i < step[1 ]; i++) {
231+ for (int64_t j = 0 ; j < step[0 ]; j++) {
232+ const char * real_in_ptr = tmp_in_data + j * whole_stride[1 ];
233+ char * real_out_ptr = tmp_out_data + j * whole_stride[0 ];
214234
215- *reinterpret_cast <T*>(real_out_ptr) =
216- *reinterpret_cast <const T*>(real_in_ptr);
235+ *reinterpret_cast <T*>(real_out_ptr) =
236+ *reinterpret_cast <const T*>(real_in_ptr);
237+ }
238+ tmp_in_data = tmp_in_data + out_stride[1 ];
239+ tmp_out_data = tmp_out_data + out_stride[0 ];
217240 }
218- tmp_in_data = tmp_in_data + out_stride[1 ];
219- tmp_out_data = tmp_out_data + out_stride[0 ];
220- }
221241
222- dimiter.iter_to_next (step);
242+ dimiter.iter_to_next (step);
243+ }
223244 }
224245 }
225- }
226246#else
227- phi::ContiguousKernel<T, Context>(dev_ctx, input, cpu_out);
247+ phi::ContiguousKernel<T, Context>(dev_ctx, input, cpu_out);
228248#endif
229- }
230-
231- auto src_cpu_place = input.place ();
232- auto dst_gpu_place = out->place ();
233-
234- auto & pool = phi::DeviceContextPool::Instance ();
235- auto * gpu_dev_ctx = static_cast <phi::GPUContext*>(pool.Get (out->place ()));
236- auto stream = gpu_dev_ctx->stream ();
249+ }
237250#if defined(PADDLE_WITH_OPENMP)
238- auto * src_ptr = cpu_output_data;
251+ auto * src_ptr = cpu_output_data;
239252#else
240- auto * src_ptr = cpu_out->data <T>();
253+ auto * src_ptr = cpu_out->data <T>();
241254#endif
242255
243- auto size = phi::SizeOf (input.dtype ()) * src_cpu.numel ();
244- void * dst_ptr = gpu_dev_ctx->Alloc (
245- &dst_gpu,
246- dst_gpu.dtype (),
247- 0 ,
248- dst_gpu_place.GetType () == AllocationType::GPUPINNED);
256+ auto size = phi::SizeOf (input.dtype ()) * src_cpu.numel ();
257+ void * dst_ptr = gpu_dev_ctx->Alloc (
258+ &dst_gpu,
259+ dst_gpu.dtype (),
260+ 0 ,
261+ dst_gpu_place.GetType () == AllocationType::GPUPINNED);
249262
250- phi::memory_utils::Copy (
251- dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
263+ phi::memory_utils::Copy (
264+ dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
252265
253- free (cpu_output_data);
266+ free (cpu_output_data);
267+ }
254268 if (out != &dst_gpu) {
255269 PD_VISIT_ALL_TYPES (
256270 out->dtype (), " StridedCopyKernel" , ([&] {
0 commit comments