diff --git a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp index 5c6333917dc54c..01f49149351617 100644 --- a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp @@ -58,7 +58,7 @@ class ov::pass::RoPEFusionChatGLMHF : public ov::pass::MatcherPass { class ov::pass::RoPEFusionQwen : public ov::pass::MatcherPass { public: OPENVINO_MATCHER_PASS_RTTI("RoPEFusionQwen"); - RoPEFusionQwen(int split_output_id); + RoPEFusionQwen(); }; class ov::pass::RoPEFusionIOSlicing : public ov::pass::MatcherPass { diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 60f3d2775829b5..672be530d9cb35 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -71,9 +71,7 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr& model) symbolic_ctx_manager->register_pass(true); symbolic_ctx_manager->register_pass(); } - symbolic_ctx_manager->register_pass(0); - symbolic_ctx_manager->register_pass(1); - + symbolic_ctx_manager->register_pass(); symbolic_ctx_manager->register_pass(); return symbolic_optimizations.run_on_model(model); } @@ -843,7 +841,7 @@ ov::pass::RoPEFusionChatGLMHF::RoPEFusionChatGLMHF() { this->register_matcher(m, callback); } -ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { +ov::pass::RoPEFusionQwen::RoPEFusionQwen() { using namespace ov::op::util; MATCHER_SCOPE(RoPEFusionQwen); @@ -857,10 +855,9 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { pattern::wrap_type({qkv_proj, 2, {"head_cnt*head_size", "head_cnt*head_size", "?"}}); ListUnpack_410_VariadicSplit->set_output_size(3); // B,L,H,S - auto view_Reshape_424 = - pattern::wrap_type({ListUnpack_410_VariadicSplit->output(split_output_id), pattern::any_input()}, - pattern::shape_matches("[?, ?, head_cnt, head_size]"), - {{"special_zero", true}}); + auto view_Reshape_424 = pattern::wrap_type({ListUnpack_410_VariadicSplit, pattern::any_input()}, + pattern::shape_matches("[?, ?, head_cnt, head_size]"), + {{"special_zero", true}}); auto slice_Slice_543 = NewGenSlice(view_Reshape_424, 0, "head_size", 1, 3); auto ShapeOf_485735 = pattern::wrap_type({pattern::any_input()}, {}); @@ -946,6 +943,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto head_size = symbols["head_size"]; auto head_size_over_2 = symbols["head_size/2"]; auto head_cnt_by_head_size = symbols["head_cnt*head_size"]; + if (!head_cnt.is_integer() || !head_size.is_integer() || !head_size_over_2.is_integer() || !head_cnt_by_head_size.is_integer() || head_size_over_2.i() * 2 != head_size.i() || head_cnt.i() * head_size.i() != head_cnt_by_head_size.i()) { @@ -958,14 +956,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { config.head_size = static_cast(head_size.i()); config.rotary_ndims = config.head_size; - if (split_output_id == 0) { - // query : split_output_id == 0 + const auto& qkv_proj_split_node = pattern_map.at(ListUnpack_410_VariadicSplit); + const size_t qkv_proj_split_id = qkv_proj_split_node.get_index(); + if (qkv_proj_split_id == 0) { + // query : split output id == 0 config.slice_start = 0; config.slice_stop = config.head_cnt * config.head_size; - } else { - // key : split_output_id == 1 + ; + } else if (qkv_proj_split_id == 1) { + // key : split output id == 1 config.slice_start = config.head_cnt * config.head_size; config.slice_stop = config.slice_start + config.head_cnt * config.head_size; + } else { + return false; } new_args.push_back(pattern_map.at(qkv_proj)); diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index d5b3f412232c9c..0e8da3858fe974 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -1354,72 +1354,70 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { } } -TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) { +TEST_P(ConvertToROPETest, ConvertToROPE_Qwen_PagedAttention) { using namespace ov; - { - auto position_ids = std::make_shared(ov::element::i64, ov::PartialShape{-1, -1}); - auto qkv = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096}); - - auto qkv_proj = makeOP({qkv, 2, {4096, 4096, -1}}); - - auto view_Reshape = makeOP({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); - auto slice_Slice_4 = makeOP({view_Reshape, {0}, {128}, {1}, {3}}); - auto slice_Slice = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1}); - - auto Convert_50535 = makeOP({position_ids}, {{"destination_type", "i32"}}); - auto Unsqueeze_23750 = makeOP({Convert_50535, {-1, 1}}, {{"special_zero", false}}); - - auto slice_Slice_1 = makeOP({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); - auto Reshape_27400 = makeOP({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}}); - - auto mul_Multiply = makeOP({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}}); - auto reshape_Reshape = makeOP({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); - auto ListUnpack_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); - auto Multiply_54136 = - makeOP({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); - auto ListUnpack_Squeeze_0 = - makeOP({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}}); - auto ListUnpack_Squeeze = - makeOP({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); - auto cat_Concat = makeOP({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}}); - - auto slice_Slice_2 = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1}); - auto slice_Slice_6 = makeOP({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); - auto Reshape_27408 = makeOP({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}}); - auto mul_Multiply_1 = makeOP({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}}); - auto add_Add = makeOP({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + constexpr int head_cnt = 32, head_size = 128; + int output_idx = GetParam(); - auto slice_Slice_10 = makeConst(element::f32, ov::Shape({1, 32767, 1, 1}), {1}); - auto view_Reshape_1 = makeOP({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}}); - auto slice_Slice_11 = makeOP({view_Reshape_1, {0}, {128}, {1}, {3}}); - auto mul_Multiply_2 = makeOP({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}}); - auto reshape_Reshape_1 = makeOP({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); - auto ListUnpack_Split_1 = makeOP({reshape_Reshape_1, -2}, {{"num_splits", 2}}); - auto Multiply_54139 = - makeOP({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); - auto ListUnpack_Squeeze_0_1 = - makeOP({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}}); - auto ListUnpack_Squeeze_1 = - makeOP({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); - auto cat_Concat_2 = makeOP({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}}); - auto mul_Multiply_3 = makeOP({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}}); - auto add_Add_1 = makeOP({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}}); - model = std::make_shared(ov::OutputVector{add_Add_1}, ov::ParameterVector{position_ids, qkv}); + { + // Parameters + auto position_ids = std::make_shared(element::i64, PartialShape{-1, -1}); + auto qkv = std::make_shared(element::f32, PartialShape{-1, 1, 3 * head_cnt * head_size}); + + // Split QKV and reshape to [batch, 1, head_cnt, head_size] + auto qkv_proj = makeOP({qkv, 2, {head_cnt * head_size, head_cnt * head_size, -1}}); + auto view = makeOP({qkv_proj->output(output_idx), {0, 0, head_cnt, head_size}}, + {{"special_zero", true}}); + + // Slice out rotary dims + auto slice = makeOP({view, {0}, {128}, {1}, {3}}); + + // Prepare rotary embedding table and gather by position + auto rotary_emp = makeConst(element::f32, {1, 4096, 1, 128}, {1}); + auto pos_i32 = makeOP({position_ids}, {{"destination_type", "i32"}}); + auto pos_reshaped = makeOP({pos_i32, {-1, 1}}, {{"special_zero", false}}); + auto gathered = makeOP({rotary_emp, pos_reshaped, 1}, {{"batch_dims", 0}}); + auto gathered_reshape = makeOP({gathered, {-1, 1, 1, 128}}, {{"special_zero", false}}); + + // Elementwise multiply + auto mul = makeOP({slice, gathered_reshape}, {{"auto_broadcast", "numpy"}}); + + // Interleave/stack for rotary + auto reshaped = makeOP({slice, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto split = makeOP({reshaped, -2}, {{"num_splits", 2}}); + auto neg = makeOP({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); + auto squeeze0 = makeOP({neg, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto squeeze1 = makeOP({split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat = makeOP({squeeze0, squeeze1}, {{"axis", -1}}); + + // Second rotary embedding gather and multiply + auto rotary_emp2 = makeConst(element::f32, {1, 4096, 1, 128}, {1}); + auto gathered2 = makeOP({rotary_emp2, pos_reshaped, 1}, {{"batch_dims", 0}}); + auto gathered2_reshape = makeOP({gathered2, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto mul2 = makeOP({cat, gathered2_reshape}, {{"auto_broadcast", "numpy"}}); + + // Final add + auto add = makeOP({mul, mul2}, {{"auto_broadcast", "numpy"}}); + + model = std::make_shared(OutputVector{add}, ParameterVector{position_ids, qkv}); } manager.register_pass(false); { - auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3}); - auto rotary_emp_sin = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1}); - auto rotary_emp_cos = makeConst(element::f32, ov::Shape({1, 4096, 1, 128}), {1}); - auto position_ids = std::make_shared(ov::element::i64, ov::PartialShape{-1, -1}); - auto Convert_50535 = makeOP({position_ids}, {{"destination_type", "i32"}}); - auto Unsqueeze_23750 = makeOP({Convert_50535, {-1, 1}}, {{"special_zero", false}}); - auto rope = makeOP({input, rotary_emp_sin, rotary_emp_cos, Unsqueeze_23750}, - {{"config.slice_start", 4096}, - {"config.slice_stop", 8192}, + int slice_start = output_idx == 0 ? 0 : head_cnt * head_size; + int slice_stop = slice_start + head_cnt * head_size; + + auto input = std::make_shared(element::f32, PartialShape{-1, 1, 4096 * 3}); + auto rotary_emp_sin = makeConst(element::f32, {1, 4096, 1, 128}, {1}); + auto rotary_emp_cos = makeConst(element::f32, {1, 4096, 1, 128}, {1}); + auto position_ids = std::make_shared(element::i64, PartialShape{-1, -1}); + auto pos_i32 = makeOP({position_ids}, {{"destination_type", "i32"}}); + auto pos_reshaped = makeOP({pos_i32, {-1, 1}}, {{"special_zero", false}}); + auto rope = makeOP({input, rotary_emp_sin, rotary_emp_cos, pos_reshaped}, + {{"config.slice_start", slice_start}, + {"config.slice_stop", slice_stop}, {"config.input_trans0213", false}, {"config.output_trans0213", false}, {"config.is_interleaved", false}, @@ -1428,11 +1426,12 @@ TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) { {"config.support_2d_rope", false}, {"config.is_qwen", true}, {"config.use_rope_cache", false}, - {"config.head_cnt", 32}, - {"config.head_size", 128}, + {"config.head_cnt", head_cnt}, + {"config.head_size", head_size}, {"config.gather_position_arg_id", 3}}); - model_ref = std::make_shared(ov::OutputVector{rope}, ov::ParameterVector{input, position_ids}); + model_ref = std::make_shared(OutputVector{rope}, ParameterVector{input, position_ids}); } + comparator.enable(FunctionsComparator::ATTRIBUTES); } TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {