Skip to content

Commit ba39537

Browse files
authored
Revert "fix cat bug" (DeepLink-org#784)
Revert "fix cat bug (DeepLink-org#772)" This reverts commit 285cca8.
1 parent 285cca8 commit ba39537

File tree

12 files changed

+128
-133
lines changed

12 files changed

+128
-133
lines changed

impl/ascend/common/acloprunner.hpp

100755100644
File mode changed.

impl/ascend/device_configs.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -207,28 +207,9 @@
207207
),
208208

209209
'conv_2d_no_contiguous': dict(
210-
name=["conv2d"],
211-
tensor_para=dict(
212-
args=[
213-
{
214-
"ins": ["input"],
215-
"dtype": [Skip(np.float32), Skip(np.float16), Skip(np.float64)],
216-
},
217-
]
218-
),
219-
),
220-
221-
'relu_no_contiguous': dict(
222-
name=["relu"],
223-
is_inplace=True,
224-
tensor_para=dict(
225-
args=[
226-
{
227-
"ins": ['input'],
228-
"dtype": [Skip(np.float32), Skip(np.float64)],
229-
},
230-
],
231-
),
210+
name=['conv2d'],
211+
atol=1e-1,
212+
rtol=1e-2,
232213
),
233214

234215
'hardswish': dict(
@@ -1607,6 +1588,78 @@
16071588
),
16081589
),
16091590

1591+
'copy': dict(
1592+
name=["copy_"],
1593+
tensor_para=dict(
1594+
# FIXME data type DT_COMPLEX128 of input [dst] is not supported
1595+
args=[
1596+
{
1597+
"ins": ["input"],
1598+
"shape": [Skip((12, 0, 9)), Skip((8,))],
1599+
"dtype": [Skip(np.complex128), Skip(np.complex64)],
1600+
},
1601+
{
1602+
"ins": ["other"],
1603+
"dtype": [Skip(np.complex128)]
1604+
},
1605+
]
1606+
)
1607+
),
1608+
1609+
'copy_input_no_contiguous': dict(
1610+
name=["copy_"],
1611+
tensor_para=dict(
1612+
# FIXME not supported complex
1613+
args=[
1614+
{
1615+
"ins": ["input"],
1616+
"shape": [Skip((12, 1, 12)),],
1617+
"dtype": [Skip(np.complex128), Skip(np.complex64)],
1618+
},
1619+
{
1620+
"ins": ["other"],
1621+
"dtype": [Skip(np.complex64)]
1622+
},
1623+
]
1624+
)
1625+
),
1626+
1627+
'copy_other_no_contiguous': dict(
1628+
name=["copy_"],
1629+
tensor_para=dict(
1630+
# FIXME data type DT_COMPLEX64 of input [dst] is not supported
1631+
# FIXME data type DT_COMPLEX128 of input [dst] is not supported
1632+
args=[
1633+
{
1634+
"ins": ["input"],
1635+
"shape": [Skip((6, 5, 384))],
1636+
"dtype": [Skip(np.complex128), Skip(np.complex64)],
1637+
},
1638+
{
1639+
"ins": ["other"],
1640+
"dtype": [Skip(np.complex128)],
1641+
},
1642+
]
1643+
)
1644+
),
1645+
1646+
'copy_all_no_contiguous': dict(
1647+
name=["copy_"],
1648+
tensor_para=dict(
1649+
# FIXME data type DT_COMPLEX64 of input [dst] is not supported
1650+
args=[
1651+
{
1652+
"ins": ["input"],
1653+
"shape": [Skip((192, 147, 2)), Skip((2, 12, 38, 45, 3))],
1654+
},
1655+
{
1656+
"ins": ["other"],
1657+
"dtype": [Skip(np.complex64)],
1658+
},
1659+
]
1660+
)
1661+
),
1662+
16101663
'fill_not_float': dict(
16111664
name=["fill_"],
16121665
tensor_para=dict(

impl/ascend/functions/cast.cpp

100755100644
Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,8 @@
77
#include "../common/acloprunner.hpp"
88

99
namespace impl {
10-
11-
// TODO(zhaoguochun): fix me
12-
namespace ascend_npu {
13-
extern diopiError_t diopiCastDtype(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
14-
}
15-
1610
namespace ascend {
17-
#if 0
11+
1812
diopiError_t diopiCastDtype(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
1913
int64_t numel = 0;
2014
diopiGetTensorNumel(input, &numel);
@@ -63,11 +57,6 @@ diopiError_t diopiCastDtype(diopiContextHandle_t ctx, diopiTensorHandle_t out, d
6357

6458
return diopiSuccess;
6559
}
66-
#endif
67-
68-
diopiError_t diopiCastDtype(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
69-
return ascend_npu::diopiCastDtype(ctx, out, input);
70-
}
7160

7261
} // namespace ascend
7362
} // namespace impl

impl/ascend_npu/ascend_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ ascend:
1717
- diopiBitwiseAnd
1818
- diopiBitwiseNot
1919
- diopiBmm
20+
- diopiCastDtype
21+
- diopiCat
2022
- diopiClamp
2123
- diopiClampInp
2224
- diopiClampInpScalar
@@ -30,6 +32,7 @@ ascend:
3032
- diopiClampMinScalar
3133
- diopiClampScalar
3234
- diopiContiguous
35+
- diopiCopyInp
3336
- diopiCos
3437
- diopiCosInp
3538
- diopiCrossEntropyLoss
@@ -201,9 +204,6 @@ ascend:
201204
- diopiScatterInpScalar
202205
- diopiApplyPenalty
203206
ascend_npu:
204-
- diopiCastDtype
205-
- diopiCopyInp
206-
- diopiCat
207207
- diopiRemainderTensor
208208
- diopiRemainderScalar
209209
- diopiRemainder

impl/ascend_npu/diopi_impl/cast.cpp

Lines changed: 0 additions & 23 deletions
This file was deleted.

impl/ascend_npu/diopi_impl/cat.cpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

impl/ascend_npu/diopi_impl/copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace OP_IMPL_NS {
1212

1313
diopiError_t diopiCopyInp(diopiContextHandle_t ctx, diopiConstTensorHandle_t src, diopiTensorHandle_t dest) {
1414
BEGIN_CALL_ACL_OP(src, dest);
15-
if (src == nullptr || dest == nullptr || !srcAt.defined() || !destAt.defined() || srcAt.numel() <= 0 || destAt.numel() <= 0) {
15+
if (!srcAt.defined() || !destAt.defined()) {
1616
return diopiSuccess;
1717
}
1818
at_npu::native::NPUNativeFunctions::copy_(destAt, srcAt, false);

impl/ascend_npu/diopi_impl/helper.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,8 @@ inline int debugLevel() {
108108
impl::aten::setCurCtx(ctx); \
109109
BUILD_ATEN_ARGS(__VA_ARGS__)
110110

111-
#define END_CALL_ACL_OP() \
112-
impl::aten::unsetCurCtx(); \
113-
if (debugLevel()) { \
114-
std::cout << __FILE__ << ":" << __LINE__ << " :" << __FUNCTION__ << " over" << std::endl; \
115-
} \
111+
#define END_CALL_ACL_OP() \
112+
impl::aten::unsetCurCtx(); \
116113
return diopiSuccess;
117114

118115
inline void logError() { std::cerr << std::endl; }

impl/ascend_npu/torch_npu/csrc/CopyKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ bool try_to_optimize_copy_with_any_format(at::Tensor& self, const at::Tensor& sr
282282
}
283283

284284
at::Tensor& NPUNativeFunctions::copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) {
285-
if (!self.defined() || self.numel() == 0) {
285+
if (self.numel() == 0) {
286286
return self;
287287
}
288288
// save tensor dim name

impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp

100644100755
Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,26 +1572,16 @@ class AclTensorDescMaker {
15721572
dims = storageDesc.base_sizes_;
15731573
}
15741574
auto format = storageDesc.origin_format_;
1575-
if (debugLevel()) {
1576-
std::cout << __FUNCTION__ << ":" << dataType << "," << dims << "," << format << std::endl;
1577-
}
1578-
15791575
desc = aclCreateTensorDesc(dataType, dims.size(), dims.data(), format);
15801576
return *this;
15811577
}
15821578

15831579
inline AclTensorDescMaker& Create(aclDataType dataType, c10::IntArrayRef dims, aclFormat format) {
1584-
if (debugLevel()) {
1585-
std::cout << __FUNCTION__ << ":" << dataType << "," << dims << "," << format << std::endl;
1586-
}
15871580
desc = aclCreateTensorDesc(dataType, dims.size(), dims.data(), format);
15881581
return *this;
15891582
}
15901583

15911584
inline AclTensorDescMaker& Create(aclDataType dataType, aclFormat format) {
1592-
if (debugLevel()) {
1593-
std::cout << __FUNCTION__ << ":" << dataType << "," << format << std::endl;
1594-
}
15951585
desc = aclCreateTensorDesc(dataType, 0, nullptr, format);
15961586
return *this;
15971587
}
@@ -2176,7 +2166,19 @@ std::tuple<aclTensorDesc*, aclDataBuffer*> CovertToAclOutput(const at::Tensor& t
21762166
// This class maintain the position of the current
21772167
// OpCommandImpl object in vector, the resources in
21782168
// the object is
2169+
class OpCommandImpls {
2170+
public:
2171+
TORCH_NPU_API static OpCommandImpls* GetInstanceByTid(std::thread::id tid);
2172+
TORCH_NPU_API void Push(OpCommandImpl*& ptr);
2173+
TORCH_NPU_API void Pop();
21792174

2175+
private:
2176+
int32_t offset = -1;
2177+
c10::SmallVector<OpCommandImpl, N> objs;
2178+
}; // class OpCommandImpls
2179+
2180+
static std::unordered_map<std::thread::id, OpCommandImpls> opcommand_impls_map;
2181+
static std::mutex map_mutex;
21802182
static bool deterministicaclnn_oldstatus = false;
21812183

21822184
void OpCommandImpl::SetDeterministic() {
@@ -2190,15 +2192,38 @@ void OpCommandImpl::SetDeterministic() {
21902192
}
21912193
}
21922194

2195+
OpCommandImpls* OpCommandImpls::GetInstanceByTid(std::thread::id tid) {
2196+
if (opcommand_impls_map.find(tid) == opcommand_impls_map.end()) {
2197+
OpCommandImpls impl;
2198+
std::lock_guard<std::mutex> lock(map_mutex);
2199+
opcommand_impls_map[tid] = std::move(impl);
2200+
}
2201+
return &opcommand_impls_map[tid];
2202+
}
2203+
2204+
void OpCommandImpls::Push(OpCommandImpl*& ptr) {
2205+
++offset;
2206+
if (static_cast<int32_t>(objs.size()) <= offset) {
2207+
OpCommandImpl impl;
2208+
objs.emplace_back(std::move(impl));
2209+
}
2210+
TORCH_CHECK(objs.size() > offset, "OpCommand size (", objs.size(), ") is smaller than offset (", offset, ")");
2211+
ptr = &objs[offset];
2212+
}
2213+
2214+
void OpCommandImpls::Pop() {
2215+
TORCH_CHECK(offset >= 0, "OpCommand current offset should not be less than ", offset);
2216+
offset -= 1;
2217+
}
2218+
21932219
OpCommand::OpCommand() {
2194-
aclCmd = new OpCommandImpl();
2220+
aclCmds = OpCommandImpls::GetInstanceByTid(std::this_thread::get_id());
2221+
2222+
aclCmds->Push(aclCmd);
21952223
aclCmd->SetCustomHandler(nullptr);
21962224
}
21972225

2198-
OpCommand::~OpCommand() {
2199-
OpCommandImpl* impl = static_cast<OpCommandImpl*>(aclCmd);
2200-
delete impl;
2201-
}
2226+
OpCommand::~OpCommand() {}
22022227

22032228
OpCommand& OpCommand::Name(const string& name) {
22042229
aclCmd->SetName(name);
@@ -2390,6 +2415,7 @@ void OpCommand::Run() {
23902415
Sync();
23912416
}
23922417
aclCmd->releaseSource();
2418+
aclCmds->Pop();
23932419
}
23942420

23952421
OpCommand& OpCommand::Sync(c10::SmallVector<int64_t, N>& index) {

0 commit comments

Comments
 (0)