@@ -80,10 +80,11 @@ function M.print_var(range_type, config)
8080 if not get_print_var then return code_gen_error (" print_var" , lang ) end
8181
8282 local references = {} --- @type refactor.ReferenceInfo[]
83- local statements = {} --- @type TSNode []
83+ local output_statements = {} --- @type refactor.OutputStatement []
8484 local scopes = {} --- @type TSNode[]
8585 for _ , tree in ipairs (nested_lang_tree :trees ()) do
8686 for _ , match , metadata in query :iter_matches (tree :root (), buf ) do
87+ local output_statement --- @type nil | refactor.OutputStatement
8788 for capture_id , nodes in pairs (match ) do
8889 local name = query .captures [capture_id ]
8990 if name == " reference.identifier" then
@@ -97,13 +98,20 @@ function M.print_var(range_type, config)
9798 end
9899 end
99100
100- if name == " statement" then table.insert (statements , nodes [1 ]) end
101+ if name == " output_statement" then
102+ output_statement = output_statement or {}
103+ output_statement .output_statement = nodes [1 ]
104+ elseif name == " output_statement.inside" then
105+ output_statement = output_statement or {}
106+ output_statement .inside = nodes [1 ]
107+ end
101108 if name == " scope" then
102109 for _ , node in ipairs (nodes ) do
103110 table.insert (scopes , node )
104111 end
105112 end
106113 end
114+ if output_statement then table.insert (output_statements , output_statement ) end
107115 end
108116 end
109117
@@ -116,36 +124,68 @@ function M.print_var(range_type, config)
116124 local extracted_reference_pos = opts .output_location == " below"
117125 and pos (extracted_range .end_row , extracted_range .end_col )
118126 or pos (extracted_range .start_row , extracted_range_start_line_first_non_white )
119- --- @type TSNode | nil
120- local statement_for_range = iter (statements )
127+ --- @type refactor.OutputStatement | nil
128+ local statement_for_range = iter (output_statements )
121129 :filter (
122- --- @param s TSNode
123- function (s )
124- local srow , scol , erow , ecol = s :range ()
125- local s_range = range (srow , scol , erow , ecol , { buf = buf })
126- return s_range :has (extracted_reference_pos )
130+ --- @param os refactor.OutputStatement
131+ function (os )
132+ local os_srow , os_scol , os_erow , os_ecol = os . output_statement :range ()
133+ local os_range = range (os_srow , os_scol , os_erow , os_ecol , { buf = buf })
134+ return os_range :has (extracted_reference_pos )
127135 end
128136 )
129137 :fold (
130138 nil ,
131- --- @param acc nil | TSNode
132- --- @param s TSNode
133- function (acc , s )
134- if not acc then return s end
135- if s :byte_length () < acc :byte_length () then return s end
139+ --- @param acc nil | refactor.OutputStatement
140+ --- @param os refactor.OutputStatement
141+ function (acc , os )
142+ if not acc then return os end
143+ if os . output_statement :byte_length () < acc . output_statement :byte_length () then return os end
136144 return acc
137145 end
138146 )
139147 if not statement_for_range then
140148 return vim .notify (" Couldn't find statement for extracted range using Treesitter" , vim .log .levels .ERROR )
141149 end
142150
143- local srow , scol , erow , ecol = statement_for_range :range ()
144- local statement_range = range (srow , scol , erow , ecol , { buf = buf })
145- local statement_srow , statement_scol , statement_erow , statement_ecol = statement_range :to_extmark ()
146- local output_range = opts .output_location == " below"
147- and range .extmark (statement_erow , statement_ecol , statement_erow , statement_ecol , { buf = buf })
148- or range .extmark (statement_srow , statement_scol , statement_srow , statement_scol , { buf = buf })
151+ local o_srow , o_scol , o_erow , o_ecol = statement_for_range .output_statement :range ()
152+ local before_range = range (o_srow , o_scol , o_srow , o_scol , { buf = buf })
153+ local after_range = range (o_erow , o_ecol , o_erow , o_ecol , { buf = buf })
154+ local output_range --- @type vim.Range
155+ local inserted_at --- @type ' start' | ' end'
156+ if statement_for_range .inside and opts .output_location == " above" then
157+ local i_srow , i_scol , i_erow , i_ecol = statement_for_range .inside :range ()
158+ local inside_range = range (i_srow , i_scol , i_erow , i_ecol , { buf = buf })
159+
160+ if extracted_range > inside_range then
161+ local _ , _ , inside_erow , inside_ecol = inside_range :to_extmark ()
162+ output_range = range .extmark (inside_erow , inside_ecol , inside_erow , inside_ecol , { buf = buf })
163+ inserted_at = " end"
164+ else
165+ output_range = before_range
166+ inserted_at = " start"
167+ end
168+ elseif statement_for_range .inside and opts .output_location == " below" then
169+ local i_srow , i_scol , i_erow , i_ecol = statement_for_range .inside :range ()
170+ local inside_range = range (i_srow , i_scol , i_erow , i_ecol , { buf = buf })
171+
172+ if extracted_range < inside_range then
173+ local inside_srow , inside_scol = inside_range :to_extmark ()
174+ output_range = range .extmark (inside_srow , inside_scol , inside_srow , inside_scol , { buf = buf })
175+ inserted_at = " start"
176+ else
177+ output_range = after_range
178+ inserted_at = " end"
179+ end
180+ else
181+ if opts .output_location == " above" then
182+ output_range = before_range
183+ inserted_at = " start"
184+ elseif opts .output_location == " below" then
185+ output_range = after_range
186+ inserted_at = " end"
187+ end
188+ end
149189
150190 local output_start_pos = pos (output_range .start_row , output_range .start_col , { buf = output_range .buf })
151191 -- TODO: I also compute `declarations_before_output_range` in
@@ -179,16 +219,8 @@ function M.print_var(range_type, config)
179219 or false
180220
181221 local r_srow , r_scol , r_erow , r_ecol = r .identifier :range ()
182- local node_start = pos (r_srow , r_scol , { buf = buf })
183- -- NOTE: I need to make end inclusive to be able to compare it with start
184- if r_ecol == 0 then
185- r_erow = r_erow - 1
186- r_ecol = # api .nvim_buf_get_lines (buf , r_erow , r_erow + 1 , true )[1 ]
187- else
188- r_ecol = r_ecol - 1
189- end
190- local node_end = pos (r_erow , r_ecol , { buf = buf })
191- return node_start <= output_start_pos and node_end <= output_start_pos and is_in_scope
222+ local r_range = range (r_srow , r_scol , r_erow , r_ecol , { buf = buf })
223+ return r_range < output_range and is_in_scope
192224 end
193225 )
194226 :map (reference_to_text )
@@ -281,8 +313,8 @@ function M.print_var(range_type, config)
281313 local print_text = table.concat (print_lines , " \n " )
282314 print_text = indent (expandtab , indent_amount , print_text )
283315 print_lines = vim .split (print_text , " \n " )
284- if opts . output_location == " below " then table.insert (print_lines , 1 , " " ) end
285- if opts . output_location == " above " then
316+ if inserted_at == " end " then table.insert (print_lines , 1 , " " ) end
317+ if inserted_at == " start " then
286318 print_lines [1 ] = indent (expandtab , 0 , print_lines [1 ])
287319 table.insert (print_lines , (expandtab and " " or " \t " ):rep (indent_amount ))
288320 end
0 commit comments