Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::compilation::analyzer::doc::tags::report_orphan_tag;
use emmylua_parser::{
LuaAst, LuaAstNode, LuaDocDescriptionOwner, LuaDocTagAsync, LuaDocTagDeprecated,
LuaDocTagExport, LuaDocTagNodiscard, LuaDocTagReadonly, LuaDocTagSource, LuaDocTagVersion,
LuaDocTagVisibility, LuaTableExpr,
LuaDocTagVisibility, LuaExpr,
};

pub fn analyze_visibility(
Expand Down Expand Up @@ -121,11 +121,27 @@ pub fn analyze_export(analyzer: &mut DocAnalyzer, tag: LuaDocTagExport) -> Optio
};
let owner_id = match owner {
LuaAst::LuaReturnStat(return_stat) => {
let return_table_expr = return_stat.child::<LuaTableExpr>()?;
LuaSemanticDeclId::LuaDecl(LuaDeclId::new(
analyzer.file_id,
return_table_expr.get_position(),
))
let expr = return_stat.child::<LuaExpr>()?;
match expr {
LuaExpr::NameExpr(name_expr) => {
let name = name_expr.get_name_text()?;
let tree = analyzer
.db
.get_decl_index()
.get_decl_tree(&analyzer.file_id)?;
let decl = tree.find_local_decl(&name, name_expr.get_position())?;

Some(LuaSemanticDeclId::LuaDecl(decl.get_id()))
}
LuaExpr::ClosureExpr(closure) => Some(LuaSemanticDeclId::Signature(
LuaSignatureId::from_closure(analyzer.file_id, &closure),
)),
LuaExpr::TableExpr(table_expr) => Some(LuaSemanticDeclId::LuaDecl(LuaDeclId::new(
analyzer.file_id,
table_expr.get_position(),
))),
_ => None,
}?
}
_ => get_owner_id_or_report(analyzer, &tag)?,
};
Expand All @@ -134,10 +150,10 @@ pub fn analyze_export(analyzer: &mut DocAnalyzer, tag: LuaDocTagExport) -> Optio
match scope_text.as_str() {
"namespace" => LuaExportScope::Namespace,
"global" => LuaExportScope::Global,
_ => LuaExportScope::Global, // 默认为 global
_ => LuaExportScope::Default,
}
} else {
LuaExportScope::Global // 没有参数时默认为 global
LuaExportScope::Default
};

let export = LuaExport {
Expand Down
18 changes: 18 additions & 0 deletions crates/emmylua_code_analysis/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use regex::Regex;
use rowan::NodeCache;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::process::Command;

#[derive(Serialize, Deserialize, Debug, JsonSchema, Default, Clone)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -197,9 +198,26 @@ fn replace_placeholders(input: &str, workspace_folder: &str) -> String {
workspace_folder.to_string()
} else if let Some(env_name) = key.strip_prefix("env:") {
std::env::var(env_name).unwrap_or_default()
} else if key == "luarocks" {
get_luarocks_deploy_dir()
} else {
caps[0].to_string()
}
})
.to_string()
}

fn get_luarocks_deploy_dir() -> String {
Command::new("luarocks")
.args(["config", "deploy_lua_dir"])
.output()
.ok()
.and_then(|output| {
if output.status.success() {
Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
} else {
None
}
})
.unwrap_or_default()
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub enum LuaDeprecated {

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LuaExportScope {
Default, // 默认声明, 会根据配置文件作不同的处理.
Global,
Namespace,
}
Expand Down
39 changes: 35 additions & 4 deletions crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use emmylua_parser::{
};

use crate::{
DiagnosticCode, InferFailReason, LuaMemberKey, LuaSemanticDeclId, LuaType, ModuleInfo,
SemanticDeclLevel, SemanticModel, enum_variable_is_param, parse_require_module_info,
DbIndex, DiagnosticCode, InferFailReason, LuaAliasCallKind, LuaAliasCallType, LuaMemberKey,
LuaSemanticDeclId, LuaType, ModuleInfo, SemanticDeclLevel, SemanticModel,
enum_variable_is_param, get_keyof_members, parse_require_module_info,
};

use super::{Checker, DiagnosticContext, humanize_lint_type};
Expand Down Expand Up @@ -262,7 +263,7 @@ fn is_valid_member(
local field
local a = Class[field]
*/
let key_types = get_key_types(&key_type);
let key_types = get_key_types(&semantic_model.get_db(), &key_type);
if key_types.is_empty() {
return None;
}
Expand Down Expand Up @@ -358,7 +359,7 @@ fn get_prefix_types(prefix_typ: &LuaType) -> HashSet<LuaType> {
type_set
}

fn get_key_types(typ: &LuaType) -> HashSet<LuaType> {
fn get_key_types(db: &DbIndex, typ: &LuaType) -> HashSet<LuaType> {
let mut type_set = HashSet::new();
let mut stack = vec![typ.clone()];
let mut visited = HashSet::new();
Expand All @@ -383,6 +384,16 @@ fn get_key_types(typ: &LuaType) -> HashSet<LuaType> {
LuaType::StrTplRef(_) | LuaType::Ref(_) => {
type_set.insert(current_type);
}
LuaType::DocStringConst(_) | LuaType::DocIntegerConst(_) => {
type_set.insert(current_type);
}
LuaType::Call(alias_call) => {
if let Some(key_types) = get_key_of_keys(db, alias_call) {
for t in key_types {
stack.push(t.clone());
}
}
}
_ => {}
}
}
Expand Down Expand Up @@ -535,3 +546,23 @@ pub fn parse_require_expr_module_info<'a>(
.get_module_index()
.find_module(&module_path)
}

fn get_key_of_keys(db: &DbIndex, alias_call: &LuaAliasCallType) -> Option<Vec<LuaType>> {
if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf {
return None;
}
let source_operands = alias_call.get_operands().iter().collect::<Vec<_>>();
if source_operands.len() != 1 {
return None;
}
let members = get_keyof_members(db, &source_operands[0]).unwrap_or_default();
let key_types = members
.iter()
.filter_map(|m| match &m.key {
LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)),
LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())),
_ => None,
})
.collect::<Vec<_>>();
Some(key_types)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1418,4 +1418,40 @@ mod test {
"#
));
}

#[test]
fn test_key_of() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@class SuiteHooks
---@field beforeAll string
---@field afterAll string

---@param name keyof SuiteHooks
function test(name)
end
"#,
);
assert!(!ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
test("a")
"#,
));
assert!(ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
test("beforeAll")
"#,
));
assert!(ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"
---@type keyof SuiteHooks
local name
test(name)
"#,
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -736,31 +736,54 @@ mod test {
));
}

// #[test]
// fn test_export() {
// let mut ws = VirtualWorkspace::new();
// ws.def_file(
// "a.lua",
// r#"
// ---@export
// local export = {}

// return export
// "#,
// );
// assert!(!ws.check_code_for(
// DiagnosticCode::UndefinedField,
// r#"
// local a = require("a")
// a.func()
// "#,
// ));

// assert!(!ws.check_code_for(
// DiagnosticCode::UndefinedField,
// r#"
// local a = require("a").ABC
// "#,
// ));
// }
#[test]
fn test_export() {
let mut ws = VirtualWorkspace::new();
ws.def_file(
"a.lua",
r#"
---@export
local export = {}

return export
"#,
);
assert!(!ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
local a = require("a")
a.func()
"#,
));

assert!(!ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
local a = require("a").ABC
"#,
));
}

#[test]
fn test_keyof_type() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@class SuiteHooks
---@field beforeAll string

---@type SuiteHooks
hooks = {}

---@type keyof SuiteHooks
name = "beforeAll"
"#,
);
assert!(ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
local a = hooks[name]
"#
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ pub fn instantiate_func_generic(
if !generic_tpls.is_empty() {
context.substitutor.add_need_infer_tpls(generic_tpls);

// 判断是否指定了泛型
if let Some(type_list) = call_expr.get_call_generic_type_list() {
// 如果使用了`obj:abc--[[@<string>]]("abc")`强制指定了泛型, 那么我们只需要直接应用
apply_call_generic_type_list(db, file_id, &mut context, &type_list);
} else {
// 没有指定泛型, 从调用参数中推断
// 如果没有指定泛型, 则需要从调用参数中推断
infer_generic_types_from_call(
db,
&mut context,
Expand Down Expand Up @@ -155,12 +155,16 @@ fn infer_generic_types_from_call(
if !func_param_type.is_variadic()
&& check_expr_can_later_infer(context, func_param_type, call_arg_expr)?
{
// If the argument cannot be inferred later, we will handle it later.
// 如果参数不能被后续推断, 那么我们先不处理
unresolve_tpls.push((func_param_type.clone(), call_arg_expr.clone()));
continue;
}

let arg_type = infer_expr(db, context.cache, call_arg_expr.clone())?;
let arg_type = match infer_expr(db, context.cache, call_arg_expr.clone()) {
Ok(t) => t,
Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断
Err(e) => return Err(e),
};
match (func_param_type, &arg_type) {
(LuaType::Variadic(variadic), _) => {
let mut arg_types = vec![];
Expand Down
58 changes: 58 additions & 0 deletions crates/emmylua_code_analysis/src/semantic/generic/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,62 @@ result = {
"#,
));
}

#[test]
fn test_generic_alias_instantiation() {
let mut ws = crate::VirtualWorkspace::new();
ws.def(
r#"
---@alias Arrayable<T> T | T[]

---@class Suite

---@generic T
---@param value Arrayable<T>
---@return T[]
function toArray(value)
end
"#,
);

ws.def(
r#"
---@type Arrayable<Suite>
local suite

arraySuites = toArray(suite)
"#,
);

let a = ws.expr_ty("arraySuites");
let expected = ws.ty("Suite[]");
assert_eq!(a, expected);
}

#[test]
fn test_generic_alias_instantiation2() {
let mut ws = crate::VirtualWorkspace::new();
ws.def(
r#"
---@alias Arrayable<T> T | T[]

---@class Suite

---@param value Arrayable<Suite>
function toArray(value)

end
"#,
);
assert!(ws.check_code_for(
DiagnosticCode::ParamTypeMismatch,
r#"

---@type Suite
local suite

local arraySuites = toArray(suite)
"#
));
}
}
Loading