431 lines
16 KiB
Python
431 lines
16 KiB
Python
|
|
from databuild.proto import JobConfig, PartitionRef, DataDep, DepType
|
|
from typing import Self, Protocol, get_type_hints, get_origin, get_args
|
|
from dataclasses import fields, is_dataclass, dataclass, field
|
|
import re
|
|
|
|
|
|
class PartitionPattern:
|
|
_raw_pattern: str
|
|
|
|
@property
|
|
def _pattern(self) -> re.Pattern:
|
|
return re.compile(self._raw_pattern)
|
|
|
|
def _validate_pattern(self):
|
|
"""Checks that both conditions are met:
|
|
1. All fields from the PartitionFields type are present in the pattern
|
|
2. All fields from the pattern are present in the PartitionFields type
|
|
"""
|
|
# TODO how do I get this to be called?
|
|
assert is_dataclass(self), "Should be a dataclass also (for partition fields)"
|
|
pattern_fields = set(self._pattern.groupindex.keys())
|
|
partition_fields = {field.name for field in fields(self)}
|
|
if pattern_fields != partition_fields:
|
|
raise ValueError(f"Pattern fields {pattern_fields} do not match partition fields {partition_fields}")
|
|
|
|
@classmethod
|
|
def deserialize(cls, raw_value: str) -> Self:
|
|
"""Parses a partition from a string based on the defined pattern."""
|
|
# Create a temporary instance to access the compiled pattern
|
|
# We need to compile the pattern to match against it
|
|
pattern = re.compile(cls._raw_pattern)
|
|
|
|
# Match the raw value against the pattern
|
|
match = pattern.match(raw_value)
|
|
if not match:
|
|
raise ValueError(f"String '{raw_value}' does not match pattern '{cls._pattern}'")
|
|
|
|
# Extract the field values from the match
|
|
field_values = match.groupdict()
|
|
|
|
# Create and return a new instance with the extracted values
|
|
return cls(**field_values)
|
|
|
|
def serialize(self) -> str:
|
|
"""Returns a string representation by filling in the pattern template with field values."""
|
|
# Start with the pattern
|
|
result = self._raw_pattern
|
|
|
|
# Replace each named group in the pattern with its corresponding field value
|
|
for field in fields(self):
|
|
# Find the named group pattern and replace it with the actual value
|
|
# We need to replace the regex pattern with the actual value
|
|
# Look for the pattern (?P<field_name>...) and replace with the field value
|
|
pattern_to_replace = rf'\(\?P<{field.name}>[^)]+\)'
|
|
actual_value = getattr(self, field.name)
|
|
result = re.sub(pattern_to_replace, actual_value, result)
|
|
|
|
return result
|
|
|
|
|
|
class DataBuildJob(Protocol):
|
|
# The types of partitions that this job produces
|
|
output_types: list[type[PartitionPattern]]
|
|
|
|
def config(self, outputs: list[PartitionPattern]) -> list[JobConfig]: ...
|
|
|
|
def exec(self, *args: str) -> None: ...
|
|
|
|
|
|
class DataBuildGraph:
|
|
def __init__(self, label: str):
|
|
self.label = label
|
|
self.lookup = {}
|
|
|
|
def job(self, cls: type[DataBuildJob]) -> None:
|
|
"""Register a job with the graph."""
|
|
for partition in cls.output_types:
|
|
assert partition not in self.lookup, f"Partition `{partition}` already registered"
|
|
self.lookup[partition] = cls
|
|
return cls
|
|
|
|
def generate_bazel_module(self):
|
|
"""Generates a complete databuild application, packaging up referenced jobs and this graph via bazel targets"""
|
|
raise NotImplementedError
|
|
|
|
def generate_bazel_package(self, name: str, output_dir: str, deps: list = None) -> None:
|
|
"""Generate BUILD.bazel and binaries into a generated/ subdirectory.
|
|
|
|
Args:
|
|
name: Base name for the generated graph (without .generate suffix)
|
|
output_dir: Directory to write generated files to (will create generated/ subdir)
|
|
deps: List of Bazel dependency labels to use in generated BUILD.bazel
|
|
"""
|
|
import os
|
|
import shutil
|
|
|
|
# Create generated/ subdirectory
|
|
generated_dir = os.path.join(output_dir, "generated")
|
|
os.makedirs(generated_dir, exist_ok=True)
|
|
|
|
# Generate BUILD.bazel with job and graph targets
|
|
self._generate_build_bazel(generated_dir, name, deps or [])
|
|
|
|
# Generate individual job scripts (instead of shared wrapper)
|
|
self._generate_job_scripts(generated_dir)
|
|
|
|
# Generate job lookup binary
|
|
self._generate_job_lookup(generated_dir, name)
|
|
|
|
package_name = self._get_package_name()
|
|
print(f"Generated DataBuild package '{name}' in {generated_dir}")
|
|
if package_name != "UNKNOWN_PACKAGE":
|
|
print(f"Run 'bazel build \"@databuild//{package_name}/generated:{name}_graph.analyze\"' to use the generated graph")
|
|
else:
|
|
print(f"Run 'bazel build generated:{name}_graph.analyze' to use the generated graph")
|
|
|
|
def _generate_build_bazel(self, output_dir: str, name: str, deps: list) -> None:
|
|
"""Generate BUILD.bazel with databuild_job and databuild_graph targets."""
|
|
import os
|
|
|
|
# Get job classes from the lookup table
|
|
job_classes = sorted(set(self.lookup.values()), key=lambda cls: cls.__name__)
|
|
|
|
# Format deps for BUILD.bazel
|
|
if deps:
|
|
deps_str = ", ".join([f'"{dep}"' for dep in deps])
|
|
else:
|
|
# Fallback to parent package if no deps provided
|
|
parent_package = self._get_package_name()
|
|
deps_str = f'"//{parent_package}:dsl_src"'
|
|
|
|
# Generate py_binary targets for each job
|
|
job_binaries = []
|
|
job_targets = []
|
|
|
|
for job_class in job_classes:
|
|
job_name = self._snake_case(job_class.__name__)
|
|
binary_name = f"{job_name}_binary"
|
|
job_targets.append(f'"{job_name}"')
|
|
|
|
job_script_name = f"{job_name}.py"
|
|
job_binaries.append(f'''py_binary(
|
|
name = "{binary_name}",
|
|
srcs = ["{job_script_name}"],
|
|
main = "{job_script_name}",
|
|
deps = [{deps_str}],
|
|
)
|
|
|
|
databuild_job(
|
|
name = "{job_name}",
|
|
binary = ":{binary_name}",
|
|
)''')
|
|
|
|
# Generate the complete BUILD.bazel content
|
|
build_content = f'''load("@databuild//databuild:rules.bzl", "databuild_job", "databuild_graph")
|
|
|
|
# Generated by DataBuild DSL - do not edit manually
|
|
# This file is generated in a subdirectory to avoid overwriting the original BUILD.bazel
|
|
|
|
{chr(10).join(job_binaries)}
|
|
|
|
py_binary(
|
|
name = "{name}_job_lookup",
|
|
srcs = ["{name}_job_lookup.py"],
|
|
deps = [{deps_str}],
|
|
)
|
|
|
|
databuild_graph(
|
|
name = "{name}_graph",
|
|
jobs = [{", ".join(job_targets)}],
|
|
lookup = ":{name}_job_lookup",
|
|
visibility = ["//visibility:public"],
|
|
)
|
|
|
|
# Create tar archive of generated files for testing
|
|
genrule(
|
|
name = "existing_generated",
|
|
srcs = glob(["*.py", "BUILD.bazel"]),
|
|
outs = ["existing_generated.tar"],
|
|
cmd = "mkdir -p temp && cp $(SRCS) temp/ && find temp -exec touch -t 197001010000 {{}} + && tar -cf $@ -C temp .",
|
|
visibility = ["//visibility:public"],
|
|
)
|
|
'''
|
|
|
|
with open(os.path.join(output_dir, "BUILD.bazel"), "w") as f:
|
|
f.write(build_content)
|
|
|
|
def _generate_job_scripts(self, output_dir: str) -> None:
|
|
"""Generate individual Python scripts for each job class."""
|
|
import os
|
|
|
|
# Get job classes and generate a script for each one
|
|
job_classes = list(set(self.lookup.values()))
|
|
graph_module_path = self._get_graph_module_path()
|
|
|
|
for job_class in job_classes:
|
|
job_name = self._snake_case(job_class.__name__)
|
|
script_name = f"{job_name}.py"
|
|
|
|
script_content = f'''#!/usr/bin/env python3
|
|
"""
|
|
Generated job script for {job_class.__name__}.
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
from {graph_module_path} import {job_class.__name__}
|
|
from databuild.proto import PartitionRef, JobConfigureResponse, to_dict
|
|
|
|
|
|
def parse_outputs_from_args(args: list[str]) -> list:
|
|
"""Parse partition output references from command line arguments."""
|
|
outputs = []
|
|
for arg in args:
|
|
# Find which output type can deserialize this partition reference
|
|
for output_type in {job_class.__name__}.output_types:
|
|
try:
|
|
partition = output_type.deserialize(arg)
|
|
outputs.append(partition)
|
|
break
|
|
except ValueError:
|
|
continue
|
|
else:
|
|
raise ValueError(f"No output type in {job_class.__name__} can deserialize partition ref: {{arg}}")
|
|
|
|
return outputs
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2:
|
|
raise Exception(f"Invalid command usage")
|
|
|
|
command = sys.argv[1]
|
|
job_instance = {job_class.__name__}()
|
|
|
|
if command == "config":
|
|
# Parse output partition references as PartitionRef objects (for Rust wrapper)
|
|
output_refs = [PartitionRef(str=raw_ref) for raw_ref in sys.argv[2:]]
|
|
|
|
# Also parse them into DSL partition objects (for DSL job.config())
|
|
outputs = parse_outputs_from_args(sys.argv[2:])
|
|
|
|
# Call job's config method - returns list[JobConfig]
|
|
configs = job_instance.config(outputs)
|
|
|
|
# Wrap in JobConfigureResponse and serialize using to_dict()
|
|
response = JobConfigureResponse(configs=configs)
|
|
print(json.dumps(to_dict(response)))
|
|
|
|
elif command == "exec":
|
|
# The exec method expects a JobConfig but the Rust wrapper passes args
|
|
# For now, let the DSL job handle the args directly
|
|
# TODO: This needs to be refined based on actual Rust wrapper interface
|
|
job_instance.exec(*sys.argv[2:])
|
|
|
|
else:
|
|
raise Exception(f"Invalid command `{{sys.argv[1]}}`")
|
|
'''
|
|
|
|
script_path = os.path.join(output_dir, script_name)
|
|
with open(script_path, "w") as f:
|
|
f.write(script_content)
|
|
|
|
# Make it executable
|
|
os.chmod(script_path, 0o755)
|
|
|
|
def _generate_job_lookup(self, output_dir: str, name: str) -> None:
|
|
"""Generate job lookup binary that maps partition patterns to job targets."""
|
|
import os
|
|
|
|
# Build the job lookup mappings with full package paths
|
|
package_name = self._get_package_name()
|
|
lookup_mappings = []
|
|
for partition_type, job_class in self.lookup.items():
|
|
job_name = self._snake_case(job_class.__name__)
|
|
pattern = partition_type._raw_pattern
|
|
full_target = f"//{package_name}/generated:{job_name}"
|
|
lookup_mappings.append(f' r"{pattern}": "{full_target}",')
|
|
|
|
lookup_content = f'''#!/usr/bin/env python3
|
|
"""
|
|
Generated job lookup for DataBuild DSL graph.
|
|
Maps partition patterns to job targets.
|
|
"""
|
|
|
|
import sys
|
|
import re
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
|
|
# Mapping from partition patterns to job targets
|
|
JOB_MAPPINGS = {{
|
|
{chr(10).join(lookup_mappings)}
|
|
}}
|
|
|
|
|
|
def lookup_job_for_partition(partition_ref: str) -> str:
|
|
"""Look up which job can build the given partition reference."""
|
|
for pattern, job_target in JOB_MAPPINGS.items():
|
|
if re.match(pattern, partition_ref):
|
|
return job_target
|
|
|
|
raise ValueError(f"No job found for partition: {{partition_ref}}")
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
print("Usage: job_lookup.py <partition_ref> [partition_ref...]", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
results = defaultdict(list)
|
|
try:
|
|
for partition_ref in sys.argv[1:]:
|
|
job_target = lookup_job_for_partition(partition_ref)
|
|
results[job_target].append(partition_ref)
|
|
|
|
# Output the results as JSON (matching existing lookup format)
|
|
print(json.dumps(dict(results)))
|
|
except ValueError as e:
|
|
print(f"ERROR: {{e}}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
'''
|
|
|
|
lookup_file = os.path.join(output_dir, f"{name}_job_lookup.py")
|
|
with open(lookup_file, "w") as f:
|
|
f.write(lookup_content)
|
|
|
|
# Make it executable
|
|
os.chmod(lookup_file, 0o755)
|
|
|
|
def _snake_case(self, name: str) -> str:
|
|
"""Convert CamelCase to snake_case."""
|
|
import re
|
|
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
|
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
|
|
def _get_graph_module_path(self) -> str:
|
|
"""Get the module path for the graph containing this instance."""
|
|
# Try to find the module by looking at where the graph object is defined
|
|
import inspect
|
|
import sys
|
|
|
|
# Look through all loaded modules to find where this graph instance is defined
|
|
for module_name, module in sys.modules.items():
|
|
if hasattr(module, 'graph') and getattr(module, 'graph') is self:
|
|
if module_name != '__main__':
|
|
return module_name
|
|
|
|
# Look through the call stack to find the module that imported us
|
|
for frame_info in inspect.stack():
|
|
frame_globals = frame_info.frame.f_globals
|
|
module_name = frame_globals.get('__name__')
|
|
if module_name and module_name != '__main__' and 'graph' in frame_globals:
|
|
# Check if this frame has our graph
|
|
if frame_globals.get('graph') is self:
|
|
return module_name
|
|
|
|
# Last resort fallback - this will need to be manually configured
|
|
return "UNKNOWN_MODULE"
|
|
|
|
def _get_package_name(self) -> str:
|
|
"""Get the Bazel package name where the DSL source files are located."""
|
|
# Extract package from the graph label if available
|
|
if hasattr(self, 'label') and self.label.startswith('//'):
|
|
# Extract package from label like "//databuild/test/app:dsl_graph"
|
|
package_part = self.label.split(':')[0]
|
|
return package_part[2:] # Remove "//" prefix
|
|
|
|
# Fallback to trying to infer from module path
|
|
module_path = self._get_graph_module_path()
|
|
if module_path != "UNKNOWN_MODULE":
|
|
# Convert module path to package path
|
|
# e.g., "databuild.test.app.dsl.graph" -> "databuild/test/app/dsl"
|
|
parts = module_path.split('.')
|
|
if parts[-1] in ['graph', 'main']:
|
|
parts = parts[:-1]
|
|
return '/'.join(parts)
|
|
|
|
return "UNKNOWN_PACKAGE"
|
|
|
|
|
|
@dataclass
|
|
class JobConfigBuilder:
|
|
outputs: list[PartitionRef] = field(default_factory=list)
|
|
inputs: list[DataDep] = field(default_factory=list)
|
|
args: list[str] = field(default_factory=list)
|
|
env: dict[str, str] = field(default_factory=dict)
|
|
|
|
def build(self) -> JobConfig:
|
|
return JobConfig(
|
|
outputs=self.outputs,
|
|
inputs=self.inputs,
|
|
args=self.args,
|
|
env=self.env,
|
|
)
|
|
|
|
def add_inputs(self, *partitions: PartitionPattern, dep_type: DepType=DepType.MATERIALIZE) -> Self:
|
|
for p in partitions:
|
|
dep_type_name = "materialize" if dep_type == DepType.MATERIALIZE else "query"
|
|
self.inputs.append(DataDep(dep_type_code=dep_type, dep_type_name=dep_type_name, partition_ref=PartitionRef(str=p.serialize())))
|
|
return self
|
|
|
|
def add_outputs(self, *partitions: PartitionPattern) -> Self:
|
|
for p in partitions:
|
|
self.outputs.append(PartitionRef(str=p.serialize()))
|
|
return self
|
|
|
|
def add_args(self, *args: str) -> Self:
|
|
self.args.extend(args)
|
|
return self
|
|
|
|
def set_args(self, args: list[str]) -> Self:
|
|
self.args = args
|
|
return self
|
|
|
|
def set_env(self, env: dict[str, str]) -> Self:
|
|
self.env = env
|
|
return self
|
|
|
|
def add_env(self, **kwargs) -> Self:
|
|
for k, v in kwargs.items():
|
|
assert isinstance(k, str), f"Expected a string key, got `{k}`"
|
|
assert isinstance(v, str), f"Expected a string key, got `{v}`"
|
|
self.env[k] = v
|
|
return self
|