Skip to content

Commit 75bb03e

Browse files
committed
Add backend specific TPU tests for DistributedEmbedding.
Under `keras_rs/src/layers/embedding`.
1 parent 5713afe commit 75bb03e

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
run: python3 -c "import jax; print('JAX devices:', jax.devices())"
9191

9292
- name: Test with pytest
93-
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
93+
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py keras_rs/src/layers/embedding/${{ matrix.backend }}
9494

9595
check_format:
9696
name: Check the code format

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""JAX implementation of the TPU embedding layer."""
22

3+
import collections
34
import math
45
import typing
56
from typing import Any, Mapping, Sequence, Union
@@ -445,6 +446,33 @@ def sparsecore_build(
445446
table_specs = embedding.get_table_specs(feature_specs)
446447
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
447448

449+
# Compute max ids for stacked tables
450+
stack_max_ids_per_partition: dict[str, int] = collections.defaultdict(
451+
int
452+
)
453+
stack_max_unique_ids_per_partition: dict[str, int] = (
454+
collections.defaultdict(int)
455+
)
456+
457+
for stack_name, table_specs in table_stacks.items():
458+
for table_spec in table_specs:
459+
stack_max_ids_per_partition[stack_name] += (
460+
table_spec.max_ids_per_partition
461+
)
462+
stack_max_unique_ids_per_partition[stack_name] += (
463+
table_spec.max_unique_ids_per_partition
464+
)
465+
466+
# stack name -> StackedTableSpec
467+
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
468+
for stack_name, stacked_table_spec in stacked_table_specs.items():
469+
stacked_table_spec.max_ids_per_partition = (
470+
stack_max_ids_per_partition[stack_name]
471+
)
472+
stacked_table_spec.max_unique_ids_per_partition = (
473+
stack_max_unique_ids_per_partition[stack_name]
474+
)
475+
448476
# Create variables for all stacked tables and slot variables.
449477
with sparsecore_distribution.scope():
450478
self._table_and_slot_variables = {

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,10 @@ def test_call(
327327
jit: bool,
328328
):
329329
table_configs = keras_test_utils.create_random_table_configs(
330-
combiner=combiner, seed=10
330+
combiner=combiner,
331+
seed=10,
332+
max_ids_per_partition=512,
333+
max_unique_ids_per_partition=512,
331334
)
332335
feature_configs = keras_test_utils.create_random_feature_configs(
333336
table_configs=table_configs, seed=20

0 commit comments

Comments
 (0)