Add rust execute impl
This commit is contained in:
parent
f2567f7567
commit
91d5fd26bc
2 changed files with 445 additions and 0 deletions
|
|
@ -13,6 +13,21 @@ go_binary(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
rust_binary(
|
||||
name = "execute_rs",
|
||||
srcs = ["execute.rs"],
|
||||
edition = "2021",
|
||||
deps = [
|
||||
"//databuild:structs",
|
||||
"@crates//:serde",
|
||||
"@crates//:serde_json",
|
||||
"@crates//:log",
|
||||
"@crates//:simple_logger",
|
||||
"@crates//:crossbeam-channel",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
rust_binary(
|
||||
name = "analyze",
|
||||
srcs = ["analyze.rs"],
|
||||
|
|
|
|||
430
databuild/graph/execute.rs
Normal file
430
databuild/graph/execute.rs
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
use structs::{DataDepType, JobConfig, JobGraph, Task};
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use log::{debug, error, info, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::{Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Stdio};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
const NUM_WORKERS: usize = 4;
|
||||
const LOG_INTERVAL: Duration = Duration::from_secs(5);
|
||||
const FAIL_FAST: bool = true; // Same default as the Go version
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum TaskState {
|
||||
Pending,
|
||||
Running,
|
||||
Succeeded,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TaskExecutionResult {
|
||||
task_key: String,
|
||||
job_label: String, // For logging
|
||||
success: bool,
|
||||
stdout: String,
|
||||
stderr: String,
|
||||
duration: Duration,
|
||||
error_message: Option<String>,
|
||||
}
|
||||
|
||||
// Generates a unique key for a task based on its JobLabel, input and output references.
|
||||
// Mirrors the Go implementation's getTaskKey.
|
||||
fn get_task_key(task: &Task) -> String {
|
||||
let mut key_parts = Vec::new();
|
||||
key_parts.push(task.job_label.clone());
|
||||
|
||||
for input_dep in &task.config.inputs {
|
||||
key_parts.push(format!("input:{}", input_dep.reference));
|
||||
}
|
||||
for output_ref in &task.config.outputs {
|
||||
key_parts.push(format!("output:{}", output_ref));
|
||||
}
|
||||
key_parts.join("|")
|
||||
}
|
||||
|
||||
// Resolves the executable path from runfiles.
|
||||
// Mirrors the Go implementation's resolveExecutableFromRunfiles.
|
||||
fn resolve_executable_from_runfiles(job_label: &str) -> PathBuf {
|
||||
let mut target_name = job_label.to_string();
|
||||
if let Some(colon_index) = job_label.rfind(':') {
|
||||
target_name = job_label[colon_index + 1..].to_string();
|
||||
} else if let Some(name) = Path::new(job_label).file_name().and_then(|n| n.to_str()) {
|
||||
target_name = name.to_string();
|
||||
}
|
||||
|
||||
let exec_name = format!("{}.exec", target_name);
|
||||
|
||||
if let Ok(runfiles_dir_str) = std::env::var("RUNFILES_DIR") {
|
||||
let path = PathBuf::from(runfiles_dir_str).join("_main").join(&exec_name);
|
||||
debug!("Resolved executable path (RUNFILES_DIR): {}", path.display());
|
||||
return path;
|
||||
}
|
||||
|
||||
if let Ok(current_exe) = std::env::current_exe() {
|
||||
let mut runfiles_dir_path = PathBuf::from(format!("{}.runfiles", current_exe.display()));
|
||||
if !runfiles_dir_path.is_dir() { // Bazel often puts it next to the binary
|
||||
if let Some(parent) = current_exe.parent() {
|
||||
runfiles_dir_path = parent.join(format!("{}.runfiles", current_exe.file_name().unwrap_or_default().to_string_lossy()));
|
||||
}
|
||||
}
|
||||
|
||||
if runfiles_dir_path.is_dir() {
|
||||
let path = runfiles_dir_path.join("_main").join(&exec_name);
|
||||
debug!("Resolved executable path (derived RUNFILES_DIR): {}", path.display());
|
||||
return path;
|
||||
} else {
|
||||
warn!("Warning: RUNFILES_DIR not found or invalid, and derived path {} is not a directory.", runfiles_dir_path.display());
|
||||
}
|
||||
} else {
|
||||
warn!("Warning: Could not determine current executable path.");
|
||||
}
|
||||
|
||||
let fallback_path = PathBuf::from(format!("{}.exec", job_label));
|
||||
warn!("Falling back to direct executable path: {}", fallback_path.display());
|
||||
fallback_path
|
||||
}
|
||||
|
||||
fn worker(
|
||||
task_rx: Receiver<Arc<Task>>,
|
||||
result_tx: Sender<TaskExecutionResult>,
|
||||
worker_id: usize,
|
||||
) {
|
||||
info!("[Worker {}] Starting", worker_id);
|
||||
while let Ok(task) = task_rx.recv() {
|
||||
let task_key = get_task_key(&task);
|
||||
info!("[Worker {}] Starting job: {} (Key: {})", worker_id, task.job_label, task_key);
|
||||
let start_time = Instant::now();
|
||||
|
||||
let exec_path = resolve_executable_from_runfiles(&task.job_label);
|
||||
|
||||
let config_json = match serde_json::to_string(&task.config) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
let err_msg = format!("Failed to serialize task config for {}: {}", task.job_label, e);
|
||||
error!("[Worker {}] {}", worker_id, err_msg);
|
||||
result_tx
|
||||
.send(TaskExecutionResult {
|
||||
task_key,
|
||||
job_label: task.job_label.clone(),
|
||||
success: false,
|
||||
stdout: String::new(),
|
||||
stderr: err_msg.clone(),
|
||||
duration: start_time.elapsed(),
|
||||
error_message: Some(err_msg),
|
||||
})
|
||||
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut cmd = Command::new(&exec_path);
|
||||
cmd.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
// Set environment variables from the current process's environment
|
||||
// This mirrors the Go `cmd.Env = os.Environ()` behavior.
|
||||
// Task-specific env vars from task.config.env are passed via JSON through stdin.
|
||||
cmd.env_clear(); // Start with no environment variables
|
||||
for (key, value) in std::env::vars() {
|
||||
cmd.env(key, value); // Add current process's environment variables
|
||||
}
|
||||
|
||||
match cmd.spawn() {
|
||||
Ok(mut child) => {
|
||||
if let Some(mut child_stdin) = child.stdin.take() {
|
||||
if let Err(e) = child_stdin.write_all(config_json.as_bytes()) {
|
||||
let err_msg = format!("[Worker {}] Failed to write to stdin for {}: {}", worker_id, task.job_label, e);
|
||||
error!("{}", err_msg);
|
||||
// Ensure child is killed if stdin write fails before wait
|
||||
let _ = child.kill();
|
||||
let _ = child.wait(); // Reap the child
|
||||
result_tx.send(TaskExecutionResult {
|
||||
task_key,
|
||||
job_label: task.job_label.clone(),
|
||||
success: false,
|
||||
stdout: String::new(),
|
||||
stderr: err_msg.clone(),
|
||||
duration: start_time.elapsed(),
|
||||
error_message: Some(err_msg),
|
||||
})
|
||||
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
|
||||
continue;
|
||||
}
|
||||
drop(child_stdin); // Close stdin to signal EOF to the child
|
||||
} else {
|
||||
let err_msg = format!("[Worker {}] Failed to get stdin for {}", worker_id, task.job_label);
|
||||
error!("{}", err_msg);
|
||||
result_tx.send(TaskExecutionResult {
|
||||
task_key,
|
||||
job_label: task.job_label.clone(),
|
||||
success: false,
|
||||
stdout: String::new(),
|
||||
stderr: err_msg.clone(),
|
||||
duration: start_time.elapsed(),
|
||||
error_message: Some(err_msg),
|
||||
})
|
||||
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
|
||||
continue;
|
||||
}
|
||||
|
||||
match child.wait_with_output() {
|
||||
Ok(output) => {
|
||||
let duration = start_time.elapsed();
|
||||
let success = output.status.success();
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
if success {
|
||||
info!(
|
||||
"[Worker {}] Job succeeded: {} (Duration: {:?})",
|
||||
worker_id, task.job_label, duration
|
||||
);
|
||||
} else {
|
||||
error!(
|
||||
"[Worker {}] Job failed: {} (Duration: {:?}, Status: {:?})\nStdout: {}\nStderr: {}",
|
||||
worker_id, task.job_label, duration, output.status, stdout, stderr
|
||||
);
|
||||
}
|
||||
result_tx
|
||||
.send(TaskExecutionResult {
|
||||
task_key,
|
||||
job_label: task.job_label.clone(),
|
||||
success,
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
error_message: if success { None } else { Some(format!("Exited with status: {:?}", output.status)) },
|
||||
})
|
||||
.unwrap_or_else(|e| error!("[Worker {}] Failed to send result: {}", worker_id, e));
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("[Worker {}] Failed to execute or wait for {}: {}", worker_id, task.job_label, e);
|
||||
error!("{}", err_msg);
|
||||
result_tx
|
||||
.send(TaskExecutionResult {
|
||||
task_key,
|
||||
job_label: task.job_label.clone(),
|
||||
success: false,
|
||||
stdout: String::new(),
|
||||
stderr: err_msg.clone(),
|
||||
duration: start_time.elapsed(),
|
||||
error_message: Some(err_msg),
|
||||
})
|
||||
.unwrap_or_else(|e| error!("[Worker {}] Failed to send execution error result: {}", worker_id, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("[Worker {}] Failed to spawn command for {}: {} (Path: {:?})", worker_id, task.job_label, e, exec_path);
|
||||
error!("{}", err_msg);
|
||||
result_tx
|
||||
.send(TaskExecutionResult {
|
||||
task_key,
|
||||
job_label: task.job_label.clone(),
|
||||
success: false,
|
||||
stdout: String::new(),
|
||||
stderr: err_msg.clone(),
|
||||
duration: start_time.elapsed(),
|
||||
error_message: Some(err_msg),
|
||||
})
|
||||
.unwrap_or_else(|e| error!("[Worker {}] Failed to send spawn error result: {}", worker_id, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("[Worker {}] Exiting", worker_id);
|
||||
}
|
||||
|
||||
fn is_task_ready(task: &Task, completed_outputs: &HashSet<String>) -> bool {
|
||||
for dep in &task.config.inputs {
|
||||
if dep.dep_type == DataDepType::Materialize {
|
||||
if !completed_outputs.contains(&dep.reference) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn log_status_summary(
|
||||
task_states: &HashMap<String, TaskState>,
|
||||
original_tasks_by_key: &HashMap<String, Arc<Task>>,
|
||||
) {
|
||||
let mut pending_tasks = Vec::new();
|
||||
let mut running_tasks = Vec::new();
|
||||
let mut succeeded_tasks = Vec::new();
|
||||
let mut failed_tasks = Vec::new();
|
||||
|
||||
for (key, state) in task_states {
|
||||
let label = original_tasks_by_key.get(key).map_or_else(|| key.as_str(), |t| t.job_label.as_str());
|
||||
match state {
|
||||
TaskState::Pending => pending_tasks.push(label),
|
||||
TaskState::Running => running_tasks.push(label),
|
||||
TaskState::Succeeded => succeeded_tasks.push(label),
|
||||
TaskState::Failed => failed_tasks.push(label),
|
||||
}
|
||||
}
|
||||
|
||||
info!("Task Status Summary:");
|
||||
info!(" Pending ({}): {:?}", pending_tasks.len(), pending_tasks);
|
||||
info!(" Running ({}): {:?}", running_tasks.len(), running_tasks);
|
||||
info!(" Succeeded ({}): {:?}", succeeded_tasks.len(), succeeded_tasks);
|
||||
info!(" Failed ({}): {:?}", failed_tasks.len(), failed_tasks);
|
||||
}
|
||||
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init()?;
|
||||
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin().read_to_string(&mut buffer)?;
|
||||
let graph: JobGraph = serde_json::from_str(&buffer)?;
|
||||
|
||||
info!("Executing job graph with {} nodes", graph.nodes.len());
|
||||
|
||||
let mut task_states: HashMap<String, TaskState> = HashMap::new();
|
||||
let mut original_tasks_by_key: HashMap<String, Arc<Task>> = HashMap::new();
|
||||
let graph_nodes_arc: Vec<Arc<Task>> = graph.nodes.into_iter().map(Arc::new).collect();
|
||||
|
||||
|
||||
for task_node in &graph_nodes_arc {
|
||||
let key = get_task_key(task_node);
|
||||
task_states.insert(key.clone(), TaskState::Pending);
|
||||
original_tasks_by_key.insert(key, task_node.clone());
|
||||
}
|
||||
|
||||
let mut completed_outputs: HashSet<String> = HashSet::new();
|
||||
let mut job_results: Vec<TaskExecutionResult> = Vec::new();
|
||||
|
||||
let (task_tx, task_rx): (Sender<Arc<Task>>, Receiver<Arc<Task>>) = crossbeam_channel::unbounded();
|
||||
let (result_tx, result_rx): (Sender<TaskExecutionResult>, Receiver<TaskExecutionResult>) = crossbeam_channel::unbounded();
|
||||
|
||||
let mut worker_handles = Vec::new();
|
||||
for i in 0..NUM_WORKERS {
|
||||
let task_rx_clone = task_rx.clone();
|
||||
let result_tx_clone = result_tx.clone();
|
||||
worker_handles.push(thread::spawn(move || {
|
||||
worker(task_rx_clone, result_tx_clone, i + 1);
|
||||
}));
|
||||
}
|
||||
// Drop the original result_tx so the channel closes when all workers are done
|
||||
// if result_rx is the only remaining receiver.
|
||||
drop(result_tx);
|
||||
|
||||
|
||||
let mut last_log_time = Instant::now();
|
||||
let mut active_tasks_count = 0;
|
||||
let mut fail_fast_triggered = false;
|
||||
|
||||
loop {
|
||||
// 1. Process results
|
||||
while let Ok(result) = result_rx.try_recv() {
|
||||
active_tasks_count -= 1;
|
||||
info!(
|
||||
"Received result for task {}: Success: {}",
|
||||
result.job_label, result.success
|
||||
);
|
||||
|
||||
let current_state = if result.success {
|
||||
TaskState::Succeeded
|
||||
} else {
|
||||
TaskState::Failed
|
||||
};
|
||||
task_states.insert(result.task_key.clone(), current_state);
|
||||
|
||||
if result.success {
|
||||
if let Some(original_task) = original_tasks_by_key.get(&result.task_key) {
|
||||
for output_ref in &original_task.config.outputs {
|
||||
completed_outputs.insert(output_ref.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if FAIL_FAST {
|
||||
warn!("Fail-fast enabled and task {} failed. Shutting down.", result.job_label);
|
||||
fail_fast_triggered = true;
|
||||
}
|
||||
}
|
||||
job_results.push(result);
|
||||
}
|
||||
|
||||
// 2. Check for fail-fast break
|
||||
if fail_fast_triggered && active_tasks_count == 0 { // Wait for running tasks to finish if fail fast
|
||||
info!("All active tasks completed after fail-fast trigger.");
|
||||
break;
|
||||
}
|
||||
if fail_fast_triggered && active_tasks_count > 0 {
|
||||
// Don't schedule new tasks, just wait for active ones or log
|
||||
} else if !fail_fast_triggered { // Only dispatch if not in fail-fast shutdown
|
||||
// 3. Dispatch ready tasks
|
||||
for task_node in &graph_nodes_arc {
|
||||
let task_key = get_task_key(task_node);
|
||||
if task_states.get(&task_key) == Some(&TaskState::Pending) {
|
||||
if is_task_ready(task_node, &completed_outputs) {
|
||||
info!("Dispatching task: {}", task_node.job_label);
|
||||
task_states.insert(task_key.clone(), TaskState::Running);
|
||||
task_tx.send(task_node.clone())?;
|
||||
active_tasks_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 4. Periodic logging
|
||||
if last_log_time.elapsed() >= LOG_INTERVAL {
|
||||
log_status_summary(&task_states, &original_tasks_by_key);
|
||||
last_log_time = Instant::now();
|
||||
}
|
||||
|
||||
// 5. Check completion
|
||||
let all_done = task_states.values().all(|s| *s == TaskState::Succeeded || *s == TaskState::Failed);
|
||||
if active_tasks_count == 0 && all_done {
|
||||
info!("All tasks are in a terminal state and no tasks are active.");
|
||||
break;
|
||||
}
|
||||
|
||||
// Avoid busy-waiting if no events, give channels time
|
||||
// Select would be better here, but for simplicity:
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
|
||||
info!("Shutting down workers...");
|
||||
drop(task_tx); // Signal workers to stop by closing the task channel
|
||||
|
||||
for handle in worker_handles {
|
||||
handle.join().expect("Failed to join worker thread");
|
||||
}
|
||||
info!("All workers finished.");
|
||||
|
||||
// Final processing of any remaining results (should be minimal if loop logic is correct)
|
||||
while let Ok(result) = result_rx.try_recv() {
|
||||
active_tasks_count -= 1; // Should be 0
|
||||
info!(
|
||||
"Received late result for task {}: Success: {}",
|
||||
result.job_label, result.success
|
||||
);
|
||||
// Update state for completeness, though it might not affect overall outcome now
|
||||
let current_state = if result.success { TaskState::Succeeded } else { TaskState::Failed };
|
||||
task_states.insert(result.task_key.clone(), current_state);
|
||||
job_results.push(result);
|
||||
}
|
||||
|
||||
|
||||
let success_count = job_results.iter().filter(|r| r.success).count();
|
||||
let failure_count = job_results.len() - success_count;
|
||||
|
||||
info!("Execution complete: {} succeeded, {} failed", success_count, failure_count);
|
||||
|
||||
if failure_count > 0 || fail_fast_triggered {
|
||||
error!("Execution finished with errors.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Reference in a new issue