Skip to content

Commit ecfabfb

Browse files
committed
refactor: extract ts parse logic into a single function
1 parent 80eff03 commit ecfabfb

File tree

8 files changed

+238
-257
lines changed

8 files changed

+238
-257
lines changed

lua/refactoring/debug/cleanup.lua

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ end
3030
function M.cleanup(range_type, config)
3131
local get_extracted_range = require("refactoring.utils").get_extracted_range
3232
local apply_text_edits = require("refactoring.utils").apply_text_edits
33+
local get_ts_info = require("refactoring.utils").get_ts_info
3334

3435
local opts = config.debug.cleanup
3536

@@ -58,16 +59,8 @@ function M.cleanup(range_type, config)
5859
return
5960
end
6061

61-
local comments = {} ---@type TSNode[]
62-
for _, tree in ipairs(nested_lang_tree:trees()) do
63-
for _, match, _ in query:iter_matches(tree:root(), buf) do
64-
for capture_id, nodes in pairs(match) do
65-
local name = query.captures[capture_id]
66-
67-
if name == "comment" then table.insert(comments, nodes[1]) end
68-
end
69-
end
70-
end
62+
local ts_info = get_ts_info(buf, nested_lang_tree, query)
63+
local comments = ts_info.comments
7164

7265
table.sort(comments, node_comp_asc)
7366
---@type vim.Range[]

lua/refactoring/debug/print_loc.lua

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function M.print_loc(range_type, config)
3232
local code_gen_error = require("refactoring.utils").code_gen_error
3333
local indent = require("refactoring.utils").indent
3434
local apply_text_edits = require("refactoring.utils").apply_text_edits
35+
local get_ts_info = require("refactoring.utils").get_ts_info
3536

3637
local opts = config.debug.print_loc
3738
local code_generation = opts.code_generation
@@ -64,33 +65,9 @@ function M.print_loc(range_type, config)
6465
local get_print_loc = code_generation.print_loc[lang]
6566
if not get_print_loc then return code_gen_error("print_loc", lang) end
6667

67-
-- TODO: change to a better name everywhere (debug_path_element?)
68-
local debug_paths = {} ---@type refactor.DebugPath[]
69-
local output_statements = {} ---@type refactor.OutputStatement[]
70-
for _, tree in ipairs(nested_lang_tree:trees()) do
71-
for _, match, metadata in query:iter_matches(tree:root(), buf) do
72-
local output_statement ---@type nil|refactor.OutputStatement
73-
for capture_id, nodes in pairs(match) do
74-
local name = query.captures[capture_id]
75-
if name == "debug_path" then
76-
for i, node in ipairs(nodes) do
77-
local text = type(metadata.text) == "string" and metadata.text
78-
or ts.get_node_text(match[metadata.text][i], buf)
79-
table.insert(debug_paths, { debug_path = node, text = text })
80-
end
81-
end
82-
83-
if name == "output_statement" then
84-
output_statement = output_statement or {}
85-
output_statement.output_statement = nodes[1]
86-
elseif name == "output_statement.inside" then
87-
output_statement = output_statement or {}
88-
output_statement.inside = nodes[1]
89-
end
90-
end
91-
if output_statement then table.insert(output_statements, output_statement) end
92-
end
93-
end
68+
local ts_info = get_ts_info(buf, nested_lang_tree, query)
69+
local debug_paths = ts_info.debug_paths
70+
local output_statements = ts_info.output_statements
9471

9572
local extracted_range_api = { extracted_range:to_extmark() }
9673
-- NOTE: treesitter nodes usualy do not include leading whitespace

lua/refactoring/debug/print_var.lua

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ function M.print_var(range_type, config)
4444
local scopes_for_range = require("refactoring.utils").scopes_for_range
4545
local get_declaration_scope = require("refactoring.utils").get_declaration_scope
4646
local indent = require("refactoring.utils").indent
47+
local get_ts_info = require("refactoring.utils").get_ts_info
4748

4849
local opts = config.debug.print_var
4950
local code_generation = opts.code_generation
@@ -79,41 +80,20 @@ function M.print_var(range_type, config)
7980
local get_print_var = code_generation.print_var[lang]
8081
if not get_print_var then return code_gen_error("print_var", lang) end
8182

82-
local references = {} ---@type refactor.ReferenceInfo[]
83-
local output_statements = {} ---@type refactor.OutputStatement[]
84-
local scopes = {} ---@type TSNode[]
85-
for _, tree in ipairs(nested_lang_tree:trees()) do
86-
for _, match, metadata in query:iter_matches(tree:root(), buf) do
87-
local output_statement ---@type nil|refactor.OutputStatement
88-
for capture_id, nodes in pairs(match) do
89-
local name = query.captures[capture_id]
90-
if name == "reference.identifier" then
91-
for i, node in ipairs(nodes) do
92-
table.insert(references, {
93-
identifier = node,
94-
reference_type = metadata.reference_type,
95-
type = metadata.types and metadata.types[i],
96-
declaration = metadata.declaration ~= nil,
97-
})
98-
end
99-
end
100-
101-
if name == "output_statement" then
102-
output_statement = output_statement or {}
103-
output_statement.output_statement = nodes[1]
104-
elseif name == "output_statement.inside" then
105-
output_statement = output_statement or {}
106-
output_statement.inside = nodes[1]
107-
end
108-
if name == "scope" then
109-
for _, node in ipairs(nodes) do
110-
table.insert(scopes, node)
111-
end
112-
end
83+
local ts_info = get_ts_info(buf, nested_lang_tree, query)
84+
local references = ts_info.references
85+
local output_statements = ts_info.output_statements
86+
-- TODO: modify the util functions that use `scopes` as TSNode[] to use
87+
-- refactor.Scope[] instead?
88+
---@type TSNode[]
89+
local scopes = iter(ts_info.scopes)
90+
:map(
91+
---@param scope refactor.Scope
92+
function(scope)
93+
return scope.scope
11394
end
114-
if output_statement then table.insert(output_statements, output_statement) end
115-
end
116-
end
95+
)
96+
:totable()
11797

11898
local extracted_range_api = { extracted_range:to_extmark() }
11999
-- NOTE: treesitter nodes usualy do not include leading whitespace
@@ -187,7 +167,6 @@ function M.print_var(range_type, config)
187167
end
188168
end
189169

190-
local output_start_pos = pos(output_range.start_row, output_range.start_col, { buf = output_range.buf })
191170
-- TODO: I also compute `declarations_before_output_range` in
192171
-- `extract_func`. Is there a cleaner wat to do all this in both places?
193172
local declarations_by_scope = get_declarations_by_scope(references, scopes, buf)

lua/refactoring/refactor/extract_func.lua

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -59,33 +59,10 @@ end
5959
---@return {method: boolean?, singleton: boolean?, struct_name: string?, struct_var_name: string?}
6060
local function get_output_node(nested_lang_tree, query, buf, extracted_range)
6161
local is_first_closer = require("refactoring.utils").is_first_closer
62+
local get_ts_info = require("refactoring.utils").get_ts_info
6263

63-
local outputs = {} ---@type refactor.Output[]
64-
for _, tree in ipairs(nested_lang_tree:trees()) do
65-
for _, match, metadata in query:iter_matches(tree:root(), buf) do
66-
local output ---@type table|refactor.Output|nil
67-
for capture_id, nodes in pairs(match) do
68-
local name = query.captures[capture_id]
69-
70-
-- TODO: split input.info and output location
71-
if name == "output.comment" then
72-
output = output or {}
73-
output.comment = nodes
74-
elseif name == "output.function" then
75-
output = output or {}
76-
output.fn = nodes[1]
77-
output.method = metadata.method ~= nil
78-
output.singleton = metadata.singleton ~= nil
79-
80-
local struct_name = metadata.struct_name
81-
if struct_name then output.struct_name = ts.get_node_text(match[struct_name][1], buf) end
82-
local struct_var_name = metadata.struct_var_name
83-
if struct_var_name then output.struct_var_name = ts.get_node_text(match[struct_var_name][1], buf) end
84-
end
85-
end
86-
if output then table.insert(outputs, output) end
87-
end
88-
end
64+
local ts_info = get_ts_info(buf, nested_lang_tree, query)
65+
local outputs = ts_info.outputs
8966

9067
local extracted_start_pos = pos(extracted_range.start_row, extracted_range.start_col, { buf = extracted_range.buf })
9168
---@type refactor.Output|nil
@@ -180,6 +157,7 @@ local function extract_func(opts)
180157
local get_declarations_by_scope = require("refactoring.utils").get_declarations_by_scope
181158
local scopes_for_range = require("refactoring.utils").scopes_for_range
182159
local get_declaration_scope = require("refactoring.utils").get_declaration_scope
160+
local get_ts_info = require("refactoring.utils").get_ts_info
183161

184162
local code_generation = opts.config_opts.code_generation
185163

@@ -220,30 +198,19 @@ local function extract_func(opts)
220198
output_range = range.extmark(0, 0, 0, 0, { buf = out_buf })
221199
end
222200

223-
local references_info = {} ---@type refactor.ReferenceInfo[]
224-
local scopes = {} ---@type TSNode[]
225-
for _, tree in ipairs(nested_lang_tree:trees()) do
226-
for _, match, metadata in in_query:iter_matches(tree:root(), in_buf) do
227-
for capture_id, nodes in pairs(match) do
228-
local name = in_query.captures[capture_id]
229-
if name == "reference.identifier" then
230-
for i, node in ipairs(nodes) do
231-
table.insert(references_info, {
232-
identifier = node,
233-
reference_type = metadata.reference_type,
234-
type = metadata.types and metadata.types[i],
235-
declaration = metadata.declaration ~= nil,
236-
})
237-
end
238-
elseif name == "scope" then
239-
for _, node in ipairs(nodes) do
240-
table.insert(scopes, node)
241-
end
242-
end
201+
local in_ts_info = get_ts_info(in_buf, nested_lang_tree, in_query)
202+
local references_info = in_ts_info.references
203+
-- TODO: modify the util functions that use `scopes` as TSNode[] to use
204+
-- refactor.Scope[] instead?
205+
---@type TSNode[]
206+
local scopes = iter(in_ts_info.scopes)
207+
:map(
208+
---@param scope refactor.Scope
209+
function(scope)
210+
return scope.scope
243211
end
244-
end
245-
end
246-
-- TODO: maybe check that all the treesitter captures are not empty(?
212+
)
213+
:totable()
247214

248215
local scopes_for_extracted_range = scopes_for_range(in_buf, scopes, extracted_range)
249216

@@ -272,7 +239,6 @@ local function extract_func(opts)
272239
---@param a refactor.ReferenceInfo
273240
---@param b refactor.ReferenceInfo
274241
function(a, b)
275-
-- TODO: don't I already have a function to sort nodes in utils?
276242
local a_srow, a_scol, a_erow, a_ecol = a.identifier:range()
277243
local a_range = range(a_srow, a_scol, a_erow, a_ecol, { buf = in_buf })
278244
local b_srow, b_scol, b_erow, b_ecol = b.identifier:range()

lua/refactoring/refactor/extract_var.lua

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ function M.extract_var(range_type, config)
6565
local apply_text_edits = require("refactoring.utils").apply_text_edits
6666
local input = require("refactoring.utils").input
6767
local code_gen_error = require("refactoring.utils").code_gen_error
68+
local get_ts_info = require("refactoring.utils").get_ts_info
6869

6970
local opts = config.refactor.extract_var
7071
local code_generation = opts.code_generation
@@ -118,30 +119,14 @@ function M.extract_var(range_type, config)
118119

119120
local extracted_significant_text = significant_text(encompassing_node, buf)
120121
local matching_nodes = {} ---@type TSNode[]
121-
local scopes = {} ---@type refactor.Scope[]
122122
for _, tree in ipairs(nested_lang_tree:trees()) do
123123
for _, node in encompasing_query:iter_captures(tree:root(), buf) do
124124
local node_significant_text = significant_text(node, buf)
125125
if node_significant_text == extracted_significant_text then table.insert(matching_nodes, node) end
126126
end
127-
for _, match in query:iter_matches(tree:root(), buf) do
128-
local match_info ---@type refactor.Scope|nil
129-
for capture_id, nodes in pairs(match) do
130-
local name = query.captures[capture_id]
131-
if name == "scope" then
132-
match_info = match_info or {}
133-
match_info.scope = nodes[1]
134-
elseif name == "scope.inside" then
135-
match_info = match_info or {}
136-
match_info.inside = nodes[1]
137-
elseif name == "scope.outside" then
138-
match_info = match_info or {}
139-
match_info.outside = nodes[1]
140-
end
141-
end
142-
if match_info then table.insert(scopes, match_info) end
143-
end
144127
end
128+
local ts_info = get_ts_info(buf, nested_lang_tree, query)
129+
local scopes = ts_info.scopes
145130

146131
---@type {[integer]: refactor.TextEdit[]}
147132
local text_edits_by_buf = {}

0 commit comments

Comments
 (0)