Skip to content

Commit 1819ecb

Browse files
hangfeicopybara-github
authored andcommitted
fix: Improve handling of partial and complete transcriptions in live calls
In `gemini_llm_connection.py`, accumulate partial transcription texts and emit `LlmResponse` with `partial=True` for each chunk. When the transcription is marked as `finished`, emit a final `LlmResponse` with the full accumulated text and `partial=False`. In `runners.py`, modify `_should_append_to_history` to only add transcription events to the history when they are fully finished, preventing partial transcriptions from being added. Co-authored-by: Hangfei Lin <[email protected]> PiperOrigin-RevId: 829029715
1 parent 44d45fe commit 1819ecb

File tree

5 files changed

+278
-22
lines changed

5 files changed

+278
-22
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -587,21 +587,19 @@ async def _postprocess_live(
587587

588588
# Handle transcription events ONCE per llm_response, outside the event loop
589589
if llm_response.input_transcription:
590-
input_transcription_event = (
591-
await self.transcription_manager.handle_input_transcription(
592-
invocation_context, llm_response.input_transcription
593-
)
590+
model_response_event.input_transcription = (
591+
llm_response.input_transcription
594592
)
595-
yield input_transcription_event
593+
model_response_event.partial = llm_response.partial
594+
yield model_response_event
596595
return
597596

598597
if llm_response.output_transcription:
599-
output_transcription_event = (
600-
await self.transcription_manager.handle_output_transcription(
601-
invocation_context, llm_response.output_transcription
602-
)
598+
model_response_event.output_transcription = (
599+
llm_response.output_transcription
603600
)
604-
yield output_transcription_event
601+
model_response_event.partial = llm_response.partial
602+
yield model_response_event
605603
return
606604

607605
# Flush audio caches based on control events using configurable settings

src/google/adk/models/gemini_llm_connection.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class GeminiLlmConnection(BaseLlmConnection):
3535

3636
def __init__(self, gemini_session: live.AsyncSession):
3737
self._gemini_session = gemini_session
38+
self._input_transcription_text: str = ''
39+
self._output_transcription_text: str = ''
3840

3941
async def send_history(self, history: list[types.Content]):
4042
"""Sends the conversation history to the gemini model.
@@ -166,15 +168,49 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
166168
text = ''
167169
yield llm_response
168170
if message.server_content.input_transcription:
169-
llm_response = LlmResponse(
170-
input_transcription=message.server_content.input_transcription,
171-
)
172-
yield llm_response
171+
if message.server_content.input_transcription.text:
172+
self._input_transcription_text += (
173+
message.server_content.input_transcription.text
174+
)
175+
yield LlmResponse(
176+
input_transcription=types.Transcription(
177+
text=message.server_content.input_transcription.text,
178+
finished=False,
179+
),
180+
partial=True,
181+
)
182+
# finished=True and partial transcription may happen in the same
183+
# message.
184+
if message.server_content.input_transcription.finished:
185+
yield LlmResponse(
186+
input_transcription=types.Transcription(
187+
text=self._input_transcription_text,
188+
finished=True,
189+
),
190+
partial=False,
191+
)
192+
self._input_transcription_text = ''
173193
if message.server_content.output_transcription:
174-
llm_response = LlmResponse(
175-
output_transcription=message.server_content.output_transcription
176-
)
177-
yield llm_response
194+
if message.server_content.output_transcription.text:
195+
self._output_transcription_text += (
196+
message.server_content.output_transcription.text
197+
)
198+
yield LlmResponse(
199+
output_transcription=types.Transcription(
200+
text=message.server_content.output_transcription.text,
201+
finished=False,
202+
),
203+
partial=True,
204+
)
205+
if message.server_content.output_transcription.finished:
206+
yield LlmResponse(
207+
output_transcription=types.Transcription(
208+
text=self._output_transcription_text,
209+
finished=True,
210+
),
211+
partial=False,
212+
)
213+
self._output_transcription_text = ''
178214
if message.server_content.turn_complete:
179215
if text:
180216
yield self.__build_full_text_response(text)
@@ -188,10 +224,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
188224
# in case it's an interrupted message, we merge the previous partial
189225
# text. Other we don't merge. because content can be none when model
190226
# safety threshold is triggered
191-
if message.server_content.interrupted and text:
192-
yield self.__build_full_text_response(text)
193-
text = ''
194-
yield LlmResponse(interrupted=message.server_content.interrupted)
227+
if message.server_content.interrupted:
228+
if text:
229+
yield self.__build_full_text_response(text)
230+
text = ''
231+
else:
232+
yield LlmResponse(interrupted=message.server_content.interrupted)
195233
if message.tool_call:
196234
if text:
197235
yield self.__build_full_text_response(text)

src/google/adk/runners.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,10 @@ def _should_append_event(self, event: Event, is_live_call: bool) -> bool:
588588
# Don't append audio response from model in live mode to session.
589589
# The data is appended to artifacts with a reference in file_data in the
590590
# event.
591+
# We should append non-partial events only.For example, non-finished(partial)
592+
# transcription events should not be appended.
593+
# Function call and function response events should be appended.
594+
# Other control events should be appended.
591595
if is_live_call and contents._is_live_model_audio_event(event):
592596
return False
593597
return True

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,135 @@ async def mock_receive_generator():
219219
)
220220
assert usage_response.usage_metadata == expected_usage
221221
assert content_response.content == mock_content
222+
223+
224+
@pytest.mark.asyncio
225+
async def test_receive_handles_input_transcription_fragments(
226+
gemini_connection, mock_gemini_session
227+
):
228+
"""Test receive handles input transcription fragments correctly."""
229+
message1 = mock.Mock()
230+
message1.usage_metadata = None
231+
message1.server_content = mock.Mock()
232+
message1.server_content.model_turn = None
233+
message1.server_content.interrupted = False
234+
message1.server_content.input_transcription = types.Transcription(
235+
text='Hello', finished=False
236+
)
237+
message1.server_content.output_transcription = None
238+
message1.server_content.turn_complete = False
239+
message1.tool_call = None
240+
message1.session_resumption_update = None
241+
242+
message2 = mock.Mock()
243+
message2.usage_metadata = None
244+
message2.server_content = mock.Mock()
245+
message2.server_content.model_turn = None
246+
message2.server_content.interrupted = False
247+
message2.server_content.input_transcription = types.Transcription(
248+
text=' world', finished=False
249+
)
250+
message2.server_content.output_transcription = None
251+
message2.server_content.turn_complete = False
252+
message2.tool_call = None
253+
message2.session_resumption_update = None
254+
255+
message3 = mock.Mock()
256+
message3.usage_metadata = None
257+
message3.server_content = mock.Mock()
258+
message3.server_content.model_turn = None
259+
message3.server_content.interrupted = False
260+
message3.server_content.input_transcription = types.Transcription(
261+
text=None, finished=True
262+
)
263+
message3.server_content.output_transcription = None
264+
message3.server_content.turn_complete = False
265+
message3.tool_call = None
266+
message3.session_resumption_update = None
267+
268+
async def mock_receive_generator():
269+
yield message1
270+
yield message2
271+
yield message3
272+
273+
receive_mock = mock.Mock(return_value=mock_receive_generator())
274+
mock_gemini_session.receive = receive_mock
275+
276+
responses = [resp async for resp in gemini_connection.receive()]
277+
278+
assert len(responses) == 3
279+
assert responses[0].input_transcription.text == 'Hello'
280+
assert responses[0].input_transcription.finished is False
281+
assert responses[0].partial is True
282+
assert responses[1].input_transcription.text == ' world'
283+
assert responses[1].input_transcription.finished is False
284+
assert responses[1].partial is True
285+
assert responses[2].input_transcription.text == 'Hello world'
286+
assert responses[2].input_transcription.finished is True
287+
assert responses[2].partial is False
288+
289+
290+
@pytest.mark.asyncio
291+
async def test_receive_handles_output_transcription_fragments(
292+
gemini_connection, mock_gemini_session
293+
):
294+
"""Test receive handles output transcription fragments correctly."""
295+
message1 = mock.Mock()
296+
message1.usage_metadata = None
297+
message1.server_content = mock.Mock()
298+
message1.server_content.model_turn = None
299+
message1.server_content.interrupted = False
300+
message1.server_content.input_transcription = None
301+
message1.server_content.output_transcription = types.Transcription(
302+
text='How can', finished=False
303+
)
304+
message1.server_content.turn_complete = False
305+
message1.tool_call = None
306+
message1.session_resumption_update = None
307+
308+
message2 = mock.Mock()
309+
message2.usage_metadata = None
310+
message2.server_content = mock.Mock()
311+
message2.server_content.model_turn = None
312+
message2.server_content.interrupted = False
313+
message2.server_content.input_transcription = None
314+
message2.server_content.output_transcription = types.Transcription(
315+
text=' I help?', finished=False
316+
)
317+
message2.server_content.turn_complete = False
318+
message2.tool_call = None
319+
message2.session_resumption_update = None
320+
321+
message3 = mock.Mock()
322+
message3.usage_metadata = None
323+
message3.server_content = mock.Mock()
324+
message3.server_content.model_turn = None
325+
message3.server_content.interrupted = False
326+
message3.server_content.input_transcription = None
327+
message3.server_content.output_transcription = types.Transcription(
328+
text=None, finished=True
329+
)
330+
message3.server_content.turn_complete = False
331+
message3.tool_call = None
332+
message3.session_resumption_update = None
333+
334+
async def mock_receive_generator():
335+
yield message1
336+
yield message2
337+
yield message3
338+
339+
receive_mock = mock.Mock(return_value=mock_receive_generator())
340+
mock_gemini_session.receive = receive_mock
341+
342+
responses = [resp async for resp in gemini_connection.receive()]
343+
344+
assert len(responses) == 3
345+
assert responses[0].output_transcription.text == 'How can'
346+
assert responses[0].output_transcription.finished is False
347+
assert responses[0].partial is True
348+
assert responses[1].output_transcription.text == ' I help?'
349+
assert responses[1].output_transcription.finished is False
350+
assert responses[1].partial is True
351+
assert responses[2].output_transcription.text == 'How can I help?'
352+
assert responses[2].output_transcription.finished is True
353+
assert responses[2].partial is False

tests/unittests/test_runners.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,5 +775,89 @@ def test_runner_realistic_cache_config_scenario(self):
775775
assert str(runner.context_cache_config) == expected_str
776776

777777

778+
class TestRunnerShouldAppendEvent:
779+
"""Tests for Runner._should_append_event method."""
780+
781+
def setup_method(self):
782+
"""Set up test fixtures."""
783+
self.session_service = InMemorySessionService()
784+
self.artifact_service = InMemoryArtifactService()
785+
self.root_agent = MockLlmAgent("root_agent")
786+
self.runner = Runner(
787+
app_name="test_app",
788+
agent=self.root_agent,
789+
session_service=self.session_service,
790+
artifact_service=self.artifact_service,
791+
)
792+
793+
def test_should_append_event_finished_input_transcription(self):
794+
event = Event(
795+
invocation_id="inv1",
796+
author="user",
797+
input_transcription=types.Transcription(text="hello", finished=True),
798+
)
799+
assert self.runner._should_append_event(event, is_live_call=True) is True
800+
801+
def test_should_append_event_unfinished_input_transcription(self):
802+
event = Event(
803+
invocation_id="inv1",
804+
author="user",
805+
input_transcription=types.Transcription(text="hello", finished=False),
806+
)
807+
assert self.runner._should_append_event(event, is_live_call=True) is True
808+
809+
def test_should_append_event_finished_output_transcription(self):
810+
event = Event(
811+
invocation_id="inv1",
812+
author="model",
813+
output_transcription=types.Transcription(text="world", finished=True),
814+
)
815+
assert self.runner._should_append_event(event, is_live_call=True) is True
816+
817+
def test_should_append_event_unfinished_output_transcription(self):
818+
event = Event(
819+
invocation_id="inv1",
820+
author="model",
821+
output_transcription=types.Transcription(text="world", finished=False),
822+
)
823+
assert self.runner._should_append_event(event, is_live_call=True) is True
824+
825+
def test_should_not_append_event_live_model_audio(self):
826+
event = Event(
827+
invocation_id="inv1",
828+
author="model",
829+
content=types.Content(
830+
parts=[
831+
types.Part(
832+
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
833+
)
834+
]
835+
),
836+
)
837+
assert self.runner._should_append_event(event, is_live_call=True) is False
838+
839+
def test_should_append_event_non_live_model_audio(self):
840+
event = Event(
841+
invocation_id="inv1",
842+
author="model",
843+
content=types.Content(
844+
parts=[
845+
types.Part(
846+
inline_data=types.Blob(data=b"123", mime_type="audio/pcm")
847+
)
848+
]
849+
),
850+
)
851+
assert self.runner._should_append_event(event, is_live_call=False) is True
852+
853+
def test_should_append_event_other_event(self):
854+
event = Event(
855+
invocation_id="inv1",
856+
author="model",
857+
content=types.Content(parts=[types.Part(text="text")]),
858+
)
859+
assert self.runner._should_append_event(event, is_live_call=True) is True
860+
861+
778862
if __name__ == "__main__":
779863
pytest.main([__file__])

0 commit comments

Comments
 (0)