Skip to content

Commit d13c207

Browse files
ez96The Meridian Authors
authored andcommitted
Add ranked_geos property to InputData.
PiperOrigin-RevId: 822204516
1 parent c005d8f commit d13c207

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

meridian/data/input_data.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,45 @@ def media_time(self) -> xr.DataArray:
363363
else:
364364
return self.reach[constants.MEDIA_TIME]
365365

366+
@functools.cached_property
367+
def ranked_geos(self) -> list[str]:
368+
"""Ranks geos by total spend then by total KPI."""
369+
n_geos = len(self.geo)
370+
if n_geos == 1:
371+
return [self.geo[0]]
372+
else:
373+
if self.media_spend is not None and self.rf_spend is not None:
374+
total_spend = (
375+
self.media_spend.to_dataframe()
376+
.unstack(level="media_channel")
377+
.merge(
378+
self.rf_spend.to_dataframe().unstack(level="rf_channel"),
379+
on=[constants.GEO, constants.TIME],
380+
how="inner",
381+
)
382+
)
383+
elif self.media_spend is not None:
384+
total_spend = self.media_spend.to_dataframe()
385+
elif self.rf_spend is not None:
386+
total_spend = self.rf_spend.to_dataframe()
387+
else:
388+
raise ValueError(
389+
"It is required to have at least one of media or reach + frequency."
390+
)
391+
392+
df_spend_sum = (
393+
total_spend.sum(axis=1).rename("spend", inplace=True).to_frame()
394+
)
395+
396+
return (
397+
df_spend_sum.merge(
398+
self.kpi.to_dataframe(), on=["geo", "time"], how="inner"
399+
)
400+
.groupby(constants.GEO)
401+
.sum()
402+
.sort_values(by=["spend", "kpi"], ascending=False)
403+
).index.values.tolist()
404+
366405
@functools.cached_property
367406
def media_time_coordinates(self) -> tc.TimeCoordinates:
368407
"""Returns the media time dimension in a `TimeCoordinates` wrapper."""

meridian/data/input_data_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,82 @@ def test_scaled_centered_kpi_supports_dtype_int(self):
16421642
data.population = data.population.astype(int)
16431643
self.assertNotEmpty(data.scaled_centered_kpi)
16441644

1645+
def test_rank_geos_only_media_spend(self):
1646+
data = input_data.InputData(
1647+
kpi=self.not_lagged_kpi,
1648+
kpi_type=constants.NON_REVENUE,
1649+
population=self.population,
1650+
media=self.lagged_media,
1651+
media_spend=self.media_spend,
1652+
)
1653+
self.assertListEqual(
1654+
data.ranked_geos,
1655+
[
1656+
"geo_1",
1657+
"geo_5",
1658+
"geo_4",
1659+
"geo_6",
1660+
"geo_7",
1661+
"geo_8",
1662+
"geo_2",
1663+
"geo_3",
1664+
"geo_9",
1665+
"geo_0",
1666+
],
1667+
)
1668+
1669+
def test_rank_geos_only_rf_spend(self):
1670+
data = input_data.InputData(
1671+
kpi=self.not_lagged_kpi,
1672+
kpi_type=constants.NON_REVENUE,
1673+
population=self.population,
1674+
reach=self.lagged_reach,
1675+
frequency=self.lagged_frequency,
1676+
rf_spend=self.rf_spend,
1677+
)
1678+
self.assertListEqual(
1679+
data.ranked_geos,
1680+
[
1681+
"geo_3",
1682+
"geo_5",
1683+
"geo_0",
1684+
"geo_4",
1685+
"geo_8",
1686+
"geo_6",
1687+
"geo_7",
1688+
"geo_1",
1689+
"geo_9",
1690+
"geo_2",
1691+
],
1692+
)
1693+
1694+
def test_rank_geos_media_spend_and_rf_spend(self):
1695+
data = input_data.InputData(
1696+
kpi=self.not_lagged_kpi,
1697+
kpi_type=constants.NON_REVENUE,
1698+
population=self.population,
1699+
media=self.lagged_media,
1700+
media_spend=self.media_spend,
1701+
reach=self.lagged_reach,
1702+
frequency=self.lagged_frequency,
1703+
rf_spend=self.rf_spend,
1704+
)
1705+
self.assertListEqual(
1706+
data.ranked_geos,
1707+
[
1708+
"geo_5",
1709+
"geo_4",
1710+
"geo_1",
1711+
"geo_6",
1712+
"geo_8",
1713+
"geo_3",
1714+
"geo_7",
1715+
"geo_0",
1716+
"geo_2",
1717+
"geo_9",
1718+
],
1719+
)
1720+
16451721

16461722
class NonpaidInputDataTest(parameterized.TestCase):
16471723
"""Tests for non-paid InputData."""

0 commit comments

Comments
 (0)