Add rust execute impl

This commit is contained in:
Stuart Axelbrooke 2025-05-07 20:01:10 -07:00
parent f2567f7567
commit 91d5fd26bc
No known key found for this signature in database
GPG key ID: 1B0A848C29D46A35
2 changed files with 445 additions and 0 deletions

View file

@ -13,6 +13,21 @@ go_binary(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
rust_binary(
name = "execute_rs",
srcs = ["execute.rs"],
edition = "2021",
deps = [
"//databuild:structs",
"@crates//:serde",
"@crates//:serde_json",
"@crates//:log",
"@crates//:simple_logger",
"@crates//:crossbeam-channel",
],
visibility = ["//visibility:public"],
)
rust_binary( rust_binary(
name = "analyze", name = "analyze",
srcs = ["analyze.rs"], srcs = ["analyze.rs"],

430
databuild/graph/execute.rs Normal file
View file

@ -0,0 +1,430 @@
use structs::{DataDepType, JobConfig, JobGraph, Task};
use crossbeam_channel::{Receiver, Sender};
use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
const NUM_WORKERS: usize = 4;
const LOG_INTERVAL: Duration = Duration::from_secs(5);
const FAIL_FAST: bool = true; // Same default as the Go version
#[derive(Debug, Clone, PartialEq, Eq)]
enum TaskState {
Pending,
Running,
Succeeded,
Failed,
}
#[derive(Debug, Clone)]
struct TaskExecutionResult {
task_key: String,
job_label: String, // For logging
success: bool,
stdout: String,
stderr: String,
duration: Duration,
error_message: Option<String>,
}
// Generates a unique key for a task based on its JobLabel, input and output references.
// Mirrors the Go implementation's getTaskKey.
fn get_task_key(task: &Task) -> String {
let mut key_parts = Vec::new();
key_parts.push(task.job_label.clone());
for input_dep in &task.config.inputs {
key_parts.push(format!("input:{}", input_dep.reference));
}
for output_ref in &task.config.outputs {
key_parts.push(format!("output:{}", output_ref));
}
key_parts.join("|")
}
// Resolves the executable path from runfiles.
// Mirrors the Go implementation's resolveExecutableFromRunfiles.
fn resolve_executable_from_runfiles(job_label: &str) -> PathBuf {
let mut target_name = job_label.to_string();
if let Some(colon_index) = job_label.rfind(':') {
target_name = job_label[colon_index + 1..].to_string();
} else if let Some(name) = Path::new(job_label).file_name().and_then(|n| n.to_str()) {
target_name = name.to_string();
}
let exec_name = format!("{}.exec", target_name);
if let Ok(runfiles_dir_str) = std::env::var("RUNFILES_DIR") {
let path = PathBuf::from(runfiles_dir_str).join("_main").join(&exec_name);
debug!("Resolved executable path (RUNFILES_DIR): {}", path.display());
return path;
}
if let Ok(current_exe) = std::env::current_exe() {
let mut runfiles_dir_path = PathBuf::from(format!("{}.runfiles", current_exe.display()));
if !runfiles_dir_path.is_dir() { // Bazel often puts it next to the binary
if let Some(parent) = current_exe.parent() {
runfiles_dir_path = parent.join(format!("{}.runfiles", current_exe.file_name().unwrap_or_default().to_string_lossy()));
}
}
if runfiles_dir_path.is_dir() {
let path = runfiles_dir_path.join("_main").join(&exec_name);
debug!("Resolved executable path (derived RUNFILES_DIR): {}", path.display());
return path;
} else {
warn!("Warning: RUNFILES_DIR not found or invalid, and derived path {} is not a directory.", runfiles_dir_path.display());
}
} else {
warn!("Warning: Could not determine current executable path.");
}
let fallback_path = PathBuf::from(format!("{}.exec", job_label));
warn!("Falling back to direct executable path: {}", fallback_path.display());
fallback_path
}
fn worker(
task_rx: Receiver<Arc<Task>>,
result_tx: Sender<TaskExecutionResult>,
worker_id: usize,
) {
info!("[Worker {}] Starting", worker_id);
while let Ok(task) = task_rx.recv() {
let task_key = get_task_key(&task);
info!("[Worker {}] Starting job: {} (Key: {})", worker_id, task.job_label, task_key);
let start_time = Instant::now();
let exec_path = resolve_executable_from_runfiles(&task.job_label);
let config_json = match serde_json::to_string(&task.config) {
Ok(json) => json,
Err(e) => {
let err_msg = format!("Failed to serialize task config for {}: {}", task.job_label, e);
error!("[Worker {}] {}", worker_id, err_msg);
result_tx
.send(TaskExecutionResult {
task_key,
job_label: task.job_label.clone(),
success: false,
stdout: String::new(),
stderr: err_msg.clone(),
duration: start_time.elapsed(),
error_message: Some(err_msg),
})
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
continue;
}
};
let mut cmd = Command::new(&exec_path);
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
// Set environment variables from the current process's environment
// This mirrors the Go `cmd.Env = os.Environ()` behavior.
// Task-specific env vars from task.config.env are passed via JSON through stdin.
cmd.env_clear(); // Start with no environment variables
for (key, value) in std::env::vars() {
cmd.env(key, value); // Add current process's environment variables
}
match cmd.spawn() {
Ok(mut child) => {
if let Some(mut child_stdin) = child.stdin.take() {
if let Err(e) = child_stdin.write_all(config_json.as_bytes()) {
let err_msg = format!("[Worker {}] Failed to write to stdin for {}: {}", worker_id, task.job_label, e);
error!("{}", err_msg);
// Ensure child is killed if stdin write fails before wait
let _ = child.kill();
let _ = child.wait(); // Reap the child
result_tx.send(TaskExecutionResult {
task_key,
job_label: task.job_label.clone(),
success: false,
stdout: String::new(),
stderr: err_msg.clone(),
duration: start_time.elapsed(),
error_message: Some(err_msg),
})
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
continue;
}
drop(child_stdin); // Close stdin to signal EOF to the child
} else {
let err_msg = format!("[Worker {}] Failed to get stdin for {}", worker_id, task.job_label);
error!("{}", err_msg);
result_tx.send(TaskExecutionResult {
task_key,
job_label: task.job_label.clone(),
success: false,
stdout: String::new(),
stderr: err_msg.clone(),
duration: start_time.elapsed(),
error_message: Some(err_msg),
})
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
continue;
}
match child.wait_with_output() {
Ok(output) => {
let duration = start_time.elapsed();
let success = output.status.success();
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
if success {
info!(
"[Worker {}] Job succeeded: {} (Duration: {:?})",
worker_id, task.job_label, duration
);
} else {
error!(
"[Worker {}] Job failed: {} (Duration: {:?}, Status: {:?})\nStdout: {}\nStderr: {}",
worker_id, task.job_label, duration, output.status, stdout, stderr
);
}
result_tx
.send(TaskExecutionResult {
task_key,
job_label: task.job_label.clone(),
success,
stdout,
stderr,
duration,
error_message: if success { None } else { Some(format!("Exited with status: {:?}", output.status)) },
})
.unwrap_or_else(|e| error!("[Worker {}] Failed to send result: {}", worker_id, e));
}
Err(e) => {
let err_msg = format!("[Worker {}] Failed to execute or wait for {}: {}", worker_id, task.job_label, e);
error!("{}", err_msg);
result_tx
.send(TaskExecutionResult {
task_key,
job_label: task.job_label.clone(),
success: false,
stdout: String::new(),
stderr: err_msg.clone(),
duration: start_time.elapsed(),
error_message: Some(err_msg),
})
.unwrap_or_else(|e| error!("[Worker {}] Failed to send execution error result: {}", worker_id, e));
}
}
}
Err(e) => {
let err_msg = format!("[Worker {}] Failed to spawn command for {}: {} (Path: {:?})", worker_id, task.job_label, e, exec_path);
error!("{}", err_msg);
result_tx
.send(TaskExecutionResult {
task_key,
job_label: task.job_label.clone(),
success: false,
stdout: String::new(),
stderr: err_msg.clone(),
duration: start_time.elapsed(),
error_message: Some(err_msg),
})
.unwrap_or_else(|e| error!("[Worker {}] Failed to send spawn error result: {}", worker_id, e));
}
}
}
info!("[Worker {}] Exiting", worker_id);
}
fn is_task_ready(task: &Task, completed_outputs: &HashSet<String>) -> bool {
for dep in &task.config.inputs {
if dep.dep_type == DataDepType::Materialize {
if !completed_outputs.contains(&dep.reference) {
return false;
}
}
}
true
}
fn log_status_summary(
task_states: &HashMap<String, TaskState>,
original_tasks_by_key: &HashMap<String, Arc<Task>>,
) {
let mut pending_tasks = Vec::new();
let mut running_tasks = Vec::new();
let mut succeeded_tasks = Vec::new();
let mut failed_tasks = Vec::new();
for (key, state) in task_states {
let label = original_tasks_by_key.get(key).map_or_else(|| key.as_str(), |t| t.job_label.as_str());
match state {
TaskState::Pending => pending_tasks.push(label),
TaskState::Running => running_tasks.push(label),
TaskState::Succeeded => succeeded_tasks.push(label),
TaskState::Failed => failed_tasks.push(label),
}
}
info!("Task Status Summary:");
info!(" Pending ({}): {:?}", pending_tasks.len(), pending_tasks);
info!(" Running ({}): {:?}", running_tasks.len(), running_tasks);
info!(" Succeeded ({}): {:?}", succeeded_tasks.len(), succeeded_tasks);
info!(" Failed ({}): {:?}", failed_tasks.len(), failed_tasks);
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init()?;
let mut buffer = String::new();
std::io::stdin().read_to_string(&mut buffer)?;
let graph: JobGraph = serde_json::from_str(&buffer)?;
info!("Executing job graph with {} nodes", graph.nodes.len());
let mut task_states: HashMap<String, TaskState> = HashMap::new();
let mut original_tasks_by_key: HashMap<String, Arc<Task>> = HashMap::new();
let graph_nodes_arc: Vec<Arc<Task>> = graph.nodes.into_iter().map(Arc::new).collect();
for task_node in &graph_nodes_arc {
let key = get_task_key(task_node);
task_states.insert(key.clone(), TaskState::Pending);
original_tasks_by_key.insert(key, task_node.clone());
}
let mut completed_outputs: HashSet<String> = HashSet::new();
let mut job_results: Vec<TaskExecutionResult> = Vec::new();
let (task_tx, task_rx): (Sender<Arc<Task>>, Receiver<Arc<Task>>) = crossbeam_channel::unbounded();
let (result_tx, result_rx): (Sender<TaskExecutionResult>, Receiver<TaskExecutionResult>) = crossbeam_channel::unbounded();
let mut worker_handles = Vec::new();
for i in 0..NUM_WORKERS {
let task_rx_clone = task_rx.clone();
let result_tx_clone = result_tx.clone();
worker_handles.push(thread::spawn(move || {
worker(task_rx_clone, result_tx_clone, i + 1);
}));
}
// Drop the original result_tx so the channel closes when all workers are done
// if result_rx is the only remaining receiver.
drop(result_tx);
let mut last_log_time = Instant::now();
let mut active_tasks_count = 0;
let mut fail_fast_triggered = false;
loop {
// 1. Process results
while let Ok(result) = result_rx.try_recv() {
active_tasks_count -= 1;
info!(
"Received result for task {}: Success: {}",
result.job_label, result.success
);
let current_state = if result.success {
TaskState::Succeeded
} else {
TaskState::Failed
};
task_states.insert(result.task_key.clone(), current_state);
if result.success {
if let Some(original_task) = original_tasks_by_key.get(&result.task_key) {
for output_ref in &original_task.config.outputs {
completed_outputs.insert(output_ref.clone());
}
}
} else {
if FAIL_FAST {
warn!("Fail-fast enabled and task {} failed. Shutting down.", result.job_label);
fail_fast_triggered = true;
}
}
job_results.push(result);
}
// 2. Check for fail-fast break
if fail_fast_triggered && active_tasks_count == 0 { // Wait for running tasks to finish if fail fast
info!("All active tasks completed after fail-fast trigger.");
break;
}
if fail_fast_triggered && active_tasks_count > 0 {
// Don't schedule new tasks, just wait for active ones or log
} else if !fail_fast_triggered { // Only dispatch if not in fail-fast shutdown
// 3. Dispatch ready tasks
for task_node in &graph_nodes_arc {
let task_key = get_task_key(task_node);
if task_states.get(&task_key) == Some(&TaskState::Pending) {
if is_task_ready(task_node, &completed_outputs) {
info!("Dispatching task: {}", task_node.job_label);
task_states.insert(task_key.clone(), TaskState::Running);
task_tx.send(task_node.clone())?;
active_tasks_count += 1;
}
}
}
}
// 4. Periodic logging
if last_log_time.elapsed() >= LOG_INTERVAL {
log_status_summary(&task_states, &original_tasks_by_key);
last_log_time = Instant::now();
}
// 5. Check completion
let all_done = task_states.values().all(|s| *s == TaskState::Succeeded || *s == TaskState::Failed);
if active_tasks_count == 0 && all_done {
info!("All tasks are in a terminal state and no tasks are active.");
break;
}
// Avoid busy-waiting if no events, give channels time
// Select would be better here, but for simplicity:
thread::sleep(Duration::from_millis(50));
}
info!("Shutting down workers...");
drop(task_tx); // Signal workers to stop by closing the task channel
for handle in worker_handles {
handle.join().expect("Failed to join worker thread");
}
info!("All workers finished.");
// Final processing of any remaining results (should be minimal if loop logic is correct)
while let Ok(result) = result_rx.try_recv() {
active_tasks_count -= 1; // Should be 0
info!(
"Received late result for task {}: Success: {}",
result.job_label, result.success
);
// Update state for completeness, though it might not affect overall outcome now
let current_state = if result.success { TaskState::Succeeded } else { TaskState::Failed };
task_states.insert(result.task_key.clone(), current_state);
job_results.push(result);
}
let success_count = job_results.iter().filter(|r| r.success).count();
let failure_count = job_results.len() - success_count;
info!("Execution complete: {} succeeded, {} failed", success_count, failure_count);
if failure_count > 0 || fail_fast_triggered {
error!("Execution finished with errors.");
std::process::exit(1);
}
Ok(())
}