130 lines
4.8 KiB
Python
130 lines
4.8 KiB
Python
|
|
from databuild.proto import JobConfig, PartitionRef, DataDep, DepType
|
|
from typing import Self, Protocol, get_type_hints, get_origin, get_args
|
|
from dataclasses import fields, is_dataclass, dataclass, field
|
|
import re
|
|
|
|
|
|
class PartitionPattern:
|
|
_raw_pattern: str
|
|
|
|
@property
|
|
def _pattern(self) -> re.Pattern:
|
|
return re.compile(self._raw_pattern)
|
|
|
|
def _validate_pattern(self):
|
|
"""Checks that both conditions are met:
|
|
1. All fields from the PartitionFields type are present in the pattern
|
|
2. All fields from the pattern are present in the PartitionFields type
|
|
"""
|
|
# TODO how do I get this to be called?
|
|
assert is_dataclass(self), "Should be a dataclass also (for partition fields)"
|
|
pattern_fields = set(self._pattern.groupindex.keys())
|
|
partition_fields = {field.name for field in fields(self)}
|
|
if pattern_fields != partition_fields:
|
|
raise ValueError(f"Pattern fields {pattern_fields} do not match partition fields {partition_fields}")
|
|
|
|
@classmethod
|
|
def deserialize(cls, raw_value: str) -> Self:
|
|
"""Parses a partition from a string based on the defined pattern."""
|
|
# Create a temporary instance to access the compiled pattern
|
|
# We need to compile the pattern to match against it
|
|
pattern = re.compile(cls._raw_pattern)
|
|
|
|
# Match the raw value against the pattern
|
|
match = pattern.match(raw_value)
|
|
if not match:
|
|
raise ValueError(f"String '{raw_value}' does not match pattern '{cls._pattern}'")
|
|
|
|
# Extract the field values from the match
|
|
field_values = match.groupdict()
|
|
|
|
# Create and return a new instance with the extracted values
|
|
return cls(**field_values)
|
|
|
|
def serialize(self) -> str:
|
|
"""Returns a string representation by filling in the pattern template with field values."""
|
|
# Start with the pattern
|
|
result = self._raw_pattern
|
|
|
|
# Replace each named group in the pattern with its corresponding field value
|
|
for field in fields(self):
|
|
# Find the named group pattern and replace it with the actual value
|
|
# We need to replace the regex pattern with the actual value
|
|
# Look for the pattern (?P<field_name>...) and replace with the field value
|
|
pattern_to_replace = rf'\(\?P<{field.name}>[^)]+\)'
|
|
actual_value = getattr(self, field.name)
|
|
result = re.sub(pattern_to_replace, actual_value, result)
|
|
|
|
return result
|
|
|
|
|
|
class DataBuildJob(Protocol):
|
|
# The types of partitions that this job produces
|
|
output_types: list[type[PartitionPattern]]
|
|
|
|
def config(self, outputs: list[PartitionPattern]) -> list[JobConfig]: ...
|
|
|
|
def exec(self, config: JobConfig) -> None: ...
|
|
|
|
|
|
class DataBuildGraph:
|
|
def __init__(self, label: str):
|
|
self.label = label
|
|
self.lookup = {}
|
|
|
|
def job(self, cls: type[DataBuildJob]) -> None:
|
|
"""Register a job with the graph."""
|
|
for partition in cls.output_types:
|
|
assert partition not in self.lookup, f"Partition `{partition}` already registered"
|
|
self.lookup[partition] = cls
|
|
|
|
def generate_bazel_module(self):
|
|
"""Generates a complete databuild application, packaging up referenced jobs and this graph via bazel targets"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class JobConfigBuilder:
|
|
outputs: list[PartitionRef] = field(default_factory=list)
|
|
inputs: list[DataDep] = field(default_factory=list)
|
|
args: list[str] = field(default_factory=list)
|
|
env: dict[str, str] = field(default_factory=dict)
|
|
|
|
def build(self) -> JobConfig:
|
|
return JobConfig(
|
|
outputs=self.outputs,
|
|
inputs=self.inputs,
|
|
args=self.args,
|
|
env=self.env,
|
|
)
|
|
|
|
def add_inputs(self, *partitions: PartitionPattern, dep_type: DepType=DepType.MATERIALIZE) -> Self:
|
|
for p in partitions:
|
|
dep_type_name = "materialize" if dep_type == DepType.Materialize else "query"
|
|
self.inputs.append(DataDep(dep_type_code=dep_type, dep_type_name=dep_type_name, partition_ref=PartitionRef(str=p.serialize())))
|
|
return self
|
|
|
|
def add_outputs(self, *partitions: PartitionPattern) -> Self:
|
|
for p in partitions:
|
|
self.outputs.append(PartitionRef(str=p.serialize()))
|
|
return self
|
|
|
|
def add_args(self, *args: str) -> Self:
|
|
self.args.extend(args)
|
|
return self
|
|
|
|
def set_args(self, args: list[str]) -> Self:
|
|
self.args = args
|
|
return self
|
|
|
|
def set_env(self, env: dict[str, str]) -> Self:
|
|
self.env = env
|
|
return self
|
|
|
|
def add_env(self, **kwargs) -> Self:
|
|
for k, v in kwargs.items():
|
|
assert isinstance(k, str), f"Expected a string key, got `{k}`"
|
|
assert isinstance(v, str), f"Expected a string key, got `{v}`"
|
|
self.env[k] = v
|
|
return self
|