diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs index 247121cd2..defc66dc2 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/property_tags.rs @@ -11,7 +11,7 @@ use crate::compilation::analyzer::doc::tags::report_orphan_tag; use emmylua_parser::{ LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTagAsync, LuaDocTagDeprecated, LuaDocTagExport, LuaDocTagNodiscard, LuaDocTagReadonly, LuaDocTagSource, LuaDocTagVersion, - LuaDocTagVisibility, LuaTableExpr, + LuaDocTagVisibility, LuaExpr, }; pub fn analyze_visibility( @@ -121,11 +121,27 @@ pub fn analyze_export(analyzer: &mut DocAnalyzer, tag: LuaDocTagExport) -> Optio }; let owner_id = match owner { LuaAst::LuaReturnStat(return_stat) => { - let return_table_expr = return_stat.child::()?; - LuaSemanticDeclId::LuaDecl(LuaDeclId::new( - analyzer.file_id, - return_table_expr.get_position(), - )) + let expr = return_stat.child::()?; + match expr { + LuaExpr::NameExpr(name_expr) => { + let name = name_expr.get_name_text()?; + let tree = analyzer + .db + .get_decl_index() + .get_decl_tree(&analyzer.file_id)?; + let decl = tree.find_local_decl(&name, name_expr.get_position())?; + + Some(LuaSemanticDeclId::LuaDecl(decl.get_id())) + } + LuaExpr::ClosureExpr(closure) => Some(LuaSemanticDeclId::Signature( + LuaSignatureId::from_closure(analyzer.file_id, &closure), + )), + LuaExpr::TableExpr(table_expr) => Some(LuaSemanticDeclId::LuaDecl(LuaDeclId::new( + analyzer.file_id, + table_expr.get_position(), + ))), + _ => None, + }? } _ => get_owner_id_or_report(analyzer, &tag)?, }; @@ -134,10 +150,10 @@ pub fn analyze_export(analyzer: &mut DocAnalyzer, tag: LuaDocTagExport) -> Optio match scope_text.as_str() { "namespace" => LuaExportScope::Namespace, "global" => LuaExportScope::Global, - _ => LuaExportScope::Global, // 默认为 global + _ => LuaExportScope::Default, } } else { - LuaExportScope::Global // 没有参数时默认为 global + LuaExportScope::Default }; let export = LuaExport { diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index f70074880..0474e4630 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -102,4 +102,28 @@ mod test { assert_eq!(e_ty, LuaType::Integer); assert_eq!(f_ty, LuaType::Integer); } + + #[test] + fn test_keyof() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class SuiteHooks + ---@field beforeAll string + ---@field afterAll number + + ---@type SuiteHooks + local hooks = {} + + ---@type keyof SuiteHooks + local name = "beforeAll" + + A = hooks[name] + "#, + ); + + let ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(ty), "(number|string)"); + } } diff --git a/crates/emmylua_code_analysis/src/config/mod.rs b/crates/emmylua_code_analysis/src/config/mod.rs index be6d2ee2b..385e20efc 100644 --- a/crates/emmylua_code_analysis/src/config/mod.rs +++ b/crates/emmylua_code_analysis/src/config/mod.rs @@ -20,6 +20,7 @@ use regex::Regex; use rowan::NodeCache; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::process::Command; #[derive(Serialize, Deserialize, Debug, JsonSchema, Default, Clone)] #[serde(rename_all = "camelCase")] @@ -197,9 +198,26 @@ fn replace_placeholders(input: &str, workspace_folder: &str) -> String { workspace_folder.to_string() } else if let Some(env_name) = key.strip_prefix("env:") { std::env::var(env_name).unwrap_or_default() + } else if key == "luarocks" { + get_luarocks_deploy_dir() } else { caps[0].to_string() } }) .to_string() } + +fn get_luarocks_deploy_dir() -> String { + Command::new("luarocks") + .args(["config", "deploy_lua_dir"]) + .output() + .ok() + .and_then(|output| { + if output.status.success() { + Some(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + None + } + }) + .unwrap_or_default() +} diff --git a/crates/emmylua_code_analysis/src/db_index/property/property.rs b/crates/emmylua_code_analysis/src/db_index/property/property.rs index e94cd9bf5..ae51e9969 100644 --- a/crates/emmylua_code_analysis/src/db_index/property/property.rs +++ b/crates/emmylua_code_analysis/src/db_index/property/property.rs @@ -127,6 +127,7 @@ pub enum LuaDeprecated { #[derive(Debug, Clone, PartialEq, Eq)] pub enum LuaExportScope { + Default, // 默认声明, 会根据配置文件作不同的处理. Global, Namespace, } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_export.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_export.rs new file mode 100644 index 000000000..ea8f6b1ba --- /dev/null +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_export.rs @@ -0,0 +1,240 @@ +use std::collections::HashSet; + +use emmylua_parser::{LuaAst, LuaAstNode, LuaCallExpr, LuaIndexExpr, LuaVarExpr}; + +use crate::{ + DiagnosticCode, LuaSemanticDeclId, LuaType, ModuleInfo, SemanticDeclLevel, SemanticModel, + parse_require_module_info, +}; + +use super::{Checker, DiagnosticContext, check_field, humanize_lint_type}; + +pub struct CheckExportChecker; + +impl Checker for CheckExportChecker { + const CODES: &[DiagnosticCode] = &[DiagnosticCode::InjectField, DiagnosticCode::UndefinedField]; + + fn check(context: &mut DiagnosticContext, semantic_model: &SemanticModel) { + let root = semantic_model.get_root().clone(); + let mut checked_index_expr = HashSet::new(); + for node in root.descendants::() { + match node { + LuaAst::LuaAssignStat(assign) => { + let (vars, _) = assign.get_var_and_expr_list(); + for var in vars.iter() { + if let LuaVarExpr::IndexExpr(index_expr) = var { + checked_index_expr.insert(index_expr.syntax().clone()); + check_export_index_expr( + context, + semantic_model, + index_expr, + DiagnosticCode::InjectField, + ); + } + } + } + LuaAst::LuaIndexExpr(index_expr) => { + if checked_index_expr.contains(index_expr.syntax()) { + continue; + } + check_export_index_expr( + context, + semantic_model, + &index_expr, + DiagnosticCode::UndefinedField, + ); + } + _ => {} + } + } + } +} + +fn check_export_index_expr( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + index_expr: &LuaIndexExpr, + code: DiagnosticCode, +) -> Option<()> { + let db = context.db; + let prefix_expr = index_expr.get_prefix_expr()?; + let prefix_info = semantic_model.get_semantic_info(prefix_expr.syntax().clone().into())?; + let prefix_typ = prefix_info.typ.clone(); + + // `check_export` 仅需要处理 `TableConst, 其它类型由 `check_field` 负责. + let LuaType::TableConst(table_const) = &prefix_typ else { + return Some(()); + }; + + let index_key = index_expr.get_index_key()?; + + // 检查该表是否为导入的表. + if let Some(module_info) = check_require_table_const_with_export(semantic_model, index_expr) { + if code == DiagnosticCode::InjectField { + // 检查字段定义是否来自导入的表. + if let Some(info) = semantic_model.get_semantic_info(index_expr.syntax().clone().into()) + && is_cross_file_member_from_imported_export_table_const( + module_info, + info.semantic_decl, + ) + { + let index_name = index_key.get_path_part(); + context.add_diagnostic( + DiagnosticCode::InjectField, + index_key.get_range()?, + t!( + "Fields cannot be injected into the reference of `%{class}` for `%{field}`. ", + class = humanize_lint_type(db, &prefix_typ), + field = index_name, + ) + .to_string(), + None, + ); + return Some(()); + } + } + + if check_field::is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key, code) + .is_some() + { + return Some(()); + } + + let index_name = index_key.get_path_part(); + match code { + DiagnosticCode::InjectField => { + context.add_diagnostic( + DiagnosticCode::InjectField, + index_key.get_range()?, + t!( + "Fields cannot be injected into the reference of `%{class}` for `%{field}`. ", + class = humanize_lint_type(db, &prefix_typ), + field = index_name, + ) + .to_string(), + None, + ); + } + DiagnosticCode::UndefinedField => { + context.add_diagnostic( + DiagnosticCode::UndefinedField, + index_key.get_range()?, + t!("Undefined field `%{field}`. ", field = index_name,).to_string(), + None, + ); + } + _ => {} + } + + return Some(()); + } + + // 不是导入表, 且定义位于当前文件中, 则尝试检查本地表. + if code != DiagnosticCode::UndefinedField && table_const.file_id != semantic_model.get_file_id() + { + return Some(()); + } + + let Some(LuaSemanticDeclId::LuaDecl(decl_id)) = prefix_info.semantic_decl else { + return Some(()); + }; + // 必须为 local 声明 + let decl = semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id)?; + if !decl.is_local() { + return Some(()); + } + // 且该声明标记了 `export` + let property = semantic_model + .get_db() + .get_property_index() + .get_property(&decl_id.into())?; + if property.export().is_none() { + return Some(()); + } + + if check_field::is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key, code) + .is_some() + { + return Some(()); + } + + let index_name = index_key.get_path_part(); + context.add_diagnostic( + DiagnosticCode::UndefinedField, + index_key.get_range()?, + t!("Undefined field `%{field}`. ", field = index_name,).to_string(), + None, + ); + + Some(()) +} + +fn check_require_table_const_with_export<'a>( + semantic_model: &'a SemanticModel, + index_expr: &LuaIndexExpr, +) -> Option<&'a ModuleInfo> { + // 获取前缀表达式的语义信息 + let prefix_expr = index_expr.get_prefix_expr()?; + if let Some(call_expr) = LuaCallExpr::cast(prefix_expr.syntax().clone()) { + let module_info = parse_require_expr_module_info(semantic_model, &call_expr)?; + if module_info.is_export(semantic_model.get_db()) { + return Some(module_info); + } + } + + let semantic_decl_id = semantic_model.find_decl( + prefix_expr.syntax().clone().into(), + SemanticDeclLevel::NoTrace, + )?; + // 检查是否是声明引用 + let decl_id = match semantic_decl_id { + LuaSemanticDeclId::LuaDecl(decl_id) => decl_id, + _ => return None, + }; + + // 获取声明 + let decl = semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id)?; + + let module_info = parse_require_module_info(semantic_model, &decl)?; + if module_info.is_export(semantic_model.get_db()) { + return Some(module_info); + } + None +} + +fn parse_require_expr_module_info<'a>( + semantic_model: &'a SemanticModel, + call_expr: &LuaCallExpr, +) -> Option<&'a ModuleInfo> { + let arg_list = call_expr.get_args_list()?; + let first_arg = arg_list.get_args().next()?; + let require_path_type = semantic_model.infer_expr(first_arg.clone()).ok()?; + let module_path: String = match &require_path_type { + LuaType::StringConst(module_path) => module_path.as_ref().to_string(), + _ => return None, + }; + + semantic_model + .get_db() + .get_module_index() + .find_module(&module_path) +} + +fn is_cross_file_member_from_imported_export_table_const( + module_info: &ModuleInfo, + semantic_decl: Option, +) -> bool { + if let Some(LuaSemanticDeclId::Member(member_id)) = semantic_decl + && module_info.file_id != member_id.file_id + { + return true; + } + + false +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs index 6fff686ac..2c0799c61 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs @@ -1,14 +1,13 @@ use std::collections::HashSet; use emmylua_parser::{ - LuaAst, LuaAstNode, LuaCallExpr, LuaElseIfClauseStat, LuaForRangeStat, LuaForStat, LuaIfStat, - LuaIndexExpr, LuaIndexKey, LuaRepeatStat, LuaSyntaxKind, LuaTokenKind, LuaVarExpr, - LuaWhileStat, + LuaAst, LuaAstNode, LuaElseIfClauseStat, LuaForRangeStat, LuaForStat, LuaIfStat, LuaIndexExpr, + LuaIndexKey, LuaRepeatStat, LuaSyntaxKind, LuaTokenKind, LuaVarExpr, LuaWhileStat, }; use crate::{ - DiagnosticCode, InferFailReason, LuaMemberKey, LuaSemanticDeclId, LuaType, ModuleInfo, - SemanticDeclLevel, SemanticModel, enum_variable_is_param, parse_require_module_info, + DbIndex, DiagnosticCode, InferFailReason, LuaAliasCallKind, LuaAliasCallType, LuaMemberKey, + LuaType, SemanticModel, enum_variable_is_param, get_keyof_members, }; use super::{Checker, DiagnosticContext, humanize_lint_type}; @@ -64,32 +63,14 @@ fn check_index_expr( let prefix_typ = semantic_model .infer_expr(index_expr.get_prefix_expr()?) .unwrap_or(LuaType::Unknown); - let mut module_info = None; if is_invalid_prefix_type(&prefix_typ) { - if matches!(prefix_typ, LuaType::TableConst(_)) { - // 如果导入了被 @export 标记的表常量, 那么不应该跳过检查 - module_info = check_require_table_const_with_export(semantic_model, index_expr); - if module_info.is_none() { - return Some(()); - } - } else { - return Some(()); - } + return Some(()); } let index_key = index_expr.get_index_key()?; - if is_valid_member( - semantic_model, - &prefix_typ, - index_expr, - &index_key, - code, - module_info, - ) - .is_some() - { + if is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key, code).is_some() { return Some(()); } @@ -140,13 +121,12 @@ fn is_invalid_prefix_type(typ: &LuaType) -> bool { } } -fn is_valid_member( +pub(super) fn is_valid_member( semantic_model: &SemanticModel, prefix_typ: &LuaType, index_expr: &LuaIndexExpr, index_key: &LuaIndexKey, code: DiagnosticCode, - module_info: Option<&ModuleInfo>, ) -> Option<()> { match prefix_typ { LuaType::Global | LuaType::Userdata => return Some(()), @@ -200,16 +180,6 @@ fn is_valid_member( }; } - // TODO: 元组类型的检查或许需要独立出来 - if !need && code == DiagnosticCode::InjectField { - // 前缀是导入的表常量, 检查定义的文件是否与导入的表常量相同, 不同则认为是非法的 - if let Some(module_info) = module_info - && let Some(LuaSemanticDeclId::Member(member_id)) = info.semantic_decl - && module_info.file_id != member_id.file_id - { - return None; - } - } need } None => true, @@ -262,7 +232,7 @@ fn is_valid_member( local field local a = Class[field] */ - let key_types = get_key_types(&key_type); + let key_types = get_key_types(&semantic_model.get_db(), &key_type); if key_types.is_empty() { return None; } @@ -358,7 +328,7 @@ fn get_prefix_types(prefix_typ: &LuaType) -> HashSet { type_set } -fn get_key_types(typ: &LuaType) -> HashSet { +fn get_key_types(db: &DbIndex, typ: &LuaType) -> HashSet { let mut type_set = HashSet::new(); let mut stack = vec![typ.clone()]; let mut visited = HashSet::new(); @@ -383,6 +353,16 @@ fn get_key_types(typ: &LuaType) -> HashSet { LuaType::StrTplRef(_) | LuaType::Ref(_) => { type_set.insert(current_type); } + LuaType::DocStringConst(_) | LuaType::DocIntegerConst(_) => { + type_set.insert(current_type); + } + LuaType::Call(alias_call) => { + if let Some(key_types) = get_keyof_keys(db, alias_call) { + for t in key_types { + stack.push(t.clone()); + } + } + } _ => {} } } @@ -479,59 +459,22 @@ fn check_enum_is_param( ) } -/// 检查导入的表常量 -fn check_require_table_const_with_export<'a>( - semantic_model: &'a SemanticModel, - index_expr: &LuaIndexExpr, -) -> Option<&'a ModuleInfo> { - // 获取前缀表达式的语义信息 - let prefix_expr = index_expr.get_prefix_expr()?; - if let Some(call_expr) = LuaCallExpr::cast(prefix_expr.syntax().clone()) { - let module_info = parse_require_expr_module_info(semantic_model, &call_expr)?; - if module_info.is_export(semantic_model.get_db()) { - return Some(module_info); - } +fn get_keyof_keys(db: &DbIndex, alias_call: &LuaAliasCallType) -> Option> { + if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf { + return None; } - - let semantic_decl_id = semantic_model.find_decl( - prefix_expr.syntax().clone().into(), - SemanticDeclLevel::NoTrace, - )?; - // 检查是否是声明引用 - let decl_id = match semantic_decl_id { - LuaSemanticDeclId::LuaDecl(decl_id) => decl_id, - _ => return None, - }; - - // 获取声明 - let decl = semantic_model - .get_db() - .get_decl_index() - .get_decl(&decl_id)?; - - let module_info = parse_require_module_info(semantic_model, &decl)?; - if module_info.is_export(semantic_model.get_db()) { - return Some(module_info); + let source_operands = alias_call.get_operands().iter().collect::>(); + if source_operands.len() != 1 { + return None; } - None -} - -pub fn parse_require_expr_module_info<'a>( - semantic_model: &'a SemanticModel, - call_expr: &LuaCallExpr, -) -> Option<&'a ModuleInfo> { - let arg_list = call_expr.get_args_list()?; - let first_arg = arg_list.get_args().next()?; - let require_path_type = semantic_model.infer_expr(first_arg.clone()).ok()?; - let module_path: String = match &require_path_type { - LuaType::StringConst(module_path) => module_path.as_ref().to_string(), - _ => { - return None; - } - }; - - semantic_model - .get_db() - .get_module_index() - .find_module(&module_path) + let members = get_keyof_members(db, &source_operands[0]).unwrap_or_default(); + let key_types = members + .iter() + .filter_map(|m| match &m.key { + LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), + LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), + _ => None, + }) + .collect::>(); + Some(key_types) } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs index 449b28fdb..881e4054f 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/mod.rs @@ -4,6 +4,7 @@ mod assign_type_mismatch; mod attribute_check; mod await_in_sync; mod cast_type_mismatch; +mod check_export; mod check_field; mod check_param_count; mod check_return_count; @@ -93,6 +94,7 @@ pub fn check_file(context: &mut DiagnosticContext, semantic_model: &SemanticMode run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); + run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); run_check::(context, semantic_model); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs index 98b3ac755..e0194b9c5 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs @@ -1418,4 +1418,40 @@ mod test { "# )); } + + #[test] + fn test_key_of() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class SuiteHooks + ---@field beforeAll string + ---@field afterAll string + + ---@param name keyof SuiteHooks + function test(name) + end + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + test("a") + "#, + )); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + test("beforeAll") + "#, + )); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type keyof SuiteHooks + local name + test(name) + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs index 71ac69e67..ad15e83b5 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs @@ -736,31 +736,68 @@ mod test { )); } - // #[test] - // fn test_export() { - // let mut ws = VirtualWorkspace::new(); - // ws.def_file( - // "a.lua", - // r#" - // ---@export - // local export = {} - - // return export - // "#, - // ); - // assert!(!ws.check_code_for( - // DiagnosticCode::UndefinedField, - // r#" - // local a = require("a") - // a.func() - // "#, - // )); - - // assert!(!ws.check_code_for( - // DiagnosticCode::UndefinedField, - // r#" - // local a = require("a").ABC - // "#, - // )); - // } + #[test] + fn test_export() { + let mut ws = VirtualWorkspace::new(); + ws.def_file( + "a.lua", + r#" + ---@export + local export = {} + + return export + "#, + ); + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + local a = require("a") + a.func() + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + local a = require("a").ABC + "#, + )); + + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + + ---@export + local export = {} + + export.aaa() + + return export + + "#, + )); + } + + #[test] + fn test_keyof_type() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class SuiteHooks + ---@field beforeAll string + + ---@type SuiteHooks + hooks = {} + + ---@type keyof SuiteHooks + name = "beforeAll" + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::UndefinedField, + r#" + local a = hooks[name] + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index bd9ccf115..cbd3cd22f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -72,11 +72,11 @@ pub fn instantiate_func_generic( if !generic_tpls.is_empty() { context.substitutor.add_need_infer_tpls(generic_tpls); - // 判断是否指定了泛型 if let Some(type_list) = call_expr.get_call_generic_type_list() { + // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 apply_call_generic_type_list(db, file_id, &mut context, &type_list); } else { - // 没有指定泛型, 从调用参数中推断 + // 如果没有指定泛型, 则需要从调用参数中推断 infer_generic_types_from_call( db, &mut context, @@ -155,12 +155,16 @@ fn infer_generic_types_from_call( if !func_param_type.is_variadic() && check_expr_can_later_infer(context, func_param_type, call_arg_expr)? { - // If the argument cannot be inferred later, we will handle it later. + // 如果参数不能被后续推断, 那么我们先不处理 unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone())); continue; } - let arg_type = infer_expr(db, context.cache, call_arg_expr.clone())?; + let arg_type = match infer_expr(db, context.cache, call_arg_expr.clone()) { + Ok(t) => t, + Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 + Err(e) => return Err(e), + }; match (func_param_type, &arg_type) { (LuaType::Variadic(variadic), _) => { let mut arg_types = vec![]; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 83ae64786..eb2274367 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -227,4 +227,62 @@ result = { "#, )); } + + #[test] + fn test_generic_alias_instantiation() { + let mut ws = crate::VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Arrayable T | T[] + + ---@class Suite + + ---@generic T + ---@param value Arrayable + ---@return T[] + function toArray(value) + end + "#, + ); + + ws.def( + r#" + ---@type Arrayable + local suite + + arraySuites = toArray(suite) + "#, + ); + + let a = ws.expr_ty("arraySuites"); + let expected = ws.ty("Suite[]"); + assert_eq!(a, expected); + } + + #[test] + fn test_generic_alias_instantiation2() { + let mut ws = crate::VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Arrayable T | T[] + + ---@class Suite + + ---@param value Arrayable + function toArray(value) + + end + "#, + ); + assert!(ws.check_code_for( + DiagnosticCode::ParamTypeMismatch, + r#" + + ---@type Suite + local suite + + local arraySuites = toArray(suite) + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index eeffa113e..b67d06f64 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -24,6 +24,24 @@ fn generic_tpl_pattern_match_inner( LuaType::Generic(target_generic) => { let base = source_generic.get_base_type_id_ref(); let target_base = target_generic.get_base_type_id_ref(); + if base == target_base { + let params = source_generic.get_params(); + let target_params = target_generic.get_params(); + let min_len = params.len().min(target_params.len()); + for i in 0..min_len { + match (¶ms[i], &target_params[i]) { + (LuaType::Variadic(variadict), _) => { + variadic_tpl_pattern_match(context, variadict, &target_params[i..])?; + break; + } + _ => { + tpl_pattern_match(context, ¶ms[i], &target_params[i])?; + } + } + } + return Ok(()); + } + let target_decl = context .db .get_type_index() @@ -44,23 +62,6 @@ fn generic_tpl_pattern_match_inner( infer_guard, ); } - } - - if base == target_base { - let params = source_generic.get_params(); - let target_params = target_generic.get_params(); - let min_len = params.len().min(target_params.len()); - for i in 0..min_len { - match (¶ms[i], &target_params[i]) { - (LuaType::Variadic(variadict), _) => { - variadic_tpl_pattern_match(context, variadict, &target_params[i..])?; - break; - } - _ => { - tpl_pattern_match(context, ¶ms[i], &target_params[i])?; - } - } - } } else if let Some(super_types) = context.db.get_type_index().get_super_types(target_base) { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs index c56a6c689..35dee15bd 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs @@ -9,13 +9,13 @@ use rowan::TextRange; use smol_str::SmolStr; use crate::{ - CacheEntry, GenericTpl, InFiled, InferGuardRef, LuaArrayLen, LuaArrayType, LuaDeclOrMemberId, - LuaInferCache, LuaInstanceType, LuaMemberOwner, LuaOperatorOwner, TypeOps, + CacheEntry, GenericTpl, InFiled, InferGuardRef, LuaAliasCallKind, LuaArrayLen, LuaArrayType, + LuaDeclOrMemberId, LuaInferCache, LuaInstanceType, LuaMemberOwner, LuaOperatorOwner, TypeOps, db_index::{ DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType, LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, }, - enum_variable_is_param, get_tpl_ref_extend_type, + enum_variable_is_param, get_keyof_members, get_tpl_ref_extend_type, semantic::{ InferGuard, generic::{TypeSubstitutor, instantiate_type_generic}, @@ -419,26 +419,9 @@ fn infer_custom_type_member( return member_item.resolve_type(db); } - if type_decl.is_class() - && let Some(super_types) = type_index.get_super_types(&prefix_type_id) - { - for super_type in super_types { - let result = - infer_member_by_member_key(db, cache, &super_type, index_expr.clone(), infer_guard); - - match result { - Ok(member_type) => { - return Ok(member_type); - } - Err(InferFailReason::FieldNotFound) => {} - Err(err) => return Err(err), - } - } - } - // 解决`key`为表达式的情况 if let LuaIndexKey::Expr(expr) = index_key - && let Some(keys) = expr_to_member_key(db, cache, &expr) + && let Some(keys) = get_expr_member_key(db, cache, &expr) { let mut result_types = Vec::new(); for key in keys { @@ -461,6 +444,23 @@ fn infer_custom_type_member( } } + if type_decl.is_class() + && let Some(super_types) = type_index.get_super_types(&prefix_type_id) + { + for super_type in super_types { + let result = + infer_member_by_member_key(db, cache, &super_type, index_expr.clone(), infer_guard); + + match result { + Ok(member_type) => { + return Ok(member_type); + } + Err(InferFailReason::FieldNotFound) => {} + Err(err) => return Err(err), + } + } + } + Err(InferFailReason::FieldNotFound) } @@ -1226,46 +1226,76 @@ fn infer_namespace_member( )) } -fn expr_to_member_key( +fn get_expr_member_key( db: &DbIndex, cache: &mut LuaInferCache, expr: &LuaExpr, -) -> Option> { +) -> Option> { let expr_type = infer_expr(db, cache, expr.clone()).ok()?; let mut keys: HashSet = HashSet::new(); let mut stack = vec![expr_type.clone()]; let mut visited = HashSet::new(); while let Some(current_type) = stack.pop() { - if visited.contains(¤t_type) { + if !visited.insert(current_type.clone()) { continue; } - visited.insert(current_type.clone()); match ¤t_type { LuaType::StringConst(name) | LuaType::DocStringConst(name) => { - keys.insert(name.as_ref().to_string().into()); + keys.insert(LuaMemberKey::Name((**name).clone())); } LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i) => { - keys.insert((*i).into()); + keys.insert(LuaMemberKey::Integer(*i)); + } + LuaType::Call(alias_call) => { + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf { + let operands = alias_call.get_operands(); + if operands.len() == 1 { + if let Some(members) = get_keyof_members(db, &operands[0]) { + keys.extend(members.into_iter().map(|member| member.key)); + } + } + } + } + LuaType::MultiLineUnion(multi_union) => { + for (typ, _) in multi_union.get_unions() { + if !visited.contains(typ) { + stack.push(typ.clone()); + } + } } LuaType::Union(union_typ) => { for t in union_typ.into_vec() { - stack.push(t.clone()); + if !visited.contains(&t) { + stack.push(t.clone()); + } } } LuaType::TableConst(_) | LuaType::Tuple(_) => { - keys.insert(LuaMemberKey::ExprType(expr_type.clone())); + keys.insert(LuaMemberKey::ExprType(current_type.clone())); } LuaType::Ref(id) => { - if let Some(type_decl) = db.get_type_index().get_type_decl(id) - && (type_decl.is_enum() || type_decl.is_alias()) - { - keys.insert(LuaMemberKey::ExprType(current_type.clone())); + if let Some(type_decl) = db.get_type_index().get_type_decl(id) { + if type_decl.is_alias() { + if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + if !visited.contains(&origin_type) { + stack.push(origin_type); + } + continue; + } + } + if type_decl.is_enum() || type_decl.is_alias() { + keys.insert(LuaMemberKey::ExprType(current_type.clone())); + } } } _ => {} } } + + // 转换为 Vec 并排序以确保顺序确定性 + let mut keys: Vec<_> = keys.into_iter().collect(); + keys.sort(); Some(keys) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs index 0ee0b7ca4..b78b19c0c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs @@ -6,7 +6,7 @@ mod var_ref_id; use crate::{ CacheEntry, DbIndex, FlowAntecedent, FlowId, FlowNode, FlowTree, InferFailReason, - LuaInferCache, LuaType, LuaTypeCache, TypeSubstitutor, infer_param, instantiate_type_generic, + LuaInferCache, LuaType, infer_param, semantic::infer::{ InferResult, infer_name::{find_decl_member_type, infer_global_type}, @@ -49,11 +49,7 @@ fn get_var_ref_type(db: &DbIndex, cache: &mut LuaInferCache, var_ref_id: &VarRef } if let Some(type_cache) = db.get_type_index().get_type_cache(&decl.get_id().into()) { - if let LuaTypeCache::DocType(ty) = type_cache { - if matches!(ty, LuaType::Generic(_)) { - return Ok(instantiate_type_generic(db, ty, &TypeSubstitutor::new())); - } - } + // 不要在此阶段展开泛型别名, 必须让后续的泛型匹配阶段基于声明形态完成推断 return Ok(type_cache.as_type().clone()); } diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/call_type_check.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/call_type_check.rs new file mode 100644 index 000000000..c096813b1 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/call_type_check.rs @@ -0,0 +1,82 @@ +use std::sync::Arc; + +use crate::{ + LuaAliasCallKind, LuaAliasCallType, LuaMemberKey, LuaType, LuaUnionType, TypeCheckFailReason, + TypeCheckResult, get_keyof_members, + semantic::type_check::{ + check_general_type_compact, type_check_context::TypeCheckContext, + type_check_guard::TypeCheckGuard, + }, +}; + +pub fn check_call_type_compact( + context: &mut TypeCheckContext, + source_call: &LuaAliasCallType, + compact_type: &LuaType, + check_guard: TypeCheckGuard, +) -> TypeCheckResult { + if let LuaAliasCallKind::KeyOf = source_call.get_call_kind() { + let source_operands = source_call.get_operands().iter().collect::>(); + if source_operands.len() != 1 { + return Err(TypeCheckFailReason::TypeNotMatch); + } + match compact_type { + LuaType::Call(compact_call) => { + if compact_call.get_call_kind() == LuaAliasCallKind::KeyOf { + if compact_call.as_ref() == source_call { + return Ok(()); + } + let compact_operands = compact_call.get_operands().iter().collect::>(); + if compact_operands.len() != 1 { + return Err(TypeCheckFailReason::TypeNotMatch); + } + + let source_key_types = LuaType::Union(Arc::new(LuaUnionType::from_vec( + get_keyof_keys(context, &source_operands[0]), + ))); + let compact_key_types = LuaType::Union(Arc::new(LuaUnionType::from_vec( + get_keyof_keys(context, &compact_operands[0]), + ))); + return check_general_type_compact( + context, + &source_key_types, + &compact_key_types, + check_guard.next_level()?, + ); + } + } + _ => { + let key_types = get_keyof_keys(context, &source_operands[0]); + for key_type in &key_types { + match check_general_type_compact( + context, + &key_type, + compact_type, + check_guard.next_level()?, + ) { + Ok(_) => return Ok(()), + Err(e) if e.is_type_not_match() => {} + Err(e) => return Err(e), + } + } + return Err(TypeCheckFailReason::TypeNotMatch); + } + } + } + + // TODO: 实现其他 call 类型的检查 + Ok(()) +} + +fn get_keyof_keys(context: &TypeCheckContext, prefix_type: &LuaType) -> Vec { + let members = get_keyof_members(context.db, prefix_type).unwrap_or_default(); + let key_types = members + .iter() + .filter_map(|m| match &m.key { + LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), + LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), + _ => None, + }) + .collect::>(); + key_types +} diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs index 6170b6aab..e9827448b 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs @@ -1,16 +1,21 @@ mod array_type_check; +mod call_type_check; mod intersection_type_check; mod object_type_check; mod table_generic_check; mod tuple_type_check; use array_type_check::check_array_type_compact; +use call_type_check::check_call_type_compact; use intersection_type_check::check_intersection_type_compact; use object_type_check::check_object_type_compact; use table_generic_check::check_table_generic_type_compact; use tuple_type_check::check_tuple_type_compact; -use crate::{LuaType, LuaUnionType, semantic::type_check::type_check_context::TypeCheckContext}; +use crate::{ + LuaType, LuaUnionType, TypeSubstitutor, + semantic::type_check::type_check_context::TypeCheckContext, +}; use super::{ TypeCheckResult, check_general_type_compact, type_check_fail_reason::TypeCheckFailReason, @@ -24,6 +29,28 @@ pub fn check_complex_type_compact( compact_type: &LuaType, check_guard: TypeCheckGuard, ) -> TypeCheckResult { + // TODO: 缓存以提高性能 + // 如果是泛型+不包含模板参数+alias, 那么尝试实例化再检查 + if let LuaType::Generic(generic) = compact_type { + if !generic.contain_tpl() { + let base_id = generic.get_base_type_id(); + if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) + && decl.is_alias() + { + let substitutor = + TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + return check_general_type_compact( + context, + source, + &alias_origin, + check_guard.next_level()?, + ); + } + } + } + } + match source { LuaType::Array(source_array_type) => { match check_array_type_compact( @@ -94,10 +121,12 @@ pub fn check_complex_type_compact( return Err(TypeCheckFailReason::TypeNotMatch); } - // need check later LuaType::Generic(_) => { return Ok(()); } + LuaType::Call(alias_call) => { + return check_call_type_compact(context, alias_call, compact_type, check_guard); + } LuaType::MultiLineUnion(multi_union) => { let union = multi_union.to_union(); return check_complex_type_compact( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs index 986e8e99c..b7307680d 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs @@ -17,29 +17,28 @@ pub fn check_generic_type_compact( compact_type: &LuaType, check_guard: TypeCheckGuard, ) -> TypeCheckResult { - // 不检查尚未实例化的泛型类 - let is_tpl = source_generic.contain_tpl(); - - let source_base_id = source_generic.get_base_type_id(); - let type_decl = context + let base_id = source_generic.get_base_type_id(); + if let Some(decl) = context .db .get_type_index() - .get_type_decl(&source_base_id) - .ok_or(TypeCheckFailReason::TypeNotMatch)?; - - if type_decl.is_alias() { - let type_params = source_generic.get_params(); - let substitutor = TypeSubstitutor::from_alias(type_params.clone(), source_base_id); - if let Some(origin_type) = type_decl.get_alias_origin(context.db, Some(&substitutor)) { + .get_type_decl(&source_generic.get_base_type_id()) + && decl.is_alias() + { + let substitutor = + TypeSubstitutor::from_alias(source_generic.get_params().clone(), base_id.clone()); + if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { return check_general_type_compact( context, - &origin_type, + &alias_origin, compact_type, check_guard.next_level()?, ); } } + // 不检查尚未实例化的泛型类 + let is_tpl = source_generic.contain_tpl(); + match compact_type { LuaType::Generic(compact_generic) => { if is_tpl { @@ -111,25 +110,6 @@ fn check_generic_type_compact_generic( ) -> TypeCheckResult { let source_base_id = source_generic.get_base_type_id(); let compact_base_id = compact_generic.get_base_type_id(); - let compact_decl = context - .db - .get_type_index() - .get_type_decl(&compact_base_id) - .ok_or(TypeCheckFailReason::TypeNotMatch)?; - if compact_decl.is_alias() { - let substitutor = TypeSubstitutor::from_alias( - compact_generic.get_params().clone(), - compact_base_id.clone(), - ); - if let Some(origin_type) = compact_decl.get_alias_origin(context.db, Some(&substitutor)) { - return check_generic_type_compact( - context, - source_generic, - &origin_type, - check_guard.next_level()?, - ); - } - } if compact_base_id != source_base_id { return Err(TypeCheckFailReason::TypeNotMatch); } @@ -167,7 +147,7 @@ fn check_generic_type_compact_table( }) .unwrap_or_default(); - // 获取泛型类型的成员,使用 find_members 来获取包括继承的所有成员 + // 获取泛型类型的成员, 使用 find_members 来获取包括继承的所有成员 let source_type = LuaType::Generic(Arc::new(source_generic.clone())); let Some(source_type_members) = find_members(context.db, &source_type) else { return Ok(()); // 空成员无需检查 diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs index 73eaa6689..628881794 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/mod.rs @@ -37,7 +37,6 @@ pub fn check_type_compact( check_general_type_compact(&mut context, source, compact_type, TypeCheckGuard::new()) } -#[allow(unused)] pub fn check_type_compact_detail( db: &DbIndex, source: &LuaType, @@ -109,26 +108,26 @@ fn check_general_type_compact( | LuaType::Namespace(_) | LuaType::Variadic(_) | LuaType::Language(_) => { - check_simple_type_compact(context, source, compact_type, check_guard) + check_simple_type_compact(context, &source, &compact_type, check_guard) } // type ref LuaType::Ref(type_decl_id) => { - check_ref_type_compact(context, type_decl_id, compact_type, check_guard) + check_ref_type_compact(context, type_decl_id, &compact_type, check_guard) } LuaType::Def(type_decl_id) => { - check_ref_type_compact(context, type_decl_id, compact_type, check_guard) + check_ref_type_compact(context, type_decl_id, &compact_type, check_guard) } // invaliad source type // LuaType::Module(arc_intern) => todo!(), // function type LuaType::DocFunction(doc_func) => { - check_doc_func_type_compact(context, doc_func, compact_type, check_guard) + check_doc_func_type_compact(context, doc_func, &compact_type, check_guard) } // signature type LuaType::Signature(sig_id) => { - check_sig_type_compact(context, sig_id, compact_type, check_guard) + check_sig_type_compact(context, sig_id, &compact_type, check_guard) } // complex type @@ -138,23 +137,21 @@ fn check_general_type_compact( | LuaType::Union(_) | LuaType::Intersection(_) | LuaType::TableGeneric(_) + | LuaType::Call(_) | LuaType::MultiLineUnion(_) => { - check_complex_type_compact(context, source, compact_type, check_guard) + check_complex_type_compact(context, &source, &compact_type, check_guard) } - // need think how to do that - LuaType::Call(_) => Ok(()), - // generic type LuaType::Generic(generic) => { - check_generic_type_compact(context, generic, compact_type, check_guard) + check_generic_type_compact(context, generic, &compact_type, check_guard) } // invalid source type // LuaType::MemberPathExist(_) | LuaType::Instance(instantiate) => check_general_type_compact( context, instantiate.get_base(), - compact_type, + &compact_type, check_guard.next_level()?, ), LuaType::TypeGuard(_) => { diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index d6bbcdda6..6f994cb0c 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use crate::{ - DbIndex, LuaType, LuaTypeDeclId, VariadicType, + DbIndex, LuaType, LuaTypeDeclId, TypeSubstitutor, VariadicType, semantic::type_check::{ is_sub_type_of, type_check_context::{TypeCheckCheckLevel, TypeCheckContext}, @@ -288,20 +288,44 @@ pub fn check_simple_type_compact( _ => {} } - if let LuaType::Union(union) = compact_type { - for sub_compact in union.into_vec() { - match check_simple_type_compact( - context, - source, - &sub_compact, - check_guard.next_level()?, - ) { - Ok(_) => {} - Err(err) => return Err(err), + match compact_type { + LuaType::Union(union) => { + for sub_compact in union.into_vec() { + match check_simple_type_compact( + context, + source, + &sub_compact, + check_guard.next_level()?, + ) { + Ok(_) => {} + Err(err) => return Err(err), + } } - } - return Ok(()); + return Ok(()); + } + LuaType::Generic(generic) => { + if !generic.contain_tpl() { + let base_id = generic.get_base_type_id(); + if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) + && decl.is_alias() + { + let substitutor = + TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + if let Some(alias_origin) = + decl.get_alias_origin(context.db, Some(&substitutor)) + { + return check_general_type_compact( + context, + source, + &alias_origin, + check_guard.next_level()?, + ); + } + } + } + } + _ => {} } // complex infer diff --git a/crates/emmylua_code_analysis/src/semantic/visibility/export.rs b/crates/emmylua_code_analysis/src/semantic/visibility/export.rs index 1c721ccaa..4ab9e2440 100644 --- a/crates/emmylua_code_analysis/src/semantic/visibility/export.rs +++ b/crates/emmylua_code_analysis/src/semantic/visibility/export.rs @@ -9,7 +9,7 @@ pub fn check_export_visibility( ) -> Option { // 检查模块是否有 export 标记 let Some(export) = module_info.get_export(semantic_model.get_db()) else { - return Some(true); + return check_default_export_visibility(semantic_model, module_info); }; match export.scope { @@ -35,10 +35,31 @@ pub fn check_export_visibility( return Some(true); } } - _ => { + LuaExportScope::Global => { return Some(true); } + LuaExportScope::Default => { + return check_default_export_visibility(semantic_model, module_info); + } } Some(false) } + +/// 检查默认导出作用域下的可见性 +/// +/// 默认情况下, 如果被声明为库文件, 则我们不认为是可见的. +/// 否则认为是可见的. +fn check_default_export_visibility( + semantic_model: &SemanticModel, + module_info: &ModuleInfo, +) -> Option { + if semantic_model + .db + .get_module_index() + .is_library(&module_info.file_id) + { + return Some(false); + } + Some(true) +} diff --git a/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs b/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs index a077a3622..d5eb53912 100644 --- a/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs +++ b/crates/emmylua_ls/src/handlers/command/commands/emmy_auto_require.rs @@ -5,10 +5,7 @@ use emmylua_parser::{LuaAstNode, LuaExpr, LuaStat}; use lsp_types::{ApplyWorkspaceEditParams, Command, Position, TextEdit, WorkspaceEdit}; use serde_json::Value; -use crate::{ - context::ServerContextSnapshot, - util::{module_name_convert, time_cancel_token}, -}; +use crate::{context::ServerContextSnapshot, util::time_cancel_token}; use super::CommandSpec; @@ -21,7 +18,8 @@ impl CommandSpec for AutoRequireCommand { let add_to: FileId = serde_json::from_value(args.first()?.clone()).ok()?; let need_require_file_id: FileId = serde_json::from_value(args.get(1)?.clone()).ok()?; let position: Position = serde_json::from_value(args.get(2)?.clone()).ok()?; - let member_name: String = serde_json::from_value(args.get(3)?.clone()).ok()?; + let local_name: String = serde_json::from_value(args.get(3)?.clone()).ok()?; + let member_name: String = serde_json::from_value(args.get(4)?.clone()).ok()?; let analysis = context.analysis().read().await; let semantic_model = analysis.compilation.get_semantic_model(add_to)?; @@ -32,8 +30,6 @@ impl CommandSpec for AutoRequireCommand { let emmyrc = semantic_model.get_emmyrc(); let require_like_func = &emmyrc.runtime.require_like_function; let auto_require_func = emmyrc.completion.auto_require_function.clone(); - let file_conversion = emmyrc.completion.auto_require_naming_convention; - let local_name = module_name_convert(module_info, file_conversion); let require_separator = emmyrc.completion.auto_require_separator.clone(); let full_module_path = match require_separator.as_str() { "." | "" => module_info.full_module_name.clone(), @@ -196,12 +192,14 @@ pub fn make_auto_require( add_to: FileId, need_require_file_id: FileId, position: Position, - member_name: Option, + local_name: String, // 导入时使用的名称 + member_name: Option, // 导入的成员名, 不要包含前缀`.`号, 它将拼接到 `require` 后面. 例如 require("a").member ) -> Command { let args = vec![ serde_json::to_value(add_to).unwrap(), serde_json::to_value(need_require_file_id).unwrap(), serde_json::to_value(position).unwrap(), + serde_json::to_value(local_name).unwrap(), serde_json::to_value(member_name.unwrap_or_default()).unwrap(), ]; diff --git a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs index 52b641a79..92abc319a 100644 --- a/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs +++ b/crates/emmylua_ls/src/handlers/completion/providers/auto_require_provider.rs @@ -77,7 +77,8 @@ fn add_module_completion_item( let completion_name = module_name_convert(module_info, file_conversion); if !completion_name.to_lowercase().starts_with(prefix) { - try_add_member_completion_items( + // 如果模块名不匹配, 则根据导出类型添加完成项 + add_completion_item_by_type( builder, prefix, module_info, @@ -98,7 +99,7 @@ fn add_module_completion_item( None }; let completion_item = CompletionItem { - label: completion_name, + label: completion_name.clone(), kind: Some(lsp_types::CompletionItemKind::MODULE), label_details: Some(lsp_types::CompletionItemLabelDetails { detail: Some(format!(" (in {})", module_info.full_module_name)), @@ -109,6 +110,7 @@ fn add_module_completion_item( builder.semantic_model.get_file_id(), module_info.file_id, position, + completion_name, None, )), data, @@ -120,7 +122,7 @@ fn add_module_completion_item( Some(()) } -fn try_add_member_completion_items( +fn add_completion_item_by_type( builder: &CompletionBuilder, prefix: &str, module_info: &ModuleInfo, @@ -192,7 +194,7 @@ fn try_add_member_completion_items( }; let completion_item = CompletionItem { - label: key_name, + label: key_name.clone(), kind: Some(get_completion_kind(&member_info.typ)), label_details: Some(lsp_types::CompletionItemLabelDetails { detail: Some(format!(" (in {})", module_info.full_module_name)), @@ -203,6 +205,7 @@ fn try_add_member_completion_items( builder.semantic_model.get_file_id(), module_info.file_id, position, + key_name, Some(member_info.key.to_path().to_string()), )), data, @@ -213,6 +216,42 @@ fn try_add_member_completion_items( } } } + LuaType::Signature(_) => { + let semantic_id = module_info.semantic_id.as_ref()?; + if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_id { + let decl = builder + .semantic_model + .get_db() + .get_decl_index() + .get_decl(&decl_id)?; + let name = decl.get_name(); + if name.to_lowercase().starts_with(prefix) { + if builder.env_duplicate_name.contains(name) { + return None; + } + + let completion_item = CompletionItem { + label: name.to_string(), + kind: Some(get_completion_kind(&export_type)), + label_details: Some(lsp_types::CompletionItemLabelDetails { + detail: Some(format!(" (in {})", module_info.full_module_name)), + ..Default::default() + }), + command: Some(make_auto_require( + "", + builder.semantic_model.get_file_id(), + module_info.file_id, + position, + name.to_string(), + None, + )), + ..Default::default() + }; + + completions.push(completion_item); + } + } + } _ => {} } } diff --git a/crates/emmylua_ls/src/handlers/test/completion_test.rs b/crates/emmylua_ls/src/handlers/test/completion_test.rs index 6cb8cdfbf..3511094ad 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_test.rs @@ -2339,4 +2339,67 @@ mod tests { Ok(()) } + + #[gtest] + fn test_function_generic_value_is_nil() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def( + r#" + ---@class Expect + ---@overload fun(actual: T): Assertion + + ---@class Assertion + ---@field toBe fun(self: self) + + ---@type table + GTable = {} + "#, + ); + + check!(ws.check_completion_with_kind( + r#" + ---@type Expect + local expect = {} + + expect(GTable["a"]): + "#, + vec![VirtualCompletionItem { + label: "toBe".to_string(), + kind: CompletionItemKind::FUNCTION, + label_detail: Some("()".to_string()), + },], + CompletionTriggerKind::TRIGGER_CHARACTER + )); + + Ok(()) + } + + #[gtest] + fn test_module_return_signature() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + ws.def_file( + "test.lua", + r#" + ---@export global + local function processError() + return 1 + end + return processError + "#, + ); + + check!(ws.check_completion_with_kind( + r#" + processError + "#, + vec![VirtualCompletionItem { + label: "processError".to_string(), + kind: CompletionItemKind::FUNCTION, + label_detail: Some(" (in test)".to_string()), + }], + CompletionTriggerKind::INVOKED + )); + + Ok(()) + } }