diff --git a/src/model_fit_api/app.py b/src/model_fit_api/app.py index f6e3b3b..348d6cc 100644 --- a/src/model_fit_api/app.py +++ b/src/model_fit_api/app.py @@ -64,7 +64,11 @@ class Result(BaseModel): def fit(data, name_model, ebv, redshift): dust = sncosmo.CCM89Dust() - model = sncosmo.Model(source=name_model, effects=[dust], effect_names=["mw"], effect_frames=["obs"]) + if name_model == "salt2": + source = sncosmo.get_source("salt2", version="2.4") + model = sncosmo.Model(source=source, effects=[dust], effect_names=["mw"], effect_frames=["obs"]) + else: + model = sncosmo.Model(source=name_model, effects=[dust], effect_names=["mw"], effect_frames=["obs"]) model.set(mwebv=ebv) summary, fitted_model = sncosmo.fit_lc( data, model, model.param_names, bounds={"z": (redshift[0], redshift[1])}