Skip to content
205 changes: 185 additions & 20 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,28 @@ Base.@nospecializeinfer function traced_type_inner(
return PT
end

function collect_tvars_in_type!(dependencies, @nospecialize(t))
if t isa TypeVar
push!(dependencies, t)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return
return nothing

end
if t isa DataType
for p in t.parameters
collect_tvars_in_type!(dependencies, p)
end
elseif t isa Union
collect_tvars_in_type!(dependencies, t.a)
collect_tvars_in_type!(dependencies, t.b)
elseif t isa UnionAll
collect_tvars_in_type!(dependencies, t.var.lb)
collect_tvars_in_type!(dependencies, t.var.ub)
collect_tvars_in_type!(dependencies, t.body)
elseif t isa Core.TypeofVararg
collect_tvars_in_type!(dependencies, t.T)
collect_tvars_in_type!(dependencies, t.N)
end
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type),
seen,
Expand Down Expand Up @@ -710,11 +732,16 @@ Base.@nospecializeinfer function traced_type_inner(
return T
end

@debug "traced_type_inner: Processing type with field changes" T=T subTys=subTys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "traced_type_inner: Processing type with field changes" T=T subTys=subTys
@debug "traced_type_inner: Processing type with field changes" T = T subTys = subTys


wrapped_cpjrt_array = T <: AbstractArray && ancestor(T) <: ConcretePJRTArray
wrapped_cifrt_array = T <: AbstractArray && ancestor(T) <: ConcreteIFRTArray
wrapped_tracedarray = T <: AbstractArray && ancestor(T) <: TracedRArray

@debug "wrapped flags" wrapped_cpjrt_array=wrapped_cpjrt_array wrapped_cifrt_array=wrapped_cifrt_array wrapped_tracedarray=wrapped_tracedarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "wrapped flags" wrapped_cpjrt_array=wrapped_cpjrt_array wrapped_cifrt_array=wrapped_cifrt_array wrapped_tracedarray=wrapped_tracedarray
@debug "wrapped flags" wrapped_cpjrt_array = wrapped_cpjrt_array wrapped_cifrt_array =
wrapped_cifrt_array wrapped_tracedarray = wrapped_tracedarray


subParms = []
@debug "Tracing type parameters" num_params=length(T.parameters) T_parameters=T.parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Tracing type parameters" num_params=length(T.parameters) T_parameters=T.parameters
@debug "Tracing type parameters" num_params = length(T.parameters) T_parameters =
T.parameters

for (i, SST) in enumerate(T.parameters)
if wrapped_cpjrt_array && i == 1 && SST isa Type && SST <: ReactantPrimitive
TrT = traced_type_inner(
Expand Down Expand Up @@ -746,30 +773,62 @@ Base.@nospecializeinfer function traced_type_inner(
end
end

@debug "Built subParms" subParms=subParms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Built subParms" subParms=subParms
@debug "Built subParms" subParms = subParms


if !isempty(subParms)
TT2 = Core.apply_type(T.name.wrapper, subParms...)
@debug "Calling apply_type_with_promotion" wrapper=T.name.wrapper subParms=subParms num_params=length(T.parameters)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Calling apply_type_with_promotion" wrapper=T.name.wrapper subParms=subParms num_params=length(T.parameters)
@debug "Calling apply_type_with_promotion" wrapper = T.name.wrapper subParms =
subParms num_params = length(T.parameters)

TT2, changed_params = apply_type_with_promotion(T.name.wrapper, subParms)
@debug "apply_type_with_promotion succeeded" TT2=TT2 result_fieldcount=fieldcount(TT2) changed_params=changed_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "apply_type_with_promotion succeeded" TT2=TT2 result_fieldcount=fieldcount(TT2) changed_params=changed_params
@debug "apply_type_with_promotion succeeded" TT2 = TT2 result_fieldcount = fieldcount(
TT2
) changed_params = changed_params

else
TT2 = T
@debug "subParms is empty, using T as-is"
TT2, changed_params = T, nothing
end
seen3 = copy(seen)
seen3[T] = TT2
@debug "Validating reconstructed type" T=T TT2=TT2 fieldcount_match=(fieldcount(T) == fieldcount(TT2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Validating reconstructed type" T=T TT2=TT2 fieldcount_match=(fieldcount(T) == fieldcount(TT2))
@debug "Validating reconstructed type" T = T TT2 = TT2 fieldcount_match = (
fieldcount(T) == fieldcount(TT2)
)


generic_T = Base.unwrap_unionall(T.name.wrapper)
param_map = typevar_dict(T.name.wrapper)

if fieldcount(T) == fieldcount(TT2)
legal = true

skipfield = false
for f in 1:fieldcount(T)
def_ft = fieldtype(generic_T, f)
field_tvars = Base.IdSet{TypeVar}()
collect_tvars_in_type!(field_tvars, def_ft)
# field_tvars now contains all typevars the field type directly depends on.
@debug "Collected field tvars" field_tvars
for tvar in field_tvars
idx = get(param_map, tvar, nothing)
isnothing(idx) && continue
if changed_params[idx]
skipfield = true
break
end
end
skipfield && continue

subT = fieldtype(T, f)
subT2 = fieldtype(TT2, f)
subTT = traced_type_inner(subT, seen3, mode, track_numbers, sharding, runtime)
@debug "Field validation" f=f subT=subT subT2=subT2 subTT=subTT match=(subT2==subTT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Field validation" f=f subT=subT subT2=subT2 subTT=subTT match=(subT2==subTT)
@debug "Field validation" f = f subT = subT subT2 = subT2 subTT = subTT match = (
subT2 == subTT
)

if subT2 != subTT
@debug "Field mismatch detected" f=f expected=subTT got=subT2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Field mismatch detected" f=f expected=subTT got=subT2
@debug "Field mismatch detected" f = f expected = subTT got = subT2

legal = false
break
end
end
if legal
@debug "All field checks passed, returning TT2"
for (k, v) in seen3
seen[k] = v
end
return TT2
end
else
@debug "Field count mismatch" fieldcount_T=fieldcount(T) fieldcount_TT2=fieldcount(TT2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@debug "Field count mismatch" fieldcount_T=fieldcount(T) fieldcount_TT2=fieldcount(TT2)
@debug "Field count mismatch" fieldcount_T = fieldcount(T) fieldcount_TT2 = fieldcount(
TT2
)

end

throw(NoFieldMatchError(T, TT2, subTys))
Expand All @@ -782,27 +841,27 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}()
# T = T.parameters[1]
# mode = mode.parameters[1]::TraceMode
# track_numbers = track_numbers.parameters[1]
#
#
#
#
# min_world = Ref{UInt}(typemin(UInt))
# max_world = Ref{UInt}(typemax(UInt))
#
#
# sig = Tuple{typeof(traced_type_inner), Type{T}, Dict{Type, Type}, TraceMode, Type{track_numbers}}
#
#
# lookup_result = lookup_world(
# sig, world, nothing, min_world, max_world
# )
# if lookup_result === nothing
# stub = Core.GeneratedFunctionStub(identity, Core.svec(:traced_type, :T, :mode, :track_numbers), Core.svec())
# return stub(world, source, method_error)
# return stub(world, source, method_error)
# end
# match = lookup_result::Core.MethodMatch
#
#
# mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
# (Any, Any, Any), match.method, match.spec_types, match.sparams)::Core.MethodInstance
#
#
# ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo
#
#
# cache = nothing
# cache_key = (mode, track_numbers)
# if haskey(traced_type_cache, cache_key)
Expand All @@ -811,8 +870,8 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}()
# cache = Dict{Type, Type}()
# traced_type_cache[cache_key] = cache
# end
#
#
#
#
# # prepare a new code info
# new_ci = copy(ci)
# empty!(new_ci.code)
Expand All @@ -830,21 +889,21 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}()
# gensig = Tuple{typeof(traced_type_inner), Type, Dict{Type, Type}, TraceMode, Type{track_numbers}}
# push!(edges, ccall(:jl_method_table_for, Any, (Any,), gensig))
# push!(edges, gensig)
#
#
# new_ci.edges = edges
#
#
# # XXX: setting this edge does not give us proper method invalidation, see
# # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# # invoking `code_llvm` also does the necessary codegen, as does calling the
# # underlying C methods -- which GPUCompiler does, so everything Just Works.
#
#
# # prepare the slots
# new_ci.slotnames = Symbol[Symbol("#self#"), :T, :mode, :track_numbers]
# new_ci.slotflags = UInt8[0x00 for i = 1:4]
#
#
# # return the codegen world age
# res1 = call_with_reactant(traced_type_inner, T, cache, mode, track_numbers)
#
#
# res0 = Base.invoke_in_world(world, traced_type_inner, T, cache, mode, track_numbers)
# res = Base.invokelatest(traced_type_inner, T, cache, mode, track_numbers)
# push!(new_ci.code, Core.Compiler.ReturnNode(res))
Expand All @@ -854,20 +913,93 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}()
# push!(new_ci.codelocs, 1) # see note below
# end
# new_ci.ssavaluetypes += 1
#
#
# # NOTE: we keep the first entry of the original linetable, and use it for location info
# # on the call to check_cache. we can't not have a codeloc (using 0 causes
# # corruption of the back trace), and reusing the target function's info
# # has as advantage that we see the name of the kernel in the backtraces.
#
#
# return new_ci
# end
#
#
# @eval Base.@assume_effects :removable :foldable :nothrow @inline function traced_type_old(T::Type, mode::Val, track_numbers::Type)
# $(Expr(:meta, :generated_only))
# $(Expr(:meta, :generated, traced_type_generator))
# end

"""
This function tries to apply the param types to the wrapper type.
When there's a constraint conflict, it tries to resolve it by promoting the conflicting types. The new param type is then propagated in any param type that depends on it.
"""
function apply_type_with_promotion(wrapper, params, relevant_typevars=typevar_dict(wrapper))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some tests?

unwrapped = Base.unwrap_unionall(wrapper) # remove all the typevars
original_params = copy(params)
params = [params...]

changed = true
iter = 0
while changed && iter < 100
changed = false
for (i, param) in enumerate(params)
# Add back the typevars to only one of the parameters:
rewrapped = Base.rewrap_unionall(unwrapped.parameters[i], wrapper)

sz = @ccall jl_subtype_env_size(rewrapped::Any)::Cint
arr = Array{Any}(undef, sz)

# Verify that the currently selected parameter subtypes the param in the wrapper type.
# In the process, `arr` is filled with with the required types for each parameter used by the current parameter:
is_subtype =
(@ccall jl_subtype_env(
params[i]::Any, rewrapped::Any, arr::Ptr{Any}, sz::Cint
)::Cint) == 1
!is_subtype && error(
"Failed to find a valid type for typevar $i ($(params[i]) <: $(rewrapped) == false)",
)

# Check whether the required types are supertypes of all the parameter types we currently have:
current_unionall = rewrapped
for value in arr
# Peel open the unionall to figure out which typevar each `value` corresponds to:
typevar = current_unionall.var
current_unionall = current_unionall.body

# `param` might have other typevars that don't occur in `wrapper`,
# here we first check if the typevar is actually relevant:
if haskey(relevant_typevars, typevar)
param_i = relevant_typevars[typevar]
(!(value isa Type) || value <: params[param_i]) && continue

# Found a conflict! Figure out a new param type by promoting:
promoted = promote_type(value, params[param_i])
params[param_i] = promoted

if value != promoted
# This happens when `value` lost the promotion battle.
# At this point, we need to update the problematic parameter in`value`.
d = typevar_dict(rewrapped)
v = [param.parameters...]
v[d[typevar]] = promoted
params[i], _changed_params = apply_type_with_promotion(rewrapped, v)
end
changed = true
end
end
end
iter += 1
end
changed_params = original_params .!= params
return Core.apply_type(wrapper, params...), changed_params
end

function typevar_dict(t)
d = Dict()
for (i, name) in enumerate(Base.unwrap_unionall(t).parameters)
d[name] = i
end
return d
end

Base.@assume_effects :total @inline function traced_type(
T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime
) where {mode}
Expand Down Expand Up @@ -1123,6 +1255,39 @@ Base.@nospecializeinfer function make_tracer_unknown(
xi2 = Core.Typeof(xi2)((newpath,), xi2.mlir_data)
seen[xi2] = xi2
changed = true
elseif !ismutabletype(FT) && !ismutabletype(Core.Typeof(xi2)) && fieldcount(FT) == fieldcount(Core.Typeof(xi2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
elseif !ismutabletype(FT) && !ismutabletype(Core.Typeof(xi2)) && fieldcount(FT) == fieldcount(Core.Typeof(xi2))
elseif !ismutabletype(FT) &&
!ismutabletype(Core.Typeof(xi2)) &&
fieldcount(FT) == fieldcount(Core.Typeof(xi2))

# Attempt to reconcile struct mismatch (e.g. Foo{Float64} -> Foo{TracedRNumber})
# arising from parent type constraints overriding local inference.
local flds_sub = Vector{Any}(undef, fieldcount(FT))
local success = true
for j in 1:fieldcount(FT)
val_j = getfield(xi2, j)
ft_j = fieldtype(FT, j)
if val_j isa ft_j
flds_sub[j] = val_j
elseif is_traced_number(ft_j) && val_j isa unwrapped_eltype(ft_j)
val_wrapped = ft_j(val_j)
# Correct the path for the wrapped scalar
sub_path = append_path(newpath, j)
val_wrapped = Core.Typeof(val_wrapped)((sub_path,), val_wrapped.mlir_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
val_wrapped = Core.Typeof(val_wrapped)((sub_path,), val_wrapped.mlir_data)
val_wrapped = Core.Typeof(val_wrapped)(
(sub_path,), val_wrapped.mlir_data
)

seen[val_wrapped] = val_wrapped
flds_sub[j] = val_wrapped
else
success = false
break
end
end

if success
xi2 = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), FT, flds_sub, fieldcount(FT))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
xi2 = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), FT, flds_sub, fieldcount(FT))
xi2 = ccall(
:jl_new_structv,
Any,
(Any, Ptr{Any}, UInt32),
FT,
flds_sub,
fieldcount(FT),
)

changed = true
else
throw(
AssertionError(
"Could not recursively make tracer of object of type $RT into $TT at field $i (named $(fieldname(TT, i))), need object of type $(fieldtype(TT, i)) found object of type $(Core.Typeof(xi2)) ",
),
)
end
else
throw(
AssertionError(
Expand Down
16 changes: 16 additions & 0 deletions test/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,22 @@ end
Reactant.XLA.runtime(),
)
end
@testset "apply_type_with_promotion" begin
struct Bar{T}
b::T
end
struct Foo{T,B<:Bar{T},AT<:AbstractArray{T}}
a::AT
b::B
end
@test Reactant.apply_type_with_promotion(
Foo, [Float64, Bar{Float64}, Reactant.TracedRArray{Float64,1}]
) == (Foo{
TracedRNumber{Float64},
Bar{TracedRNumber{Float64}},
Reactant.TracedRArray{Float64,1},
}, [true, true, false])

This comment was marked as outdated.

end
end

@testset "specialized dispatches" begin
Expand Down
Loading