Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
gm.data.FormatText(
key=_INPUT_FIELD,
template="""<start_of_turn>user
Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No.
Please classify whether the following sentence is grammatically correct, please answer only with Yes or No.
Sentence: {text}<end_of_turn>
<start_of_turn>model""",
),
Expand All @@ -113,7 +113,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline:
gm.data.MapInts(
key=_LABEL_FIELD,
# Rather than predicting the token 0 and 1, we are using the
# token 1294 and 3553 which respectivelly correspond to "No" and
# token 1294 and 3553 which respectively correspond to "No" and
# "Yes". We do this because those token already contain semantic
# information, so even zero-shot prediction without any
# finetuning has better than random performances.
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def vision_encoder(self) -> gemma_vision.SigLiPFromPatches | None:
'return_hidden_states',
),
)
# The function accepts/returns aribtrary batch shape, but inside the
# The function accepts/returns arbitrary batch shape, but inside the
# function, the batch dimension is flattened to a single dimension.
@_jax_utils.flatten_unflatten_batch_dim()
@typechecked
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/nn/gemma3n/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def vision_encoder(self) -> gemma_vision.SigLiPFromPatches | None:
'return_hidden_states',
),
)
# The function accepts/returns aribtrary batch shape, but inside the
# The function accepts/returns arbitrary batch shape, but inside the
# function, the batch dimension is flattened to a single dimension.
@_jax_utils.flatten_unflatten_batch_dim()
@typechecked
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/text/_chat_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ChatSampler:
conversation can have (prompts, answers, images for all turns). Setting
this to a fixed value avoids re-compilation between turns.
max_out_length: Length of the output buffer for a single turn. Static value
used to avoid trigering a jit recompilation. Shouldn't be changed unless
used to avoid triggering a jit recompilation. Shouldn't be changed unless
you have a task where the model generates really long outputs.
last_state: Last state of the sampler, automatically handled by the sampler,
but exposed for power users to access the logits, cache, ... or initialize
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/text/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class Sampler:
conversation can have (prompts, answers, images for all turns). Setting
this to a fixed value avoids re-compilation between turns.
max_out_length: Length of the output buffer for a single turn. Static value
used to avoid trigering a jit recompilation. Shouldn't be changed unless
used to avoid triggering a jit recompilation. Shouldn't be changed unless
you have a task where the model generates really long outputs.
pad_length: If provided, pad the prompt to this length. This ensure the
prompt is always the same length, to avoid jit re-compilation.
Expand Down
2 changes: 1 addition & 1 deletion gemma/gm/vision/_token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _get_new_mm_tokens(
offset_by: int,
length_with_mm: int,
) -> Int['B max_num_images num_tokens_per_image+4']:
# Jax vmap does not support positional argiments, so need the
# Jax vmap does not support positional arguments, so need the
# _get_new_mm_tokens_inner indirection.
return jax.vmap(
_get_new_mm_tokens_inner, in_axes=(0, None, None, None, None)
Expand Down