Skip to content

Commit 3c0fbb5

Browse files
committed
fix: unify range handling by bundling WIP vim.range implementation
1 parent 9dfd9e7 commit 3c0fbb5

File tree

9 files changed

+1021
-387
lines changed

9 files changed

+1021
-387
lines changed

lua/refactoring/debug/cleanup.lua

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
local api = vim.api
22
local iter = vim.iter
33
local async = require "async"
4+
local range = require "refactoring.range"
5+
local pos = require "refactoring.pos"
46
local ts = vim.treesitter
57

68
-- TODO: Search inside strings (using treesitter) on `printf` (and
@@ -26,15 +28,14 @@ end
2628
---@param range_type 'v' | 'V' | ''
2729
---@param config refactor.Config
2830
function M.cleanup(range_type, config)
29-
local get_extracted_range = require("refactoring.range").get_extracted_range
31+
local get_extracted_range = require("refactoring.utils").get_extracted_range
3032
local apply_text_edits = require("refactoring.utils").apply_text_edits
31-
local contains_range = require("refactoring.range").contains_range
3233

3334
local opts = config.debug.cleanup
3435

3536
local buf = api.nvim_get_current_buf()
3637
local last_line = vim.fn.line "$"
37-
local extracted_range = get_extracted_range(range_type)
38+
local extracted_range = get_extracted_range(buf, range_type)
3839

3940
local task = async.run(function()
4041
local lang_tree, err1 = ts.get_parser(buf, nil, { error = false })
@@ -44,7 +45,7 @@ function M.cleanup(range_type, config)
4445
end
4546
-- TODO: use async parsing
4647
lang_tree:parse(true)
47-
local nested_lang_tree = lang_tree:language_for_range(extracted_range)
48+
local nested_lang_tree = lang_tree:language_for_range { extracted_range:to_treesitter() }
4849
local lang = nested_lang_tree:lang()
4950
local query = ts.query.get(lang, "refactor")
5051
if not query then
@@ -64,13 +65,13 @@ function M.cleanup(range_type, config)
6465
end
6566

6667
table.sort(comments, node_comp_asc)
67-
---@type Range4[]
68+
---@type vim.Range[]
6869
local ranges_to_cleanup = iter(comments)
6970
:filter(
7071
---@param comment TSNode
7172
function(comment)
72-
local comment_range = { comment:range() }
73-
return contains_range(extracted_range, comment_range)
73+
local comment_range = range.treesitter(buf, comment:range())
74+
return extracted_range:has(comment_range)
7475
end
7576
)
7677
:map(
@@ -84,21 +85,14 @@ function M.cleanup(range_type, config)
8485
return text:find(opts.markers[name].start) ~= nil
8586
end
8687
)
87-
if is_start then return "start", { comment:start() } end
88-
local comment_end = { comment:end_() }
89-
if comment_end[1] ~= last_line - 1 then
90-
comment_end[1], comment_end[2] = comment_end[1] + 1, 0
91-
end
88+
if is_start then return "start", pos.treesitter(buf, "start", comment:start()) end
9289
local is_end = iter(opts.types):any(
9390
---@param name 'print_var'|'print_loc'|'print_exp'
9491
function(name)
9592
return text:find(opts.markers[name]["end"]) ~= nil
9693
end
9794
)
98-
if is_end then return "end", comment_end end
99-
-- TODO: I'll need to generalize the handling of 0-based/1-based
100-
-- end-exclusive/end-inclusive/end_row-inclusive_col-exclusive/end_row_exclusive-_col-0
101-
-- ranges everywhere x2
95+
if is_end then return "end", pos.treesitter(buf, "end", comment:end_()) end
10296
end
10397
)
10498
:filter(
@@ -109,13 +103,13 @@ function M.cleanup(range_type, config)
109103
)
110104
:fold(
111105
{},
112-
---@param acc Range4[]|{current_start: Range2}
106+
---@param acc vim.Range|{current_start: vim.Pos}
113107
---@param kind 'start'|'end'
114-
---@param range Range2
115-
function(acc, kind, range)
116-
if kind == "start" then acc.current_start = range end
108+
---@param position vim.Pos
109+
function(acc, kind, position)
110+
if kind == "start" then acc.current_start = position end
117111
if kind == "end" and acc.current_start ~= nil then
118-
table.insert(acc, { acc.current_start[1], acc.current_start[2], range[1], range[2] })
112+
table.insert(acc, range(acc.current_start, position))
119113
acc.current_start = nil
120114
end
121115

@@ -127,9 +121,9 @@ function M.cleanup(range_type, config)
127121
local text_edits_by_buf = {}
128122
text_edits_by_buf[buf] = {}
129123
iter(ipairs(ranges_to_cleanup)):each(
130-
---@param range Range4
131-
function(_, range)
132-
table.insert(text_edits_by_buf[buf], { range = range, lines = {} })
124+
---@param r vim.Range
125+
function(_, r)
126+
table.insert(text_edits_by_buf[buf], { range = r, lines = {} })
133127
end
134128
)
135129

lua/refactoring/debug/print_var.lua

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ local api = vim.api
22
local ts = vim.treesitter
33
local iter = vim.iter
44
local async = require "async"
5+
local range = require "refactoring.range"
6+
local pos = require "refactoring.pos"
57

68
local M = {}
79

@@ -35,22 +37,19 @@ local M = {}
3537
---@param range_type 'v' | 'V' | ''
3638
---@param config refactor.Config
3739
function M.print_var(range_type, config)
38-
local get_extracted_range = require("refactoring.range").get_extracted_range
39-
local contains = require("refactoring.range").contains
40+
local get_extracted_range = require("refactoring.utils").get_extracted_range
4041
local is_unique = require("refactoring.utils").is_unique
4142
local code_gen_error = require("refactoring.utils").code_gen_error
4243
local apply_text_edits = require("refactoring.utils").apply_text_edits
4344
local get_declarations_by_scope = require("refactoring.utils").get_declarations_by_scope
4445
local scopes_for_range = require("refactoring.utils").scopes_for_range
4546
local get_declaration_scope = require("refactoring.utils").get_declaration_scope
46-
local compare = require("refactoring.range").compare
47-
local comp_non_overlaping_ranges_asc = require("refactoring.range").comp_non_overlaping_ranges_asc
4847

4948
local opts = config.debug.print_var
5049
local code_generation = opts.code_generation
5150

5251
local buf = api.nvim_get_current_buf()
53-
local extracted_range = get_extracted_range(range_type)
52+
local extracted_range = get_extracted_range(buf, range_type)
5453

5554
local task = async.run(function()
5655
local lang_tree, err1 = ts.get_parser(buf, nil, { error = false })
@@ -60,7 +59,7 @@ function M.print_var(range_type, config)
6059
end
6160
-- TODO: use async parsing
6261
lang_tree:parse(true)
63-
local nested_lang_tree = lang_tree:language_for_range(extracted_range)
62+
local nested_lang_tree = lang_tree:language_for_range { extracted_range:to_treesitter() }
6463
local lang = nested_lang_tree:lang()
6564
local query = ts.query.get(lang, "refactor")
6665
if not query then
@@ -99,19 +98,21 @@ function M.print_var(range_type, config)
9998
end
10099
end
101100

101+
local extracted_range_api = { extracted_range:to_api() }
102102
-- NOTE: treesitter nodes usualy do not include leading whitespace
103-
local extracted_range_start_line = api.nvim_buf_get_lines(buf, extracted_range[1], extracted_range[1] + 1, true)[1]
103+
local extracted_range_start_line =
104+
api.nvim_buf_get_lines(buf, extracted_range_api[1], extracted_range_api[1] + 1, true)[1]
104105
local _, extracted_range_start_line_first_non_white = extracted_range_start_line:find "^%s*"
105106
extracted_range_start_line_first_non_white = extracted_range_start_line_first_non_white or 0
106-
local extracted_reference_point = opts.output_location == "below" and { extracted_range[3], extracted_range[4] }
107-
or { extracted_range[1], extracted_range_start_line_first_non_white }
107+
local extracted_reference_pos = opts.output_location == "below" and extracted_range.end_
108+
or pos(extracted_range.start.row, extracted_range_start_line_first_non_white)
108109
---@type TSNode|nil
109110
local statement_for_range = iter(statements)
110111
:filter(
111112
---@param s TSNode
112113
function(s)
113-
local s_range = { s:range() }
114-
return contains(s_range, extracted_reference_point)
114+
local s_range = range.treesitter(buf, s:range())
115+
return s_range:has_pos(extracted_reference_pos)
115116
end
116117
)
117118
:fold(
@@ -128,16 +129,16 @@ function M.print_var(range_type, config)
128129
return vim.notify("Couldn't find statement for extracted range using Treesitter", vim.log.levels.ERROR)
129130
end
130131

131-
local statement_range = { statement_for_range:range() }
132+
local statement_range = range.treesitter(buf, statement_for_range:range())
133+
local statement_srow, statement_scol, statement_erow, statement_ecol = statement_range:to_api()
132134
local output_range = opts.output_location == "below"
133-
and { statement_range[3], statement_range[4], statement_range[3], statement_range[4] }
134-
or { statement_range[1], statement_range[2], statement_range[1], statement_range[2] }
135-
local output_start = { output_range[1], output_range[2] }
135+
and range.api(buf, statement_erow, statement_ecol, statement_erow, statement_ecol)
136+
or range.api(buf, statement_srow, statement_scol, statement_srow, statement_scol)
136137

137138
-- TODO: I also compute `declarations_before_output_range` in
138139
-- `extract_func`. Is there a cleaner wat to do all this in both places?
139140
local declarations_by_scope = get_declarations_by_scope(references, scopes, buf)
140-
local scopes_for_extracted_range = scopes_for_range(scopes, extracted_range)
141+
local scopes_for_extracted_range = scopes_for_range(buf, scopes, extracted_range)
141142
local reference_to_text =
142143
---@param reference refactor.Reference
143144
function(reference)
@@ -153,9 +154,6 @@ function M.print_var(range_type, config)
153154
:filter(
154155
---@param r refactor.Reference
155156
function(r)
156-
local start_node = { r.identifier:start() }
157-
local end_node = { r.identifier:end_() }
158-
159157
local declaration_scope = get_declaration_scope(declarations_by_scope, scopes, r, buf)
160158

161159
local is_in_scope = declaration_scope
@@ -167,25 +165,24 @@ function M.print_var(range_type, config)
167165
)
168166
or false
169167

170-
return compare(start_node, output_start) ~= 1 and compare(end_node, output_start) ~= 1 and is_in_scope
168+
local node_start = pos.treesitter(buf, "start", r.identifier:start())
169+
local node_end = pos.treesitter(buf, "end", r.identifier:end_())
170+
return node_start <= output_range.start and node_end <= output_range.start and is_in_scope
171171
end
172172
)
173173
:map(reference_to_text)
174174
:totable()
175175

176-
-- TODO: should I generalize/unify how ranges and range transformations are
177-
-- handled everywhere? Some ranges are having a 1-off error because of no standar
178-
-- handling of ranges (e.g. `extract_func` for a single word). Take a look
179-
-- at how `vim.treesitter.get_node_text` handles treesitter ranges when doing this
180-
local extracted_range_ts = { extracted_range[1], extracted_range[2], extracted_range[3], extracted_range[4] + 1 }
176+
-- TODO: Some ranges are having a 1-off error because of no standar
177+
-- handling of ranges (e.g. `extract_func` for a single word).
178+
181179
---@type {[string]: refactor.Reference}
182180
local selected_references_by_start = iter(references)
183181
:filter(
184182
---@param r refactor.Reference
185183
function(r)
186-
local r_start = { r.identifier:start() }
187-
local r_end = { r.identifier:end_() }
188-
return contains(extracted_range_ts, r_start) and contains(extracted_range_ts, r_end)
184+
local r_range = range.treesitter(buf, r.identifier:range())
185+
return extracted_range:has(r_range)
189186
end
190187
)
191188
:fold(
@@ -209,19 +206,18 @@ function M.print_var(range_type, config)
209206
end)
210207
:totable()
211208
table.sort(selected_references, function(a, b)
212-
local a_range = { a.identifier:range() }
213-
local b_range = { b.identifier:range() }
209+
local a_range = range.treesitter(buf, a.identifier:range())
210+
local b_range = range.treesitter(buf, b.identifier:range())
214211

215-
return comp_non_overlaping_ranges_asc(a_range, b_range)
212+
return a_range < b_range
216213
end)
217214
---@type string[]
218215
local print_lines = iter(selected_references)
219216
:filter(
220217
---@param r refactor.Reference
221218
function(r)
222-
local r_start = { r.identifier:start() }
223-
local r_end = { r.identifier:end_() }
224-
return contains(extracted_range_ts, r_start) and contains(extracted_range_ts, r_end)
219+
local r_range = range.treesitter(buf, r.identifier:range())
220+
return extracted_range:has(r_range)
225221
end
226222
)
227223
:map(

0 commit comments

Comments
 (0)