diff --git a/databuild/graph/BUILD.bazel b/databuild/graph/BUILD.bazel index 6c4ce9b..bb500b0 100644 --- a/databuild/graph/BUILD.bazel +++ b/databuild/graph/BUILD.bazel @@ -13,6 +13,21 @@ go_binary( visibility = ["//visibility:public"], ) +rust_binary( + name = "execute_rs", + srcs = ["execute.rs"], + edition = "2021", + deps = [ + "//databuild:structs", + "@crates//:serde", + "@crates//:serde_json", + "@crates//:log", + "@crates//:simple_logger", + "@crates//:crossbeam-channel", + ], + visibility = ["//visibility:public"], +) + rust_binary( name = "analyze", srcs = ["analyze.rs"], diff --git a/databuild/graph/execute.rs b/databuild/graph/execute.rs new file mode 100644 index 0000000..269bc3e --- /dev/null +++ b/databuild/graph/execute.rs @@ -0,0 +1,430 @@ +use structs::{DataDepType, JobConfig, JobGraph, Task}; +use crossbeam_channel::{Receiver, Sender}; +use log::{debug, error, info, warn}; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, Instant}; + +const NUM_WORKERS: usize = 4; +const LOG_INTERVAL: Duration = Duration::from_secs(5); +const FAIL_FAST: bool = true; // Same default as the Go version + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TaskState { + Pending, + Running, + Succeeded, + Failed, +} + +#[derive(Debug, Clone)] +struct TaskExecutionResult { + task_key: String, + job_label: String, // For logging + success: bool, + stdout: String, + stderr: String, + duration: Duration, + error_message: Option, +} + +// Generates a unique key for a task based on its JobLabel, input and output references. +// Mirrors the Go implementation's getTaskKey. +fn get_task_key(task: &Task) -> String { + let mut key_parts = Vec::new(); + key_parts.push(task.job_label.clone()); + + for input_dep in &task.config.inputs { + key_parts.push(format!("input:{}", input_dep.reference)); + } + for output_ref in &task.config.outputs { + key_parts.push(format!("output:{}", output_ref)); + } + key_parts.join("|") +} + +// Resolves the executable path from runfiles. +// Mirrors the Go implementation's resolveExecutableFromRunfiles. +fn resolve_executable_from_runfiles(job_label: &str) -> PathBuf { + let mut target_name = job_label.to_string(); + if let Some(colon_index) = job_label.rfind(':') { + target_name = job_label[colon_index + 1..].to_string(); + } else if let Some(name) = Path::new(job_label).file_name().and_then(|n| n.to_str()) { + target_name = name.to_string(); + } + + let exec_name = format!("{}.exec", target_name); + + if let Ok(runfiles_dir_str) = std::env::var("RUNFILES_DIR") { + let path = PathBuf::from(runfiles_dir_str).join("_main").join(&exec_name); + debug!("Resolved executable path (RUNFILES_DIR): {}", path.display()); + return path; + } + + if let Ok(current_exe) = std::env::current_exe() { + let mut runfiles_dir_path = PathBuf::from(format!("{}.runfiles", current_exe.display())); + if !runfiles_dir_path.is_dir() { // Bazel often puts it next to the binary + if let Some(parent) = current_exe.parent() { + runfiles_dir_path = parent.join(format!("{}.runfiles", current_exe.file_name().unwrap_or_default().to_string_lossy())); + } + } + + if runfiles_dir_path.is_dir() { + let path = runfiles_dir_path.join("_main").join(&exec_name); + debug!("Resolved executable path (derived RUNFILES_DIR): {}", path.display()); + return path; + } else { + warn!("Warning: RUNFILES_DIR not found or invalid, and derived path {} is not a directory.", runfiles_dir_path.display()); + } + } else { + warn!("Warning: Could not determine current executable path."); + } + + let fallback_path = PathBuf::from(format!("{}.exec", job_label)); + warn!("Falling back to direct executable path: {}", fallback_path.display()); + fallback_path +} + +fn worker( + task_rx: Receiver>, + result_tx: Sender, + worker_id: usize, +) { + info!("[Worker {}] Starting", worker_id); + while let Ok(task) = task_rx.recv() { + let task_key = get_task_key(&task); + info!("[Worker {}] Starting job: {} (Key: {})", worker_id, task.job_label, task_key); + let start_time = Instant::now(); + + let exec_path = resolve_executable_from_runfiles(&task.job_label); + + let config_json = match serde_json::to_string(&task.config) { + Ok(json) => json, + Err(e) => { + let err_msg = format!("Failed to serialize task config for {}: {}", task.job_label, e); + error!("[Worker {}] {}", worker_id, err_msg); + result_tx + .send(TaskExecutionResult { + task_key, + job_label: task.job_label.clone(), + success: false, + stdout: String::new(), + stderr: err_msg.clone(), + duration: start_time.elapsed(), + error_message: Some(err_msg), + }) + .unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e)); + continue; + } + }; + + let mut cmd = Command::new(&exec_path); + cmd.stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + // Set environment variables from the current process's environment + // This mirrors the Go `cmd.Env = os.Environ()` behavior. + // Task-specific env vars from task.config.env are passed via JSON through stdin. + cmd.env_clear(); // Start with no environment variables + for (key, value) in std::env::vars() { + cmd.env(key, value); // Add current process's environment variables + } + + match cmd.spawn() { + Ok(mut child) => { + if let Some(mut child_stdin) = child.stdin.take() { + if let Err(e) = child_stdin.write_all(config_json.as_bytes()) { + let err_msg = format!("[Worker {}] Failed to write to stdin for {}: {}", worker_id, task.job_label, e); + error!("{}", err_msg); + // Ensure child is killed if stdin write fails before wait + let _ = child.kill(); + let _ = child.wait(); // Reap the child + result_tx.send(TaskExecutionResult { + task_key, + job_label: task.job_label.clone(), + success: false, + stdout: String::new(), + stderr: err_msg.clone(), + duration: start_time.elapsed(), + error_message: Some(err_msg), + }) + .unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e)); + continue; + } + drop(child_stdin); // Close stdin to signal EOF to the child + } else { + let err_msg = format!("[Worker {}] Failed to get stdin for {}", worker_id, task.job_label); + error!("{}", err_msg); + result_tx.send(TaskExecutionResult { + task_key, + job_label: task.job_label.clone(), + success: false, + stdout: String::new(), + stderr: err_msg.clone(), + duration: start_time.elapsed(), + error_message: Some(err_msg), + }) + .unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e)); + continue; + } + + match child.wait_with_output() { + Ok(output) => { + let duration = start_time.elapsed(); + let success = output.status.success(); + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + + if success { + info!( + "[Worker {}] Job succeeded: {} (Duration: {:?})", + worker_id, task.job_label, duration + ); + } else { + error!( + "[Worker {}] Job failed: {} (Duration: {:?}, Status: {:?})\nStdout: {}\nStderr: {}", + worker_id, task.job_label, duration, output.status, stdout, stderr + ); + } + result_tx + .send(TaskExecutionResult { + task_key, + job_label: task.job_label.clone(), + success, + stdout, + stderr, + duration, + error_message: if success { None } else { Some(format!("Exited with status: {:?}", output.status)) }, + }) + .unwrap_or_else(|e| error!("[Worker {}] Failed to send result: {}", worker_id, e)); + } + Err(e) => { + let err_msg = format!("[Worker {}] Failed to execute or wait for {}: {}", worker_id, task.job_label, e); + error!("{}", err_msg); + result_tx + .send(TaskExecutionResult { + task_key, + job_label: task.job_label.clone(), + success: false, + stdout: String::new(), + stderr: err_msg.clone(), + duration: start_time.elapsed(), + error_message: Some(err_msg), + }) + .unwrap_or_else(|e| error!("[Worker {}] Failed to send execution error result: {}", worker_id, e)); + } + } + } + Err(e) => { + let err_msg = format!("[Worker {}] Failed to spawn command for {}: {} (Path: {:?})", worker_id, task.job_label, e, exec_path); + error!("{}", err_msg); + result_tx + .send(TaskExecutionResult { + task_key, + job_label: task.job_label.clone(), + success: false, + stdout: String::new(), + stderr: err_msg.clone(), + duration: start_time.elapsed(), + error_message: Some(err_msg), + }) + .unwrap_or_else(|e| error!("[Worker {}] Failed to send spawn error result: {}", worker_id, e)); + } + } + } + info!("[Worker {}] Exiting", worker_id); +} + +fn is_task_ready(task: &Task, completed_outputs: &HashSet) -> bool { + for dep in &task.config.inputs { + if dep.dep_type == DataDepType::Materialize { + if !completed_outputs.contains(&dep.reference) { + return false; + } + } + } + true +} + +fn log_status_summary( + task_states: &HashMap, + original_tasks_by_key: &HashMap>, +) { + let mut pending_tasks = Vec::new(); + let mut running_tasks = Vec::new(); + let mut succeeded_tasks = Vec::new(); + let mut failed_tasks = Vec::new(); + + for (key, state) in task_states { + let label = original_tasks_by_key.get(key).map_or_else(|| key.as_str(), |t| t.job_label.as_str()); + match state { + TaskState::Pending => pending_tasks.push(label), + TaskState::Running => running_tasks.push(label), + TaskState::Succeeded => succeeded_tasks.push(label), + TaskState::Failed => failed_tasks.push(label), + } + } + + info!("Task Status Summary:"); + info!(" Pending ({}): {:?}", pending_tasks.len(), pending_tasks); + info!(" Running ({}): {:?}", running_tasks.len(), running_tasks); + info!(" Succeeded ({}): {:?}", succeeded_tasks.len(), succeeded_tasks); + info!(" Failed ({}): {:?}", failed_tasks.len(), failed_tasks); +} + + +fn main() -> Result<(), Box> { + simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init()?; + + let mut buffer = String::new(); + std::io::stdin().read_to_string(&mut buffer)?; + let graph: JobGraph = serde_json::from_str(&buffer)?; + + info!("Executing job graph with {} nodes", graph.nodes.len()); + + let mut task_states: HashMap = HashMap::new(); + let mut original_tasks_by_key: HashMap> = HashMap::new(); + let graph_nodes_arc: Vec> = graph.nodes.into_iter().map(Arc::new).collect(); + + + for task_node in &graph_nodes_arc { + let key = get_task_key(task_node); + task_states.insert(key.clone(), TaskState::Pending); + original_tasks_by_key.insert(key, task_node.clone()); + } + + let mut completed_outputs: HashSet = HashSet::new(); + let mut job_results: Vec = Vec::new(); + + let (task_tx, task_rx): (Sender>, Receiver>) = crossbeam_channel::unbounded(); + let (result_tx, result_rx): (Sender, Receiver) = crossbeam_channel::unbounded(); + + let mut worker_handles = Vec::new(); + for i in 0..NUM_WORKERS { + let task_rx_clone = task_rx.clone(); + let result_tx_clone = result_tx.clone(); + worker_handles.push(thread::spawn(move || { + worker(task_rx_clone, result_tx_clone, i + 1); + })); + } + // Drop the original result_tx so the channel closes when all workers are done + // if result_rx is the only remaining receiver. + drop(result_tx); + + + let mut last_log_time = Instant::now(); + let mut active_tasks_count = 0; + let mut fail_fast_triggered = false; + + loop { + // 1. Process results + while let Ok(result) = result_rx.try_recv() { + active_tasks_count -= 1; + info!( + "Received result for task {}: Success: {}", + result.job_label, result.success + ); + + let current_state = if result.success { + TaskState::Succeeded + } else { + TaskState::Failed + }; + task_states.insert(result.task_key.clone(), current_state); + + if result.success { + if let Some(original_task) = original_tasks_by_key.get(&result.task_key) { + for output_ref in &original_task.config.outputs { + completed_outputs.insert(output_ref.clone()); + } + } + } else { + if FAIL_FAST { + warn!("Fail-fast enabled and task {} failed. Shutting down.", result.job_label); + fail_fast_triggered = true; + } + } + job_results.push(result); + } + + // 2. Check for fail-fast break + if fail_fast_triggered && active_tasks_count == 0 { // Wait for running tasks to finish if fail fast + info!("All active tasks completed after fail-fast trigger."); + break; + } + if fail_fast_triggered && active_tasks_count > 0 { + // Don't schedule new tasks, just wait for active ones or log + } else if !fail_fast_triggered { // Only dispatch if not in fail-fast shutdown + // 3. Dispatch ready tasks + for task_node in &graph_nodes_arc { + let task_key = get_task_key(task_node); + if task_states.get(&task_key) == Some(&TaskState::Pending) { + if is_task_ready(task_node, &completed_outputs) { + info!("Dispatching task: {}", task_node.job_label); + task_states.insert(task_key.clone(), TaskState::Running); + task_tx.send(task_node.clone())?; + active_tasks_count += 1; + } + } + } + } + + + // 4. Periodic logging + if last_log_time.elapsed() >= LOG_INTERVAL { + log_status_summary(&task_states, &original_tasks_by_key); + last_log_time = Instant::now(); + } + + // 5. Check completion + let all_done = task_states.values().all(|s| *s == TaskState::Succeeded || *s == TaskState::Failed); + if active_tasks_count == 0 && all_done { + info!("All tasks are in a terminal state and no tasks are active."); + break; + } + + // Avoid busy-waiting if no events, give channels time + // Select would be better here, but for simplicity: + thread::sleep(Duration::from_millis(50)); + } + + info!("Shutting down workers..."); + drop(task_tx); // Signal workers to stop by closing the task channel + + for handle in worker_handles { + handle.join().expect("Failed to join worker thread"); + } + info!("All workers finished."); + + // Final processing of any remaining results (should be minimal if loop logic is correct) + while let Ok(result) = result_rx.try_recv() { + active_tasks_count -= 1; // Should be 0 + info!( + "Received late result for task {}: Success: {}", + result.job_label, result.success + ); + // Update state for completeness, though it might not affect overall outcome now + let current_state = if result.success { TaskState::Succeeded } else { TaskState::Failed }; + task_states.insert(result.task_key.clone(), current_state); + job_results.push(result); + } + + + let success_count = job_results.iter().filter(|r| r.success).count(); + let failure_count = job_results.len() - success_count; + + info!("Execution complete: {} succeeded, {} failed", success_count, failure_count); + + if failure_count > 0 || fail_fast_triggered { + error!("Execution finished with errors."); + std::process::exit(1); + } + + Ok(()) +}