Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub(crate) fn compile_codegen_unit(
.unstable_opts
.offload
.iter()
.any(|o| matches!(o, Offload::Host(_) | Offload::Test));
.any(|o| matches!(o, Offload::Host(_) | Offload::Test | Offload::Args));
if has_host_offload && !cx.sess().target.is_like_gpu {
cx.offload_globals.replace(Some(OffloadGlobals::declare(&cx)));
}
Expand Down
148 changes: 94 additions & 54 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_middle::bug;
use rustc_middle::ty::offload_meta::OffloadMetadata;

use crate::builder::Builder;
use crate::common::CodegenCx;
use crate::common::{AsCCharPtr, CodegenCx};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, Linkage, Type, Value};
use crate::{SimpleCx, attributes};
Expand Down Expand Up @@ -346,8 +346,8 @@ impl KernelArgsTy {
pub(crate) struct OffloadKernelGlobals<'ll> {
pub offload_sizes: &'ll llvm::Value,
pub memtransfer_types: &'ll llvm::Value,
pub region_id: &'ll llvm::Value,
pub offload_entry: &'ll llvm::Value,
pub region_id: Option<&'ll llvm::Value>,
pub offload_entry: Option<&'ll llvm::Value>,
}

fn gen_tgt_data_mappers<'ll>(
Expand Down Expand Up @@ -417,6 +417,7 @@ pub(crate) fn gen_define_handling<'ll>(
metadata: &[OffloadMetadata],
symbol: String,
offload_globals: &OffloadGlobals<'ll>,
host: bool,
) -> OffloadKernelGlobals<'ll> {
if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
return *entry;
Expand All @@ -440,33 +441,38 @@ pub(crate) fn gen_define_handling<'ll>(
// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the llvm_offload_entries value

let name = format!(".{symbol}.region_id");
let initializer = cx.get_const_i8(0);
let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);

let c_entry_name = CString::new(symbol.clone()).unwrap();
let c_val = c_entry_name.as_bytes_with_nul();
let offload_entry_name = format!(".offloading.entry_name.{symbol}");

let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
llvm::set_alignment(llglobal, Align::ONE);
llvm::set_section(llglobal, c".llvm.rodata.offloading");

let name = format!(".offloading.entry.{symbol}");

// See the __tgt_offload_entry documentation above.
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);

let initializer = crate::common::named_struct(offload_entry_ty, &elems);
let c_name = CString::new(name).unwrap();
let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
llvm::set_global_constant(offload_entry, true);
llvm::set_linkage(offload_entry, WeakAnyLinkage);
llvm::set_initializer(offload_entry, initializer);
llvm::set_alignment(offload_entry, Align::EIGHT);
let c_section_name = CString::new("llvm_offload_entries").unwrap();
llvm::set_section(offload_entry, &c_section_name);
let (offload_entry, region_id) = if !host {
let name = format!(".{symbol}.region_id");
let initializer = cx.get_const_i8(0);
let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);

let c_entry_name = CString::new(symbol.clone()).unwrap();
let c_val = c_entry_name.as_bytes_with_nul();
let offload_entry_name = format!(".offloading.entry_name.{symbol}");

let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
llvm::set_alignment(llglobal, Align::ONE);
llvm::set_section(llglobal, c".llvm.rodata.offloading");

let name = format!(".offloading.entry.{symbol}");

// See the __tgt_offload_entry documentation above.
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);

let initializer = crate::common::named_struct(offload_entry_ty, &elems);
let c_name = CString::new(name).unwrap();
let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
llvm::set_global_constant(offload_entry, true);
llvm::set_linkage(offload_entry, WeakAnyLinkage);
llvm::set_initializer(offload_entry, initializer);
llvm::set_alignment(offload_entry, Align::EIGHT);
let c_section_name = CString::new("llvm_offload_entries").unwrap();
llvm::set_section(offload_entry, &c_section_name);
(Some(offload_entry), Some(region_id))
} else {
(None, None)
};

let result =
OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
Expand Down Expand Up @@ -529,13 +535,16 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
types: &[&Type],
metadata: &[OffloadMetadata],
offload_globals: &OffloadGlobals<'ll>,
offload_dims: &OffloadKernelDims<'ll>,
offload_dims: Option<&OffloadKernelDims<'ll>>,
host: bool,
host_llfn: &'ll Value,
host_llty: &'ll Type,
) {
let cx = builder.cx;
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
offload_data;
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;
//let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
// offload_dims;

let tgt_decl = offload_globals.launcher_fn;
let tgt_target_kernel_ty = offload_globals.launcher_ty;
Expand All @@ -550,7 +559,12 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(

// FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
// variables from being optimized away
for val in [offload_sizes, offload_entry] {
let to_keep: &[&llvm::Value] = if let Some(offload_entry) = offload_entry {
&[offload_sizes, offload_entry]
} else {
&[offload_sizes]
};
for val in to_keep {
unsafe {
let dummy = llvm::LLVMBuildLoad2(
&builder.llbuilder,
Expand Down Expand Up @@ -686,27 +700,53 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
num_args,
s_ident_t,
);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);

// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
for (i, value) in values.iter().enumerate() {
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
builder.store(value.1, ptr, value.0);
}

let args = vec![
s_ident_t,
// FIXME(offload) give users a way to select which GPU to use.
cx.get_const_i64(u64::MAX), // MAX == -1.
num_workgroups,
threads_per_block,
region_id,
a5,
];
builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
if host {
let fn_name = "omp_get_mapped_ptr";
let ty2: &'ll Type = cx.type_func(&[cx.type_ptr(), cx.type_i32()], cx.type_ptr());
let mapper_fn = unsafe {
llvm::LLVMRustGetOrInsertFunction(
builder.llmod,
fn_name.as_c_char_ptr(),
fn_name.len(),
ty2,
)
};

let mut device_vals = Vec::with_capacity(vals.len());
let device_num = cx.get_const_i32(0);
for arg in vals {
let device_arg =
builder.call(ty2, None, None, mapper_fn, &[arg, device_num], None, None);
device_vals.push(device_arg);
}
builder.call(host_llty, None, None, host_llfn, &device_vals, None, None);
} else {
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims.unwrap();
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);

// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
for (i, value) in values.iter().enumerate() {
let ptr =
builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
builder.store(value.1, ptr, value.0);
}
// In the host case, we know by construction that this variable is set.
let args = vec![
s_ident_t,
// FIXME(offload) give users a way to select which GPU to use.
cx.get_const_i64(u64::MAX), // MAX == -1.
num_workgroups,
threads_per_block,
region_id.unwrap(),
a5,
];
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
}

// Step 4)
let geps = get_geps(builder, ty, ty2, a1, a2, a4);
Expand Down
49 changes: 44 additions & 5 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,19 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutFatLTO);
}

codegen_offload(self, tcx, instance, args);
codegen_offload(self, tcx, instance, args, false);
return Ok(());
}
sym::offload_args => {
if tcx.sess.opts.unstable_opts.offload.is_empty() {
let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutEnable);
}

if tcx.sess.lto() != rustc_session::config::Lto::Fat {
let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutFatLTO);
}

codegen_offload(self, tcx, instance, args, true);
return Ok(());
}
sym::is_val_statically_known => {
Expand Down Expand Up @@ -1362,6 +1374,7 @@ fn codegen_offload<'ll, 'tcx>(
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
host: bool,
) {
let cx = bx.cx;
let fn_args = instance.args;
Expand All @@ -1384,8 +1397,18 @@ fn codegen_offload<'ll, 'tcx>(
}
};

let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
let args = get_args_from_tuple(bx, args[3], fn_target);
let llfn = cx.get_fn(fn_target);
let (offload_dims, args) = if host {
// If we only map arguments to the gpu and otherwise work on host code, there is no need to
// handle block or thread dimensions.
let args = get_args_from_tuple(bx, args[1], fn_target);
(None, args)
} else {
let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
let args = get_args_from_tuple(bx, args[3], fn_target);
(Some(offload_dims), args)
};

let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);

let sig = tcx.fn_sig(fn_target.def_id()).skip_binder();
Expand All @@ -1405,8 +1428,24 @@ fn codegen_offload<'ll, 'tcx>(
}
};
register_offload(cx);
let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
let instance = rustc_middle::ty::Instance::mono(tcx, fn_target.def_id());
let fn_abi = cx.fn_abi_of_instance(instance, tcx.mk_type_list(&[]));
let host_fn_ty = fn_abi.llvm_type(cx);

let offload_data =
gen_define_handling(&cx, &metadata, target_symbol.clone(), offload_globals, host);
gen_call_handling(
bx,
&offload_data,
&args,
&types,
&metadata,
offload_globals,
offload_dims.as_ref(),
host,
llfn,
host_fn_ty,
);
}

fn get_args_from_tuple<'ll, 'tcx>(
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::mul_with_overflow
| sym::needs_drop
| sym::offload
| sym::offload_args
| sym::offset_of
| sym::overflow_checks
| sym::powf16
Expand Down Expand Up @@ -339,6 +340,7 @@ pub(crate) fn check_intrinsic_type(
],
param(2),
),
sym::offload_args => (3, 0, vec![param(0), param(1)], param(2)),
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::arith_offset => (
1,
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_monomorphize/src/collector/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ pub(crate) fn collect_autodiff_fn<'tcx>(
intrinsic: IntrinsicDef,
output: &mut MonoItems<'tcx>,
) {
if intrinsic.name != rustc_span::sym::autodiff {
if intrinsic.name != rustc_span::sym::autodiff
&& intrinsic.name != rustc_span::sym::offload
&& intrinsic.name != rustc_span::sym::offload_args
{
return;
};

Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ pub enum Offload {
Device,
/// Second step in the offload pipeline, generates the host code to call kernels.
Host(String),
/// We only map arguments, but still call host (=CPU) code.
Args,
/// Test is similar to Host, but allows testing without a device artifact.
Test,
}
Expand Down
19 changes: 14 additions & 5 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ mod desc {
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`";
pub(crate) const parse_offload: &str =
"a comma separated list of settings: `Host=<Absolute-Path>`, `Device`, `Test`";
"a comma separated list of settings: `Host=<Absolute-Path>`, `Device`, `Test`, `Args`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
Expand Down Expand Up @@ -1480,6 +1480,13 @@ pub mod parse {
}
Offload::Test
}
"Args" => {
if let Some(_) = arg {
// Args does not accept a value
return false;
}
Offload::Args
}
_ => {
// FIXME(ZuseZ4): print an error saying which value is not recognized
return false;
Expand Down Expand Up @@ -2526,10 +2533,12 @@ options! {
normalize_docs: bool = (false, parse_bool, [TRACKED],
"normalize associated items in rustdoc when generating documentation"),
offload: Vec<crate::config::Offload> = (Vec::new(), parse_offload, [TRACKED],
"a list of offload flags to enable
Mandatory setting:
`=Enable`
Currently the only option available"),
"a list of offload flags to enable:
`=Device`
`=Host(path)`
`=Test`
`=Args`
Multiple options can be combined with commas."),
on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED],
"behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"),
osx_rpath_install_name: bool = (false, parse_bool, [TRACKED],
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,7 @@ symbols! {
of,
off,
offload,
offload_args,
offset,
offset_of,
offset_of_enum,
Expand Down
Loading
Loading