Skip to content

Commit 9755d3b

Browse files
iraurruariwalkerwilliambdean
authored
Add root saturation function (issue #702) (#858)
* feat: adding root_saturation to transformers.py * feat: adding RootSaturation class to saturation.py * chore: adding missing RootSaturation to SATURATION_TRANSFORMATIONS * feat: adding root_saturation to transformers.py * feat: adding RootSaturation class to saturation.py * chore: adding missing RootSaturation to SATURATION_TRANSFORMATIONS * chore: linting edits * chore: adding coefficient to function * chore: linting corrections * chore: removed empty References section of docstring * chore: produce visual examples of root saturation * chore: adding root to test_saturation.py * chore: adding RootSaturation to init file --------- Co-authored-by: ruari.walker <[email protected]> Co-authored-by: Will Dean <[email protected]>
1 parent 19aea61 commit 9755d3b

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

pymc_marketing/mmm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
InverseScaledLogisticSaturation,
2727
LogisticSaturation,
2828
MichaelisMentenSaturation,
29+
RootSaturation,
2930
SaturationTransformation,
3031
TanhSaturation,
3132
TanhSaturationBaselined,
@@ -51,6 +52,7 @@
5152
"MMMModelBuilder",
5253
"MichaelisMentenSaturation",
5354
"MonthlyFourier",
55+
"RootSaturation",
5456
"SaturationTransformation",
5557
"TanhSaturation",
5658
"TanhSaturationBaselined",

pymc_marketing/mmm/components/saturation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def function(self, x, b):
7979
inverse_scaled_logistic_saturation,
8080
logistic_saturation,
8181
michaelis_menten,
82+
root_saturation,
8283
tanh_saturation,
8384
tanh_saturation_baselined,
8485
)
@@ -369,6 +370,39 @@ class HillSaturation(SaturationTransformation):
369370
}
370371

371372

373+
class RootSaturation(SaturationTransformation):
374+
"""Wrapper around Root saturation function.
375+
376+
For more information, see :func:`pymc_marketing.mmm.transformers.root_saturation`.
377+
378+
.. plot::
379+
:context: close-figs
380+
381+
import matplotlib.pyplot as plt
382+
import numpy as np
383+
from pymc_marketing.mmm import RootSaturation
384+
385+
rng = np.random.default_rng(0)
386+
387+
saturation = RootSaturation()
388+
prior = saturation.sample_prior(random_seed=rng)
389+
curve = saturation.sample_curve(prior)
390+
saturation.plot_curve(curve, sample_kwargs={"rng": rng})
391+
plt.show()
392+
393+
"""
394+
395+
lookup_name = "root"
396+
397+
def function(self, x, alpha, beta):
398+
return beta * root_saturation(x, alpha)
399+
400+
default_priors = {
401+
"alpha": Prior("Beta", alpha=1, beta=2),
402+
"beta": Prior("Gamma", mu=1, sigma=1),
403+
}
404+
405+
372406
SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {
373407
cls.lookup_name: cls
374408
for cls in [
@@ -378,6 +412,7 @@ class HillSaturation(SaturationTransformation):
378412
TanhSaturationBaselined,
379413
MichaelisMentenSaturation,
380414
HillSaturation,
415+
RootSaturation,
381416
]
382417
}
383418

pymc_marketing/mmm/transformers.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,3 +988,49 @@ def hill_saturation(
988988
The value of the Hill function for each input value of x.
989989
"""
990990
return sigma / (1 + pt.exp(-beta * (x - lam)))
991+
992+
993+
def root_saturation(
994+
x: pt.TensorLike,
995+
alpha: pt.TensorLike,
996+
) -> pt.TensorVariable:
997+
r"""Root saturation transformation.
998+
999+
.. math::
1000+
f(x) = x^{\alpha}
1001+
1002+
.. plot::
1003+
:context: close-figs
1004+
1005+
import matplotlib.pyplot as plt
1006+
import numpy as np
1007+
import arviz as az
1008+
from pymc_marketing.mmm.transformers import root_saturation
1009+
plt.style.use('arviz-darkgrid')
1010+
alpha = np.array([0.1, 0.3, 0.5, 0.7])
1011+
x = np.linspace(0, 5, 100)
1012+
ax = plt.subplot(111)
1013+
for a in alpha:
1014+
y = root_saturation(x, alpha=a)
1015+
plt.plot(x, y, label=f'alpha = {a}')
1016+
plt.xlabel('spend', fontsize=12)
1017+
plt.ylabel('f(spend)', fontsize=12)
1018+
box = ax.get_position()
1019+
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
1020+
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
1021+
plt.show()
1022+
1023+
Parameters
1024+
----------
1025+
x : tensor
1026+
Input tensor.
1027+
alpha : float
1028+
Exponent for the root transformation. Must be non-negative.
1029+
1030+
Returns
1031+
-------
1032+
tensor
1033+
Transformed tensor.
1034+
1035+
"""
1036+
return x**alpha

tests/mmm/components/test_saturation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
InverseScaledLogisticSaturation,
2626
LogisticSaturation,
2727
MichaelisMentenSaturation,
28+
RootSaturation,
2829
TanhSaturation,
2930
TanhSaturationBaselined,
3031
_get_saturation_function,
@@ -46,6 +47,7 @@ def saturation_functions():
4647
TanhSaturationBaselined(),
4748
MichaelisMentenSaturation(),
4849
HillSaturation(),
50+
RootSaturation(),
4951
]
5052

5153

@@ -101,6 +103,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
101103
("tanh_baselined", TanhSaturationBaselined),
102104
("michaelis_menten", MichaelisMentenSaturation),
103105
("hill", HillSaturation),
106+
("root", RootSaturation),
104107
],
105108
)
106109
def test_get_saturation_function(name, saturation_cls) -> None:

0 commit comments

Comments
 (0)