@@ -197,6 +197,28 @@ namespace matx
197197
198198 tp->TransformExec (tp->Shape (), ex);
199199 }
200+ else if constexpr (is_tensor_view_v<typename T::tensor_type> && is_tensor_view_v<typename T::op_type> && is_cuda_executor_v<Ex>) {
201+ // If we are doing a tensor to tensor assignment we should prefer cudaMemcpyAsync instead of a kernel
202+ if (detail::check_aliased_memory (tp->get_lhs (), tp->get_rhs (), true )) {
203+ MATX_THROW (matxInvalidParameter, " Possible aliased memory detected: LHS and RHS memory ranges overlap" );
204+ }
205+
206+ if (tp->get_lhs ().IsContiguous () && tp->get_rhs ().IsContiguous () && tp->get_lhs ().Rank () == tp->get_rhs ().Rank ()) {
207+ MATX_ASSERT_STR (tp->get_lhs ().Bytes () >= tp->get_rhs ().Bytes (), matxInvalidSize, " LHS tensor is smaller than RHS tensor in assignment" );
208+ MATX_LOG_TRACE (" Copying {} bytes from {} to {} using cudaMemcpyAsync" ,
209+ tp->get_lhs ().Bytes (), reinterpret_cast <void *>(tp->get_rhs ().Data ()), reinterpret_cast <void *>(tp->get_lhs ().Data ()));
210+ cudaMemcpyAsync (reinterpret_cast <void *>(tp->get_lhs ().Data ()),
211+ reinterpret_cast <void *>(tp->get_rhs ().Data ()),
212+ tp->get_lhs ().Bytes (),
213+ cudaMemcpyDefault,
214+ ex.getStream ());
215+ }
216+ else {
217+ MATX_LOG_TRACE (" Copying {} bytes from {} to {} using kernel" ,
218+ tp->get_lhs ().Bytes (), reinterpret_cast <void *>(tp->get_rhs ().Data ()), reinterpret_cast <void *>(tp->get_lhs ().Data ()));
219+ ex.Exec (*tp);
220+ }
221+ }
200222 else {
201223 if (detail::check_aliased_memory (tp->get_lhs (), tp->get_rhs (), true )) {
202224 MATX_THROW (matxInvalidParameter, " Possible aliased memory detected: LHS and RHS memory ranges overlap" );
0 commit comments