Add rust execute impl
This commit is contained in:
parent
f2567f7567
commit
91d5fd26bc
2 changed files with 445 additions and 0 deletions
|
|
@ -13,6 +13,21 @@ go_binary(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rust_binary(
|
||||||
|
name = "execute_rs",
|
||||||
|
srcs = ["execute.rs"],
|
||||||
|
edition = "2021",
|
||||||
|
deps = [
|
||||||
|
"//databuild:structs",
|
||||||
|
"@crates//:serde",
|
||||||
|
"@crates//:serde_json",
|
||||||
|
"@crates//:log",
|
||||||
|
"@crates//:simple_logger",
|
||||||
|
"@crates//:crossbeam-channel",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
rust_binary(
|
rust_binary(
|
||||||
name = "analyze",
|
name = "analyze",
|
||||||
srcs = ["analyze.rs"],
|
srcs = ["analyze.rs"],
|
||||||
|
|
|
||||||
430
databuild/graph/execute.rs
Normal file
430
databuild/graph/execute.rs
Normal file
|
|
@ -0,0 +1,430 @@
|
||||||
|
use structs::{DataDepType, JobConfig, JobGraph, Task};
|
||||||
|
use crossbeam_channel::{Receiver, Sender};
|
||||||
|
use log::{debug, error, info, warn};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::{Command, Stdio};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
const NUM_WORKERS: usize = 4;
|
||||||
|
const LOG_INTERVAL: Duration = Duration::from_secs(5);
|
||||||
|
const FAIL_FAST: bool = true; // Same default as the Go version
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
enum TaskState {
|
||||||
|
Pending,
|
||||||
|
Running,
|
||||||
|
Succeeded,
|
||||||
|
Failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TaskExecutionResult {
|
||||||
|
task_key: String,
|
||||||
|
job_label: String, // For logging
|
||||||
|
success: bool,
|
||||||
|
stdout: String,
|
||||||
|
stderr: String,
|
||||||
|
duration: Duration,
|
||||||
|
error_message: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a unique key for a task based on its JobLabel, input and output references.
|
||||||
|
// Mirrors the Go implementation's getTaskKey.
|
||||||
|
fn get_task_key(task: &Task) -> String {
|
||||||
|
let mut key_parts = Vec::new();
|
||||||
|
key_parts.push(task.job_label.clone());
|
||||||
|
|
||||||
|
for input_dep in &task.config.inputs {
|
||||||
|
key_parts.push(format!("input:{}", input_dep.reference));
|
||||||
|
}
|
||||||
|
for output_ref in &task.config.outputs {
|
||||||
|
key_parts.push(format!("output:{}", output_ref));
|
||||||
|
}
|
||||||
|
key_parts.join("|")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolves the executable path from runfiles.
|
||||||
|
// Mirrors the Go implementation's resolveExecutableFromRunfiles.
|
||||||
|
fn resolve_executable_from_runfiles(job_label: &str) -> PathBuf {
|
||||||
|
let mut target_name = job_label.to_string();
|
||||||
|
if let Some(colon_index) = job_label.rfind(':') {
|
||||||
|
target_name = job_label[colon_index + 1..].to_string();
|
||||||
|
} else if let Some(name) = Path::new(job_label).file_name().and_then(|n| n.to_str()) {
|
||||||
|
target_name = name.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let exec_name = format!("{}.exec", target_name);
|
||||||
|
|
||||||
|
if let Ok(runfiles_dir_str) = std::env::var("RUNFILES_DIR") {
|
||||||
|
let path = PathBuf::from(runfiles_dir_str).join("_main").join(&exec_name);
|
||||||
|
debug!("Resolved executable path (RUNFILES_DIR): {}", path.display());
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(current_exe) = std::env::current_exe() {
|
||||||
|
let mut runfiles_dir_path = PathBuf::from(format!("{}.runfiles", current_exe.display()));
|
||||||
|
if !runfiles_dir_path.is_dir() { // Bazel often puts it next to the binary
|
||||||
|
if let Some(parent) = current_exe.parent() {
|
||||||
|
runfiles_dir_path = parent.join(format!("{}.runfiles", current_exe.file_name().unwrap_or_default().to_string_lossy()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if runfiles_dir_path.is_dir() {
|
||||||
|
let path = runfiles_dir_path.join("_main").join(&exec_name);
|
||||||
|
debug!("Resolved executable path (derived RUNFILES_DIR): {}", path.display());
|
||||||
|
return path;
|
||||||
|
} else {
|
||||||
|
warn!("Warning: RUNFILES_DIR not found or invalid, and derived path {} is not a directory.", runfiles_dir_path.display());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Warning: Could not determine current executable path.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let fallback_path = PathBuf::from(format!("{}.exec", job_label));
|
||||||
|
warn!("Falling back to direct executable path: {}", fallback_path.display());
|
||||||
|
fallback_path
|
||||||
|
}
|
||||||
|
|
||||||
|
fn worker(
|
||||||
|
task_rx: Receiver<Arc<Task>>,
|
||||||
|
result_tx: Sender<TaskExecutionResult>,
|
||||||
|
worker_id: usize,
|
||||||
|
) {
|
||||||
|
info!("[Worker {}] Starting", worker_id);
|
||||||
|
while let Ok(task) = task_rx.recv() {
|
||||||
|
let task_key = get_task_key(&task);
|
||||||
|
info!("[Worker {}] Starting job: {} (Key: {})", worker_id, task.job_label, task_key);
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let exec_path = resolve_executable_from_runfiles(&task.job_label);
|
||||||
|
|
||||||
|
let config_json = match serde_json::to_string(&task.config) {
|
||||||
|
Ok(json) => json,
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = format!("Failed to serialize task config for {}: {}", task.job_label, e);
|
||||||
|
error!("[Worker {}] {}", worker_id, err_msg);
|
||||||
|
result_tx
|
||||||
|
.send(TaskExecutionResult {
|
||||||
|
task_key,
|
||||||
|
job_label: task.job_label.clone(),
|
||||||
|
success: false,
|
||||||
|
stdout: String::new(),
|
||||||
|
stderr: err_msg.clone(),
|
||||||
|
duration: start_time.elapsed(),
|
||||||
|
error_message: Some(err_msg),
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cmd = Command::new(&exec_path);
|
||||||
|
cmd.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
|
// Set environment variables from the current process's environment
|
||||||
|
// This mirrors the Go `cmd.Env = os.Environ()` behavior.
|
||||||
|
// Task-specific env vars from task.config.env are passed via JSON through stdin.
|
||||||
|
cmd.env_clear(); // Start with no environment variables
|
||||||
|
for (key, value) in std::env::vars() {
|
||||||
|
cmd.env(key, value); // Add current process's environment variables
|
||||||
|
}
|
||||||
|
|
||||||
|
match cmd.spawn() {
|
||||||
|
Ok(mut child) => {
|
||||||
|
if let Some(mut child_stdin) = child.stdin.take() {
|
||||||
|
if let Err(e) = child_stdin.write_all(config_json.as_bytes()) {
|
||||||
|
let err_msg = format!("[Worker {}] Failed to write to stdin for {}: {}", worker_id, task.job_label, e);
|
||||||
|
error!("{}", err_msg);
|
||||||
|
// Ensure child is killed if stdin write fails before wait
|
||||||
|
let _ = child.kill();
|
||||||
|
let _ = child.wait(); // Reap the child
|
||||||
|
result_tx.send(TaskExecutionResult {
|
||||||
|
task_key,
|
||||||
|
job_label: task.job_label.clone(),
|
||||||
|
success: false,
|
||||||
|
stdout: String::new(),
|
||||||
|
stderr: err_msg.clone(),
|
||||||
|
duration: start_time.elapsed(),
|
||||||
|
error_message: Some(err_msg),
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
drop(child_stdin); // Close stdin to signal EOF to the child
|
||||||
|
} else {
|
||||||
|
let err_msg = format!("[Worker {}] Failed to get stdin for {}", worker_id, task.job_label);
|
||||||
|
error!("{}", err_msg);
|
||||||
|
result_tx.send(TaskExecutionResult {
|
||||||
|
task_key,
|
||||||
|
job_label: task.job_label.clone(),
|
||||||
|
success: false,
|
||||||
|
stdout: String::new(),
|
||||||
|
stderr: err_msg.clone(),
|
||||||
|
duration: start_time.elapsed(),
|
||||||
|
error_message: Some(err_msg),
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|e| error!("[Worker {}] Failed to send error result: {}", worker_id, e));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match child.wait_with_output() {
|
||||||
|
Ok(output) => {
|
||||||
|
let duration = start_time.elapsed();
|
||||||
|
let success = output.status.success();
|
||||||
|
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||||
|
|
||||||
|
if success {
|
||||||
|
info!(
|
||||||
|
"[Worker {}] Job succeeded: {} (Duration: {:?})",
|
||||||
|
worker_id, task.job_label, duration
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
error!(
|
||||||
|
"[Worker {}] Job failed: {} (Duration: {:?}, Status: {:?})\nStdout: {}\nStderr: {}",
|
||||||
|
worker_id, task.job_label, duration, output.status, stdout, stderr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result_tx
|
||||||
|
.send(TaskExecutionResult {
|
||||||
|
task_key,
|
||||||
|
job_label: task.job_label.clone(),
|
||||||
|
success,
|
||||||
|
stdout,
|
||||||
|
stderr,
|
||||||
|
duration,
|
||||||
|
error_message: if success { None } else { Some(format!("Exited with status: {:?}", output.status)) },
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|e| error!("[Worker {}] Failed to send result: {}", worker_id, e));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = format!("[Worker {}] Failed to execute or wait for {}: {}", worker_id, task.job_label, e);
|
||||||
|
error!("{}", err_msg);
|
||||||
|
result_tx
|
||||||
|
.send(TaskExecutionResult {
|
||||||
|
task_key,
|
||||||
|
job_label: task.job_label.clone(),
|
||||||
|
success: false,
|
||||||
|
stdout: String::new(),
|
||||||
|
stderr: err_msg.clone(),
|
||||||
|
duration: start_time.elapsed(),
|
||||||
|
error_message: Some(err_msg),
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|e| error!("[Worker {}] Failed to send execution error result: {}", worker_id, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = format!("[Worker {}] Failed to spawn command for {}: {} (Path: {:?})", worker_id, task.job_label, e, exec_path);
|
||||||
|
error!("{}", err_msg);
|
||||||
|
result_tx
|
||||||
|
.send(TaskExecutionResult {
|
||||||
|
task_key,
|
||||||
|
job_label: task.job_label.clone(),
|
||||||
|
success: false,
|
||||||
|
stdout: String::new(),
|
||||||
|
stderr: err_msg.clone(),
|
||||||
|
duration: start_time.elapsed(),
|
||||||
|
error_message: Some(err_msg),
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|e| error!("[Worker {}] Failed to send spawn error result: {}", worker_id, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("[Worker {}] Exiting", worker_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_task_ready(task: &Task, completed_outputs: &HashSet<String>) -> bool {
|
||||||
|
for dep in &task.config.inputs {
|
||||||
|
if dep.dep_type == DataDepType::Materialize {
|
||||||
|
if !completed_outputs.contains(&dep.reference) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_status_summary(
|
||||||
|
task_states: &HashMap<String, TaskState>,
|
||||||
|
original_tasks_by_key: &HashMap<String, Arc<Task>>,
|
||||||
|
) {
|
||||||
|
let mut pending_tasks = Vec::new();
|
||||||
|
let mut running_tasks = Vec::new();
|
||||||
|
let mut succeeded_tasks = Vec::new();
|
||||||
|
let mut failed_tasks = Vec::new();
|
||||||
|
|
||||||
|
for (key, state) in task_states {
|
||||||
|
let label = original_tasks_by_key.get(key).map_or_else(|| key.as_str(), |t| t.job_label.as_str());
|
||||||
|
match state {
|
||||||
|
TaskState::Pending => pending_tasks.push(label),
|
||||||
|
TaskState::Running => running_tasks.push(label),
|
||||||
|
TaskState::Succeeded => succeeded_tasks.push(label),
|
||||||
|
TaskState::Failed => failed_tasks.push(label),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Task Status Summary:");
|
||||||
|
info!(" Pending ({}): {:?}", pending_tasks.len(), pending_tasks);
|
||||||
|
info!(" Running ({}): {:?}", running_tasks.len(), running_tasks);
|
||||||
|
info!(" Succeeded ({}): {:?}", succeeded_tasks.len(), succeeded_tasks);
|
||||||
|
info!(" Failed ({}): {:?}", failed_tasks.len(), failed_tasks);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
simple_logger::SimpleLogger::new().with_level(log::LevelFilter::Info).init()?;
|
||||||
|
|
||||||
|
let mut buffer = String::new();
|
||||||
|
std::io::stdin().read_to_string(&mut buffer)?;
|
||||||
|
let graph: JobGraph = serde_json::from_str(&buffer)?;
|
||||||
|
|
||||||
|
info!("Executing job graph with {} nodes", graph.nodes.len());
|
||||||
|
|
||||||
|
let mut task_states: HashMap<String, TaskState> = HashMap::new();
|
||||||
|
let mut original_tasks_by_key: HashMap<String, Arc<Task>> = HashMap::new();
|
||||||
|
let graph_nodes_arc: Vec<Arc<Task>> = graph.nodes.into_iter().map(Arc::new).collect();
|
||||||
|
|
||||||
|
|
||||||
|
for task_node in &graph_nodes_arc {
|
||||||
|
let key = get_task_key(task_node);
|
||||||
|
task_states.insert(key.clone(), TaskState::Pending);
|
||||||
|
original_tasks_by_key.insert(key, task_node.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut completed_outputs: HashSet<String> = HashSet::new();
|
||||||
|
let mut job_results: Vec<TaskExecutionResult> = Vec::new();
|
||||||
|
|
||||||
|
let (task_tx, task_rx): (Sender<Arc<Task>>, Receiver<Arc<Task>>) = crossbeam_channel::unbounded();
|
||||||
|
let (result_tx, result_rx): (Sender<TaskExecutionResult>, Receiver<TaskExecutionResult>) = crossbeam_channel::unbounded();
|
||||||
|
|
||||||
|
let mut worker_handles = Vec::new();
|
||||||
|
for i in 0..NUM_WORKERS {
|
||||||
|
let task_rx_clone = task_rx.clone();
|
||||||
|
let result_tx_clone = result_tx.clone();
|
||||||
|
worker_handles.push(thread::spawn(move || {
|
||||||
|
worker(task_rx_clone, result_tx_clone, i + 1);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
// Drop the original result_tx so the channel closes when all workers are done
|
||||||
|
// if result_rx is the only remaining receiver.
|
||||||
|
drop(result_tx);
|
||||||
|
|
||||||
|
|
||||||
|
let mut last_log_time = Instant::now();
|
||||||
|
let mut active_tasks_count = 0;
|
||||||
|
let mut fail_fast_triggered = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// 1. Process results
|
||||||
|
while let Ok(result) = result_rx.try_recv() {
|
||||||
|
active_tasks_count -= 1;
|
||||||
|
info!(
|
||||||
|
"Received result for task {}: Success: {}",
|
||||||
|
result.job_label, result.success
|
||||||
|
);
|
||||||
|
|
||||||
|
let current_state = if result.success {
|
||||||
|
TaskState::Succeeded
|
||||||
|
} else {
|
||||||
|
TaskState::Failed
|
||||||
|
};
|
||||||
|
task_states.insert(result.task_key.clone(), current_state);
|
||||||
|
|
||||||
|
if result.success {
|
||||||
|
if let Some(original_task) = original_tasks_by_key.get(&result.task_key) {
|
||||||
|
for output_ref in &original_task.config.outputs {
|
||||||
|
completed_outputs.insert(output_ref.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if FAIL_FAST {
|
||||||
|
warn!("Fail-fast enabled and task {} failed. Shutting down.", result.job_label);
|
||||||
|
fail_fast_triggered = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
job_results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check for fail-fast break
|
||||||
|
if fail_fast_triggered && active_tasks_count == 0 { // Wait for running tasks to finish if fail fast
|
||||||
|
info!("All active tasks completed after fail-fast trigger.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if fail_fast_triggered && active_tasks_count > 0 {
|
||||||
|
// Don't schedule new tasks, just wait for active ones or log
|
||||||
|
} else if !fail_fast_triggered { // Only dispatch if not in fail-fast shutdown
|
||||||
|
// 3. Dispatch ready tasks
|
||||||
|
for task_node in &graph_nodes_arc {
|
||||||
|
let task_key = get_task_key(task_node);
|
||||||
|
if task_states.get(&task_key) == Some(&TaskState::Pending) {
|
||||||
|
if is_task_ready(task_node, &completed_outputs) {
|
||||||
|
info!("Dispatching task: {}", task_node.job_label);
|
||||||
|
task_states.insert(task_key.clone(), TaskState::Running);
|
||||||
|
task_tx.send(task_node.clone())?;
|
||||||
|
active_tasks_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// 4. Periodic logging
|
||||||
|
if last_log_time.elapsed() >= LOG_INTERVAL {
|
||||||
|
log_status_summary(&task_states, &original_tasks_by_key);
|
||||||
|
last_log_time = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Check completion
|
||||||
|
let all_done = task_states.values().all(|s| *s == TaskState::Succeeded || *s == TaskState::Failed);
|
||||||
|
if active_tasks_count == 0 && all_done {
|
||||||
|
info!("All tasks are in a terminal state and no tasks are active.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid busy-waiting if no events, give channels time
|
||||||
|
// Select would be better here, but for simplicity:
|
||||||
|
thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Shutting down workers...");
|
||||||
|
drop(task_tx); // Signal workers to stop by closing the task channel
|
||||||
|
|
||||||
|
for handle in worker_handles {
|
||||||
|
handle.join().expect("Failed to join worker thread");
|
||||||
|
}
|
||||||
|
info!("All workers finished.");
|
||||||
|
|
||||||
|
// Final processing of any remaining results (should be minimal if loop logic is correct)
|
||||||
|
while let Ok(result) = result_rx.try_recv() {
|
||||||
|
active_tasks_count -= 1; // Should be 0
|
||||||
|
info!(
|
||||||
|
"Received late result for task {}: Success: {}",
|
||||||
|
result.job_label, result.success
|
||||||
|
);
|
||||||
|
// Update state for completeness, though it might not affect overall outcome now
|
||||||
|
let current_state = if result.success { TaskState::Succeeded } else { TaskState::Failed };
|
||||||
|
task_states.insert(result.task_key.clone(), current_state);
|
||||||
|
job_results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
let success_count = job_results.iter().filter(|r| r.success).count();
|
||||||
|
let failure_count = job_results.len() - success_count;
|
||||||
|
|
||||||
|
info!("Execution complete: {} succeeded, {} failed", success_count, failure_count);
|
||||||
|
|
||||||
|
if failure_count > 0 || fail_fast_triggered {
|
||||||
|
error!("Execution finished with errors.");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue