diff --git a/examples/classification.py b/examples/classification.py index c9e9f5bf..c0b79926 100644 --- a/examples/classification.py +++ b/examples/classification.py @@ -96,7 +96,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline: gm.data.FormatText( key=_INPUT_FIELD, template="""user - Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No. + Please classify whether the following sentence is grammatically correct, please answer only with Yes or No. Sentence: {text} model""", ), @@ -113,7 +113,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline: gm.data.MapInts( key=_LABEL_FIELD, # Rather than predicting the token 0 and 1, we are using the - # token 1294 and 3553 which respectivelly correspond to "No" and + # token 1294 and 3553 which respectively correspond to "No" and # "Yes". We do this because those token already contain semantic # information, so even zero-shot prediction without any # finetuning has better than random performances. diff --git a/gemma/gm/nn/_transformer.py b/gemma/gm/nn/_transformer.py index 4c158bbd..c2ccad1c 100644 --- a/gemma/gm/nn/_transformer.py +++ b/gemma/gm/nn/_transformer.py @@ -175,7 +175,7 @@ def vision_encoder(self) -> gemma_vision.SigLiPFromPatches | None: 'return_hidden_states', ), ) - # The function accepts/returns aribtrary batch shape, but inside the + # The function accepts/returns arbitrary batch shape, but inside the # function, the batch dimension is flattened to a single dimension. @_jax_utils.flatten_unflatten_batch_dim() @typechecked diff --git a/gemma/gm/nn/gemma3n/_transformer.py b/gemma/gm/nn/gemma3n/_transformer.py index 4940cc78..701102b5 100644 --- a/gemma/gm/nn/gemma3n/_transformer.py +++ b/gemma/gm/nn/gemma3n/_transformer.py @@ -218,7 +218,7 @@ def vision_encoder(self) -> gemma_vision.SigLiPFromPatches | None: 'return_hidden_states', ), ) - # The function accepts/returns aribtrary batch shape, but inside the + # The function accepts/returns arbitrary batch shape, but inside the # function, the batch dimension is flattened to a single dimension. @_jax_utils.flatten_unflatten_batch_dim() @typechecked diff --git a/gemma/gm/text/_chat_sampler.py b/gemma/gm/text/_chat_sampler.py index 09db1bc9..198ac447 100644 --- a/gemma/gm/text/_chat_sampler.py +++ b/gemma/gm/text/_chat_sampler.py @@ -68,7 +68,7 @@ class ChatSampler: conversation can have (prompts, answers, images for all turns). Setting this to a fixed value avoids re-compilation between turns. max_out_length: Length of the output buffer for a single turn. Static value - used to avoid trigering a jit recompilation. Shouldn't be changed unless + used to avoid triggering a jit recompilation. Shouldn't be changed unless you have a task where the model generates really long outputs. last_state: Last state of the sampler, automatically handled by the sampler, but exposed for power users to access the logits, cache, ... or initialize diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 7c941a93..a31f3a8d 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -108,7 +108,7 @@ class Sampler: conversation can have (prompts, answers, images for all turns). Setting this to a fixed value avoids re-compilation between turns. max_out_length: Length of the output buffer for a single turn. Static value - used to avoid trigering a jit recompilation. Shouldn't be changed unless + used to avoid triggering a jit recompilation. Shouldn't be changed unless you have a task where the model generates really long outputs. pad_length: If provided, pad the prompt to this length. This ensure the prompt is always the same length, to avoid jit re-compilation. diff --git a/gemma/gm/vision/_token_utils.py b/gemma/gm/vision/_token_utils.py index 74170a16..adfe7693 100644 --- a/gemma/gm/vision/_token_utils.py +++ b/gemma/gm/vision/_token_utils.py @@ -219,7 +219,7 @@ def _get_new_mm_tokens( offset_by: int, length_with_mm: int, ) -> Int['B max_num_images num_tokens_per_image+4']: - # Jax vmap does not support positional argiments, so need the + # Jax vmap does not support positional arguments, so need the # _get_new_mm_tokens_inner indirection. return jax.vmap( _get_new_mm_tokens_inner, in_axes=(0, None, None, None, None)