Skip to content

Commit 288cb6b

Browse files
committed
impl auto require
1 parent b8cdeb0 commit 288cb6b

File tree

6 files changed

+245
-70
lines changed

6 files changed

+245
-70
lines changed
Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,158 @@
1+
use std::{collections::HashMap, time::Duration};
2+
3+
use code_analysis::FileId;
4+
use emmylua_parser::{LuaAstNode, LuaExpr, LuaStat};
5+
use lsp_types::{ApplyWorkspaceEditParams, Command, Position, TextEdit, WorkspaceEdit};
16
use serde_json::Value;
27

3-
use crate::context::ServerContextSnapshot;
8+
use crate::{
9+
context::ServerContextSnapshot,
10+
util::{module_name_convert, time_cancel_token},
11+
};
412

513
pub const COMMAND: &str = "emmy.auto.require";
614

7-
#[allow(unused)]
815
pub async fn handle(context: ServerContextSnapshot, args: Vec<Value>) -> Option<()> {
16+
let add_to: FileId = serde_json::from_value(args.get(0)?.clone()).ok()?;
17+
let need_require_file_id: FileId = serde_json::from_value(args.get(1)?.clone()).ok()?;
18+
let position: Position = serde_json::from_value(args.get(2)?.clone()).ok()?;
19+
20+
let analysis = context.analysis.read().await;
21+
let semantic_model = analysis.compilation.get_semantic_model(add_to)?;
22+
let module_info = semantic_model
23+
.get_db()
24+
.get_module_index()
25+
.get_module(need_require_file_id)?;
26+
let emmyrc = semantic_model.get_emmyrc();
27+
let require_like_func = &emmyrc.runtime.require_like_function;
28+
let auto_require_func = emmyrc.completion.auto_require_function.clone();
29+
let file_convension = emmyrc.completion.auto_require_naming_convention;
30+
let local_name = module_name_convert(&module_info.name, file_convension);
31+
let require_str = format!(
32+
"local {} = {}(\"{}\")",
33+
local_name, auto_require_func, module_info.full_module_name
34+
);
35+
let document = semantic_model.get_document();
36+
let offset = document.get_offset(position.line as usize, position.character as usize)?;
37+
let root_block = semantic_model.get_root().get_block()?;
38+
let mut last_require_stat: Option<LuaStat> = None;
39+
for stat in root_block.get_stats() {
40+
if stat.get_position() > offset {
41+
break;
42+
}
43+
44+
if is_require_stat(stat.clone(), &require_like_func).unwrap_or(false) {
45+
last_require_stat = Some(stat);
46+
}
47+
}
48+
49+
let line = if let Some(last_require_stat) = last_require_stat {
50+
let last_require_stat_end = last_require_stat.get_range().end();
51+
document.get_line(last_require_stat_end)? + 1
52+
} else {
53+
0
54+
};
55+
56+
let text_edit = TextEdit {
57+
range: lsp_types::Range {
58+
start: Position {
59+
line: line as u32,
60+
character: 0,
61+
},
62+
end: Position {
63+
line: line as u32,
64+
character: 0,
65+
},
66+
},
67+
new_text: format!("{}\n", require_str),
68+
};
69+
70+
let uri = document.get_uri();
71+
let mut changes = HashMap::new();
72+
changes.insert(uri.clone(), vec![text_edit.clone()]);
73+
74+
let client = context.client;
75+
let cancel_token = time_cancel_token(Duration::from_secs(5));
76+
let apply_edit_params = ApplyWorkspaceEditParams {
77+
label: None,
78+
edit: WorkspaceEdit {
79+
changes: Some(changes),
80+
document_changes: None,
81+
change_annotations: None,
82+
},
83+
};
84+
85+
tokio::spawn(async move {
86+
let res = client.apply_edit(apply_edit_params, cancel_token).await;
87+
if let Some(res) = res {
88+
if !res.applied {
89+
log::error!("Failed to apply edit: {:?}", res.failure_reason);
90+
}
91+
}
92+
});
93+
994
Some(())
1095
}
96+
97+
fn is_require_stat(stat: LuaStat, require_like_func: &Vec<String>) -> Option<bool> {
98+
match stat {
99+
LuaStat::LocalStat(local_stat) => {
100+
let exprs = local_stat.get_value_exprs();
101+
for expr in exprs {
102+
if is_require_expr(expr, require_like_func).unwrap_or(false) {
103+
return Some(true);
104+
}
105+
}
106+
}
107+
LuaStat::AssignStat(assign_stat) => {
108+
let (_, exprs) = assign_stat.get_var_and_expr_list();
109+
for expr in exprs {
110+
if is_require_expr(expr, require_like_func).unwrap_or(false) {
111+
return Some(true);
112+
}
113+
}
114+
}
115+
LuaStat::CallExprStat(call_expr_stat) => {
116+
let expr = call_expr_stat.get_call_expr()?;
117+
if is_require_expr(expr.into(), require_like_func).unwrap_or(false) {
118+
return Some(true);
119+
}
120+
}
121+
_ => {}
122+
}
123+
124+
Some(false)
125+
}
126+
127+
fn is_require_expr(expr: LuaExpr, require_like_func: &Vec<String>) -> Option<bool> {
128+
if let LuaExpr::CallExpr(call_expr) = expr {
129+
let name = call_expr.get_prefix_expr()?;
130+
if let LuaExpr::NameExpr(name_expr) = name {
131+
let name = name_expr.get_name_text()?;
132+
if require_like_func.contains(&name.to_string()) || name == "require" {
133+
return Some(true);
134+
}
135+
}
136+
}
137+
138+
Some(false)
139+
}
140+
141+
pub fn make_auto_require(
142+
title: &str,
143+
add_to: FileId,
144+
need_require_file_id: FileId,
145+
position: Position,
146+
) -> Command {
147+
let args = vec![
148+
serde_json::to_value(add_to).unwrap(),
149+
serde_json::to_value(need_require_file_id).unwrap(),
150+
serde_json::to_value(position).unwrap(),
151+
];
152+
153+
Command {
154+
title: title.to_string(),
155+
command: COMMAND.to_string(),
156+
arguments: Some(args),
157+
}
158+
}

crates/emmylua_ls/src/handlers/command/commands/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod emmy_disable_code;
77
mod emmy_fix_format;
88

99
pub use emmy_disable_code::{make_disable_code_command, DisableAction};
10+
pub use emmy_auto_require::make_auto_require;
1011

1112
pub fn get_commands_list() -> Vec<String> {
1213
let mut commands = Vec::new();
Lines changed: 21 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
use code_analysis::{EmmyrcFilenameConvention, ModuleInfo};
22
use emmylua_parser::{LuaAstNode, LuaNameExpr};
3-
use lsp_types::CompletionItem;
3+
use lsp_types::{CompletionItem, Position};
44

5-
use crate::handlers::completion::completion_builder::CompletionBuilder;
5+
use crate::{
6+
handlers::{command::make_auto_require, completion::completion_builder::CompletionBuilder},
7+
util::module_name_convert,
8+
};
69

710
pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> {
811
if builder.is_cancelled() {
912
return None;
1013
}
1114

15+
let enable = builder.semantic_model.get_emmyrc().completion.auto_require;
16+
if !enable {
17+
return None;
18+
}
19+
1220
let name_expr = LuaNameExpr::cast(builder.trigger_token.parent()?)?;
1321
// optimize for large project
1422
let prefix = name_expr.get_name_text()?.to_lowercase();
@@ -23,6 +31,9 @@ pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> {
2331
.get_db()
2432
.get_module_index()
2533
.get_module_infos();
34+
let range = builder.trigger_token.text_range();
35+
let document = builder.semantic_model.get_document();
36+
let lsp_position = document.to_lsp_range(range)?.start;
2637

2738
let mut completions = Vec::new();
2839
for module_info in module_infos {
@@ -35,6 +46,7 @@ pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> {
3546
&prefix,
3647
&module_info,
3748
file_convension,
49+
lsp_position,
3850
&mut completions,
3951
);
4052
}
@@ -52,6 +64,7 @@ fn add_module_completion_item(
5264
prefix: &str,
5365
module_info: &ModuleInfo,
5466
file_convension: EmmyrcFilenameConvention,
67+
position: Position,
5568
completions: &mut Vec<CompletionItem>,
5669
) -> Option<()> {
5770
let completion_name = module_name_convert(&module_info.name, file_convension);
@@ -70,76 +83,16 @@ fn add_module_completion_item(
7083
detail: Some(format!(" (in {})", module_info.full_module_name)),
7184
..Default::default()
7285
}),
86+
command: Some(make_auto_require(
87+
"",
88+
builder.semantic_model.get_file_id(),
89+
module_info.file_id,
90+
position
91+
)),
7392
..Default::default()
7493
};
7594

7695
completions.push(completion_item);
7796

7897
Some(())
7998
}
80-
81-
fn module_name_convert(name: &str, file_convension: EmmyrcFilenameConvention) -> String {
82-
let mut module_name = name.to_string();
83-
84-
match file_convension {
85-
EmmyrcFilenameConvention::SnakeCase => {
86-
module_name = to_snake_case(&module_name);
87-
}
88-
EmmyrcFilenameConvention::CamelCase => {
89-
module_name = to_camel_case(&module_name);
90-
}
91-
EmmyrcFilenameConvention::PascalCase => {
92-
module_name = to_pascal_case(&module_name);
93-
}
94-
EmmyrcFilenameConvention::Keep => {}
95-
}
96-
97-
module_name
98-
}
99-
100-
fn to_snake_case(s: &str) -> String {
101-
let mut result = String::new();
102-
for (i, ch) in s.chars().enumerate() {
103-
if ch.is_uppercase() && i != 0 {
104-
result.push('_');
105-
result.push(ch.to_ascii_lowercase());
106-
} else {
107-
result.push(ch.to_ascii_lowercase());
108-
}
109-
}
110-
result
111-
}
112-
113-
fn to_camel_case(s: &str) -> String {
114-
let mut result = String::new();
115-
let mut next_uppercase = false;
116-
for (i, ch) in s.chars().enumerate() {
117-
if ch == '_' || ch == '-' || ch == '.' {
118-
next_uppercase = true;
119-
} else if next_uppercase {
120-
result.push(ch.to_ascii_uppercase());
121-
next_uppercase = false;
122-
} else if i == 0 {
123-
result.push(ch.to_ascii_lowercase());
124-
} else {
125-
result.push(ch);
126-
}
127-
}
128-
result
129-
}
130-
131-
fn to_pascal_case(s: &str) -> String {
132-
let mut result = String::new();
133-
let mut next_uppercase = true;
134-
for ch in s.chars() {
135-
if ch == '_' || ch == '-' || ch == '.' {
136-
next_uppercase = true;
137-
} else if next_uppercase {
138-
result.push(ch.to_ascii_uppercase());
139-
next_uppercase = false;
140-
} else {
141-
result.push(ch.to_ascii_lowercase());
142-
}
143-
}
144-
result
145-
}

crates/emmylua_ls/src/handlers/semantic_token/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ pub async fn on_semantic_token_handler(
2626
let file_id = analysis.get_file_id(&uri)?;
2727
let mut semantic_model = analysis.compilation.get_semantic_model(file_id)?;
2828

29+
if !semantic_model.get_emmyrc().semantic_tokens.enable {
30+
return None;
31+
}
32+
2933
let result = build_semantic_tokens(
3034
&mut semantic_model,
3135
unsafe { SEMANTIC_MULTILINE_SUPPORT },

crates/emmylua_ls/src/util/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod time_cancel_token;
22
mod humanize_type;
3+
mod module_name_convert;
34

45
pub use time_cancel_token::time_cancel_token;
56
pub use humanize_type::humanize_type;
7+
pub use module_name_convert::module_name_convert;
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use code_analysis::EmmyrcFilenameConvention;
2+
3+
pub fn module_name_convert(name: &str, file_convension: EmmyrcFilenameConvention) -> String {
4+
let mut module_name = name.to_string();
5+
6+
match file_convension {
7+
EmmyrcFilenameConvention::SnakeCase => {
8+
module_name = to_snake_case(&module_name);
9+
}
10+
EmmyrcFilenameConvention::CamelCase => {
11+
module_name = to_camel_case(&module_name);
12+
}
13+
EmmyrcFilenameConvention::PascalCase => {
14+
module_name = to_pascal_case(&module_name);
15+
}
16+
EmmyrcFilenameConvention::Keep => {}
17+
}
18+
19+
module_name
20+
}
21+
22+
fn to_snake_case(s: &str) -> String {
23+
let mut result = String::new();
24+
for (i, ch) in s.chars().enumerate() {
25+
if ch.is_uppercase() && i != 0 {
26+
result.push('_');
27+
result.push(ch.to_ascii_lowercase());
28+
} else {
29+
result.push(ch.to_ascii_lowercase());
30+
}
31+
}
32+
result
33+
}
34+
35+
fn to_camel_case(s: &str) -> String {
36+
let mut result = String::new();
37+
let mut next_uppercase = false;
38+
for (i, ch) in s.chars().enumerate() {
39+
if ch == '_' || ch == '-' || ch == '.' {
40+
next_uppercase = true;
41+
} else if next_uppercase {
42+
result.push(ch.to_ascii_uppercase());
43+
next_uppercase = false;
44+
} else if i == 0 {
45+
result.push(ch.to_ascii_lowercase());
46+
} else {
47+
result.push(ch);
48+
}
49+
}
50+
result
51+
}
52+
53+
fn to_pascal_case(s: &str) -> String {
54+
let mut result = String::new();
55+
let mut next_uppercase = true;
56+
for ch in s.chars() {
57+
if ch == '_' || ch == '-' || ch == '.' {
58+
next_uppercase = true;
59+
} else if next_uppercase {
60+
result.push(ch.to_ascii_uppercase());
61+
next_uppercase = false;
62+
} else {
63+
result.push(ch.to_ascii_lowercase());
64+
}
65+
}
66+
result
67+
}

0 commit comments

Comments
 (0)