Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ov::pass::RoPEFusionChatGLMHF : public ov::pass::MatcherPass {
class ov::pass::RoPEFusionQwen : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("RoPEFusionQwen");
RoPEFusionQwen(int split_output_id);
RoPEFusionQwen();
};

class ov::pass::RoPEFusionIOSlicing : public ov::pass::MatcherPass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model)
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionChatGLM>(true);
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionChatGLMHF>();
}
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionQwen>(0);
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionQwen>(1);

symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionQwen>();
symbolic_ctx_manager->register_pass<ov::pass::RoPEShareCosSin>();
return symbolic_optimizations.run_on_model(model);
}
Expand Down Expand Up @@ -843,7 +841,7 @@ ov::pass::RoPEFusionChatGLMHF::RoPEFusionChatGLMHF() {
this->register_matcher(m, callback);
}

ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
ov::pass::RoPEFusionQwen::RoPEFusionQwen() {
using namespace ov::op::util;
MATCHER_SCOPE(RoPEFusionQwen);

Expand All @@ -857,10 +855,9 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
pattern::wrap_type<v1::VariadicSplit>({qkv_proj, 2, {"head_cnt*head_size", "head_cnt*head_size", "?"}});
ListUnpack_410_VariadicSplit->set_output_size(3);
// B,L,H,S
auto view_Reshape_424 =
pattern::wrap_type<v1::Reshape>({ListUnpack_410_VariadicSplit->output(split_output_id), pattern::any_input()},
pattern::shape_matches("[?, ?, head_cnt, head_size]"),
{{"special_zero", true}});
auto view_Reshape_424 = pattern::wrap_type<v1::Reshape>({ListUnpack_410_VariadicSplit, pattern::any_input()},
pattern::shape_matches("[?, ?, head_cnt, head_size]"),
{{"special_zero", true}});
auto slice_Slice_543 = NewGenSlice(view_Reshape_424, 0, "head_size", 1, 3);

auto ShapeOf_485735 = pattern::wrap_type<ov::op::util::ShapeOfBase>({pattern::any_input()}, {});
Expand Down Expand Up @@ -946,6 +943,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto head_size = symbols["head_size"];
auto head_size_over_2 = symbols["head_size/2"];
auto head_cnt_by_head_size = symbols["head_cnt*head_size"];

if (!head_cnt.is_integer() || !head_size.is_integer() || !head_size_over_2.is_integer() ||
!head_cnt_by_head_size.is_integer() || head_size_over_2.i() * 2 != head_size.i() ||
head_cnt.i() * head_size.i() != head_cnt_by_head_size.i()) {
Expand All @@ -958,14 +956,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
config.head_size = static_cast<size_t>(head_size.i());
config.rotary_ndims = config.head_size;

if (split_output_id == 0) {
// query : split_output_id == 0
const auto& qkv_proj_split_node = pattern_map.at(ListUnpack_410_VariadicSplit);
const size_t qkv_proj_split_id = qkv_proj_split_node.get_index();
if (qkv_proj_split_id == 0) {
// query : split output id == 0
config.slice_start = 0;
config.slice_stop = config.head_cnt * config.head_size;
} else {
// key : split_output_id == 1
;
} else if (qkv_proj_split_id == 1) {
// key : split output id == 1
config.slice_start = config.head_cnt * config.head_size;
config.slice_stop = config.slice_start + config.head_cnt * config.head_size;
} else {
return false;
}

new_args.push_back(pattern_map.at(qkv_proj));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1354,72 +1354,70 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
}
}

TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) {
TEST_P(ConvertToROPETest, ConvertToROPE_Qwen_PagedAttention) {
using namespace ov;

{
auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1, -1});
auto qkv = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096});

auto qkv_proj = makeOP<opset1::VariadicSplit>({qkv, 2, {4096, 4096, -1}});

auto view_Reshape = makeOP<opset1::Reshape>({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}});
auto slice_Slice_4 = makeOP<opset8::Slice>({view_Reshape, {0}, {128}, {1}, {3}});
auto slice_Slice = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});

auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1, 1}}, {{"special_zero", false}});

auto slice_Slice_1 = makeOP<opset8::Gather>({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}});
auto Reshape_27400 = makeOP<opset1::Reshape>({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}});

auto mul_Multiply = makeOP<opset1::Multiply>({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}});
auto reshape_Reshape = makeOP<opset1::Reshape>({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
auto ListUnpack_Split = makeOP<opset1::Split>({reshape_Reshape, -2}, {{"num_splits", 2}});
auto Multiply_54136 =
makeOP<opset1::Multiply>({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}});
auto ListUnpack_Squeeze_0 =
makeOP<opset1::Reshape>({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto ListUnpack_Squeeze =
makeOP<opset1::Reshape>({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
auto cat_Concat = makeOP<opset1::Concat>({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}});

auto slice_Slice_2 = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
auto slice_Slice_6 = makeOP<opset8::Gather>({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}});
auto Reshape_27408 = makeOP<opset1::Reshape>({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}});
auto mul_Multiply_1 = makeOP<opset1::Multiply>({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}});
auto add_Add = makeOP<opset1::Add>({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}});
constexpr int head_cnt = 32, head_size = 128;
int output_idx = GetParam();

auto slice_Slice_10 = makeConst(element::f32, ov::Shape({1, 32767, 1, 1}), {1});
auto view_Reshape_1 = makeOP<opset1::Reshape>({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}});
auto slice_Slice_11 = makeOP<opset8::Slice>({view_Reshape_1, {0}, {128}, {1}, {3}});
auto mul_Multiply_2 = makeOP<opset1::Multiply>({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}});
auto reshape_Reshape_1 = makeOP<opset1::Reshape>({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
auto ListUnpack_Split_1 = makeOP<opset1::Split>({reshape_Reshape_1, -2}, {{"num_splits", 2}});
auto Multiply_54139 =
makeOP<opset1::Multiply>({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}});
auto ListUnpack_Squeeze_0_1 =
makeOP<opset1::Reshape>({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto ListUnpack_Squeeze_1 =
makeOP<opset1::Reshape>({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
auto cat_Concat_2 = makeOP<opset1::Concat>({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}});
auto mul_Multiply_3 = makeOP<opset1::Multiply>({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}});
auto add_Add_1 = makeOP<opset1::Add>({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}});
model = std::make_shared<ov::Model>(ov::OutputVector{add_Add_1}, ov::ParameterVector{position_ids, qkv});
{
// Parameters
auto position_ids = std::make_shared<opset1::Parameter>(element::i64, PartialShape{-1, -1});
auto qkv = std::make_shared<opset1::Parameter>(element::f32, PartialShape{-1, 1, 3 * head_cnt * head_size});

// Split QKV and reshape to [batch, 1, head_cnt, head_size]
auto qkv_proj = makeOP<opset1::VariadicSplit>({qkv, 2, {head_cnt * head_size, head_cnt * head_size, -1}});
auto view = makeOP<opset1::Reshape>({qkv_proj->output(output_idx), {0, 0, head_cnt, head_size}},
{{"special_zero", true}});

// Slice out rotary dims
auto slice = makeOP<opset8::Slice>({view, {0}, {128}, {1}, {3}});

// Prepare rotary embedding table and gather by position
auto rotary_emp = makeConst(element::f32, {1, 4096, 1, 128}, {1});
auto pos_i32 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
auto pos_reshaped = makeOP<opset1::Reshape>({pos_i32, {-1, 1}}, {{"special_zero", false}});
auto gathered = makeOP<opset8::Gather>({rotary_emp, pos_reshaped, 1}, {{"batch_dims", 0}});
auto gathered_reshape = makeOP<opset1::Reshape>({gathered, {-1, 1, 1, 128}}, {{"special_zero", false}});

// Elementwise multiply
auto mul = makeOP<opset1::Multiply>({slice, gathered_reshape}, {{"auto_broadcast", "numpy"}});

// Interleave/stack for rotary
auto reshaped = makeOP<opset1::Reshape>({slice, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
auto split = makeOP<opset1::Split>({reshaped, -2}, {{"num_splits", 2}});
auto neg = makeOP<opset1::Multiply>({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}});
auto squeeze0 = makeOP<opset1::Reshape>({neg, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto squeeze1 = makeOP<opset1::Reshape>({split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
auto cat = makeOP<opset1::Concat>({squeeze0, squeeze1}, {{"axis", -1}});

// Second rotary embedding gather and multiply
auto rotary_emp2 = makeConst(element::f32, {1, 4096, 1, 128}, {1});
auto gathered2 = makeOP<opset8::Gather>({rotary_emp2, pos_reshaped, 1}, {{"batch_dims", 0}});
auto gathered2_reshape = makeOP<opset1::Reshape>({gathered2, {-1, 1, 1, 128}}, {{"special_zero", false}});
auto mul2 = makeOP<opset1::Multiply>({cat, gathered2_reshape}, {{"auto_broadcast", "numpy"}});

// Final add
auto add = makeOP<opset1::Add>({mul, mul2}, {{"auto_broadcast", "numpy"}});

model = std::make_shared<Model>(OutputVector{add}, ParameterVector{position_ids, qkv});
}

manager.register_pass<ov::pass::RoPEFusion>(false);

{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3});
auto rotary_emp_sin = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
auto rotary_emp_cos = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1});
auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1, -1});
auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1, 1}}, {{"special_zero", false}});
auto rope = makeOP<ov::op::internal::RoPE>({input, rotary_emp_sin, rotary_emp_cos, Unsqueeze_23750},
{{"config.slice_start", 4096},
{"config.slice_stop", 8192},
int slice_start = output_idx == 0 ? 0 : head_cnt * head_size;
int slice_stop = slice_start + head_cnt * head_size;

auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape{-1, 1, 4096 * 3});
auto rotary_emp_sin = makeConst(element::f32, {1, 4096, 1, 128}, {1});
auto rotary_emp_cos = makeConst(element::f32, {1, 4096, 1, 128}, {1});
auto position_ids = std::make_shared<opset1::Parameter>(element::i64, PartialShape{-1, -1});
auto pos_i32 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
auto pos_reshaped = makeOP<opset1::Reshape>({pos_i32, {-1, 1}}, {{"special_zero", false}});
auto rope = makeOP<ov::op::internal::RoPE>({input, rotary_emp_sin, rotary_emp_cos, pos_reshaped},
{{"config.slice_start", slice_start},
{"config.slice_stop", slice_stop},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", false},
Expand All @@ -1428,11 +1426,12 @@ TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) {
{"config.support_2d_rope", false},
{"config.is_qwen", true},
{"config.use_rope_cache", false},
{"config.head_cnt", 32},
{"config.head_size", 128},
{"config.head_cnt", head_cnt},
{"config.head_size", head_size},
{"config.gather_position_arg_id", 3}});
model_ref = std::make_shared<ov::Model>(ov::OutputVector{rope}, ov::ParameterVector{input, position_ids});
model_ref = std::make_shared<Model>(OutputVector{rope}, ParameterVector{input, position_ids});
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
}

TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
Expand Down
Loading