Skip to content

Commit 5a788c2

Browse files
committed
feat(print_var): statement -> output_statement and improvements from print_loc
1 parent 52f7d23 commit 5a788c2

File tree

10 files changed

+115
-51
lines changed

10 files changed

+115
-51
lines changed

lua/refactoring/debug/print_loc.lua

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ function M.print_loc(range_type, config)
114114
:fold(
115115
nil,
116116
---@param acc nil|refactor.OutputStatement
117-
---@param s refactor.OutputStatement
118-
function(acc, s)
119-
if not acc then return s end
120-
if s.output_statement:byte_length() < acc.output_statement:byte_length() then return s end
117+
---@param os refactor.OutputStatement
118+
function(acc, os)
119+
if not acc then return os end
120+
if os.output_statement:byte_length() < acc.output_statement:byte_length() then return os end
121121
return acc
122122
end
123123
)
@@ -126,10 +126,8 @@ function M.print_loc(range_type, config)
126126
end
127127

128128
local o_srow, o_scol, o_erow, o_ecol = statement_for_range.output_statement:range()
129-
local statement_range = range(o_srow, o_scol, o_erow, o_ecol, { buf = buf })
130-
local statement_srow, statement_scol, statement_erow, statement_ecol = statement_range:to_extmark()
131-
local before_range = range.extmark(statement_srow, statement_scol, statement_srow, statement_scol, { buf = buf })
132-
local after_range = range.extmark(statement_erow, statement_ecol, statement_erow, statement_ecol, { buf = buf })
129+
local before_range = range(o_srow, o_scol, o_srow, o_scol, { buf = buf })
130+
local after_range = range(o_erow, o_ecol, o_erow, o_ecol, { buf = buf })
133131
local output_range ---@type vim.Range
134132
local inserted_at ---@type 'start'|'end'
135133
if statement_for_range.inside and opts.output_location == "above" then

lua/refactoring/debug/print_var.lua

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ function M.print_var(range_type, config)
8080
if not get_print_var then return code_gen_error("print_var", lang) end
8181

8282
local references = {} ---@type refactor.ReferenceInfo[]
83-
local statements = {} ---@type TSNode[]
83+
local output_statements = {} ---@type refactor.OutputStatement[]
8484
local scopes = {} ---@type TSNode[]
8585
for _, tree in ipairs(nested_lang_tree:trees()) do
8686
for _, match, metadata in query:iter_matches(tree:root(), buf) do
87+
local output_statement ---@type nil|refactor.OutputStatement
8788
for capture_id, nodes in pairs(match) do
8889
local name = query.captures[capture_id]
8990
if name == "reference.identifier" then
@@ -97,13 +98,20 @@ function M.print_var(range_type, config)
9798
end
9899
end
99100

100-
if name == "statement" then table.insert(statements, nodes[1]) end
101+
if name == "output_statement" then
102+
output_statement = output_statement or {}
103+
output_statement.output_statement = nodes[1]
104+
elseif name == "output_statement.inside" then
105+
output_statement = output_statement or {}
106+
output_statement.inside = nodes[1]
107+
end
101108
if name == "scope" then
102109
for _, node in ipairs(nodes) do
103110
table.insert(scopes, node)
104111
end
105112
end
106113
end
114+
if output_statement then table.insert(output_statements, output_statement) end
107115
end
108116
end
109117

@@ -116,36 +124,68 @@ function M.print_var(range_type, config)
116124
local extracted_reference_pos = opts.output_location == "below"
117125
and pos(extracted_range.end_row, extracted_range.end_col)
118126
or pos(extracted_range.start_row, extracted_range_start_line_first_non_white)
119-
---@type TSNode|nil
120-
local statement_for_range = iter(statements)
127+
---@type refactor.OutputStatement|nil
128+
local statement_for_range = iter(output_statements)
121129
:filter(
122-
---@param s TSNode
123-
function(s)
124-
local srow, scol, erow, ecol = s:range()
125-
local s_range = range(srow, scol, erow, ecol, { buf = buf })
126-
return s_range:has(extracted_reference_pos)
130+
---@param os refactor.OutputStatement
131+
function(os)
132+
local os_srow, os_scol, os_erow, os_ecol = os.output_statement:range()
133+
local os_range = range(os_srow, os_scol, os_erow, os_ecol, { buf = buf })
134+
return os_range:has(extracted_reference_pos)
127135
end
128136
)
129137
:fold(
130138
nil,
131-
---@param acc nil|TSNode
132-
---@param s TSNode
133-
function(acc, s)
134-
if not acc then return s end
135-
if s:byte_length() < acc:byte_length() then return s end
139+
---@param acc nil|refactor.OutputStatement
140+
---@param os refactor.OutputStatement
141+
function(acc, os)
142+
if not acc then return os end
143+
if os.output_statement:byte_length() < acc.output_statement:byte_length() then return os end
136144
return acc
137145
end
138146
)
139147
if not statement_for_range then
140148
return vim.notify("Couldn't find statement for extracted range using Treesitter", vim.log.levels.ERROR)
141149
end
142150

143-
local srow, scol, erow, ecol = statement_for_range:range()
144-
local statement_range = range(srow, scol, erow, ecol, { buf = buf })
145-
local statement_srow, statement_scol, statement_erow, statement_ecol = statement_range:to_extmark()
146-
local output_range = opts.output_location == "below"
147-
and range.extmark(statement_erow, statement_ecol, statement_erow, statement_ecol, { buf = buf })
148-
or range.extmark(statement_srow, statement_scol, statement_srow, statement_scol, { buf = buf })
151+
local o_srow, o_scol, o_erow, o_ecol = statement_for_range.output_statement:range()
152+
local before_range = range(o_srow, o_scol, o_srow, o_scol, { buf = buf })
153+
local after_range = range(o_erow, o_ecol, o_erow, o_ecol, { buf = buf })
154+
local output_range ---@type vim.Range
155+
local inserted_at ---@type 'start'|'end'
156+
if statement_for_range.inside and opts.output_location == "above" then
157+
local i_srow, i_scol, i_erow, i_ecol = statement_for_range.inside:range()
158+
local inside_range = range(i_srow, i_scol, i_erow, i_ecol, { buf = buf })
159+
160+
if extracted_range > inside_range then
161+
local _, _, inside_erow, inside_ecol = inside_range:to_extmark()
162+
output_range = range.extmark(inside_erow, inside_ecol, inside_erow, inside_ecol, { buf = buf })
163+
inserted_at = "end"
164+
else
165+
output_range = before_range
166+
inserted_at = "start"
167+
end
168+
elseif statement_for_range.inside and opts.output_location == "below" then
169+
local i_srow, i_scol, i_erow, i_ecol = statement_for_range.inside:range()
170+
local inside_range = range(i_srow, i_scol, i_erow, i_ecol, { buf = buf })
171+
172+
if extracted_range < inside_range then
173+
local inside_srow, inside_scol = inside_range:to_extmark()
174+
output_range = range.extmark(inside_srow, inside_scol, inside_srow, inside_scol, { buf = buf })
175+
inserted_at = "start"
176+
else
177+
output_range = after_range
178+
inserted_at = "end"
179+
end
180+
else
181+
if opts.output_location == "above" then
182+
output_range = before_range
183+
inserted_at = "start"
184+
elseif opts.output_location == "below" then
185+
output_range = after_range
186+
inserted_at = "end"
187+
end
188+
end
149189

150190
local output_start_pos = pos(output_range.start_row, output_range.start_col, { buf = output_range.buf })
151191
-- TODO: I also compute `declarations_before_output_range` in
@@ -179,16 +219,8 @@ function M.print_var(range_type, config)
179219
or false
180220

181221
local r_srow, r_scol, r_erow, r_ecol = r.identifier:range()
182-
local node_start = pos(r_srow, r_scol, { buf = buf })
183-
-- NOTE: I need to make end inclusive to be able to compare it with start
184-
if r_ecol == 0 then
185-
r_erow = r_erow - 1
186-
r_ecol = #api.nvim_buf_get_lines(buf, r_erow, r_erow + 1, true)[1]
187-
else
188-
r_ecol = r_ecol - 1
189-
end
190-
local node_end = pos(r_erow, r_ecol, { buf = buf })
191-
return node_start <= output_start_pos and node_end <= output_start_pos and is_in_scope
222+
local r_range = range(r_srow, r_scol, r_erow, r_ecol, { buf = buf })
223+
return r_range < output_range and is_in_scope
192224
end
193225
)
194226
:map(reference_to_text)
@@ -281,8 +313,8 @@ function M.print_var(range_type, config)
281313
local print_text = table.concat(print_lines, "\n")
282314
print_text = indent(expandtab, indent_amount, print_text)
283315
print_lines = vim.split(print_text, "\n")
284-
if opts.output_location == "below" then table.insert(print_lines, 1, "") end
285-
if opts.output_location == "above" then
316+
if inserted_at == "end" then table.insert(print_lines, 1, "") end
317+
if inserted_at == "start" then
286318
print_lines[1] = indent(expandtab, 0, print_lines[1])
287319
table.insert(print_lines, (expandtab and " " or "\t"):rep(indent_amount))
288320
end

queries/c/print_var.scm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
(statement)
6262
(function_definition)
6363
(declaration)
64-
] @statement
64+
] @output_statement
6565

6666
(struct_specifier) @scope
6767

queries/ecma/print_var.scm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
object: (identifier) @reference.identifier
3131
(#set! reference_type read))
3232

33-
(statement) @statement
33+
(statement) @output_statement
3434

3535
(do_statement
3636
body: (statement_block

queries/lua/print_loc.scm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
(#not-has-parent? @debug_path expression_list)
3333
(#set! text "(anon)"))
3434

35-
; TODO: move these statements (and the changes to how statements are processed) to the other implementations
3635
[
3736
(empty_statement)
3837
(assignment_statement)

queries/lua/print_var.scm

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,42 @@
169169
(#set! reference_type read))
170170

171171
[
172-
(statement)
172+
(empty_statement)
173+
(assignment_statement)
174+
(function_call)
175+
(label_statement)
176+
(break_statement)
177+
(goto_statement)
173178
(return_statement)
174-
] @statement
179+
(variable_declaration)
180+
] @output_statement
181+
182+
(do_statement
183+
body: (_) @output_statement.inside) @output_statement
184+
185+
(while_statement
186+
body: (_) @output_statement.inside) @output_statement
187+
188+
(repeat_statement
189+
body: (_) @output_statement.inside) @output_statement
190+
191+
(if_statement
192+
consequence: (_) @output_statement.inside) @output_statement
193+
194+
(elseif_statement
195+
consequence: (_) @output_statement.inside) @output_statement
196+
197+
(else_statement
198+
body: (_) @output_statement.inside) @output_statement
199+
200+
(for_statement
201+
body: (_) @output_statement.inside) @output_statement
202+
203+
(function_declaration
204+
body: (_) @output_statement.inside) @output_statement
205+
206+
(function_definition
207+
body: (_) @output_statement.inside) @output_statement
175208

176209
; table.sort(function() end)
177210
((function_definition

queries/powershell/print_var.scm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
(#set! reference_type read))
110110

111111
(statement_list
112-
(_) @statement)
112+
(_) @output_statement)
113113

114114
(class_statement
115115
(class_method_definition) @scope.inside) @scope

queries/python/print_var.scm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
(nonlocal_statement)
9797
(exec_statement)
9898
(type_alias_statement)
99-
] @statement
99+
] @output_statement
100100

101101
; _compund_statement
102102
[
@@ -109,7 +109,7 @@
109109
(class_definition)
110110
(decorated_definition)
111111
(match_statement)
112-
] @statement
112+
] @output_statement
113113

114114
(module) @scope
115115

queries/vim/print_var.scm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
(eval_statement)
190190
(substitute_statement)
191191
(user_command)
192-
] @statement
192+
] @output_statement
193193

194194
(script_file) @scope
195195

tests/test_print_var.lua

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,19 @@ T["python"]["works"] = function()
171171
local lines = [[
172172
def foo():
173173
i = 3
174+
foo = i
174175
return i]]
175176
local expected_lines = [[
176177
def foo():
178+
i = 3
177179
# __PRINT_VAR_START
178180
print(f"i: {str(i)}")# __PRINT_VAR_END
179-
i = 3
181+
foo = i
180182
return i]]
181183
child.cmd "edit tmp.py"
182184
child.bo.expandtab = true
183185
child.bo.shiftwidth = 4
184-
validate(lines, { 2, 4 }, expected_lines, " pViw")
186+
validate(lines, { 3, 4 }, expected_lines, " pV_")
185187
end
186188

187189
T["vimscript"] = MiniTest.new_set()

0 commit comments

Comments
 (0)