Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions containers/BasicTerm_ME_python/term_me_iterative_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0)
mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0)
model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0)
# model_point_table = model_point_table = model_point_table.iloc[[0]]
premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1])

class ModelPointsEqx(eqx.Module):
Expand All @@ -23,6 +24,7 @@ class ModelPointsEqx(eqx.Module):
def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1):
table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True)
table.sort_values(by="policy_id", inplace=True)
print(table)
self.premium_pp = jnp.round(jnp.array(np.tile(table["sum_assured"].to_numpy() * table["premium_rate"].to_numpy(), size_multiplier)),decimals=2)
self.duration_mth = jnp.array(jnp.tile(table["duration_mth"].to_numpy(), size_multiplier))
self.age_at_entry = jnp.array(jnp.tile(table["age_at_entry"].to_numpy(), size_multiplier))
Expand All @@ -39,6 +41,11 @@ class AssumptionsEqx(eqx.Module):

def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame):
self.disc_rate_ann = jnp.array(disc_rate_ann["zero_spot"].to_numpy())
# Get the shape of the original data from the "zero_spot" column.
zero_spot_shape = disc_rate_ann["zero_spot"].to_numpy().shape

# Create a JAX array of zeros with the same shape.
# self.disc_rate_ann = jnp.zeros(zero_spot_shape, dtype=jnp.float64)
self.mort_table = jnp.array(mort_table.to_numpy())
self.expense_acq = jnp.array(300)
self.expense_maint = jnp.array(60)
Expand Down Expand Up @@ -77,18 +84,45 @@ def iterative_core(ls: LoopState, _):
pols_if_at_BEF_NB = pols_if_at_BEF_MAT - pols_maturity
pols_new_biz = jnp.where(duration_month_t == 0, self.mp.policy_count, 0)
pols_if_at_BEF_DECR = pols_if_at_BEF_NB + pols_new_biz
mort_rate = self.assume.mort_table[age_t-18, jnp.clip(duration_t, a_max=5)]
mort_rate = self.assume.mort_table[age_t-18 - jnp.clip(duration_t, a_max=5), jnp.clip(duration_t, a_max=5)]
mort_rate_mth = 1 - (1 - mort_rate) ** (1/12)
pols_death = pols_if_at_BEF_DECR * mort_rate_mth
claims = self.mp.sum_assured * pols_death
premiums = self.mp.premium_pp * pols_if_at_BEF_DECR
commissions = (duration_t == 0) * premiums
discount = (1 + self.assume.disc_rate_ann[ls.t//12]) ** (-ls.t/12)
commissions = (duration_month_t == 0) * premiums
discount = (1 + self.assume.disc_rate_ann[ls.t//12+1]) ** (-ls.t/12)
inflation_factor = (1 + 0.01) ** (ls.t/12)
expenses = self.assume.expense_acq * pols_new_biz + pols_if_at_BEF_DECR * self.assume.expense_maint/12 * inflation_factor
lapse_rate = jnp.clip(0.1 - 0.02 * duration_t, a_min=0.02)
net_cf = premiums - claims - expenses - commissions
discounted_net_cf = jnp.sum(net_cf) * discount
undiscounted_net_cf = jnp.sum(net_cf)

# # Debug printing each variable
# jax.debug.print("duration_month_t = {}", duration_month_t)
# jax.debug.print("duration_t = {}", duration_t)
# jax.debug.print("age_t = {}", age_t)
# jax.debug.print("pols_if_init = {}", pols_if_init)
# jax.debug.print("pols_if_at_BEF_MAT = {}", pols_if_at_BEF_MAT)
# jax.debug.print("pols_maturity = {}", pols_maturity)
# jax.debug.print("pols_if_at_BEF_NB = {}", pols_if_at_BEF_NB)
# jax.debug.print("pols_new_biz = {}", pols_new_biz)
# jax.debug.print("pols_if_at_BEF_DECR = {}", pols_if_at_BEF_DECR)
# jax.debug.print("mort_rate = {}", mort_rate)
# jax.debug.print("mort_rate_mth = {}", mort_rate_mth)
# jax.debug.print("pols_death = {}", pols_death)
# jax.debug.print("claims = {}", claims)
# jax.debug.print("premiums = {}", premiums)
# jax.debug.print("commissions = {}", commissions)
# jax.debug.print("discount = {}", discount)
# jax.debug.print("inflation_factor = {}", inflation_factor)
# jax.debug.print("expenses = {}", expenses)
# jax.debug.print("lapse_rate = {}", lapse_rate)
# jax.debug.print("lapses = {}", (pols_if_at_BEF_DECR - pols_death) * (1 - (1 - lapse_rate) ** (1/12)))
# jax.debug.print("net_cf = {}", net_cf)
# jax.debug.print("undiscounted_net_cf = {}", undiscounted_net_cf)
# jax.debug.print("---")

discounted_net_cf = undiscounted_net_cf * discount
nxt_ls = LoopState(
t=ls.t+1,
tot = ls.tot + discounted_net_cf,
Expand All @@ -97,6 +131,7 @@ def iterative_core(ls: LoopState, _):
pols_if_at_BEF_DECR_prev=pols_if_at_BEF_DECR
)
return nxt_ls, None

return jax.lax.scan(iterative_core, self.init_ls, xs=None, length=277)[0].tot


Expand All @@ -112,6 +147,7 @@ def time_jax_func(mp, assume, func):
result = func(term_me).block_until_ready()
end = timeit.default_timer()
elapsed_time = end - start # Time in seconds
print(result)
return float(result), elapsed_time

def time_iterative_jax(multiplier: int):
Expand All @@ -124,4 +160,4 @@ def time_iterative_jax(multiplier: int):
print(f"{time_in_seconds=}")

if __name__ == "__main__":
time_iterative_jax(100)
time_iterative_jax(1)
138 changes: 138 additions & 0 deletions github-runners-benchmarks/Julia/GPU.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
using CUDA
using DataFrames
using XLSX
using BenchmarkTools

# Load assumption data
disc_rate_ann = DataFrame(XLSX.readtable("BasicTerm_ME/disc_rate_ann.xlsx", "Sheet1")...)
mort_table = DataFrame(XLSX.readtable("BasicTerm_ME/mort_table.xlsx", "Sheet1")...)
model_point_table = DataFrame(XLSX.readtable("BasicTerm_ME/model_point_table.xlsx", "Sheet1")...)
premium_table = DataFrame(XLSX.readtable("BasicTerm_ME/premium_table.xlsx", "Sheet1")...)

# Define model point struct
struct ModelPoints
premium_pp::CuArray{Float64}
duration_mth::CuArray{Int}
age_at_entry::CuArray{Int}
sum_assured::CuArray{Float64}
policy_count::CuArray{Float64}
policy_term::CuArray{Int}
max_proj_len::Int
end

function ModelPoints(model_point_table, premium_table; size_multiplier=1)
# Join and preprocess model point and premium data
mp_data = innerjoin(model_point_table, premium_table, on=[:age_at_entry, :policy_term])
sort!(mp_data, :policy_id)

# Initialize model point struct
premium_pp = CuArray(repeat(mp_data.sum_assured .* mp_data.premium_rate, size_multiplier))
duration_mth = CuArray(repeat(mp_data.duration_mth, size_multiplier))
age_at_entry = CuArray(repeat(mp_data.age_at_entry, size_multiplier))
sum_assured = CuArray(repeat(mp_data.sum_assured, size_multiplier))
policy_count = CuArray(repeat(mp_data.policy_count, size_multiplier))
policy_term = CuArray(repeat(mp_data.policy_term, size_multiplier))
max_proj_len = maximum(12 .* policy_term .- duration_mth) + 1

ModelPoints(premium_pp, duration_mth, age_at_entry, sum_assured, policy_count, policy_term, max_proj_len)
end

# Define assumptions struct
struct Assumptions
disc_rate_ann::CuArray{Float64}
mort_table::CuArray{Float64}
expense_acq::Float64
expense_maint::Float64
end

function Assumptions(disc_rate_ann, mort_table)
disc_rate_ann_cu = CuArray(disc_rate_ann.zero_spot)
mort_table_cu = CuArray(Matrix(mort_table))
expense_acq = 300.0
expense_maint = 60.0
Assumptions(disc_rate_ann_cu, mort_table_cu, expense_acq, expense_maint)
end

# Define projection loop state struct
struct LoopState
t::Int
tot::Float64
pols_lapse_prev::CuArray{Float64}
pols_death_prev::CuArray{Float64}
pols_if_at_BEF_DECR_prev::CuArray{Float64}
end

# Define main projection struct
struct TermME
mp::ModelPoints
assume::Assumptions
init_ls::LoopState
end

function TermME(mp, assume)
init_ls = LoopState(
0,
0.0,
CUDA.zeros(Float64, length(mp.duration_mth)),
CUDA.zeros(Float64, length(mp.duration_mth)),
CuArray((mp.duration_mth .> 0) .* mp.policy_count)
)
TermME(mp, assume, init_ls)
end

function run_term_ME(tm::TermME)
function iterative_core(ls, _)
duration_month_t = tm.mp.duration_mth .+ ls.t
duration_t = duration_month_t .÷ 12
age_t = tm.mp.age_at_entry .+ duration_t
pols_if_init = ls.pols_if_at_BEF_DECR_prev .- ls.pols_lapse_prev .- ls.pols_death_prev
pols_if_at_BEF_MAT = pols_if_init
pols_maturity = (duration_month_t .== tm.mp.policy_term .* 12) .* pols_if_at_BEF_MAT
pols_if_at_BEF_NB = pols_if_at_BEF_MAT .- pols_maturity
pols_new_biz = (duration_month_t .== 0) .* tm.mp.policy_count
pols_if_at_BEF_DECR = pols_if_at_BEF_NB .+ pols_new_biz
mort_rate = tm.assume.mort_table[age_t.-17, min.(duration_t, 5)]
mort_rate_mth = 1 .- (1 .- mort_rate) .^ (1 / 12)
pols_death = pols_if_at_BEF_DECR .* mort_rate_mth
claims = tm.mp.sum_assured .* pols_death
premiums = tm.mp.premium_pp .* pols_if_at_BEF_DECR
commissions = (duration_t .== 0) .* premiums
discount = (1 .+ tm.assume.disc_rate_ann[ls.t.÷12]) .^ (-ls.t ./ 12)
inflation_factor = (1 + 0.01) .^ (ls.t ./ 12)
expenses = tm.assume.expense_acq .* pols_new_biz .+ pols_if_at_BEF_DECR .* tm.assume.expense_maint ./ 12 .* inflation_factor
lapse_rate = max.(0.1 .- 0.02 .* duration_t, 0.02)
net_cf = premiums .- claims .- expenses .- commissions
discounted_net_cf = sum(net_cf .* discount)
nxt_ls = LoopState(
ls.t + 1,
ls.tot + discounted_net_cf,
(pols_if_at_BEF_DECR .- pols_death) .* (1 .- (1 .- lapse_rate) .^ (1 / 12)),
pols_death,
pols_if_at_BEF_DECR
)
return nxt_ls, nothing
end

result, _ = foldl(iterative_core, 1:tm.mp.max_proj_len, init=tm.init_ls)
result.tot
end

function time_term_ME(mp, assume)
tm = TermME(mp, assume)
run_term_ME(tm) # warmup
tot = @btime run_term_ME($tm)
tot, @elapsed run_term_ME(tm)
end

function main(multiplier)
mp = ModelPoints(model_point_table, premium_table, size_multiplier=multiplier)
assume = Assumptions(disc_rate_ann, mort_table)
tot, elapsed = time_term_ME(mp, assume)

println("CUDA.jl Term ME Model")
println("Number modelpoints: $(length(mp.duration_mth))")
println("Total: $tot")
println("Elapsed time: $elapsed seconds")
end

main(1000)
Loading
Loading