Skip to content

Commit e264503

Browse files
committed
udpate codes
1 parent f64cd19 commit e264503

File tree

4 files changed

+7
-30
lines changed

4 files changed

+7
-30
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,15 @@ def infer_result(
835835
self, x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, dtype
836836
):
837837
M, K = x.shape
838+
ranks = lora_a.size(0)
838839
N = lora_b.size(1)
839840
output = torch.empty((M, N), dtype=x.dtype, device=x.device)
840-
return output, output
841+
# assuem totalRank is the max rank
842+
internal_output_x_lora_a = torch.empty(
843+
(M, ranks * M), dtype=x.dtype, device=x.device
844+
)
845+
internal_lora_a_transpose = torch.empty_like(lora_a)
846+
return output, internal_output_x_lora_a, internal_lora_a_transpose
841847

842848

843849
class AclNnInplaceAdd(Operator):

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,6 @@ def CustomFusedLora(
12511251
name, x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, dtype
12521252
):
12531253
op = Operation(name, "CustomFusedLoraOperation")
1254-
# TODO: add param
12551254
param = infer_param.CustomFusedLoraParam()
12561255
param.name = name
12571256
param.dtype = get_ascend_dtype(dtype)

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ void CustomFusedLoraOperation::ClearInternal() {
102102
aclWeightA_.clear();
103103
aclWeightB_.clear();
104104
aclWeightATranspose_.clear();
105-
weightA_.clear();
106-
weightB_.clear();
107-
weightATranspose_.clear();
108105

109106
aclScalingInput_.clear();
110107
scalingInput_.clear();
@@ -115,19 +112,6 @@ void CustomFusedLoraOperation::ClearInternal() {
115112
aclScalingExecutor_.clear();
116113
}
117114

118-
// Helper function to create weight tensor
119-
atb::Tensor CustomFusedLoraOperation::CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset) {
120-
atb::Tensor weightTensor;
121-
weightTensor.desc.dtype = baseTensor.desc.dtype;
122-
weightTensor.desc.format = baseTensor.desc.format;
123-
weightTensor.desc.shape.dimNum = baseTensor.desc.shape.dimNum;
124-
weightTensor.desc.shape.dims[0] = rank;
125-
weightTensor.desc.shape.dims[1] = dim;
126-
weightTensor.dataSize = atb::Utils::GetTensorSize(weightTensor.desc);
127-
weightTensor.deviceData = static_cast<uint8_t*>(baseTensor.deviceData) + offset;
128-
return weightTensor;
129-
}
130-
131115
// Helper function to calculate offset for weight tensors
132116
uint64_t CustomFusedLoraOperation::CalculateWeightOffset(const std::vector<int32_t>& ranksVec, size_t adapterId, uint64_t tensorSizePerRank) {
133117
uint64_t offset = 0;
@@ -183,12 +167,6 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
183167
const int64_t loraBDim = variantPack.inTensors.at(2).desc.shape.dims[1];
184168

185169
ClearInternal();
186-
187-
// Pre-allocate vectors to avoid reallocations
188-
weightA_.reserve(adapterIdsVec.size());
189-
weightATranspose_.reserve(adapterIdsVec.size());
190-
weightB_.reserve(adapterIdsVec.size());
191-
192170
aclWeightA_.reserve(adapterIdsVec.size());
193171
aclWeightB_.reserve(adapterIdsVec.size());
194172
aclWeightATranspose_.reserve(adapterIdsVec.size());

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,13 @@ class CustomFusedLoraOperation : public atb::Operation {
2828
void ClearAclScalrs();
2929
void ClearInternal();
3030

31-
// Helper functions for weight tensor creation and offset calculation
32-
atb::Tensor CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset);
3331
uint64_t CalculateWeightOffset(const std::vector<int32_t>& ranksVec, size_t adapterId, uint64_t tensorSizePerRank);
3432

3533
private:
3634
std::string opName_;
3735
std::string dtype_;
3836
std::vector<aclScalar*> aclScalingScalar_;
3937

40-
std::vector<atb::Tensor> weightA_;
41-
std::vector<atb::Tensor> weightB_;
42-
std::vector<atb::Tensor> weightATranspose_;
43-
4438
std::vector<AclNnTensor> aclWeightA_;
4539
std::vector<AclNnTensor> aclWeightB_;
4640
std::vector<AclNnTensor> aclWeightATranspose_;

0 commit comments

Comments
 (0)