Skip to content

Commit 6adfcf7

Browse files
committed
is_reference_to 优化, 允许从 table 查找到 class
1 parent 7317314 commit 6adfcf7

File tree

5 files changed

+223
-19
lines changed

5 files changed

+223
-19
lines changed

crates/emmylua_code_analysis/src/semantic/reference/mod.rs

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
use emmylua_parser::LuaSyntaxNode;
1+
use std::collections::HashSet;
22

3-
use crate::{DbIndex, LuaMemberId, LuaSemanticDeclId};
3+
use emmylua_parser::{
4+
LuaAssignStat, LuaAstNode, LuaSyntaxKind, LuaSyntaxNode, LuaTableExpr, LuaTableField,
5+
};
46

5-
use super::{semantic_info::infer_node_semantic_decl, LuaInferCache, SemanticDeclLevel};
7+
use crate::{DbIndex, LuaMemberId, LuaMemberKey, LuaSemanticDeclId, LuaType};
8+
9+
use super::{
10+
infer_table_should_be, member::find_members, semantic_info::infer_node_semantic_decl,
11+
LuaInferCache, SemanticDeclLevel,
12+
};
613

714
pub fn is_reference_to(
815
db: &DbIndex,
@@ -16,9 +23,13 @@ pub fn is_reference_to(
1623
return Some(true);
1724
}
1825

19-
match (node_semantic_decl_id, semantic_decl) {
26+
match (node_semantic_decl_id, &semantic_decl) {
2027
(LuaSemanticDeclId::Member(node_member_id), LuaSemanticDeclId::Member(member_id)) => {
21-
is_member_reference_to(db, node_member_id, member_id)
28+
if let Some(true) = is_member_reference_to(db, node_member_id, *member_id) {
29+
return Some(true);
30+
}
31+
32+
is_member_origin_reference_to(db, infer_config, node_member_id, semantic_decl)
2233
}
2334
_ => Some(false),
2435
}
@@ -34,3 +45,126 @@ fn is_member_reference_to(
3445

3546
Some(node_owner == owner)
3647
}
48+
49+
fn is_member_origin_reference_to(
50+
db: &DbIndex,
51+
infer_config: &mut LuaInferCache,
52+
node_member_id: LuaMemberId,
53+
semantic_decl: LuaSemanticDeclId,
54+
) -> Option<bool> {
55+
let node_origin = find_member_origin_owner(db, infer_config, node_member_id)
56+
.unwrap_or(LuaSemanticDeclId::Member(node_member_id));
57+
58+
match (node_origin, semantic_decl) {
59+
(LuaSemanticDeclId::Member(node_owner), LuaSemanticDeclId::Member(member_owner)) => {
60+
is_member_reference_to(db, node_owner, member_owner)
61+
}
62+
(node_origin, member_origin) => Some(node_origin == member_origin),
63+
}
64+
}
65+
66+
pub fn find_member_origin_owner(
67+
db: &DbIndex,
68+
infer_config: &mut LuaInferCache,
69+
member_id: LuaMemberId,
70+
) -> Option<LuaSemanticDeclId> {
71+
const MAX_ITERATIONS: usize = 50;
72+
let mut visited_members = HashSet::new();
73+
74+
let mut current_owner = resolve_member_owner(db, infer_config, &member_id);
75+
let mut final_owner = current_owner.clone();
76+
let mut iteration_count = 0;
77+
78+
while let Some(LuaSemanticDeclId::Member(current_member_id)) = &current_owner {
79+
if visited_members.contains(current_member_id) || iteration_count >= MAX_ITERATIONS {
80+
break;
81+
}
82+
83+
visited_members.insert(current_member_id.clone());
84+
iteration_count += 1;
85+
86+
match resolve_member_owner(db, infer_config, current_member_id) {
87+
Some(next_owner) => {
88+
final_owner = Some(next_owner.clone());
89+
current_owner = Some(next_owner);
90+
}
91+
None => break,
92+
}
93+
}
94+
95+
final_owner
96+
}
97+
98+
fn resolve_member_owner(
99+
db: &DbIndex,
100+
infer_config: &mut LuaInferCache,
101+
member_id: &LuaMemberId,
102+
) -> Option<LuaSemanticDeclId> {
103+
let root = db
104+
.get_vfs()
105+
.get_syntax_tree(&member_id.file_id)?
106+
.get_red_root();
107+
let current_node = member_id.get_syntax_id().to_node_from_root(&root)?;
108+
match member_id.get_syntax_id().get_kind() {
109+
LuaSyntaxKind::TableFieldAssign => {
110+
if LuaTableField::can_cast(current_node.kind().into()) {
111+
let table_field = LuaTableField::cast(current_node.clone())?;
112+
// 如果表是类, 那么通过类型推断获取 owner
113+
if let Some(owner_id) =
114+
resolve_table_field_through_type_inference(db, infer_config, &table_field)
115+
{
116+
return Some(owner_id);
117+
}
118+
// 非类, 那么通过右值推断
119+
let value_expr = table_field.get_value_expr()?;
120+
let value_node = value_expr.get_syntax_id().to_node_from_root(&root)?;
121+
infer_node_semantic_decl(db, infer_config, value_node, SemanticDeclLevel::default())
122+
} else {
123+
None
124+
}
125+
}
126+
LuaSyntaxKind::IndexExpr => {
127+
let assign_node = current_node.parent()?;
128+
let assign_stat = LuaAssignStat::cast(assign_node)?;
129+
let (vars, exprs) = assign_stat.get_var_and_expr_list();
130+
131+
for (var, expr) in vars.iter().zip(exprs.iter()) {
132+
if var.syntax().text_range() == current_node.text_range() {
133+
let expr_node = expr.get_syntax_id().to_node_from_root(&root)?;
134+
return infer_node_semantic_decl(
135+
db,
136+
infer_config,
137+
expr_node,
138+
SemanticDeclLevel::default(),
139+
);
140+
}
141+
}
142+
None
143+
}
144+
_ => None,
145+
}
146+
}
147+
148+
fn resolve_table_field_through_type_inference(
149+
db: &DbIndex,
150+
infer_config: &mut LuaInferCache,
151+
table_field: &LuaTableField,
152+
) -> Option<LuaSemanticDeclId> {
153+
let parent = table_field.syntax().parent()?;
154+
let table_expr = LuaTableExpr::cast(parent)?;
155+
let table_type = infer_table_should_be(db, infer_config, table_expr).ok()?;
156+
157+
if !matches!(table_type, LuaType::Ref(_) | LuaType::Def(_)) {
158+
return None;
159+
}
160+
161+
let field_key = table_field.get_field_key()?;
162+
let key = LuaMemberKey::from_index_key(db, infer_config, &field_key).ok()?;
163+
let member_infos = find_members(db, &table_type)?;
164+
165+
member_infos
166+
.iter()
167+
.find(|m| m.key == key)?
168+
.property_owner_id
169+
.clone()
170+
}

crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ pub fn goto_str_tpl_ref_definition(
179179
None
180180
}
181181

182-
fn find_table_member_definition(
182+
pub fn find_table_member_definition(
183183
semantic_model: &SemanticModel,
184184
trigger_token: &LuaSyntaxToken,
185185
member_key: &LuaMemberKey,

crates/emmylua_ls/src/handlers/implementation/implementation_searcher.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ use emmylua_code_analysis::{
55
SemanticDeclLevel, SemanticModel,
66
};
77
use emmylua_parser::{
8-
LuaAstNode, LuaDocTagField, LuaIndexExpr, LuaStat, LuaSyntaxNode, LuaSyntaxToken,
8+
LuaAstNode, LuaDocTagField, LuaIndexExpr, LuaStat, LuaSyntaxNode, LuaSyntaxToken, LuaTableField,
99
};
1010
use lsp_types::Location;
1111

12+
use crate::handlers::hover::find_member_origin_owner;
13+
1214
pub fn search_implementations(
1315
semantic_model: &SemanticModel,
1416
compilation: &LuaCompilation,
@@ -45,15 +47,17 @@ pub fn search_member_implementations(
4547
.get_db()
4648
.get_member_index()
4749
.get_member(&member_id)?;
48-
let key = member.get_key();
50+
let member_key = member.get_key();
51+
4952
let index_references = semantic_model
5053
.get_db()
5154
.get_reference_index()
52-
.get_index_references(&key)?;
55+
.get_index_references(&member_key)?;
5356

5457
let mut semantic_cache = HashMap::new();
5558

56-
let property_owner = LuaSemanticDeclId::Member(member_id);
59+
let property_owner = find_member_origin_owner(semantic_model, member_id)
60+
.unwrap_or(LuaSemanticDeclId::Member(member_id));
5761
for in_filed_syntax_id in index_references {
5862
let semantic_model =
5963
if let Some(semantic_model) = semantic_cache.get_mut(&in_filed_syntax_id.file_id) {
@@ -94,6 +98,21 @@ fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) -
9498
let prefix_type = semantic_model
9599
.infer_expr(expr.get_prefix_expr()?.into())
96100
.ok()?;
101+
// TODO: 需要实现更复杂的逻辑, 即当为`Ref`时, 针对指定的实例定义到其实现
102+
/*
103+
---@class A
104+
---@field a number -- 这里寻找实现只匹配到`A.a`, 不能穿透到`a.a`与`b.a`
105+
local A = {}
106+
A.a = 1
107+
108+
---@type A
109+
local a = {}
110+
a.a = 1 -- 这里寻找实现不能匹配到`b.a`
111+
112+
---@type A
113+
local b = a
114+
b.a = 2 -- 这里寻找实现不能匹配到`a.a`
115+
*/
97116
match prefix_type {
98117
LuaType::Ref(_) => {
99118
return None;
@@ -128,6 +147,14 @@ fn check_member_reference(semantic_model: &SemanticModel, node: LuaSyntaxNode) -
128147
tag_field_node if LuaDocTagField::can_cast(tag_field_node.kind().into()) => {
129148
return Some(());
130149
}
150+
table_field_node if LuaTableField::can_cast(table_field_node.kind().into()) => {
151+
let table_field = LuaTableField::cast(table_field_node.clone())?;
152+
if table_field.is_assign_field() {
153+
return Some(());
154+
} else {
155+
return None;
156+
}
157+
}
131158
_ => {}
132159
}
133160

crates/emmylua_ls/src/handlers/test/implementation_test.rs

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ mod tests {
2020
delete()
2121
"#,
2222
);
23-
24-
ws.check_implementation(
23+
assert!(ws.check_implementation(
2524
r#"
2625
local M = {}
2726
function M.de<??>lete(a)
2827
end
2928
return M
3029
"#,
31-
);
30+
1,
31+
));
3232
}
3333

3434
#[test]
@@ -57,11 +57,12 @@ mod tests {
5757
local a = test.a
5858
"#,
5959
);
60-
ws.check_implementation(
60+
assert!(ws.check_implementation(
6161
r#"
6262
t<??>est
6363
"#,
64-
);
64+
3,
65+
));
6566
}
6667

6768
#[test]
@@ -82,10 +83,49 @@ mod tests {
8283
8384
"#,
8485
);
85-
ws.check_implementation(
86+
assert!(ws.check_implementation(
8687
r#"
8788
yyy.<??>a = 2
8889
"#,
89-
);
90+
3,
91+
));
92+
}
93+
94+
#[test]
95+
fn test_table_field_definition_1() {
96+
let mut ws = ProviderVirtualWorkspace::new();
97+
assert!(ws.check_implementation(
98+
r#"
99+
---@class T
100+
---@field func fun(self: T) 注释注释
101+
102+
---@type T
103+
local t = {
104+
func = function(self)
105+
end,
106+
}
107+
108+
t:fun<??>c()
109+
"#,
110+
2,
111+
));
112+
}
113+
114+
#[test]
115+
fn test_table_field_definition_2() {
116+
let mut ws = ProviderVirtualWorkspace::new();
117+
assert!(ws.check_implementation(
118+
r#"
119+
---@class T
120+
---@field func fun(self: T) 注释注释
121+
122+
---@type T
123+
local t = {
124+
f<??>unc = function(self)
125+
end,
126+
}
127+
"#,
128+
2,
129+
));
90130
}
91131
}

crates/emmylua_ls/src/handlers/test_lib/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ impl ProviderVirtualWorkspace {
236236
true
237237
}
238238

239-
pub fn check_implementation(&mut self, block_str: &str) -> bool {
239+
pub fn check_implementation(&mut self, block_str: &str, len: usize) -> bool {
240240
let content = Self::handle_file_content(block_str);
241241
let Some((content, position)) = content else {
242242
return false;
@@ -251,7 +251,10 @@ impl ProviderVirtualWorkspace {
251251
return false;
252252
};
253253
dbg!(&implementations.len());
254-
true
254+
if implementations.len() == len {
255+
return true;
256+
}
257+
false
255258
}
256259

257260
pub fn check_definition(&mut self, block_str: &str) -> bool {

0 commit comments

Comments
 (0)