diff --git a/include/matx/core/tensor.h b/include/matx/core/tensor.h index 0d52b496..5386d069 100644 --- a/include/matx/core/tensor.h +++ b/include/matx/core/tensor.h @@ -741,7 +741,7 @@ class tensor_t : public detail::tensor_impl_t { int dev; cudaGetDevice(&dev); - #if CUDA_VERSION <= 12000 + #if CUDART_VERSION <= 12000 cudaMemPrefetchAsync(this->Data(), this->desc_.TotalSize() * sizeof(T), dev, stream); #else cudaMemLocation loc; @@ -765,7 +765,7 @@ class tensor_t : public detail::tensor_impl_t { { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - #if CUDA_VERSION <= 12000 + #if CUDART_VERSION <= 12000 cudaMemPrefetchAsync(this->Data(), this->desc_.TotalSize() * sizeof(T), cudaCpuDeviceId, stream); #else diff --git a/include/matx/operators/base_operator.h b/include/matx/operators/base_operator.h index 5a4e30e6..bfba5a3a 100644 --- a/include/matx/operators/base_operator.h +++ b/include/matx/operators/base_operator.h @@ -197,6 +197,28 @@ namespace matx tp->TransformExec(tp->Shape(), ex); } + else if constexpr (is_tensor_view_v && is_tensor_view_v && is_cuda_executor_v) { + // If we are doing a tensor to tensor assignment we should prefer cudaMemcpyAsync instead of a kernel + if (detail::check_aliased_memory(tp->get_lhs(), tp->get_rhs(), true)) { + MATX_THROW(matxInvalidParameter, "Possible aliased memory detected: LHS and RHS memory ranges overlap"); + } + + if (tp->get_lhs().IsContiguous() && tp->get_rhs().IsContiguous() && tp->get_lhs().Rank() == tp->get_rhs().Rank()) { + MATX_ASSERT_STR(tp->get_lhs().Bytes() >= tp->get_rhs().Bytes(), matxInvalidSize, "LHS tensor is smaller than RHS tensor in assignment"); + MATX_LOG_TRACE("Copying {} bytes from {} to {} using cudaMemcpyAsync", + tp->get_lhs().Bytes(), reinterpret_cast(tp->get_rhs().Data()), reinterpret_cast(tp->get_lhs().Data())); + cudaMemcpyAsync(reinterpret_cast(tp->get_lhs().Data()), + reinterpret_cast(tp->get_rhs().Data()), + tp->get_rhs().Bytes(), + cudaMemcpyDefault, + ex.getStream()); + } + else { + MATX_LOG_TRACE("Copying {} bytes from {} to {} using kernel", + tp->get_lhs().Bytes(), reinterpret_cast(tp->get_rhs().Data()), reinterpret_cast(tp->get_lhs().Data())); + ex.Exec(*tp); + } + } else { if (detail::check_aliased_memory(tp->get_lhs(), tp->get_rhs(), true)) { MATX_THROW(matxInvalidParameter, "Possible aliased memory detected: LHS and RHS memory ranges overlap");