from databuild.proto import JobConfig, PartitionRef, DataDep, DepType from typing import Self, Protocol, get_type_hints, get_origin, get_args from dataclasses import fields, is_dataclass, dataclass, field import re class PartitionPattern: _raw_pattern: str @property def _pattern(self) -> re.Pattern: return re.compile(self._raw_pattern) def _validate_pattern(self): """Checks that both conditions are met: 1. All fields from the PartitionFields type are present in the pattern 2. All fields from the pattern are present in the PartitionFields type """ # TODO how do I get this to be called? assert is_dataclass(self), "Should be a dataclass also (for partition fields)" pattern_fields = set(self._pattern.groupindex.keys()) partition_fields = {field.name for field in fields(self)} if pattern_fields != partition_fields: raise ValueError(f"Pattern fields {pattern_fields} do not match partition fields {partition_fields}") @classmethod def deserialize(cls, raw_value: str) -> Self: """Parses a partition from a string based on the defined pattern.""" # Create a temporary instance to access the compiled pattern # We need to compile the pattern to match against it pattern = re.compile(cls._raw_pattern) # Match the raw value against the pattern match = pattern.match(raw_value) if not match: raise ValueError(f"String '{raw_value}' does not match pattern '{cls._pattern}'") # Extract the field values from the match field_values = match.groupdict() # Create and return a new instance with the extracted values return cls(**field_values) def serialize(self) -> str: """Returns a string representation by filling in the pattern template with field values.""" # Start with the pattern result = self._raw_pattern # Replace each named group in the pattern with its corresponding field value for field in fields(self): # Find the named group pattern and replace it with the actual value # We need to replace the regex pattern with the actual value # Look for the pattern (?P...) and replace with the field value pattern_to_replace = rf'\(\?P<{field.name}>[^)]+\)' actual_value = getattr(self, field.name) result = re.sub(pattern_to_replace, actual_value, result) return result class DataBuildJob(Protocol): # The types of partitions that this job produces output_types: list[type[PartitionPattern]] def config(self, outputs: list[PartitionPattern]) -> list[JobConfig]: ... def exec(self, *args: str) -> None: ... class DataBuildGraph: def __init__(self, label: str): self.label = label self.lookup = {} def job(self, cls: type[DataBuildJob]) -> None: """Register a job with the graph.""" for partition in cls.output_types: assert partition not in self.lookup, f"Partition `{partition}` already registered" self.lookup[partition] = cls return cls def generate_bazel_module(self): """Generates a complete databuild application, packaging up referenced jobs and this graph via bazel targets""" raise NotImplementedError def generate_bazel_package(self, name: str, output_dir: str, deps: list = None) -> None: """Generate BUILD.bazel and binaries into a generated/ subdirectory. Args: name: Base name for the generated graph (without .generate suffix) output_dir: Directory to write generated files to (will create generated/ subdir) deps: List of Bazel dependency labels to use in generated BUILD.bazel """ import os import shutil # Create generated/ subdirectory generated_dir = os.path.join(output_dir, "generated") os.makedirs(generated_dir, exist_ok=True) # Generate BUILD.bazel with job and graph targets self._generate_build_bazel(generated_dir, name, deps or []) # Generate individual job scripts (instead of shared wrapper) self._generate_job_scripts(generated_dir) # Generate job lookup binary self._generate_job_lookup(generated_dir, name) package_name = self._get_package_name() print(f"Generated DataBuild package '{name}' in {generated_dir}") if package_name != "UNKNOWN_PACKAGE": print(f"Run 'bazel build \"@databuild//{package_name}/generated:{name}_graph.analyze\"' to use the generated graph") else: print(f"Run 'bazel build generated:{name}_graph.analyze' to use the generated graph") def _generate_build_bazel(self, output_dir: str, name: str, deps: list) -> None: """Generate BUILD.bazel with databuild_job and databuild_graph targets.""" import os # Get job classes from the lookup table job_classes = sorted(set(self.lookup.values()), key=lambda cls: cls.__name__) # Format deps for BUILD.bazel if deps: deps_str = ", ".join([f'"{dep}"' for dep in deps]) else: # Fallback to parent package if no deps provided parent_package = self._get_package_name() deps_str = f'"//{parent_package}:dsl_src"' # Generate py_binary targets for each job job_binaries = [] job_targets = [] for job_class in job_classes: job_name = self._snake_case(job_class.__name__) binary_name = f"{job_name}_binary" job_targets.append(f'"{job_name}"') job_script_name = f"{job_name}.py" job_binaries.append(f'''py_binary( name = "{binary_name}", srcs = ["{job_script_name}"], main = "{job_script_name}", deps = [{deps_str}], ) databuild_job( name = "{job_name}", binary = ":{binary_name}", )''') # Generate the complete BUILD.bazel content build_content = f'''load("@databuild//databuild:rules.bzl", "databuild_job", "databuild_graph") # Generated by DataBuild DSL - do not edit manually # This file is generated in a subdirectory to avoid overwriting the original BUILD.bazel {chr(10).join(job_binaries)} py_binary( name = "{name}_job_lookup", srcs = ["{name}_job_lookup.py"], deps = [{deps_str}], ) databuild_graph( name = "{name}_graph", jobs = [{", ".join(job_targets)}], lookup = ":{name}_job_lookup", visibility = ["//visibility:public"], ) # Create tar archive of generated files for testing genrule( name = "existing_generated", srcs = glob(["*.py", "BUILD.bazel"]), outs = ["existing_generated.tar"], cmd = "mkdir -p temp && cp $(SRCS) temp/ && find temp -exec touch -t 197001010000 {{}} + && tar -cf $@ -C temp .", visibility = ["//visibility:public"], ) ''' with open(os.path.join(output_dir, "BUILD.bazel"), "w") as f: f.write(build_content) def _generate_job_scripts(self, output_dir: str) -> None: """Generate individual Python scripts for each job class.""" import os # Get job classes and generate a script for each one job_classes = list(set(self.lookup.values())) graph_module_path = self._get_graph_module_path() for job_class in job_classes: job_name = self._snake_case(job_class.__name__) script_name = f"{job_name}.py" script_content = f'''#!/usr/bin/env python3 """ Generated job script for {job_class.__name__}. """ import sys import json from {graph_module_path} import {job_class.__name__} from databuild.proto import PartitionRef, JobConfigureResponse, to_dict def parse_outputs_from_args(args: list[str]) -> list: """Parse partition output references from command line arguments.""" outputs = [] for arg in args: # Find which output type can deserialize this partition reference for output_type in {job_class.__name__}.output_types: try: partition = output_type.deserialize(arg) outputs.append(partition) break except ValueError: continue else: raise ValueError(f"No output type in {job_class.__name__} can deserialize partition ref: {{arg}}") return outputs if __name__ == "__main__": if len(sys.argv) < 2: raise Exception(f"Invalid command usage") command = sys.argv[1] job_instance = {job_class.__name__}() if command == "config": # Parse output partition references as PartitionRef objects (for Rust wrapper) output_refs = [PartitionRef(str=raw_ref) for raw_ref in sys.argv[2:]] # Also parse them into DSL partition objects (for DSL job.config()) outputs = parse_outputs_from_args(sys.argv[2:]) # Call job's config method - returns list[JobConfig] configs = job_instance.config(outputs) # Wrap in JobConfigureResponse and serialize using to_dict() response = JobConfigureResponse(configs=configs) print(json.dumps(to_dict(response))) elif command == "exec": # The exec method expects a JobConfig but the Rust wrapper passes args # For now, let the DSL job handle the args directly # TODO: This needs to be refined based on actual Rust wrapper interface job_instance.exec(*sys.argv[2:]) else: raise Exception(f"Invalid command `{{sys.argv[1]}}`") ''' script_path = os.path.join(output_dir, script_name) with open(script_path, "w") as f: f.write(script_content) # Make it executable os.chmod(script_path, 0o755) def _generate_job_lookup(self, output_dir: str, name: str) -> None: """Generate job lookup binary that maps partition patterns to job targets.""" import os # Build the job lookup mappings with full package paths package_name = self._get_package_name() lookup_mappings = [] for partition_type, job_class in self.lookup.items(): job_name = self._snake_case(job_class.__name__) pattern = partition_type._raw_pattern full_target = f"//{package_name}/generated:{job_name}" lookup_mappings.append(f' r"{pattern}": "{full_target}",') lookup_content = f'''#!/usr/bin/env python3 """ Generated job lookup for DataBuild DSL graph. Maps partition patterns to job targets. """ import sys import re import json from collections import defaultdict # Mapping from partition patterns to job targets JOB_MAPPINGS = {{ {chr(10).join(lookup_mappings)} }} def lookup_job_for_partition(partition_ref: str) -> str: """Look up which job can build the given partition reference.""" for pattern, job_target in JOB_MAPPINGS.items(): if re.match(pattern, partition_ref): return job_target raise ValueError(f"No job found for partition: {{partition_ref}}") def main(): if len(sys.argv) < 2: print("Usage: job_lookup.py [partition_ref...]", file=sys.stderr) sys.exit(1) results = defaultdict(list) try: for partition_ref in sys.argv[1:]: job_target = lookup_job_for_partition(partition_ref) results[job_target].append(partition_ref) # Output the results as JSON (matching existing lookup format) print(json.dumps(dict(results))) except ValueError as e: print(f"ERROR: {{e}}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main() ''' lookup_file = os.path.join(output_dir, f"{name}_job_lookup.py") with open(lookup_file, "w") as f: f.write(lookup_content) # Make it executable os.chmod(lookup_file, 0o755) def _snake_case(self, name: str) -> str: """Convert CamelCase to snake_case.""" import re s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() def _get_graph_module_path(self) -> str: """Get the module path for the graph containing this instance.""" # Try to find the module by looking at where the graph object is defined import inspect import sys # Look through all loaded modules to find where this graph instance is defined for module_name, module in sys.modules.items(): if hasattr(module, 'graph') and getattr(module, 'graph') is self: if module_name != '__main__': return module_name # Look through the call stack to find the module that imported us for frame_info in inspect.stack(): frame_globals = frame_info.frame.f_globals module_name = frame_globals.get('__name__') if module_name and module_name != '__main__' and 'graph' in frame_globals: # Check if this frame has our graph if frame_globals.get('graph') is self: return module_name # Last resort fallback - this will need to be manually configured return "UNKNOWN_MODULE" def _get_package_name(self) -> str: """Get the Bazel package name where the DSL source files are located.""" # Extract package from the graph label if available if hasattr(self, 'label') and self.label.startswith('//'): # Extract package from label like "//databuild/test/app:dsl_graph" package_part = self.label.split(':')[0] return package_part[2:] # Remove "//" prefix # Fallback to trying to infer from module path module_path = self._get_graph_module_path() if module_path != "UNKNOWN_MODULE": # Convert module path to package path # e.g., "databuild.test.app.dsl.graph" -> "databuild/test/app/dsl" parts = module_path.split('.') if parts[-1] in ['graph', 'main']: parts = parts[:-1] return '/'.join(parts) return "UNKNOWN_PACKAGE" @dataclass class JobConfigBuilder: outputs: list[PartitionRef] = field(default_factory=list) inputs: list[DataDep] = field(default_factory=list) args: list[str] = field(default_factory=list) env: dict[str, str] = field(default_factory=dict) def build(self) -> JobConfig: return JobConfig( outputs=self.outputs, inputs=self.inputs, args=self.args, env=self.env, ) def add_inputs(self, *partitions: PartitionPattern, dep_type: DepType=DepType.MATERIALIZE) -> Self: for p in partitions: dep_type_name = "materialize" if dep_type == DepType.MATERIALIZE else "query" self.inputs.append(DataDep(dep_type_code=dep_type, dep_type_name=dep_type_name, partition_ref=PartitionRef(str=p.serialize()))) return self def add_outputs(self, *partitions: PartitionPattern) -> Self: for p in partitions: self.outputs.append(PartitionRef(str=p.serialize())) return self def add_args(self, *args: str) -> Self: self.args.extend(args) return self def set_args(self, args: list[str]) -> Self: self.args = args return self def set_env(self, env: dict[str, str]) -> Self: self.env = env return self def add_env(self, **kwargs) -> Self: for k, v in kwargs.items(): assert isinstance(k, str), f"Expected a string key, got `{k}`" assert isinstance(v, str), f"Expected a string key, got `{v}`" self.env[k] = v return self