diff --git a/src/main.rs b/src/main.rs index 4f4da4e..42c996e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use std::time::Duration; use std::{path::PathBuf, process::Command, time::Instant}; use wait_timeout::ChildExt; -use anyhow::Result; +use anyhow::{anyhow, Result}; use clap::Parser; use git2::{DiffOptions, Oid, Repository}; use std::fs::{File, OpenOptions}; @@ -44,6 +44,12 @@ struct Args { #[arg(long)] cgroup_path: String, + + // Where is the directory that should contain correctness data. + // + // data should be tuples of files named {num}.in and answer is in the same directory {num}.out + #[arg(long)] + correctness_data: PathBuf, } // Move the pid to the cgroup @@ -116,11 +122,20 @@ struct ExecResult { impl ExecResult { pub fn avg_time(&self) -> Duration { - if self.times.is_empty() { + let mut times = self.times.clone(); + times.sort(); + if times.is_empty() { return Duration::from_millis(0); } - let sum: Duration = self.times.iter().sum(); - sum / (self.times.len() as u32) + if times.len() > 2 { + // Remove the slowest and fastest solutions to further remove outliers + let new_times = ×[1..self.times.len() - 1]; + let sum: Duration = new_times.iter().sum(); + sum / (new_times.len() as u32) + } else { + let sum: Duration = times.iter().sum(); + sum / (self.times.len() as u32) + } } pub fn median(&self) -> Duration { @@ -129,7 +144,13 @@ impl ExecResult { } let mut times = self.times.clone(); times.sort(); - times[times.len() / 2] + if times.len() > 2 { + // Remove the slowest and fastest solutions to further remove outliers + let new_times = ×[1..times.len() - 1]; + new_times[new_times.len() / 2] + } else { + times[times.len() / 2] + } } } @@ -142,19 +163,40 @@ fn compute_verdict(output: &Path, expected_output: &Path) -> Result { output.read_to_string(&mut output_str)?; expected_output.read_to_string(&mut ex_output_str)?; if output_str.trim() != ex_output_str.trim() { - info!("Solution comparison failed, output for the solution is: {output_str}"); + info!("Solution comparison failed, output for the solution is: {output_str}\n expected output is {ex_output_str}"); return Ok(Verdict::Wa); } Ok(Verdict::Ac) } +fn drop_file_cache() -> Result<()> { + info!("Dropping the file cache"); + // Execute the sync command to flush file system buffers + let _ = Command::new("sync").status()?; + + let output = Command::new("sh") + .arg("-c") + .arg("echo 3 > /proc/sys/vm/drop_caches") + .output()?; + + // Check if the command was successful + if output.status.success() { + Ok(()) + } else { + Err(anyhow!( + "Failed to drop linux file cache: {}", + output.status + )) + } +} + fn run_rust(context: &RunContext) -> Result { let tmp_dir = tempfile::tempdir()?; let mut output_path = tmp_dir.path().to_path_buf(); output_path.push("sol"); - let mut res = Command::new("/usr/local/bin/rustc") + let mut res = Command::new("rustc") .args(vec![ "-O", context.abs_solution().to_str().unwrap(), @@ -177,6 +219,21 @@ fn run_rust(context: &RunContext) -> Result { Err(e) => return Err(anyhow::anyhow!("Failed to compile solution file: {e:?}")), } + let verdict = ensure_correct( + context, + vec![output_path.to_str().unwrap().to_string()], + None, + )?; + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); + let mut input_file = tmp_dir.path().to_path_buf(); input_file.push("input.txt"); let mut output_file = tmp_dir.path().to_path_buf(); @@ -256,6 +313,197 @@ fn run_rust(context: &RunContext) -> Result { times, }) } + +fn extract_files(directory_path: &Path, suf: &str) -> Vec { + let mut in_files = Vec::new(); + if let Ok(entries) = std::fs::read_dir(directory_path) { + for entry in entries { + if let Ok(entry) = entry { + if let Some(file_name) = entry.file_name().to_str() { + if entry.path().is_file() && file_name.ends_with(suf) { + in_files.push(entry.path()); + } + } + } + } + } + + in_files.sort(); + in_files +} + +fn copy_to_input_txt(input_path: &Path, file_path: &Path) -> Result<()> { + // Create or open the input.txt file for writing + let mut input_txt_file = OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open(input_path)?; + + // Open the input file for reading + let mut file = std::fs::File::open(file_path)?; + + // Read the content of the input file + let mut content = String::new(); + file.read_to_string(&mut content)?; + + input_txt_file.write_all(content.as_bytes())?; + + Ok(()) +} + +fn get_matching_output(input_file: &Path, output_files: &[PathBuf]) -> Option { + // Extract the filename stem of the input file + let input_filename_stem = input_file + .file_stem() + .and_then(|stem| stem.to_str()) + .unwrap_or(""); + + // Iterate over the output files + for output_file in output_files { + // Extract the filename stem of the output file + let output_filename_stem = output_file + .file_stem() + .and_then(|stem| stem.to_str()) + .unwrap_or(""); + + // Check if the filename stem of the input file matches the output file + if input_filename_stem == output_filename_stem { + return Some(output_file.clone()); + } + } + + None +} + +fn ensure_correct( + context: &RunContext, + cmd: Vec, + src_name: Option<&str>, +) -> Result { + info!("Starting correctness checks"); + let tmp_dir = tempfile::tempdir()?; + let tmp = tmp_dir.path(); + let in_files = extract_files(context.correctness_data, ".in"); + let out_files = extract_files(context.correctness_data, ".out"); + + if let Some(src) = src_name { + let mut sol_file = tmp.to_path_buf(); + sol_file.push(src); + debug!( + "Copying solution source from: {:?} to {:?}", + context.abs_solution(), + sol_file + ); + std::fs::copy(context.abs_solution(), &sol_file)?; + if context.lang()? == "java" { + debug!("Recompiling Java solution for correctness checks"); + // Java could generate more than one file after compilation, let's recompile to ensure + // that '.class' files are getting properly generated. + let mut res = Command::new("javac") + .args(vec![sol_file.to_str().unwrap()]) + .current_dir(tmp) + .spawn()?; + + match res.wait() { + Ok(status) => { + debug!( + "Finished compilation of the target: {sol_file:?} with exit code: {status}" + ); + if let Some(code) = status.code() { + if code != 0 { + return Err(anyhow::anyhow!( + "Failed to compile solution file: none-zero exit code" + )); + } + } + } + Err(e) => return Err(anyhow::anyhow!("Failed to compile solution file: {e:?}")), + } + } + } + + if in_files.len() != out_files.len() { + return Err(anyhow!( + "Unexpected in_files vs out_files length, in files length: {}, out files length {}", + in_files.len(), + out_files.len() + )); + } + + let mut input_path = tmp.to_path_buf(); + input_path.push("input.txt"); + let mut output_path = tmp.to_path_buf(); + output_path.push("output.txt"); + + for file_path in in_files { + copy_to_input_txt(&input_path, &file_path)?; + + let start_time = Instant::now(); + let mut cmd_builder = Command::new(cmd[0].clone()); + cmd_builder.current_dir(tmp); + if cmd.len() > 1 { + cmd_builder.args(&cmd[1..]); + } + let mut child = cmd_builder.spawn()?; + + let pid = child.id(); + debug!("The current started process has pid: {pid}"); + move_to_cgroup(pid, context)?; + debug!( + "Moved process to cgroup {} successfully", + context.cgroup_path + ); + match child.wait_timeout(context.timeout) { + Ok(Some(result)) => { + info!( + "Finished execution of the solution with status: {:?}", + result.code() + ); + if let Some(code) = result.code() { + if code != 0 { + return Err(anyhow::anyhow!( + "Failed to run the solution file: none-zero exit code" + )); + } + } + } + Ok(None) => { + debug!("Child process timed out"); + child.kill().unwrap(); + let code = child.wait()?.code(); + debug!("Killed with exit code: {code:?}"); + return Ok(Verdict::Tle); + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to run the solution file: {e:?}")); + } + } + + let elapsed = Instant::now() - start_time; + debug!("Correctness execution finished in: {elapsed:?}"); + + if let Some(expected_output) = get_matching_output(&file_path, &out_files) { + debug!( + "Found matching output_file: input: {:?}, output: {:?}", + file_path, expected_output + ); + let verdict = compute_verdict(&output_path, &expected_output)?; + if !matches!(verdict, Verdict::Ac) { + info!("Solution didn't pass correctness check, aborting early!"); + return Ok(verdict); + } + } else { + return Err(anyhow!( + "Failed to find corresponding output file for {:?}", + file_path + )); + } + } + + Ok(Verdict::Ac) +} + fn run_cpp(context: &RunContext) -> Result { let tmp_dir = tempfile::tempdir()?; @@ -287,6 +535,23 @@ fn run_cpp(context: &RunContext) -> Result { Err(e) => return Err(anyhow::anyhow!("Failed to compile solution file: {e:?}")), } + let verdict = ensure_correct( + context, + vec![output_path.to_str().unwrap().to_string()], + None, + )?; + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); + + drop_file_cache()?; + let mut input_file = tmp_dir.path().to_path_buf(); input_file.push("input.txt"); let mut output_file = tmp_dir.path().to_path_buf(); @@ -394,6 +659,23 @@ fn run_java(context: &RunContext) -> Result { Err(e) => return Err(anyhow::anyhow!("Failed to compile solution file: {e:?}")), } + let verdict = ensure_correct( + context, + vec!["java".to_string(), "Main".to_string()], + Some("Main.java"), + )?; + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); + + drop_file_cache()?; + let mut input_file = tmp_dir.path().to_path_buf(); input_file.push("input.txt"); @@ -479,6 +761,24 @@ fn run_java(context: &RunContext) -> Result { } fn run_php(context: &RunContext) -> Result { + let verdict = ensure_correct( + context, + vec!["php".to_string(), "sol.php".to_string()], + Some("sol.php"), + )?; + + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); + + drop_file_cache()?; + let tmp_dir = tempfile::tempdir()?; let mut src_path = tmp_dir.path().to_path_buf(); @@ -505,13 +805,20 @@ fn run_php(context: &RunContext) -> Result { let start_time = Instant::now(); let mut child = Command::new("php") .args(vec![ - "-d", "opcache.memory_consumption=128", - "-d", "opcache.interned_strings_buffer=8", - "-d", "opcache.revalidate_freq=60", - "-d", "opcache.enable_cli=1", - "-d", "opcache.jit=function", - "-d", "opcache.jit_buffer_size=128M", - "-f", "sol.php" + "-d", + "opcache.memory_consumption=128", + "-d", + "opcache.interned_strings_buffer=8", + "-d", + "opcache.revalidate_freq=60", + "-d", + "opcache.enable_cli=1", + "-d", + "opcache.jit=function", + "-d", + "opcache.jit_buffer_size=128M", + "-f", + "sol.php", ]) .current_dir(&tmp_dir) .spawn()?; @@ -584,6 +891,23 @@ fn run_python(context: &RunContext) -> Result { // Copy the source file to the compilation directory std::fs::copy(context.abs_solution(), &src_path)?; + let verdict = ensure_correct( + context, + vec!["python3".to_string(), "main.py".to_string()], + Some("main.py"), + )?; + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); + + drop_file_cache()?; + let mut input_file = tmp_dir.path().to_path_buf(); input_file.push("input.txt"); @@ -663,6 +987,21 @@ fn run_python(context: &RunContext) -> Result { } fn run_node(context: &RunContext) -> Result { + let verdict = ensure_correct( + context, + vec!["node".to_string(), "main.js".to_string()], + Some("main.js"), + )?; + + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); let tmp_dir = tempfile::tempdir()?; let mut src_path = tmp_dir.path().to_path_buf(); @@ -754,7 +1093,7 @@ fn run_golang(context: &RunContext) -> Result { let mut output_path = tmp_dir.path().to_path_buf(); output_path.push("sol"); - let mut res = Command::new("/usr/local/bin/go") + let mut res = Command::new("go") .args(vec![ "build", "-o", @@ -777,6 +1116,23 @@ fn run_golang(context: &RunContext) -> Result { Err(e) => return Err(anyhow::anyhow!("Failed to compile solution file: {e:?}")), } + let verdict = ensure_correct( + context, + vec![output_path.to_str().unwrap().to_string()], + None, + )?; + if !matches!(verdict, Verdict::Ac) { + info!("Solution failed correctness checks, exiting early"); + return Ok(ExecResult { + verdict, + times: vec![], + }); + } + + info!("Solution passed correctness checks"); + + drop_file_cache()?; + let mut input_file = tmp_dir.path().to_path_buf(); input_file.push("input.txt"); let mut output_file = tmp_dir.path().to_path_buf(); @@ -868,14 +1224,7 @@ fn extract_language(context: &RunContext) -> String { } fn run_file(context: &RunContext) -> Result { - let extension = context - .solution_file - .extension() - .ok_or(anyhow::anyhow!("failed to get file extension"))?; - - let mut source_file = context.root.to_path_buf(); - source_file.push(context.solution_file); - match extension.to_str().unwrap() { + match context.lang()? { "cpp" => run_cpp(context), "cc" => run_cpp(context), "c" => run_cpp(context), @@ -897,6 +1246,7 @@ struct RunContext<'a> { input_file: &'a Path, expected_output: &'a Path, solution_file: &'a Path, + correctness_data: &'a Path, timeout: Duration, cgroup_path: String, times: u32, @@ -910,6 +1260,14 @@ impl<'a> RunContext<'a> { result.push(self.solution_file); result } + + fn lang(&self) -> Result<&str> { + let extension = self + .solution_file + .extension() + .ok_or(anyhow::anyhow!("failed to get file extension"))?; + Ok(extension.to_str().unwrap()) + } } #[serde_as] @@ -1030,6 +1388,7 @@ fn main() { timeout: Duration::from_secs(args.timeout_sec), times: args.times_to_run, cgroup_path: args.cgroup_path.clone(), + correctness_data: &args.correctness_data, }; match run_file(&run_context) {