From 59420f58d0b20e804c3995df5f92bc19900f0e67 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Thu, 8 Jan 2026 21:59:19 +0100 Subject: [PATCH 01/19] wip: setup base pipeline with blocks --- .../pipeline-editor/BlockConfigPanel.tsx | 64 +++++ .../components/pipeline-editor/BlockNode.tsx | 52 +++- lib/blocks/builtin/duplicate_remover.py | 159 ++++++++++ lib/blocks/builtin/semantic_infiller.py | 221 ++++++++++++++ lib/blocks/builtin/structure_sampler.py | 272 ++++++++++++++++++ lib/templates/data_augmentation.yaml | 23 ++ .../seeds/seed_data_augmentation.json | 23 ++ pyproject.toml | 1 + tests/integration/test_data_augmentation.py | 261 +++++++++++++++++ 9 files changed, 1069 insertions(+), 7 deletions(-) create mode 100644 lib/blocks/builtin/duplicate_remover.py create mode 100644 lib/blocks/builtin/semantic_infiller.py create mode 100644 lib/blocks/builtin/structure_sampler.py create mode 100644 lib/templates/data_augmentation.yaml create mode 100644 lib/templates/seeds/seed_data_augmentation.json create mode 100644 tests/integration/test_data_augmentation.py diff --git a/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx b/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx index c8cd3dd..0e8b4b8 100644 --- a/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx +++ b/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx @@ -35,6 +35,8 @@ export default function BlockConfigPanel({ const [errors, setErrors] = useState>({}); const [panelWidth, setPanelWidth] = useState(400); const [isResizing, setIsResizing] = useState(false); + const [llmModels, setLlmModels] = useState([]); + const [embeddingModels, setEmbeddingModels] = useState([]); // sync formData with parent config changes // this ensures that saved config persists when panel is reopened @@ -75,6 +77,32 @@ export default function BlockConfigPanel({ }; }, []); + // fetch available LLM and embedding models + useEffect(() => { + const fetchModels = async () => { + try { + const [llmResponse, embeddingResponse] = await Promise.all([ + fetch("/api/llm-models"), + fetch("/api/embedding-models"), + ]); + + if (llmResponse.ok) { + const llmData = await llmResponse.json(); + setLlmModels(llmData.map((m: any) => m.name)); + } + + if (embeddingResponse.ok) { + const embeddingData = await embeddingResponse.json(); + setEmbeddingModels(embeddingData.map((m: any) => m.name)); + } + } catch (error) { + console.error("Failed to fetch models:", error); + } + }; + + fetchModels(); + }, []); + // handle resize useEffect(() => { if (!isResizing) return; @@ -206,6 +234,42 @@ export default function BlockConfigPanel({ ); } + // llm model dropdown + if (key === "model" && llmModels.length > 0) { + return ( + + ); + } + + // embedding model dropdown + if (key === "embedding_model" && embeddingModels.length > 0) { + return ( + + ); + } + // field reference dropdown (references to accumulated_state fields) if (schema.isFieldReference) { if (availableFields.length > 0) { diff --git a/frontend/src/components/pipeline-editor/BlockNode.tsx b/frontend/src/components/pipeline-editor/BlockNode.tsx index 92e7922..da30bb9 100644 --- a/frontend/src/components/pipeline-editor/BlockNode.tsx +++ b/frontend/src/components/pipeline-editor/BlockNode.tsx @@ -61,12 +61,43 @@ function getPreviewFields(blockType: string, config: Record): Array // priority fields based on block type let priorityKeys: string[] = []; - if (type.includes("generator")) { + console.log(type); + + // data augmentation blocks + if (type.includes("sampler")) { + priorityKeys = ["target_count", "categorical_fields"]; + } else if (type.includes("infiller")) { + priorityKeys = ["fields_to_generate", "model", "temperature"]; + } else if (type.includes("remover")) { + priorityKeys = ["similarity_threshold", "comparison_fields", "embedding_model"]; + } + // multiplier blocks + else if (type.includes("multiplier")) { + priorityKeys = ["parser_type", "chunk_size"]; + } + // langfuse integration + else if (type.includes("langfuse")) { + priorityKeys = ["dataset_name"]; + } + // field mapper + else if (type.includes("mapper")) { + priorityKeys = ["mappings"]; + } + // ragas metrics + else if (type.includes("ragas")) { + priorityKeys = ["metrics", "model", "score_threshold"]; + } + // generators (text/structured) + else if (type.includes("generator")) { priorityKeys = ["model", "temperature", "max_tokens"]; - } else if (type.includes("validator")) { - priorityKeys = ["min_length", "max_length", "required_fields"]; - } else if (type.includes("score")) { - priorityKeys = ["generated_field", "reference_field", "metric"]; + } + // validators + else if (type.includes("validator")) { + priorityKeys = ["min_length", "max_length", "required_fields", "field_name"]; + } + // score blocks + else if (type.includes("score")) { + priorityKeys = ["generated_field", "reference_field", "field_name", "metric"]; } // find up to 2 configured values from priority keys @@ -76,9 +107,16 @@ function getPreviewFields(blockType: string, config: Record): Array if (config[key] !== undefined && config[key] !== null && config[key] !== "") { let displayValue = String(config[key]); + // special formatting for arrays/objects + if (Array.isArray(config[key])) { + displayValue = `[${config[key].length} items]`; + } else if (typeof config[key] === "object") { + displayValue = `{${Object.keys(config[key]).length} keys}`; + } + // truncate long values - if (displayValue.length > 20) { - displayValue = displayValue.slice(0, 20) + "..."; + if (displayValue.length > 25) { + displayValue = displayValue.slice(0, 25) + "..."; } preview.push([key, displayValue]); diff --git a/lib/blocks/builtin/duplicate_remover.py b/lib/blocks/builtin/duplicate_remover.py new file mode 100644 index 0000000..abf2481 --- /dev/null +++ b/lib/blocks/builtin/duplicate_remover.py @@ -0,0 +1,159 @@ +import logging +from typing import Any + +import litellm +from sklearn.metrics.pairwise import cosine_similarity + +from lib.blocks.base import BaseBlock +from lib.entities.block_execution_context import BlockExecutionContext + +logger = logging.getLogger(__name__) + + +class DuplicateRemover(BaseBlock): + name = "Duplicate Remover" + description = "Flag records similar to reference dataset using embedding-based similarity" + category = "validators" + inputs = ["*"] + outputs = ["*", "is_duplicate", "similarity_score"] + + _config_descriptions = { + "similarity_threshold": "Similarity threshold (0.0-1.0). Above = duplicate.", + "comparison_fields": "Fields to compare (leave empty to compare all text fields)", + "embedding_model": "Embedding model to use (leave empty for default). Skips check if no model configured.", + } + + def __init__( + self, + similarity_threshold: float = 0.85, + comparison_fields: list[str] | None = None, + embedding_model: str | None = None, + ): + self.similarity_threshold = similarity_threshold + self.comparison_fields = comparison_fields + self.embedding_model_name = embedding_model + + # cache reference embeddings (shared across records in same job) + self._reference_embeddings: list[list[float]] = [] + self._embeddings_initialized = False + + def _extract_text(self, record: dict[str, Any], fields: list[str] | None) -> str: + """ + extract text from specified fields or all string fields + joins with spaces for embedding + """ + if fields: + texts = [] + for field in fields: + value = record.get(field, "") + if value is not None: + texts.append(str(value)) + else: + # auto-detect string fields + texts = [] + for value in record.values(): + if isinstance(value, str) and value: + texts.append(value) + + return " ".join(texts) + + async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from app import llm_config_manager + + # get current record from context + current_record = context.accumulated_state.copy() + current_record.pop("_usage", None) # remove internal fields + current_record.pop("_hints", None) + + # get reference samples from initial state + samples = context.get_state("samples", []) + + if not samples: + logger.warning("No samples found for duplicate checking, marking as not duplicate") + return { + **current_record, + "is_duplicate": False, + "similarity_score": 0.0, + } + + # extract text for comparison + current_text = self._extract_text(current_record, self.comparison_fields) + + if not current_text: + logger.warning("No text found in record for comparison, skipping check") + return { + **current_record, + "is_duplicate": False, + "similarity_score": 0.0, + } + + try: + # get embedding model + embedding_config = await llm_config_manager.get_embedding_model( + self.embedding_model_name + ) + + # build reference embeddings (lazy, once per pipeline run) + if not self._embeddings_initialized: + logger.info(f"Building reference embeddings for {len(samples)} samples") + + sample_texts = [ + self._extract_text(s, self.comparison_fields) for s in samples + ] + + # filter empty texts + sample_texts = [t for t in sample_texts if t] + + if not sample_texts: + logger.warning("No valid sample texts for embedding, skipping check") + return { + **current_record, + "is_duplicate": False, + "similarity_score": 0.0, + } + + # embed all sample texts + embedding_params = llm_config_manager._prepare_embedding_call( + embedding_config, input_text=sample_texts + ) + response = await litellm.aembedding(**embedding_params) + + self._reference_embeddings = [item["embedding"] for item in response.data] + self._embeddings_initialized = True + + logger.info(f"Initialized {len(self._reference_embeddings)} reference embeddings") + + # embed current text + embedding_params = llm_config_manager._prepare_embedding_call( + embedding_config, input_text=current_text + ) + response = await litellm.aembedding(**embedding_params) + current_embedding = response.data[0]["embedding"] + + # compute cosine similarities + similarities = cosine_similarity( + [current_embedding], self._reference_embeddings + )[0] + + max_similarity = float(max(similarities)) if len(similarities) > 0 else 0.0 + is_duplicate = max_similarity >= self.similarity_threshold + + if is_duplicate: + logger.warning( + f"Duplicate detected: similarity={max_similarity:.4f} >= {self.similarity_threshold}" + ) + + except Exception as e: + # no embedding model configured or error - skip check + logger.warning( + f"Embedding check failed or no model configured: {e}. " + f"Skipping similarity check." + ) + is_duplicate = False + max_similarity = 0.0 + + return { + **current_record, + "is_duplicate": is_duplicate, + "similarity_score": round(max_similarity, 4), + } diff --git a/lib/blocks/builtin/semantic_infiller.py b/lib/blocks/builtin/semantic_infiller.py new file mode 100644 index 0000000..a6a4b2c --- /dev/null +++ b/lib/blocks/builtin/semantic_infiller.py @@ -0,0 +1,221 @@ +import json +import logging +import re +from typing import Any + +import litellm + +from lib.blocks.base import BaseBlock +from lib.entities import pipeline +from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError + +logger = logging.getLogger(__name__) + + +class SemanticInfiller(BaseBlock): + name = "Semantic Infiller" + description = "Complete skeleton records using LLM to generate free-text fields" + category = "generators" + inputs = ["*"] # accepts any skeleton fields + outputs = ["*"] # returns merged skeleton + generated fields + + _config_descriptions = { + "fields_to_generate": "List of field names for LLM to generate (e.g., ['bio', 'description'])", + "model": "Select LLM model to use (leave empty for default)", + "temperature": "Sampling temperature (0.0 = deterministic, 1.0 = creative)", + "max_tokens": "Maximum tokens for generated response", + "system_prompt": "Custom system prompt (optional, overrides default)", + } + + def __init__( + self, + fields_to_generate: list[str], + model: str | None = None, + temperature: float = 0.8, + max_tokens: int = 500, + system_prompt: str = "", + ): + self.fields_to_generate = fields_to_generate + self.model_name = model + self.temperature = temperature + self.max_tokens = max_tokens + self.system_prompt = system_prompt + + def _build_generation_prompt( + self, skeleton: dict[str, Any], hints: dict[str, Any] + ) -> str: + """ + construct LLM prompt with constraints and hints + + format: + - specify fields to generate + - lock categorical constraints from skeleton + - provide numeric hints and exemplars + """ + fields_str = ", ".join(f'"{field}"' for field in self.fields_to_generate) + + # extract constraints (non-hint fields) + constraints = [] + for key, value in skeleton.items(): + constraints.append(f' - {key}: "{value}" (FIXED)') + + constraints_str = "\n".join(constraints) if constraints else " (none)" + + # extract hints + hint_lines = [] + for key, value in hints.items(): + if key.endswith("_range") and isinstance(value, list) and len(value) == 2: + field_name = key.replace("_range", "") + hint_lines.append(f" - {field_name} should be between {value[0]}-{value[1]}") + elif key == "exemplars" and isinstance(value, list): + hint_lines.append(" - Example records for reference:") + for ex in value[:2]: # show max 2 exemplars + # only show generated fields from exemplar + ex_fields = { + f: ex.get(f, "") + for f in self.fields_to_generate + if f in ex + } + hint_lines.append(f" {json.dumps(ex_fields)}") + + hints_str = "\n".join(hint_lines) if hint_lines else " (none)" + + prompt = f"""You are a synthetic data generator. + +Generate a JSON object with the following fields: {fields_str} + +CONSTRAINTS (must follow exactly): +{constraints_str} + +HINTS (use as guidance): +{hints_str} + +Return ONLY valid JSON with the requested fields, no markdown formatting or explanations.""" + + return prompt + + def _parse_json_safely(self, content: str) -> dict[str, Any]: + """ + parse JSON from LLM response + handles markdown code blocks and other common patterns + """ + # first try direct parsing + try: + return json.loads(content) + except json.JSONDecodeError: + pass + + # try extracting from markdown code block + json_match = re.search(r"```(?:json)?\s*\n(.*?)\n```", content, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + + # try extracting anything that looks like JSON + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(0)) + except json.JSONDecodeError: + pass + + raise BlockExecutionError( + "LLM returned invalid JSON", + detail={ + "content": content[:500], # first 500 chars + "hint": "LLM should return pure JSON without markdown or explanations", + }, + ) + + async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from app import llm_config_manager + + # extract skeleton from context + skeleton = context.accumulated_state.copy() + hints = skeleton.pop("_hints", {}) + skeleton.pop("_usage", None) # remove internal fields + + # build generation prompt + prompt = self._build_generation_prompt(skeleton, hints) + + # prepare system prompt + system_content = ( + self.system_prompt + if self.system_prompt + else "You are a synthetic data generator that produces realistic, diverse records." + ) + + messages = [ + {"role": "system", "content": system_content}, + {"role": "user", "content": prompt}, + ] + + # get LLM config + llm_config = await llm_config_manager.get_llm_model(self.model_name) + llm_params = llm_config_manager.prepare_llm_call( + llm_config, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + # add trace metadata + llm_params["metadata"] = { + "trace_id": context.trace_id, + "tags": ["datagenflow", "semantic-infiller"], + } + + logger.info( + f"Generating fields {self.fields_to_generate} with model={llm_params.get('model')}" + ) + + try: + response = await litellm.acompletion(**llm_params) + except Exception as e: + raise BlockExecutionError( + f"LLM call failed: {str(e)}", + detail={ + "skeleton": skeleton, + "prompt_preview": prompt[:200], + "error": str(e), + }, + ) + + # parse response + content = response.choices[0].message.content + try: + generated = self._parse_json_safely(content) + except BlockExecutionError as e: + logger.error(f"Failed to parse JSON: {e.message}") + raise + + # validate that LLM didn't modify skeleton fields + for field, value in skeleton.items(): + if field in generated and generated[field] != value: + logger.warning( + f"LLM modified locked field '{field}': expected {value}, got {generated[field]}. " + f"Restoring original value." + ) + generated[field] = value + + # merge skeleton + generated + result = {**skeleton, **generated} + + # extract usage + usage_info = pipeline.Usage( + input_tokens=response.usage.prompt_tokens or 0, + output_tokens=response.usage.completion_tokens or 0, + cached_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0, + ) + + result["_usage"] = usage_info.model_dump() + + logger.info( + f"Generated {len(generated)} fields " + f"(tokens: {usage_info.input_tokens}+{usage_info.output_tokens})" + ) + + return result diff --git a/lib/blocks/builtin/structure_sampler.py b/lib/blocks/builtin/structure_sampler.py new file mode 100644 index 0000000..eabb6ef --- /dev/null +++ b/lib/blocks/builtin/structure_sampler.py @@ -0,0 +1,272 @@ +import logging +import random +from collections import Counter, defaultdict +from typing import Any + +from lib.blocks.base import BaseMultiplierBlock +from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import ValidationError + +logger = logging.getLogger(__name__) + + +class StructureSampler(BaseMultiplierBlock): + name = "Structure Sampler" + description = "Learn distributions from samples and generate skeleton records" + category = "seeders" + inputs = [] # reads from initial state + outputs = ["*"] # dynamic based on categorical fields + + _config_descriptions = { + "target_count": "Number of skeleton records to generate", + "categorical_fields": "List of categorical field names to sample (e.g., ['plan', 'role'])", + "numeric_fields": "List of numeric field names for hint generation (e.g., ['storage'])", + "dependencies": "Field dependencies as {child: [parent1]} (e.g., {'role': ['plan']})", + "seed": "Random seed for reproducibility (optional)", + } + + def __init__( + self, + target_count: int, + categorical_fields: list[str], + numeric_fields: list[str] = [], + dependencies: dict[str, list[str]] = {}, + seed: int | None = None, + ): + self.target_count = target_count + self.categorical_fields = categorical_fields + self.numeric_fields = numeric_fields + self.dependencies = dependencies + self.seed = seed + + if seed is not None: + random.seed(seed) + + def _validate_samples(self, samples: list[dict[str, Any]]) -> None: + """validate samples meet minimum requirements""" + if not samples: + raise ValidationError( + "No samples provided in metadata", + detail={ + "required_field": "samples", + "hint": "Add 'samples' array to seed metadata", + }, + ) + + if len(samples) < 10: + logger.warning( + f"Only {len(samples)} samples provided - statistical accuracy may be low. " + f"Recommend at least 20 samples for better distribution modeling." + ) + + def _analyze_samples(self, samples: list[dict[str, Any]]) -> dict[str, Any]: + """ + extract statistical patterns from samples + + returns: + { + "categorical_probs": {"field": {"value": prob, ...}}, + "conditional_probs": {"field|parent=val": {"value": prob, ...}}, + "numeric_stats": {"field": {"min": x, "max": y, "mean": z}}, + "exemplars": [sample1, sample2, ...] + } + """ + profile: dict[str, Any] = { + "categorical_probs": {}, + "conditional_probs": {}, + "numeric_stats": {}, + "exemplars": [], + } + + # categorical field distributions + for field in self.categorical_fields: + values = [sample.get(field) for sample in samples] + counts = Counter(values) + total = sum(counts.values()) + profile["categorical_probs"][field] = { + value: count / total for value, count in counts.items() + } + + # conditional probabilities for dependencies + for child_field, parent_fields in self.dependencies.items(): + if child_field not in self.categorical_fields: + continue + + # group samples by parent values + grouped: dict[tuple, list[Any]] = defaultdict(list) + for sample in samples: + parent_key = tuple(sample.get(p) for p in parent_fields) + child_value = sample.get(child_field) + grouped[parent_key].append(child_value) + + # compute conditional probabilities + for parent_key, child_values in grouped.items(): + counts = Counter(child_values) + total = sum(counts.values()) + probs = {value: count / total for value, count in counts.items()} + + # build key: "child|parent1=val1,parent2=val2" + parent_str = ",".join(f"{p}={v}" for p, v in zip(parent_fields, parent_key)) + key = f"{child_field}|{parent_str}" + profile["conditional_probs"][key] = probs + + # numeric field statistics + for field in self.numeric_fields: + values = [sample.get(field) for sample in samples if sample.get(field) is not None] + if values: + # filter non-numeric + numeric_values = [] + for v in values: + try: + numeric_values.append(float(v)) + except (ValueError, TypeError): + logger.warning( + f"Non-numeric value {v} in numeric field {field}, skipping" + ) + + if numeric_values: + profile["numeric_stats"][field] = { + "min": min(numeric_values), + "max": max(numeric_values), + "mean": sum(numeric_values) / len(numeric_values), + } + + # select random exemplars + num_exemplars = min(5, len(samples)) + profile["exemplars"] = random.sample(samples, num_exemplars) + + return profile + + def _topological_sort(self, fields: list[str]) -> list[str]: + """ + sort fields by dependency order (parents before children) + uses simple algorithm for flat dependencies + """ + # build in-degree map + in_degree = {field: 0 for field in fields} + for child_field, parent_fields in self.dependencies.items(): + if child_field in in_degree: + in_degree[child_field] = len(parent_fields) + + # collect fields with no dependencies first + result = [] + remaining = set(fields) + + while remaining: + # find fields with no remaining dependencies + no_deps = [f for f in remaining if in_degree[f] == 0] + + if not no_deps: + raise ValidationError( + "Circular dependency detected in field dependencies", + detail={"dependencies": self.dependencies}, + ) + + # add to result + result.extend(sorted(no_deps)) # sort for determinism + remaining -= set(no_deps) + + # decrease in-degree for children + for field in no_deps: + for child_field, parent_fields in self.dependencies.items(): + if field in parent_fields and child_field in remaining: + in_degree[child_field] -= 1 + + return result + + def _sample_from_distribution(self, probs: dict[str, float]) -> Any: + """weighted random choice from probability distribution""" + if not probs: + return None + + values = list(probs.keys()) + weights = list(probs.values()) + return random.choices(values, weights=weights, k=1)[0] + + def _generate_skeletons( + self, profile: dict[str, Any], count: int + ) -> list[dict[str, Any]]: + """ + generate N skeleton records by sampling from learned distributions + + each skeleton contains: + - all categorical fields (sampled values) + - _hints field (numeric ranges, exemplars for LLM) + """ + results = [] + field_order = self._topological_sort(self.categorical_fields) + + for _ in range(count): + skeleton: dict[str, Any] = {} + + # sample categorical values in dependency order + for field in field_order: + if field in self.dependencies: + # conditional sampling + parent_fields = self.dependencies[field] + parent_values = tuple(skeleton.get(p) for p in parent_fields) + parent_str = ",".join(f"{p}={v}" for p, v in zip(parent_fields, parent_values)) + key = f"{field}|{parent_str}" + + if key in profile["conditional_probs"]: + probs = profile["conditional_probs"][key] + else: + # fallback to marginal distribution + logger.warning( + f"Unseen combination {key}, using marginal distribution for {field}" + ) + probs = profile["categorical_probs"].get(field, {}) + + else: + # independent sampling + probs = profile["categorical_probs"].get(field, {}) + + skeleton[field] = self._sample_from_distribution(probs) + + # generate hints for numeric fields + hints: dict[str, Any] = {} + + for field in self.numeric_fields: + if field in profile["numeric_stats"]: + stats = profile["numeric_stats"][field] + hints[f"{field}_range"] = [stats["min"], stats["max"]] + + # add exemplars that match current categorical values + matching_exemplars = [ + ex + for ex in profile["exemplars"] + if all(ex.get(f) == skeleton.get(f) for f in self.categorical_fields) + ] + + if not matching_exemplars: + # use any exemplars + matching_exemplars = profile["exemplars"][:3] + + hints["exemplars"] = matching_exemplars + + skeleton["_hints"] = hints + results.append(skeleton) + + return results + + async def execute(self, context: BlockExecutionContext) -> list[dict[str, Any]]: # type: ignore[override] + # read samples from initial state + samples = context.get_state("samples", []) + + # validate samples + self._validate_samples(samples) + + # analyze samples (internal stats modeling) + logger.info(f"Analyzing {len(samples)} samples for distribution patterns") + profile = self._analyze_samples(samples) + + # generate skeletons + logger.info(f"Generating {self.target_count} skeleton records") + skeletons = self._generate_skeletons(profile, self.target_count) + + logger.info( + f"Successfully generated {len(skeletons)} skeletons with " + f"{len(self.categorical_fields)} categorical fields" + ) + + return skeletons diff --git a/lib/templates/data_augmentation.yaml b/lib/templates/data_augmentation.yaml new file mode 100644 index 0000000..e142e27 --- /dev/null +++ b/lib/templates/data_augmentation.yaml @@ -0,0 +1,23 @@ +name: Data Augmentation +description: Generate synthetic records preserving statistical distributions from sample data +blocks: + - type: StructureSampler + config: + target_count: "{{ target_count }}" + categorical_fields: "{{ categorical_fields }}" + numeric_fields: "{{ numeric_fields }}" + dependencies: "{{ dependencies }}" + seed: 42 + + - type: SemanticInfiller + config: + fields_to_generate: "{{ fields_to_generate }}" + temperature: 0.8 + max_tokens: 200 + model: null + + - type: DuplicateRemover + config: + similarity_threshold: 0.85 + comparison_fields: "{{ comparison_fields }}" + embedding_model: null diff --git a/lib/templates/seeds/seed_data_augmentation.json b/lib/templates/seeds/seed_data_augmentation.json new file mode 100644 index 0000000..cb4f6c6 --- /dev/null +++ b/lib/templates/seeds/seed_data_augmentation.json @@ -0,0 +1,23 @@ +[ + { + "repetitions": 1, + "metadata": { + "samples": [ + {"plan": "Free", "role": "Viewer", "storage": 1, "bio": "Student learning web development"}, + {"plan": "Free", "role": "Viewer", "storage": 2, "bio": "Just exploring the platform"}, + {"plan": "Pro", "role": "Editor", "storage": 50, "bio": "Freelance designer managing projects"}, + {"plan": "Pro", "role": "Editor", "storage": 75, "bio": "Small agency owner"}, + {"plan": "Pro", "role": "Admin", "storage": 100, "bio": "Team lead overseeing projects"}, + {"plan": "Enterprise", "role": "Admin", "storage": 500, "bio": "CTO managing infrastructure"} + ], + "target_count": 20, + "categorical_fields": ["plan", "role"], + "numeric_fields": ["storage"], + "fields_to_generate": ["bio", "storage"], + "dependencies": { + "role": ["plan"] + }, + "comparison_fields": ["bio"] + } + } +] diff --git a/pyproject.toml b/pyproject.toml index a618afc..700b299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pyyaml>=6.0.3", "litellm>=1.78.5", "rouge-score>=0.1.2", + "scikit-learn>=1.3.0", "llama-index-core>=0.14.7", "anthropic>=0.73.0", "google-generativeai>=0.8.5", diff --git a/tests/integration/test_data_augmentation.py b/tests/integration/test_data_augmentation.py new file mode 100644 index 0000000..1ba0061 --- /dev/null +++ b/tests/integration/test_data_augmentation.py @@ -0,0 +1,261 @@ +"""integration test for data augmentation pipeline""" +import json +import pytest + +from lib.storage import Storage +from lib.workflow import Pipeline + + +@pytest.mark.asyncio +async def test_data_augmentation_pipeline(tmp_path): + """test complete data augmentation pipeline with all 3 blocks""" + + # setup test database + db_path = tmp_path / "test.db" + storage = Storage(str(db_path)) + await storage.init_db() + + try: + # define pipeline + pipeline_def = { + "blocks": [ + { + "type": "StructureSampler", + "config": { + "target_count": 5, + "categorical_fields": ["plan", "role"], + "numeric_fields": ["storage"], + "dependencies": {"role": ["plan"]}, + "seed": 42, + }, + }, + { + "type": "SemanticInfiller", + "config": { + "fields_to_generate": ["bio", "storage"], + "temperature": 0.8, + "max_tokens": 200, + "model": None, + }, + }, + { + "type": "DuplicateRemover", + "config": { + "similarity_threshold": 0.85, + "comparison_fields": ["bio"], + "embedding_model": None, + }, + }, + ] + } + + # save pipeline to database + pipeline_id = await storage.save_pipeline("test_augmentation", json.dumps(pipeline_def)) + + # create pipeline instance + pipeline = Pipeline("test_augmentation", pipeline_def["blocks"]) + + # prepare seed data + initial_data = { + "samples": [ + { + "plan": "Free", + "role": "Viewer", + "storage": 1, + "bio": "Student learning", + }, + { + "plan": "Free", + "role": "Viewer", + "storage": 2, + "bio": "Just exploring", + }, + { + "plan": "Pro", + "role": "Editor", + "storage": 50, + "bio": "Freelancer", + }, + { + "plan": "Pro", + "role": "Admin", + "storage": 100, + "bio": "Team lead", + }, + ] + } + + # execute pipeline + results = await pipeline.execute(initial_data) + + # verify results + assert isinstance(results, list), "Multiplier pipeline should return list" + assert len(results) == 5, f"Expected 5 results, got {len(results)}" + + # verify each result + for exec_result in results: + result = exec_result.result + trace = exec_result.trace + trace_id = exec_result.trace_id + # check required fields + assert "plan" in result, "Missing plan field" + assert "role" in result, "Missing role field" + assert "storage" in result, "Missing storage field" + assert "bio" in result, "Missing bio field" + + # check duplicate check fields + assert "is_duplicate" in result, "Missing is_duplicate field" + assert "similarity_score" in result, "Missing similarity_score field" + assert isinstance(result["is_duplicate"], bool) + assert isinstance(result["similarity_score"], float) + + # check plan values are valid + assert result["plan"] in ["Free", "Pro"], f"Invalid plan: {result['plan']}" + + # check role values are valid + assert result["role"] in ["Viewer", "Editor", "Admin"], f"Invalid role: {result['role']}" + + # check dependencies: Free -> Viewer + if result["plan"] == "Free": + assert result["role"] == "Viewer", "Free plan should have Viewer role" + + # check trace has 3 steps + assert len(trace) == 3, f"Expected 3 trace steps, got {len(trace)}" + + step_types = [step["block_type"] for step in trace] + assert step_types == [ + "StructureSampler", + "SemanticInfiller", + "DuplicateRemover", + ], f"Unexpected trace steps: {step_types}" + + # verify trace_id is valid + assert isinstance(trace_id, str) + assert len(trace_id) > 0 + + print("\n✅ All integration tests passed!") + print(f"Generated {len(results)} records successfully") + + # print sample result for inspection + sample = results[0].result + print(f"\nSample result:") + print(f" plan: {sample['plan']}") + print(f" role: {sample['role']}") + print(f" storage: {sample['storage']}") + print(f" bio: {sample['bio']}") + print(f" is_duplicate: {sample['is_duplicate']}") + print(f" similarity_score: {sample['similarity_score']}") + + finally: + await storage.close() + + +@pytest.mark.asyncio +async def test_structure_sampler_alone(tmp_path): + """test StructureSampler block in isolation""" + + db_path = tmp_path / "test.db" + storage = Storage(str(db_path)) + await storage.init_db() + + try: + pipeline_def = { + "blocks": [ + { + "type": "StructureSampler", + "config": { + "target_count": 10, + "categorical_fields": ["plan"], + "numeric_fields": [], + "dependencies": {}, + "seed": 42, + }, + } + ] + } + + pipeline_id = await storage.save_pipeline("test_sampler", json.dumps(pipeline_def)) + pipeline = Pipeline("test_sampler", pipeline_def["blocks"]) + + initial_data = { + "samples": [ + {"plan": "Free"}, + {"plan": "Free"}, + {"plan": "Pro"}, + ] + } + + results = await pipeline.execute(initial_data) + + assert isinstance(results, list) + assert len(results) == 10 + + # check distribution approximately matches input (2 Free, 1 Pro = 67% Free, 33% Pro) + plan_counts = {"Free": 0, "Pro": 0} + for exec_result in results: + plan_counts[exec_result.result["plan"]] += 1 + + # expect approximately 6-7 Free, 3-4 Pro (with seed=42, should be deterministic) + assert 5 <= plan_counts["Free"] <= 8, f"Free count out of range: {plan_counts['Free']}" + assert 2 <= plan_counts["Pro"] <= 5, f"Pro count out of range: {plan_counts['Pro']}" + + print(f"\n✅ StructureSampler test passed! Distribution: {plan_counts}") + + finally: + await storage.close() + + +@pytest.mark.asyncio +async def test_data_augmentation_with_no_embedding_model(tmp_path): + """test that DuplicateRemover gracefully handles missing embedding model""" + + db_path = tmp_path / "test.db" + storage = Storage(str(db_path)) + await storage.init_db() + + try: + pipeline_def = { + "blocks": [ + { + "type": "StructureSampler", + "config": { + "target_count": 3, + "categorical_fields": ["plan"], + "numeric_fields": [], + "dependencies": {}, + "seed": 42, + }, + }, + { + "type": "DuplicateRemover", + "config": { + "similarity_threshold": 0.85, + "comparison_fields": ["plan"], + "embedding_model": "non_existent_model", + }, + }, + ] + } + + pipeline_id = await storage.save_pipeline( + "test_no_embedding", json.dumps(pipeline_def) + ) + pipeline = Pipeline("test_no_embedding", pipeline_def["blocks"]) + + initial_data = {"samples": [{"plan": "Free"}]} + + # should not raise error, just skip similarity check + results = await pipeline.execute(initial_data) + + assert isinstance(results, list) + assert len(results) == 3 + + for exec_result in results: + # should have is_duplicate = False when embedding check fails + assert exec_result.result["is_duplicate"] is False + assert exec_result.result["similarity_score"] == 0.0 + + print("\n✅ No embedding model test passed!") + + finally: + await storage.close() From a9166d6efd619be96e04b9a9413d5b45eeee917b Mon Sep 17 00:00:00 2001 From: nicofretti Date: Fri, 9 Jan 2026 23:15:33 +0100 Subject: [PATCH 02/19] add: skill for creating blocks --- .../implementing-datagenflow-blocks/SKILL.md | 622 ++++++++++++++++++ .gitignore | 3 +- 2 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 .claude/skills/implementing-datagenflow-blocks/SKILL.md diff --git a/.claude/skills/implementing-datagenflow-blocks/SKILL.md b/.claude/skills/implementing-datagenflow-blocks/SKILL.md new file mode 100644 index 0000000..9d4e730 --- /dev/null +++ b/.claude/skills/implementing-datagenflow-blocks/SKILL.md @@ -0,0 +1,622 @@ +--- +name: implementing-datagenflow-blocks +description: Use when creating new blocks for DataGenFlow pipeline system or modifying existing blocks to ensure consistency with established patterns +--- + +# Implementing DataGenFlow Blocks + +## Overview + +DataGenFlow blocks are composable pipeline components. Follow KISS principles: write minimal functions, make code self-explanatory, keep it simple. + +## When to Use + +- Creating a new block +- Modifying existing block behavior +- Reviewing block implementations +- Debugging block execution issues + +**When NOT to use:** +- General backend code (use llm/rules-backend.md) +- Frontend development (use llm/rules-frontend.md) + +## Block Structure + +```python +import logging +from typing import Any + +import litellm # if using LLM + +from lib.blocks.base import BaseBlock +from lib.entities import pipeline +from lib.entities.block_execution_context import BlockExecutionContext +from lib.template_renderer import render_template # if using templates + +logger = logging.getLogger(__name__) + + +class MyBlock(BaseBlock): + name = "My Block" + description = "Short description of what this block does" + category = "generators" # generators|transformers|validators|utilities + inputs = ["field1"] # or ["*"] for any input fields + outputs = ["field2"] # or ["*"] for dynamic outputs + + _config_descriptions = { + "param_name": "Help text shown in UI", + } + + def __init__( + self, + param1: str, + model: str | None = None, # EXACTLY "model" for LLM selection UI + temperature: float = 0.7, + ): + self.param1 = param1 + self.model_name = model # store as model_name + self.temperature = temperature + + async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from app import llm_config_manager # import inside execute + + # your logic here + + return {"field": value, "_usage": usage_info} +``` + +## UI Integration Patterns + +The frontend automatically renders different UI controls based on parameter names, types, and class attributes. + +### Model Dropdown (LLM) + +**Parameter MUST be named exactly `model`** for automatic dropdown: + +```python +def __init__( + self, + model: str | None = None, # MUST be "model" and str|None + temperature: float = 0.7, + max_tokens: int = 2048, +): + self.model_name = model # store as model_name +``` + +**Config description:** +```python +_config_descriptions = { + "model": "Select LLM model to use (leave empty for default)", +} +``` + +**Usage in execute:** +```python +async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from app import llm_config_manager + + llm_config = await llm_config_manager.get_llm_model(self.model_name) + llm_params = llm_config_manager.prepare_llm_call( + llm_config, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) +``` + +### Embedding Model Dropdown + +**Parameter MUST be named exactly `embedding_model`**: + +```python +def __init__( + self, + embedding_model: str | None = None, # MUST be "embedding_model" +): + self.embedding_model_name = embedding_model +``` + +**Config description:** +```python +_config_descriptions = { + "embedding_model": "Embedding model to use (leave empty for default)", +} +``` + +**Usage:** +```python +embedding_config = await llm_config_manager.get_embedding_model( + self.embedding_model_name +) +``` + +### Enum Dropdown + +Use `_config_enums` class attribute to create dropdown with predefined options: + +```python +class MyBlock(BaseBlock): + _config_enums = { + "mode": ["strict", "lenient", "auto"], + "format": ["json", "yaml", "xml"], + } + + def __init__( + self, + mode: str = "auto", + format: str = "json", + ): + self.mode = mode + self.format = format +``` + +### Multi-Select Checkboxes + +For array parameters with enum values: + +```python +class MyBlock(BaseBlock): + _config_enums = { + "features": ["feature_a", "feature_b", "feature_c"], + } + + def __init__( + self, + features: list[str] | None = None, + ): + self.features = features or [] +``` + +### Field Reference Dropdown + +Use `_field_references` to create dropdown showing available fields from pipeline: + +```python +class MyBlock(BaseBlock): + _field_references = ["source_field", "target_field"] + + _config_descriptions = { + "source_field": "Field to read from", + "target_field": "Field to write to", + } + + def __init__( + self, + source_field: str, + target_field: str, + ): + self.source_field = source_field + self.target_field = target_field +``` + +### Template Fields (Monaco Editor) + +Parameters with these patterns automatically get Monaco editor: +- Name contains "prompt", "template", or "instruction" +- Or set `schema.format = "jinja2"` via config + +```python +def __init__( + self, + user_prompt: str = "", # automatically gets editor + system_prompt: str = "", # automatically gets editor + custom_template: str = "", # automatically gets editor +): + self.user_prompt = user_prompt +``` + +**Config description should mention Jinja2:** +```python +_config_descriptions = { + "user_prompt": ( + "Jinja2 template. Reference fields with {{ field_name }} or " + "{{ metadata.field_name }}" + ), +} +``` + +**Rendering:** +```python +from lib.template_renderer import render_template + +rendered = render_template(self.user_prompt, context.accumulated_state) +``` + +### JSON Object/Array (Monaco Editor) + +Parameters typed as `dict` or `list` get JSON Monaco editor: + +```python +def __init__( + self, + json_schema: dict[str, Any], # JSON editor + field_list: list[str], # JSON editor +): + self.json_schema = json_schema + self.field_list = field_list +``` + +### Number Input + +Parameters typed as `int` or `float` get number input: + +```python +def __init__( + self, + temperature: float = 0.7, # number input + max_tokens: int = 2048, # number input +): + self.temperature = temperature +``` + +### Textarea + +Parameters with these patterns get multi-line textarea: +- String length > 100 characters +- Name contains "description" +- Type has long content + +```python +def __init__( + self, + description: str = "", # automatically gets textarea +): + self.description = description +``` + +### Text Input (Default) + +Short string parameters get single-line text input: + +```python +def __init__( + self, + name: str, + label: str = "", +): + self.name = name +``` + +## JSON Array as String Pattern + +For parameters that should accept either JSON array or Jinja template (like `fields_to_generate`): + +```python +def __init__( + self, + fields_to_generate: str, # str, not list[str] +): + self.fields_to_generate_template = fields_to_generate + +_config_descriptions = { + "fields_to_generate": ( + 'JSON array or Jinja template. Examples: ["bio", "storage"] or ' + '{{ fields_to_generate | tojson }}' + ), +} +``` + +**Parsing in execute:** +```python +import json + +fields_rendered = render_template( + self.fields_to_generate_template, + context.accumulated_state +) +try: + fields_list = json.loads(fields_rendered) + if not isinstance(fields_list, list): + raise BlockExecutionError("Must be JSON array") +except json.JSONDecodeError as e: + raise BlockExecutionError(f"Invalid JSON: {str(e)}") +``` + +**Template usage:** +```yaml +fields_to_generate: "{{ fields_to_generate | tojson }}" +``` + +## LLM Integration Pattern + +Full pattern for blocks that call LLM: + +```python +async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from app import llm_config_manager + + # prepare messages + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # get llm config + llm_config = await llm_config_manager.get_llm_model(self.model_name) + llm_params = llm_config_manager.prepare_llm_call( + llm_config, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + # add trace metadata for langfuse grouping + llm_params["metadata"] = { + "trace_id": context.trace_id, + "tags": ["datagenflow"], + } + + logger.info(f"Calling LiteLLM with model={llm_params.get('model')}") + + try: + response = await litellm.acompletion(**llm_params) + except Exception as e: + logger.error(f"LLM call failed for {self.name}: {e}") + raise + + content = response.choices[0].message.content + + # extract usage info + usage_info = pipeline.Usage( + input_tokens=response.usage.prompt_tokens or 0, + output_tokens=response.usage.completion_tokens or 0, + cached_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0, + ) + + return { + "generated": content, + "_usage": usage_info.model_dump(), + } +``` + +## State Management + +### Reading State + +```python +async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + # get current record + current = context.accumulated_state.copy() + + # remove internal fields + current.pop("_usage", None) + current.pop("_hints", None) + + # get reference data from initial state + samples = context.get_state("samples", []) +``` + +### Caching Per Execution + +**Never use instance-level state that persists across jobs.** Use trace_id-keyed caching: + +```python +def __init__(self): + # cache per trace_id (one cache per pipeline execution) + self._embeddings_cache: dict[str, list[list[float]]] = {} + +async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + trace_id = context.trace_id + + # build cache once per pipeline execution + if trace_id not in self._embeddings_cache: + # compute embeddings + self._embeddings_cache[trace_id] = embeddings + + # use cached data + cached_embeddings = self._embeddings_cache[trace_id] +``` + +## Multiplier Blocks + +Blocks that generate multiple items from one input: + +```python +from lib.blocks.base import BaseMultiplierBlock + +class StructureSampler(BaseMultiplierBlock): + name = "Structure Sampler" + category = "generators" + + async def execute( + self, + initial_data: dict[str, Any] + ) -> list[dict[str, Any]]: + # return list of records + return [record1, record2, record3] +``` + +## Code Quality + +### KISS Principle + +Write minimal number of functions, make code self-explanatory: + +```python +# ✅ good - simple and clear +def _prepare_prompts(self, data: dict[str, Any]) -> tuple[str, str]: + """render jinja2 templates with data context""" + system_template = self.system_prompt or data.get("system", "") + user_template = self.user_prompt or data.get("user", "") + + system = render_template(system_template, data) if system_template else "" + user = render_template(user_template, data) if user_template else "" + + return system, user + +# ❌ bad - over-engineered with too many tiny functions +def _get_system(self, data): ... +def _get_user(self, data): ... +def _render_system(self, template, data): ... +def _render_user(self, template, data): ... +``` + +### Comments + +Comments in lowercase, explain WHY not WHAT: + +```python +# ✅ good - explains why +def _extract_text(self, record: dict[str, Any]) -> str: + """ + extract text from specified fields or all string fields + joins with spaces for embedding + """ + +# ❌ bad - just describes what code does +def _extract_text(self, record: dict[str, Any]) -> str: + """Extract text from record fields""" + # Loop through fields and get string values +``` + +### Imports + +All imports at top of file, not inside functions (except `from app import llm_config_manager`): + +```python +# ✅ good +import json +import logging +from typing import Any + +import litellm + +from lib.blocks.base import BaseBlock + +# ❌ bad +def execute(self, context): + import json # wrong place +``` + +**Exception:** `from app import llm_config_manager` goes inside `execute()` to avoid circular imports. + +## Testing + +### Unit Tests + +Create `tests/blocks/test_.py`: + +```python +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +from lib.blocks.builtin.my_block import MyBlock +from lib.entities.block_execution_context import BlockExecutionContext + + +def make_context(state: dict) -> BlockExecutionContext: + """helper to create test context""" + return BlockExecutionContext( + trace_id="test-trace", + pipeline_id=1, + accumulated_state=state, + ) + + +class TestMyBlockInit: + def test_init_basic(self): + block = MyBlock(param="value") + assert block.param == "value" + + +class TestMyBlockExecution: + @pytest.mark.asyncio + @patch("litellm.acompletion") + @patch("app.llm_config_manager") + async def test_execute_basic(self, mock_config_manager, mock_completion): + # setup mocks + mock_config_manager.get_llm_model = AsyncMock(...) + mock_completion.return_value = MagicMock(...) + + block = MyBlock(param="value") + context = make_context({"field": "value"}) + + result = await block.execute(context) + + assert result["field"] == "expected" +``` + +### Integration Tests + +Add to `tests/integration/test_data_augmentation.py`. + +## Documentation Updates + +**Always update after implementing:** + +1. **llm/state-project.md** - block count, description +2. **llm/state-backend.md** - block count, details +3. **lib/templates/** - template YAML if applicable + +## Common Mistakes + +| Mistake | Problem | Fix | +|---------|---------|-----| +| Parameter named `model_name` | No dropdown UI | Name it exactly `model` | +| Parameter named `embedding` | No dropdown UI | Name it exactly `embedding_model` | +| `list[str]` for JSON arrays | Can't use templates | Use `str`, render + parse | +| Instance-level cache | Data leaks between jobs | Use `dict[str, T]` keyed by `trace_id` | +| Imports inside functions | Not the codebase style | Move to top (except llm_config_manager) | +| Over-engineering | Too many tiny functions | KISS - keep it simple | +| Comments describe what | Obvious from code | Explain WHY, lowercase | +| Forgot `_usage` | Usage not tracked | Always return `_usage` from LLM | +| Missing `_config_descriptions` | No help text in UI | Add descriptions for all params | +| Wrong enum format | UI doesn't render dropdown | Use `_config_enums` class attribute | + +## Implementation Checklist + +**Design:** +- [ ] Choose block type (BaseBlock vs BaseMultiplierBlock) +- [ ] Define inputs/outputs +- [ ] Identify parameters and their types +- [ ] Name model parameters correctly (`model`, `embedding_model`) +- [ ] Decide which params need enum dropdowns or field references + +**Implementation:** +- [ ] Add all imports at top (except llm_config_manager) +- [ ] Create class with `name`, `description`, `category`, `inputs`, `outputs` +- [ ] Add `_config_descriptions` with helpful UI text +- [ ] Add `_config_enums` if using dropdowns +- [ ] Add `_field_references` if using field selection +- [ ] Implement `__init__` with correct parameter types +- [ ] Implement `execute()` method +- [ ] Add template rendering if needed +- [ ] Use `llm_config_manager.get_llm_model()` for LLM +- [ ] Use `llm_config_manager.get_embedding_model()` for embeddings +- [ ] Add trace metadata to `llm_params["metadata"]` +- [ ] Track usage with `pipeline.Usage()` and return `_usage` +- [ ] Use trace_id-keyed caching if needed +- [ ] Write lowercase comments explaining WHY + +**Testing:** +- [ ] Create unit test file `tests/blocks/test_.py` +- [ ] Test initialization variants +- [ ] Test execution with mocked LLM config +- [ ] Test edge cases and error handling +- [ ] Add integration test +- [ ] Run `pytest tests/` - all pass + +**Documentation:** +- [ ] Update `llm/state-project.md` +- [ ] Update `llm/state-backend.md` +- [ ] Create template YAML if applicable + +**Review:** +- [ ] Model parameters named exactly right +- [ ] Imports at top (except llm_config_manager) +- [ ] No instance-level state +- [ ] KISS principle followed +- [ ] `_usage` returned if using LLM +- [ ] All UI integrations correct (enums, field refs, descriptions) + +## Reference Examples + +**Simple:** `lib/blocks/builtin/field_mapper.py` + +**LLM:** `lib/blocks/builtin/text_generator.py` + +**Structured:** `lib/blocks/builtin/structured_generator.py` + +**Multiplier:** `lib/blocks/builtin/structure_sampler.py` + +**Embedding:** `lib/blocks/builtin/duplicate_remover.py` diff --git a/.gitignore b/.gitignore index cf68d8c..737f956 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,8 @@ data/*.db-journal # ide .vscode/ .idea/ -.claude/ +.claude/* +!.claude/skills/ .worktrees/ # cache From 84ce41e78747d95b6e5077496a19dc50cc178f88 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 00:20:04 +0100 Subject: [PATCH 03/19] wip: fixing blocks --- lib/blocks/builtin/duplicate_remover.py | 29 +- lib/blocks/builtin/semantic_infiller.py | 40 ++- lib/blocks/builtin/structure_sampler.py | 182 +++++++----- lib/templates/data_augmentation.yaml | 2 +- llm/state-backend.md | 29 +- llm/state-project.md | 20 +- tests/blocks/test_duplicate_remover.py | 290 +++++++++++++++++++ tests/blocks/test_semantic_infiller.py | 305 ++++++++++++++++++++ tests/blocks/test_structure_sampler.py | 244 ++++++++++++++++ tests/integration/test_data_augmentation.py | 40 ++- 10 files changed, 1074 insertions(+), 107 deletions(-) create mode 100644 tests/blocks/test_duplicate_remover.py create mode 100644 tests/blocks/test_semantic_infiller.py create mode 100644 tests/blocks/test_structure_sampler.py diff --git a/lib/blocks/builtin/duplicate_remover.py b/lib/blocks/builtin/duplicate_remover.py index abf2481..3c8fb6e 100644 --- a/lib/blocks/builtin/duplicate_remover.py +++ b/lib/blocks/builtin/duplicate_remover.py @@ -33,9 +33,8 @@ def __init__( self.comparison_fields = comparison_fields self.embedding_model_name = embedding_model - # cache reference embeddings (shared across records in same job) - self._reference_embeddings: list[list[float]] = [] - self._embeddings_initialized = False + # cache reference embeddings per trace_id (one cache per pipeline execution) + self._embeddings_cache: dict[str, list[list[float]]] = {} def _extract_text(self, record: dict[str, Any], fields: list[str] | None) -> str: """ @@ -93,8 +92,11 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: self.embedding_model_name ) - # build reference embeddings (lazy, once per pipeline run) - if not self._embeddings_initialized: + # get trace_id for cache key + trace_id = context.trace_id + + # build reference embeddings (lazy, once per pipeline execution) + if trace_id not in self._embeddings_cache: logger.info(f"Building reference embeddings for {len(samples)} samples") sample_texts = [ @@ -118,10 +120,14 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: ) response = await litellm.aembedding(**embedding_params) - self._reference_embeddings = [item["embedding"] for item in response.data] - self._embeddings_initialized = True + self._embeddings_cache[trace_id] = [ + item["embedding"] for item in response.data + ] - logger.info(f"Initialized {len(self._reference_embeddings)} reference embeddings") + logger.info( + f"Initialized {len(self._embeddings_cache[trace_id])} reference embeddings " + f"for trace_id={trace_id}" + ) # embed current text embedding_params = llm_config_manager._prepare_embedding_call( @@ -130,10 +136,9 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: response = await litellm.aembedding(**embedding_params) current_embedding = response.data[0]["embedding"] - # compute cosine similarities - similarities = cosine_similarity( - [current_embedding], self._reference_embeddings - )[0] + # compute cosine similarities against cached embeddings + reference_embeddings = self._embeddings_cache[trace_id] + similarities = cosine_similarity([current_embedding], reference_embeddings)[0] max_similarity = float(max(similarities)) if len(similarities) > 0 else 0.0 is_duplicate = max_similarity >= self.similarity_threshold diff --git a/lib/blocks/builtin/semantic_infiller.py b/lib/blocks/builtin/semantic_infiller.py index a6a4b2c..1a10a24 100644 --- a/lib/blocks/builtin/semantic_infiller.py +++ b/lib/blocks/builtin/semantic_infiller.py @@ -20,8 +20,13 @@ class SemanticInfiller(BaseBlock): inputs = ["*"] # accepts any skeleton fields outputs = ["*"] # returns merged skeleton + generated fields + # constants for prompt generation + MAX_EXEMPLARS_IN_PROMPT = 2 + _config_descriptions = { - "fields_to_generate": "List of field names for LLM to generate (e.g., ['bio', 'description'])", + "fields_to_generate": ( + 'JSON array or Jinja template. Examples: ["bio", "storage"] or {{ fields_to_generate | tojson }}' + ), "model": "Select LLM model to use (leave empty for default)", "temperature": "Sampling temperature (0.0 = deterministic, 1.0 = creative)", "max_tokens": "Maximum tokens for generated response", @@ -30,13 +35,13 @@ class SemanticInfiller(BaseBlock): def __init__( self, - fields_to_generate: list[str], + fields_to_generate: str, model: str | None = None, temperature: float = 0.8, max_tokens: int = 500, system_prompt: str = "", ): - self.fields_to_generate = fields_to_generate + self.fields_to_generate_template = fields_to_generate self.model_name = model self.temperature = temperature self.max_tokens = max_tokens @@ -70,7 +75,7 @@ def _build_generation_prompt( hint_lines.append(f" - {field_name} should be between {value[0]}-{value[1]}") elif key == "exemplars" and isinstance(value, list): hint_lines.append(" - Example records for reference:") - for ex in value[:2]: # show max 2 exemplars + for ex in value[: self.MAX_EXEMPLARS_IN_PROMPT]: # only show generated fields from exemplar ex_fields = { f: ex.get(f, "") @@ -131,6 +136,8 @@ def _parse_json_safely(self, content: str) -> dict[str, Any]: ) async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + from lib.template_renderer import render_template + from app import llm_config_manager # extract skeleton from context @@ -138,6 +145,31 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: hints = skeleton.pop("_hints", {}) skeleton.pop("_usage", None) # remove internal fields + # render fields_to_generate template and parse as JSON + fields_template_rendered = render_template( + self.fields_to_generate_template, context.accumulated_state + ) + try: + fields_to_generate = json.loads(fields_template_rendered) + if not isinstance(fields_to_generate, list): + raise BlockExecutionError( + "fields_to_generate must be a JSON array", + detail={"rendered_value": fields_template_rendered}, + ) + if not all(isinstance(f, str) for f in fields_to_generate): + raise BlockExecutionError( + "All items in fields_to_generate must be strings", + detail={"fields_to_generate": fields_to_generate}, + ) + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"fields_to_generate must be valid JSON: {str(e)}", + detail={"template": self.fields_to_generate_template, "rendered": fields_template_rendered}, + ) + + # temporarily set for prompt building + self.fields_to_generate = fields_to_generate + # build generation prompt prompt = self._build_generation_prompt(skeleton, hints) diff --git a/lib/blocks/builtin/structure_sampler.py b/lib/blocks/builtin/structure_sampler.py index eabb6ef..ac0cada 100644 --- a/lib/blocks/builtin/structure_sampler.py +++ b/lib/blocks/builtin/structure_sampler.py @@ -17,6 +17,10 @@ class StructureSampler(BaseMultiplierBlock): inputs = [] # reads from initial state outputs = ["*"] # dynamic based on categorical fields + # constants for sampling configuration + MAX_EXEMPLARS = 5 + MAX_MATCHING_EXEMPLARS = 3 + _config_descriptions = { "target_count": "Number of skeleton records to generate", "categorical_fields": "List of categorical field names to sample (e.g., ['plan', 'role'])", @@ -59,35 +63,23 @@ def _validate_samples(self, samples: list[dict[str, Any]]) -> None: f"Recommend at least 20 samples for better distribution modeling." ) - def _analyze_samples(self, samples: list[dict[str, Any]]) -> dict[str, Any]: - """ - extract statistical patterns from samples - - returns: - { - "categorical_probs": {"field": {"value": prob, ...}}, - "conditional_probs": {"field|parent=val": {"value": prob, ...}}, - "numeric_stats": {"field": {"min": x, "max": y, "mean": z}}, - "exemplars": [sample1, sample2, ...] - } - """ - profile: dict[str, Any] = { - "categorical_probs": {}, - "conditional_probs": {}, - "numeric_stats": {}, - "exemplars": [], - } - - # categorical field distributions + def _compute_categorical_distributions( + self, samples: list[dict[str, Any]] + ) -> dict[str, dict[str, float]]: + """compute probability distributions for categorical fields""" + distributions = {} for field in self.categorical_fields: values = [sample.get(field) for sample in samples] counts = Counter(values) total = sum(counts.values()) - profile["categorical_probs"][field] = { - value: count / total for value, count in counts.items() - } - - # conditional probabilities for dependencies + distributions[field] = {value: count / total for value, count in counts.items()} + return distributions + + def _compute_conditional_probabilities( + self, samples: list[dict[str, Any]] + ) -> dict[str, dict[str, float]]: + """compute conditional probabilities for dependent fields""" + conditional_probs = {} for child_field, parent_fields in self.dependencies.items(): if child_field not in self.categorical_fields: continue @@ -108,9 +100,15 @@ def _analyze_samples(self, samples: list[dict[str, Any]]) -> dict[str, Any]: # build key: "child|parent1=val1,parent2=val2" parent_str = ",".join(f"{p}={v}" for p, v in zip(parent_fields, parent_key)) key = f"{child_field}|{parent_str}" - profile["conditional_probs"][key] = probs + conditional_probs[key] = probs + + return conditional_probs - # numeric field statistics + def _compute_numeric_statistics( + self, samples: list[dict[str, Any]] + ) -> dict[str, dict[str, float]]: + """compute min/max/mean statistics for numeric fields""" + numeric_stats = {} for field in self.numeric_fields: values = [sample.get(field) for sample in samples if sample.get(field) is not None] if values: @@ -125,17 +123,40 @@ def _analyze_samples(self, samples: list[dict[str, Any]]) -> dict[str, Any]: ) if numeric_values: - profile["numeric_stats"][field] = { + numeric_stats[field] = { "min": min(numeric_values), "max": max(numeric_values), "mean": sum(numeric_values) / len(numeric_values), } + return numeric_stats + + def _select_exemplars( + self, samples: list[dict[str, Any]], max_count: int | None = None + ) -> list[dict]: + """randomly select exemplar samples for reference""" + if max_count is None: + max_count = self.MAX_EXEMPLARS + num_exemplars = min(max_count, len(samples)) + return random.sample(samples, num_exemplars) - # select random exemplars - num_exemplars = min(5, len(samples)) - profile["exemplars"] = random.sample(samples, num_exemplars) + def _analyze_samples(self, samples: list[dict[str, Any]]) -> dict[str, Any]: + """ + extract statistical patterns from samples - return profile + returns: + { + "categorical_probs": {"field": {"value": prob, ...}}, + "conditional_probs": {"field|parent=val": {"value": prob, ...}}, + "numeric_stats": {"field": {"min": x, "max": y, "mean": z}}, + "exemplars": [sample1, sample2, ...] + } + """ + return { + "categorical_probs": self._compute_categorical_distributions(samples), + "conditional_probs": self._compute_conditional_probabilities(samples), + "numeric_stats": self._compute_numeric_statistics(samples), + "exemplars": self._select_exemplars(samples), + } def _topological_sort(self, fields: list[str]) -> list[str]: """ @@ -183,6 +204,57 @@ def _sample_from_distribution(self, probs: dict[str, float]) -> Any: weights = list(probs.values()) return random.choices(values, weights=weights, k=1)[0] + def _sample_categorical_field( + self, field: str, skeleton: dict[str, Any], profile: dict[str, Any] + ) -> Any: + """sample value for a single categorical field, respecting dependencies""" + if field in self.dependencies: + # conditional sampling based on parent values + parent_fields = self.dependencies[field] + parent_values = tuple(skeleton.get(p) for p in parent_fields) + parent_str = ",".join(f"{p}={v}" for p, v in zip(parent_fields, parent_values)) + key = f"{field}|{parent_str}" + + if key in profile["conditional_probs"]: + probs = profile["conditional_probs"][key] + else: + # fallback to marginal distribution + logger.warning( + f"Unseen combination {key}, using marginal distribution for {field}" + ) + probs = profile["categorical_probs"].get(field, {}) + else: + # independent sampling + probs = profile["categorical_probs"].get(field, {}) + + return self._sample_from_distribution(probs) + + def _generate_hints( + self, skeleton: dict[str, Any], profile: dict[str, Any] + ) -> dict[str, Any]: + """generate hints for numeric fields and matching exemplars""" + hints: dict[str, Any] = {} + + # add numeric field ranges + for field in self.numeric_fields: + if field in profile["numeric_stats"]: + stats = profile["numeric_stats"][field] + hints[f"{field}_range"] = [stats["min"], stats["max"]] + + # add exemplars that match current categorical values + matching_exemplars = [ + ex + for ex in profile["exemplars"] + if all(ex.get(f) == skeleton.get(f) for f in self.categorical_fields) + ] + + if not matching_exemplars: + # use any exemplars from the full set + matching_exemplars = profile["exemplars"][: self.MAX_MATCHING_EXEMPLARS] + + hints["exemplars"] = matching_exemplars + return hints + def _generate_skeletons( self, profile: dict[str, Any], count: int ) -> list[dict[str, Any]]: @@ -201,50 +273,10 @@ def _generate_skeletons( # sample categorical values in dependency order for field in field_order: - if field in self.dependencies: - # conditional sampling - parent_fields = self.dependencies[field] - parent_values = tuple(skeleton.get(p) for p in parent_fields) - parent_str = ",".join(f"{p}={v}" for p, v in zip(parent_fields, parent_values)) - key = f"{field}|{parent_str}" - - if key in profile["conditional_probs"]: - probs = profile["conditional_probs"][key] - else: - # fallback to marginal distribution - logger.warning( - f"Unseen combination {key}, using marginal distribution for {field}" - ) - probs = profile["categorical_probs"].get(field, {}) - - else: - # independent sampling - probs = profile["categorical_probs"].get(field, {}) - - skeleton[field] = self._sample_from_distribution(probs) - - # generate hints for numeric fields - hints: dict[str, Any] = {} - - for field in self.numeric_fields: - if field in profile["numeric_stats"]: - stats = profile["numeric_stats"][field] - hints[f"{field}_range"] = [stats["min"], stats["max"]] - - # add exemplars that match current categorical values - matching_exemplars = [ - ex - for ex in profile["exemplars"] - if all(ex.get(f) == skeleton.get(f) for f in self.categorical_fields) - ] - - if not matching_exemplars: - # use any exemplars - matching_exemplars = profile["exemplars"][:3] - - hints["exemplars"] = matching_exemplars + skeleton[field] = self._sample_categorical_field(field, skeleton, profile) - skeleton["_hints"] = hints + # add hints for LLM generation + skeleton["_hints"] = self._generate_hints(skeleton, profile) results.append(skeleton) return results diff --git a/lib/templates/data_augmentation.yaml b/lib/templates/data_augmentation.yaml index e142e27..0a650fe 100644 --- a/lib/templates/data_augmentation.yaml +++ b/lib/templates/data_augmentation.yaml @@ -11,7 +11,7 @@ blocks: - type: SemanticInfiller config: - fields_to_generate: "{{ fields_to_generate }}" + fields_to_generate: "{{ fields_to_generate | tojson }}" temperature: 0.8 max_tokens: 200 model: null diff --git a/llm/state-backend.md b/llm/state-backend.md index a31744b..1f96424 100644 --- a/llm/state-backend.md +++ b/llm/state-backend.md @@ -11,10 +11,11 @@ fastapi + aiosqlite + pydantic + jinja2 + pyyaml + litellm + rouge-score ``` lib/ blocks/ - builtin/ # 11 blocks: text_generator, structured_generator, validator, - # json_validator, diversity_score, coherence_score, - # rouge_score, markdown_multiplier, langfuse, - # field_mapper, ragas_metrics + builtin/ # 14 blocks: text_generator, structured_generator, + # semantic_infiller, validator, json_validator, + # duplicate_remover, diversity_score, coherence_score, + # rouge_score, markdown_multiplier, structure_sampler, + # langfuse, field_mapper, ragas_metrics commons/ # shared utilities (usage_tracker) custom/ # user experimental blocks base.py # BaseBlock interface @@ -414,28 +415,48 @@ class BaseBlock: - `_config_descriptions` → description fields in schema ### builtin blocks +- **StructureSampler**: statistical sampler (target_count, categorical_fields, numeric_fields, dependencies, seed) + - outputs: * (dynamic skeletons + hints) + - category: seeders - **TextGenerator**: text via litellm (system_prompt, user_prompt, model, temperature, max_tokens) - outputs: assistant, system, user + - category: generators - **StructuredGenerator**: json via litellm (json_schema, user_prompt, model, temperature, max_tokens) - outputs: generated + - category: generators +- **SemanticInfiller**: complete skeletons with llm (fields_to_generate, model, temperature, max_tokens) + - outputs: * (merged skeleton + generated fields) + - category: generators - **MarkdownMultiplierBlock**: split markdown into chunks (is_multiplier: true, must be first) - outputs: content (per chunk) + - category: multipliers - **ValidatorBlock**: validate text (min_length, max_length, forbidden_words) - outputs: text, valid, assistant + - category: validators - **JSONValidatorBlock**: parse json from field (field_name, required_fields, strict) - outputs: valid, parsed_json + - category: validators +- **DuplicateRemover**: embedding-based similarity check (similarity_threshold, comparison_fields, embedding_model) + - outputs: *, is_duplicate, similarity_score + - category: validators - **DiversityScore**: lexical diversity (field_name) - outputs: diversity_score + - category: metrics - **CoherenceScore**: text coherence (field_name) - outputs: coherence_score + - category: metrics - **RougeScore**: rouge comparison (generated_field, reference_field, rouge_type) - outputs: rouge_score + - category: metrics - **LangfuseBlock**: observability logging (public_key, secret_key, host, session_id) - outputs: langfuse_trace_url + - category: observability - **FieldMapper**: create fields from Jinja2 expressions (mappings) - outputs: dynamic (keys from mappings config) + - category: utilities - **RagasMetrics**: evaluate QA using RAGAS metrics (question_field, answer_field, etc.) - outputs: ragas_scores + - category: metrics ### block discovery - registry scans: lib/blocks/builtin/, lib/blocks/custom/, user_blocks/ diff --git a/llm/state-project.md b/llm/state-project.md index 0e693c5..f4462d1 100644 --- a/llm/state-project.md +++ b/llm/state-project.md @@ -28,7 +28,7 @@ tools: uv (python), yarn (js) ``` lib/ blocks/ - builtin/ # 9 blocks (text/structured gen, multiplier, validators, metrics, langfuse) + builtin/ # 12 blocks (generators, multiplier, validators, metrics, seeders, observability) custom/ # experimental base.py # BaseBlock interface config.py # schema extraction @@ -47,6 +47,10 @@ frontend/src/ pages/ # Pipelines, Generator, Review, Settings components/ # GlobalJobIndicator, pipeline-editor/, settings/, ui/ +.claude/ + skills/ + implementing-datagenflow-blocks/ # guide for creating new blocks + tests/ conftest.py # test db setup blocks/ # block unit tests @@ -94,11 +98,15 @@ class BaseBlock: pass ``` -### builtin blocks (9 total) +### builtin blocks (12 total) + +**seeders:** +- StructureSampler: statistical sampler (target_count, categorical_fields, numeric_fields, dependencies, seed) → * (skeletons + hints) **generators:** - TextGenerator: litellm text (system_prompt, user_prompt, model, temp, max_tokens) → assistant, system, user - StructuredGenerator: litellm json (json_schema, user_prompt, model, temp, max_tokens) → generated +- SemanticInfiller: complete skeletons (fields_to_generate, model, temperature, max_tokens) → * (merged skeleton + generated) **multipliers:** - MarkdownMultiplierBlock: split markdown (file_content required, is_multiplier: true) → content (per chunk) @@ -106,6 +114,7 @@ class BaseBlock: **validators:** - ValidatorBlock: text rules (min_length, max_length, forbidden_words) → text, valid, assistant - JSONValidatorBlock: parse json (field_name, required_fields, strict) → valid, parsed_json +- DuplicateRemover: embedding similarity (similarity_threshold, comparison_fields, embedding_model) → *, is_duplicate, similarity_score **metrics:** - DiversityScore: lexical diversity (field_name) → diversity_score @@ -207,10 +216,11 @@ blocks: temperature: 0.7 ``` -### built-in (3 templates) +### built-in (4 templates) - **json_generation**: extract title/description (StructuredGenerator + JSONValidator) - **text_classification**: classify with confidence (StructuredGenerator + JSONValidator) - **qa_generation**: generate Q&A pairs (TextGenerator + StructuredGenerator + JSONValidator) +- **data_augmentation**: synthetic records from samples (StructureSampler + SemanticInfiller + DuplicateRemover) ## storage @@ -360,10 +370,10 @@ blocks/, integration/, test_api.py, test_workflow.py, test_storage.py, test_cons production-ready full-stack data generation platform ### features -- 9 blocks (generators, multiplier, validators, metrics, observability) +- 12 blocks (seeders, generators, multiplier, validators, metrics, observability) - auto-discovery from builtin/custom/user_blocks - reactflow visual editor with drag-drop -- jinja2 templates + 3 yaml templates +- jinja2 templates + 4 yaml templates - background jobs with real-time progress - incremental record visibility - job-scoped delete/export/filter diff --git a/tests/blocks/test_duplicate_remover.py b/tests/blocks/test_duplicate_remover.py new file mode 100644 index 0000000..ce7d140 --- /dev/null +++ b/tests/blocks/test_duplicate_remover.py @@ -0,0 +1,290 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from lib.blocks.builtin.duplicate_remover import DuplicateRemover +from lib.entities.block_execution_context import BlockExecutionContext + + +def make_context(state: dict, initial_state: dict | None = None) -> BlockExecutionContext: + """helper to create test context""" + if initial_state: + state = {**state} # don't mutate + context = BlockExecutionContext( + trace_id="test-trace", + pipeline_id=1, + accumulated_state=state, + ) + if initial_state: + # add initial state items to accumulated_state + context.accumulated_state.update(initial_state) + return context + + +class TestDuplicateRemoverInit: + def test_init_basic(self): + block = DuplicateRemover() + assert block.similarity_threshold == 0.85 + assert block.comparison_fields is None + assert block.embedding_model_name is None + + def test_init_with_params(self): + block = DuplicateRemover( + similarity_threshold=0.9, + comparison_fields=["bio", "description"], + embedding_model="text-embedding-ada-002", + ) + assert block.similarity_threshold == 0.9 + assert block.comparison_fields == ["bio", "description"] + assert block.embedding_model_name == "text-embedding-ada-002" + + +class TestDuplicateRemoverTextExtraction: + def test_extract_text_specific_fields(self): + block = DuplicateRemover(comparison_fields=["bio"]) + + record = {"bio": "Test bio", "other": "Ignored"} + text = block._extract_text(record, ["bio"]) + + assert text == "Test bio" + + def test_extract_text_multiple_fields(self): + block = DuplicateRemover(comparison_fields=["bio", "description"]) + + record = {"bio": "Bio text", "description": "Description text"} + text = block._extract_text(record, ["bio", "description"]) + + assert text == "Bio text Description text" + + def test_extract_text_auto_detect(self): + block = DuplicateRemover() + + record = {"bio": "Bio text", "plan": "Free", "count": 123} + text = block._extract_text(record, None) + + # should only include string fields + assert "Bio text" in text + assert "Free" in text + assert "123" not in text + + def test_extract_text_handles_none(self): + block = DuplicateRemover(comparison_fields=["bio"]) + + record = {"bio": None, "other": "text"} + text = block._extract_text(record, ["bio"]) + + # None should be converted to empty string + assert text == "" + + +class TestDuplicateRemoverNoSamples: + @pytest.mark.asyncio + async def test_no_samples_returns_not_duplicate(self): + block = DuplicateRemover() + + context = make_context({"bio": "Test bio"}) + + result = await block.execute(context) + + assert result["is_duplicate"] is False + assert result["similarity_score"] == 0.0 + + @pytest.mark.asyncio + async def test_empty_samples_returns_not_duplicate(self): + block = DuplicateRemover() + + context = make_context({"bio": "Test bio"}, {"samples": []}) + + result = await block.execute(context) + + assert result["is_duplicate"] is False + assert result["similarity_score"] == 0.0 + + +class TestDuplicateRemoverNoText: + @pytest.mark.asyncio + async def test_no_text_returns_not_duplicate(self): + block = DuplicateRemover(comparison_fields=["bio"]) + + context = make_context({}, {"samples": [{"bio": "Sample"}]}) + + result = await block.execute(context) + + assert result["is_duplicate"] is False + assert result["similarity_score"] == 0.0 + + +class TestDuplicateRemoverWithEmbeddings: + @pytest.mark.asyncio + @patch("litellm.aembedding") + @patch("app.llm_config_manager") + async def test_duplicate_detection_below_threshold( + self, mock_config_manager, mock_embedding + ): + # setup mocks + mock_config_manager.get_embedding_model = AsyncMock( + return_value={"model": "text-embedding-ada-002"} + ) + mock_config_manager._prepare_embedding_call = MagicMock( + return_value={"model": "text-embedding-ada-002"} + ) + + # mock embeddings - different vectors (low similarity) + mock_embedding.side_effect = [ + # reference embeddings + MagicMock(data=[{"embedding": [1.0, 0.0, 0.0]}]), + # current embedding + MagicMock(data=[{"embedding": [0.0, 1.0, 0.0]}]), + ] + + block = DuplicateRemover( + similarity_threshold=0.85, + comparison_fields=["bio"], + ) + + context = make_context( + {"bio": "New unique bio"}, + {"samples": [{"bio": "Reference bio"}]}, + ) + + result = await block.execute(context) + + assert result["is_duplicate"] is False + assert result["similarity_score"] < 0.85 + + @pytest.mark.asyncio + @patch("litellm.aembedding") + @patch("app.llm_config_manager") + async def test_duplicate_detection_above_threshold( + self, mock_config_manager, mock_embedding + ): + # setup mocks + mock_config_manager.get_embedding_model = AsyncMock( + return_value={"model": "text-embedding-ada-002"} + ) + mock_config_manager._prepare_embedding_call = MagicMock( + return_value={"model": "text-embedding-ada-002"} + ) + + # mock embeddings - very similar vectors (high similarity) + mock_embedding.side_effect = [ + # reference embeddings + MagicMock(data=[{"embedding": [1.0, 0.1, 0.0]}]), + # current embedding (very similar) + MagicMock(data=[{"embedding": [0.99, 0.11, 0.01]}]), + ] + + block = DuplicateRemover( + similarity_threshold=0.85, + comparison_fields=["bio"], + ) + + context = make_context( + {"bio": "Very similar bio"}, + {"samples": [{"bio": "Similar bio"}]}, + ) + + result = await block.execute(context) + + assert result["is_duplicate"] is True + assert result["similarity_score"] >= 0.85 + + @pytest.mark.asyncio + @patch("litellm.aembedding") + @patch("app.llm_config_manager") + async def test_embedding_cache_by_trace_id( + self, mock_config_manager, mock_embedding + ): + """test that embeddings are cached per trace_id""" + mock_config_manager.get_embedding_model = AsyncMock( + return_value={"model": "text-embedding-ada-002"} + ) + mock_config_manager._prepare_embedding_call = MagicMock( + return_value={"model": "text-embedding-ada-002"} + ) + + mock_embedding.side_effect = [ + # first call - build reference embeddings + MagicMock(data=[{"embedding": [1.0, 0.0, 0.0]}]), + # second call - embed current text + MagicMock(data=[{"embedding": [0.5, 0.5, 0.0]}]), + # third call - embed second current text (reuses cache, so no reference embedding call) + MagicMock(data=[{"embedding": [0.6, 0.4, 0.0]}]), + ] + + block = DuplicateRemover(comparison_fields=["bio"]) + + # first execution + context1 = make_context( + {"bio": "First bio"}, + {"samples": [{"bio": "Reference"}]}, + ) + await block.execute(context1) + + # second execution with same trace_id - should reuse cache + context2 = make_context( + {"bio": "Second bio"}, + {"samples": [{"bio": "Reference"}]}, + ) + context2.trace_id = "test-trace" # same trace_id + await block.execute(context2) + + # embedding should be called 3 times total (1 ref + 2 current) + assert mock_embedding.call_count == 3 + + +class TestDuplicateRemoverErrorHandling: + @pytest.mark.asyncio + async def test_no_embedding_model_skips_check(self): + """test that missing embedding model gracefully skips check""" + block = DuplicateRemover() + + context = make_context( + {"bio": "Test bio"}, + {"samples": [{"bio": "Reference"}]}, + ) + + # should not raise error + result = await block.execute(context) + + assert result["is_duplicate"] is False + assert result["similarity_score"] == 0.0 + + @pytest.mark.asyncio + @patch("app.llm_config_manager") + async def test_embedding_error_skips_check(self, mock_config_manager): + """test that embedding errors are caught and check is skipped""" + mock_config_manager.get_embedding_model = AsyncMock( + side_effect=Exception("Embedding model not found") + ) + + block = DuplicateRemover(embedding_model="invalid-model") + + context = make_context( + {"bio": "Test bio"}, + {"samples": [{"bio": "Reference"}]}, + ) + + # should not raise error + result = await block.execute(context) + + assert result["is_duplicate"] is False + assert result["similarity_score"] == 0.0 + + +class TestDuplicateRemoverSchema: + def test_schema_structure(self): + schema = DuplicateRemover.get_schema() + assert schema["name"] == "Duplicate Remover" + assert schema["category"] == "validators" + assert schema["inputs"] == ["*"] + assert "*" in schema["outputs"] + assert "is_duplicate" in schema["outputs"] + assert "similarity_score" in schema["outputs"] + + def test_schema_has_required_configs(self): + schema = DuplicateRemover.get_schema() + config_props = schema["config_schema"]["properties"] + assert "similarity_threshold" in config_props + assert "comparison_fields" in config_props + assert "embedding_model" in config_props diff --git a/tests/blocks/test_semantic_infiller.py b/tests/blocks/test_semantic_infiller.py new file mode 100644 index 0000000..d715933 --- /dev/null +++ b/tests/blocks/test_semantic_infiller.py @@ -0,0 +1,305 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from lib.blocks.builtin.semantic_infiller import SemanticInfiller +from lib.entities import LLMModelConfig, LLMProvider +from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError + + +def make_context(state: dict) -> BlockExecutionContext: + """helper to create test context""" + return BlockExecutionContext( + trace_id="test-trace", + pipeline_id=1, + accumulated_state=state, + ) + + +class TestSemanticInfillerInit: + def test_init_basic(self): + block = SemanticInfiller(fields_to_generate='["bio"]') + assert block.fields_to_generate_template == '["bio"]' + assert block.model_name is None + assert block.temperature == 0.8 + assert block.max_tokens == 500 + + def test_init_with_all_params(self): + block = SemanticInfiller( + fields_to_generate='["bio", "description"]', + model="gpt-4", + temperature=0.9, + max_tokens=1000, + system_prompt="Custom prompt", + ) + assert block.fields_to_generate_template == '["bio", "description"]' + assert block.model_name == "gpt-4" + assert block.temperature == 0.9 + assert block.max_tokens == 1000 + assert block.system_prompt == "Custom prompt" + + def test_init_with_template(self): + block = SemanticInfiller(fields_to_generate="{{ fields_to_generate }}") + assert block.fields_to_generate_template == "{{ fields_to_generate }}" + + +class TestSemanticInfillerPromptBuilding: + def test_build_prompt_with_constraints(self): + block = SemanticInfiller(fields_to_generate='["bio"]') + # Set the parsed fields for prompt building + block.fields_to_generate = ["bio"] + + skeleton = {"plan": "Free", "role": "Viewer"} + hints = {} + + prompt = block._build_generation_prompt(skeleton, hints) + + assert '"bio"' in prompt + assert 'plan: "Free" (FIXED)' in prompt + assert 'role: "Viewer" (FIXED)' in prompt + + def test_build_prompt_with_numeric_hints(self): + block = SemanticInfiller(fields_to_generate='["storage"]') + block.fields_to_generate = ["storage"] + + skeleton = {"plan": "Pro"} + hints = {"storage_range": [10, 100]} + + prompt = block._build_generation_prompt(skeleton, hints) + + assert "storage should be between 10-100" in prompt + + def test_build_prompt_with_exemplars(self): + block = SemanticInfiller(fields_to_generate='["bio"]') + block.fields_to_generate = ["bio"] + + skeleton = {"plan": "Free"} + hints = { + "exemplars": [ + {"plan": "Free", "bio": "Student learning"}, + {"plan": "Free", "bio": "Just exploring"}, + ] + } + + prompt = block._build_generation_prompt(skeleton, hints) + + assert "Example records" in prompt + assert "Student learning" in prompt + assert "Just exploring" in prompt + + +class TestSemanticInfillerJSONParsing: + def test_parse_valid_json(self): + block = SemanticInfiller(fields_to_generate=["bio"]) + + content = '{"bio": "Test bio"}' + result = block._parse_json_safely(content) + + assert result == {"bio": "Test bio"} + + def test_parse_json_with_markdown(self): + block = SemanticInfiller(fields_to_generate=["bio"]) + + content = '```json\n{"bio": "Test bio"}\n```' + result = block._parse_json_safely(content) + + assert result == {"bio": "Test bio"} + + def test_parse_json_embedded_in_text(self): + block = SemanticInfiller(fields_to_generate=["bio"]) + + content = 'Here is the result: {"bio": "Test bio"} done' + result = block._parse_json_safely(content) + + assert result == {"bio": "Test bio"} + + def test_parse_invalid_json_raises_error(self): + block = SemanticInfiller(fields_to_generate=["bio"]) + + content = "not json at all" + + with pytest.raises(BlockExecutionError, match="invalid JSON"): + block._parse_json_safely(content) + + +class TestSemanticInfillerExecution: + @pytest.mark.asyncio + @patch("litellm.acompletion") + @patch("app.llm_config_manager") + async def test_execute_basic(self, mock_config_manager, mock_completion): + # setup mocks + mock_config_manager.get_llm_model = AsyncMock( + return_value=LLMModelConfig( + name="test", + provider=LLMProvider.OPENAI, + endpoint="http://test", + model_name="gpt-4", + ) + ) + mock_config_manager.prepare_llm_call = MagicMock( + return_value={"model": "gpt-4", "messages": []} + ) + mock_completion.return_value = MagicMock( + choices=[ + MagicMock(message=MagicMock(content='{"bio": "Generated bio"}')) + ], + usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), + ) + + block = SemanticInfiller(fields_to_generate='["bio"]') + context = make_context({"plan": "Free", "role": "Viewer"}) + + result = await block.execute(context) + + assert result["plan"] == "Free" + assert result["role"] == "Viewer" + assert result["bio"] == "Generated bio" + assert "_usage" in result + + @pytest.mark.asyncio + @patch("litellm.acompletion") + @patch("app.llm_config_manager") + async def test_execute_with_hints(self, mock_config_manager, mock_completion): + # setup mocks + mock_config_manager.get_llm_model = AsyncMock( + return_value=LLMModelConfig( + name="test", + provider=LLMProvider.OPENAI, + endpoint="http://test", + model_name="gpt-4", + ) + ) + mock_config_manager.prepare_llm_call = MagicMock( + return_value={"model": "gpt-4", "messages": []} + ) + mock_completion.return_value = MagicMock( + choices=[ + MagicMock(message=MagicMock(content='{"bio": "Generated bio", "storage": 50}')) + ], + usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), + ) + + block = SemanticInfiller(fields_to_generate='["bio", "storage"]') + context = make_context({ + "plan": "Pro", + "_hints": {"storage_range": [10, 100]} + }) + + result = await block.execute(context) + + assert result["bio"] == "Generated bio" + assert result["storage"] == 50 + # hints should be removed from result + assert "_hints" not in result + + @pytest.mark.asyncio + @patch("litellm.acompletion") + @patch("app.llm_config_manager") + async def test_execute_restores_locked_fields(self, mock_config_manager, mock_completion): + # LLM tries to modify a locked field + mock_config_manager.get_llm_model = AsyncMock( + return_value=LLMModelConfig( + name="test", + provider=LLMProvider.OPENAI, + endpoint="http://test", + model_name="gpt-4", + ) + ) + mock_config_manager.prepare_llm_call = MagicMock( + return_value={"model": "gpt-4", "messages": []} + ) + mock_completion.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='{"plan": "Modified", "bio": "Generated bio"}' + ) + ) + ], + usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), + ) + + block = SemanticInfiller(fields_to_generate='["bio"]') + context = make_context({"plan": "Free"}) + + result = await block.execute(context) + + # plan should be restored to original value + assert result["plan"] == "Free" + assert result["bio"] == "Generated bio" + + @pytest.mark.asyncio + @patch("litellm.acompletion") + @patch("app.llm_config_manager") + async def test_execute_llm_error_raises(self, mock_config_manager, mock_completion): + mock_config_manager.get_llm_model = AsyncMock( + return_value=LLMModelConfig( + name="test", + provider=LLMProvider.OPENAI, + endpoint="http://test", + model_name="gpt-4", + ) + ) + mock_config_manager.prepare_llm_call = MagicMock( + return_value={"model": "gpt-4", "messages": []} + ) + mock_completion.side_effect = Exception("LLM API error") + + block = SemanticInfiller(fields_to_generate='["bio"]') + context = make_context({"plan": "Free"}) + + with pytest.raises(BlockExecutionError, match="LLM call failed"): + await block.execute(context) + + @pytest.mark.asyncio + @patch("litellm.acompletion") + @patch("app.llm_config_manager") + async def test_execute_with_template(self, mock_config_manager, mock_completion): + """Test that Jinja templates work for fields_to_generate""" + mock_config_manager.get_llm_model = AsyncMock( + return_value=LLMModelConfig( + name="test", + provider=LLMProvider.OPENAI, + endpoint="http://test", + model_name="gpt-4", + ) + ) + mock_config_manager.prepare_llm_call = MagicMock( + return_value={"model": "gpt-4", "messages": []} + ) + mock_completion.return_value = MagicMock( + choices=[ + MagicMock(message=MagicMock(content='{"bio": "Generated bio"}')) + ], + usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), + ) + + # Use tojson filter to properly serialize the list as JSON + block = SemanticInfiller(fields_to_generate="{{ fields_to_generate | tojson }}") + # Provide fields_to_generate in the accumulated state (from metadata) + context = make_context({ + "plan": "Free", + "fields_to_generate": ["bio"] + }) + + result = await block.execute(context) + + assert result["bio"] == "Generated bio" + + +class TestSemanticInfillerSchema: + def test_schema_structure(self): + schema = SemanticInfiller.get_schema() + assert schema["name"] == "Semantic Infiller" + assert schema["category"] == "generators" + assert schema["inputs"] == ["*"] + assert schema["outputs"] == ["*"] + + def test_schema_has_required_configs(self): + schema = SemanticInfiller.get_schema() + config_props = schema["config_schema"]["properties"] + assert "fields_to_generate" in config_props + assert "model" in config_props + assert "temperature" in config_props + assert "max_tokens" in config_props diff --git a/tests/blocks/test_structure_sampler.py b/tests/blocks/test_structure_sampler.py new file mode 100644 index 0000000..65eedd1 --- /dev/null +++ b/tests/blocks/test_structure_sampler.py @@ -0,0 +1,244 @@ +import pytest + +from lib.blocks.builtin.structure_sampler import StructureSampler +from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import ValidationError + + +def make_context(state: dict) -> BlockExecutionContext: + """helper to create test context""" + return BlockExecutionContext( + trace_id="test-trace", + pipeline_id=1, + accumulated_state=state, + ) + + +class TestStructureSamplerInit: + def test_init_basic(self): + block = StructureSampler( + target_count=10, + categorical_fields=["plan"], + ) + assert block.target_count == 10 + assert block.categorical_fields == ["plan"] + assert block.numeric_fields == [] + assert block.dependencies == {} + + def test_init_with_all_params(self): + block = StructureSampler( + target_count=5, + categorical_fields=["plan", "role"], + numeric_fields=["storage"], + dependencies={"role": ["plan"]}, + seed=42, + ) + assert block.target_count == 5 + assert block.categorical_fields == ["plan", "role"] + assert block.numeric_fields == ["storage"] + assert block.dependencies == {"role": ["plan"]} + assert block.seed == 42 + + +class TestStructureSamplerDistributions: + @pytest.mark.asyncio + async def test_categorical_distribution(self): + block = StructureSampler( + target_count=10, + categorical_fields=["plan"], + seed=42, + ) + samples = [ + {"plan": "Free"}, + {"plan": "Free"}, + {"plan": "Pro"}, + ] + + profile = block._compute_categorical_distributions(samples) + + # check probabilities sum to 1 + assert abs(sum(profile["plan"].values()) - 1.0) < 0.001 + # check Free is ~67% (2/3) and Pro is ~33% (1/3) + assert abs(profile["plan"]["Free"] - 0.667) < 0.01 + assert abs(profile["plan"]["Pro"] - 0.333) < 0.01 + + @pytest.mark.asyncio + async def test_conditional_probabilities(self): + block = StructureSampler( + target_count=10, + categorical_fields=["plan", "role"], + dependencies={"role": ["plan"]}, + seed=42, + ) + samples = [ + {"plan": "Free", "role": "Viewer"}, + {"plan": "Free", "role": "Viewer"}, + {"plan": "Pro", "role": "Editor"}, + {"plan": "Pro", "role": "Admin"}, + ] + + profile = block._compute_conditional_probabilities(samples) + + # check conditional probability for role given plan + assert "role|plan=Free" in profile + assert profile["role|plan=Free"]["Viewer"] == 1.0 + + assert "role|plan=Pro" in profile + assert profile["role|plan=Pro"]["Editor"] == 0.5 + assert profile["role|plan=Pro"]["Admin"] == 0.5 + + @pytest.mark.asyncio + async def test_numeric_statistics(self): + block = StructureSampler( + target_count=10, + numeric_fields=["storage"], + categorical_fields=[], + seed=42, + ) + samples = [ + {"storage": 1}, + {"storage": 2}, + {"storage": 3}, + ] + + stats = block._compute_numeric_statistics(samples) + + assert stats["storage"]["min"] == 1 + assert stats["storage"]["max"] == 3 + assert stats["storage"]["mean"] == 2.0 + + +class TestStructureSamplerGeneration: + @pytest.mark.asyncio + async def test_generate_skeletons_basic(self): + block = StructureSampler( + target_count=5, + categorical_fields=["plan"], + seed=42, + ) + + context = make_context({ + "samples": [ + {"plan": "Free"}, + {"plan": "Free"}, + {"plan": "Pro"}, + ] + }) + + results = await block.execute(context) + + # check we got 5 results + assert len(results) == 5 + # check all have plan field + for result in results: + assert "plan" in result + assert result["plan"] in ["Free", "Pro"] + + @pytest.mark.asyncio + async def test_generate_skeletons_with_dependencies(self): + block = StructureSampler( + target_count=10, + categorical_fields=["plan", "role"], + dependencies={"role": ["plan"]}, + seed=42, + ) + + context = make_context({ + "samples": [ + {"plan": "Free", "role": "Viewer"}, + {"plan": "Free", "role": "Viewer"}, + {"plan": "Pro", "role": "Editor"}, + ] + }) + + results = await block.execute(context) + + # check all Free plans have Viewer role (100% in samples) + for result in results: + if result["plan"] == "Free": + assert result["role"] == "Viewer" + + @pytest.mark.asyncio + async def test_generate_skeletons_with_hints(self): + block = StructureSampler( + target_count=3, + categorical_fields=["plan"], + numeric_fields=["storage"], + seed=42, + ) + + context = make_context({ + "samples": [ + {"plan": "Free", "storage": 1}, + {"plan": "Free", "storage": 2}, + {"plan": "Pro", "storage": 50}, + ] + }) + + results = await block.execute(context) + + # check hints are included + for result in results: + assert "_hints" in result + assert "storage_range" in result["_hints"] + assert "exemplars" in result["_hints"] + # check storage range is [1, 50] + assert result["_hints"]["storage_range"] == [1, 50] + + +class TestStructureSamplerEdgeCases: + @pytest.mark.asyncio + async def test_empty_samples_raises_error(self): + block = StructureSampler( + target_count=5, + categorical_fields=["plan"], + ) + + context = make_context({"samples": []}) + + with pytest.raises(ValidationError, match="No samples provided"): + await block.execute(context) + + @pytest.mark.asyncio + async def test_missing_samples_raises_error(self): + block = StructureSampler( + target_count=5, + categorical_fields=["plan"], + ) + + context = make_context({}) + + with pytest.raises(ValidationError, match="No samples provided"): + await block.execute(context) + + @pytest.mark.asyncio + async def test_circular_dependency_detection(self): + block = StructureSampler( + target_count=5, + categorical_fields=["a", "b"], + dependencies={"a": ["b"], "b": ["a"]}, + ) + + context = make_context({ + "samples": [{"a": "1", "b": "2"}] + }) + + with pytest.raises(ValidationError, match="Circular dependency"): + await block.execute(context) + + +class TestStructureSamplerSchema: + def test_schema_structure(self): + schema = StructureSampler.get_schema() + assert schema["name"] == "Structure Sampler" + assert schema["category"] == "seeders" + assert schema["outputs"] == ["*"] + + def test_schema_has_required_configs(self): + schema = StructureSampler.get_schema() + config_props = schema["config_schema"]["properties"] + assert "target_count" in config_props + assert "categorical_fields" in config_props + assert "numeric_fields" in config_props + assert "dependencies" in config_props + assert "seed" in config_props diff --git a/tests/integration/test_data_augmentation.py b/tests/integration/test_data_augmentation.py index 1ba0061..2112082 100644 --- a/tests/integration/test_data_augmentation.py +++ b/tests/integration/test_data_augmentation.py @@ -1,15 +1,44 @@ """integration test for data augmentation pipeline""" import json +from unittest.mock import AsyncMock, MagicMock, patch + import pytest +from lib.entities import LLMModelConfig, LLMProvider from lib.storage import Storage from lib.workflow import Pipeline @pytest.mark.asyncio -async def test_data_augmentation_pipeline(tmp_path): +@patch("litellm.acompletion") +@patch("app.llm_config_manager") +async def test_data_augmentation_pipeline(mock_config_manager, mock_completion, tmp_path): """test complete data augmentation pipeline with all 3 blocks""" + # setup mocks for LLM calls + mock_config_manager.get_llm_model = AsyncMock( + return_value=LLMModelConfig( + name="test", + provider=LLMProvider.OPENAI, + endpoint="http://test", + model_name="gpt-4", + ) + ) + mock_config_manager.prepare_llm_call = MagicMock( + return_value={"model": "gpt-4", "messages": []} + ) + # mock LLM response with realistic generated fields + mock_completion.return_value = MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='{"bio": "Generated bio text", "storage": 10}' + ) + ) + ], + usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), + ) + # setup test database db_path = tmp_path / "test.db" storage = Storage(str(db_path)) @@ -32,7 +61,7 @@ async def test_data_augmentation_pipeline(tmp_path): { "type": "SemanticInfiller", "config": { - "fields_to_generate": ["bio", "storage"], + "fields_to_generate": '["bio", "storage"]', "temperature": 0.8, "max_tokens": 200, "model": None, @@ -119,12 +148,11 @@ async def test_data_augmentation_pipeline(tmp_path): if result["plan"] == "Free": assert result["role"] == "Viewer", "Free plan should have Viewer role" - # check trace has 3 steps - assert len(trace) == 3, f"Expected 3 trace steps, got {len(trace)}" + # check trace has 2 steps (StructureSampler is multiplier, doesn't appear in per-item trace) + assert len(trace) == 2, f"Expected 2 trace steps, got {len(trace)}" - step_types = [step["block_type"] for step in trace] + step_types = [step.block_type for step in trace] assert step_types == [ - "StructureSampler", "SemanticInfiller", "DuplicateRemover", ], f"Unexpected trace steps: {step_types}" From 91aa398bd1e54285785478f92e2da9166727e749 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 19:31:01 +0100 Subject: [PATCH 04/19] fix: blocks data --- frontend/src/components/pipeline-editor/BlockNode.tsx | 2 -- lib/blocks/builtin/semantic_infiller.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/frontend/src/components/pipeline-editor/BlockNode.tsx b/frontend/src/components/pipeline-editor/BlockNode.tsx index da30bb9..6005f2b 100644 --- a/frontend/src/components/pipeline-editor/BlockNode.tsx +++ b/frontend/src/components/pipeline-editor/BlockNode.tsx @@ -61,8 +61,6 @@ function getPreviewFields(blockType: string, config: Record): Array // priority fields based on block type let priorityKeys: string[] = []; - console.log(type); - // data augmentation blocks if (type.includes("sampler")) { priorityKeys = ["target_count", "categorical_fields"]; diff --git a/lib/blocks/builtin/semantic_infiller.py b/lib/blocks/builtin/semantic_infiller.py index 1a10a24..99e0c14 100644 --- a/lib/blocks/builtin/semantic_infiller.py +++ b/lib/blocks/builtin/semantic_infiller.py @@ -9,6 +9,7 @@ from lib.entities import pipeline from lib.entities.block_execution_context import BlockExecutionContext from lib.errors import BlockExecutionError +from lib.template_renderer import render_template logger = logging.getLogger(__name__) @@ -136,8 +137,6 @@ def _parse_json_safely(self, content: str) -> dict[str, Any]: ) async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: - from lib.template_renderer import render_template - from app import llm_config_manager # extract skeleton from context From 4de867c22ffaab8c821427297ea1d616ebad49dd Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 20:48:11 +0100 Subject: [PATCH 05/19] add: docs template --- docs/template_data_augmentation.md | 412 +++++++++++++++++++++++++++++ 1 file changed, 412 insertions(+) create mode 100644 docs/template_data_augmentation.md diff --git a/docs/template_data_augmentation.md b/docs/template_data_augmentation.md new file mode 100644 index 0000000..e12a053 --- /dev/null +++ b/docs/template_data_augmentation.md @@ -0,0 +1,412 @@ +--- +title: Data Augmentation Template +description: Generate synthetic records preserving statistical distributions from sample data +--- + +# Data Augmentation Template + +## Table of Contents +- [Overview](#overview) +- [Pipeline Architecture](#pipeline-architecture) +- [Seed Format](#seed-format) +- [Output Format](#output-format) +- [How It Works](#how-it-works) +- [Use Cases](#use-cases) +- [Customization](#customization) +- [Filtering Duplicates](#filtering-duplicates) +- [Tuning Parameters](#tuning-parameters) +- [Common Issues](#common-issues) +- [Example Workflow](#example-workflow) +- [Related Documentation](#related-documentation) + +## Overview + +**Complexity:** Advanced (3 blocks with multiplier) +**Use Case:** Generate synthetic data that preserves statistical patterns from samples + +This template creates realistic synthetic records from sample data while maintaining: +- Statistical distributions (e.g., "Free" plan appears 40% of the time) +- Field dependencies (e.g., "Admin" role only with "Pro" or "Enterprise" plans) +- Semantic coherence (LLM-generated fields match context) +- Output diversity (duplicate detection via embeddings) + +**Special Features:** +- Statistical sampling preserves distributions +- LLM-powered semantic field generation +- Embedding-based duplicate detection +- Supports field dependencies + +## Pipeline Architecture + +``` +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ Structure │──►│ Semantic │──►│ Duplicate │ +│ Sampler │ │ Infiller │ │ Remover │ +└─────────────┘ └─────────────┘ └─────────────┘ + +Input: samples array + ↓ ++ plan, role, storage, generation_hints (multiplies: 1 seed → N skeletons) + ↓ ++ bio (LLM-generated semantic field) + ↓ ++ is_duplicate, similarity_score +``` + +**Blocks:** +1. **StructureSampler** - Learns distributions from samples, generates statistical skeletons +2. **SemanticInfiller** - Completes skeletons with LLM-generated semantic fields +3. **DuplicateRemover** - Filters similar records using embedding similarity + +**Key Concept:** The StructureSampler is a multiplier block that generates N skeletons from one seed. Each skeleton flows through the remaining blocks to create one record. + +## Seed Format + +**Required fields:** +- `samples` - Array of example records (minimum 3 recommended) +- `target_count` - Number of synthetic records to generate +- `categorical_fields` - Fields to preserve distribution +- `fields_to_generate` - Fields for LLM to generate + +**Optional fields:** +- `numeric_fields` - Numeric distributions to preserve +- `dependencies` - Field relationships (e.g., role depends on plan) +- `comparison_fields` - Fields for duplicate detection + +**Example seed:** +```json +[ + { + "repetitions": 1, + "metadata": { + "samples": [ + {"plan": "Free", "role": "Viewer", "storage": 1, "bio": "Student learning web development"}, + {"plan": "Free", "role": "Viewer", "storage": 2, "bio": "Just exploring the platform"}, + {"plan": "Pro", "role": "Editor", "storage": 50, "bio": "Freelance designer managing projects"}, + {"plan": "Pro", "role": "Editor", "storage": 75, "bio": "Small agency owner"}, + {"plan": "Pro", "role": "Admin", "storage": 100, "bio": "Team lead overseeing projects"}, + {"plan": "Enterprise", "role": "Admin", "storage": 500, "bio": "CTO managing infrastructure"} + ], + "target_count": 20, + "categorical_fields": ["plan", "role"], + "numeric_fields": ["storage"], + "fields_to_generate": ["bio", "storage"], + "dependencies": { + "role": ["plan"] + }, + "comparison_fields": ["bio"] + } + } +] +``` + +> **Tip:** Use 5-10 diverse samples for best results. More samples = better distribution learning. + +## Output Format + +Each generated record contains: +- Sampled categorical fields (preserving distribution) +- Sampled or generated numeric fields +- LLM-generated semantic fields +- Duplicate detection metadata + +**Example output:** +```json +{ + "plan": "Pro", + "role": "Editor", + "storage": 75, + "bio": "Product designer with 5 years experience managing client projects", + "is_duplicate": false, + "similarity_score": 0.72, + "generation_hints": { + "numeric_ranges": { + "storage": {"mean": 50.5, "std": 30.2, "min": 1, "max": 500} + }, + "matching_exemplars": [ + {"plan": "Pro", "role": "Editor", "storage": 50, "bio": "Freelance designer"} + ] + } +} +``` + +## How It Works + +### Stage 1: StructureSampler (Statistical Skeleton Generation) + +**What it does:** +- Analyzes sample data to learn categorical frequencies +- Computes numeric statistics (mean, std, min, max) +- Detects field dependencies (e.g., role depends on plan) +- Generates N skeletons respecting learned distributions + +**Example:** If samples show "Free" plan 40% and "Pro" 30%, generated skeletons maintain these ratios. + +**Output per skeleton:** +```json +{ + "plan": "Pro", + "role": "Editor", + "storage": 75, + "generation_hints": { + "numeric_ranges": {"storage": {"mean": 50.5, "std": 30.2}}, + "matching_exemplars": [ + {"plan": "Pro", "role": "Editor", "storage": 50, "bio": "Freelance designer"} + ] + } +} +``` + +### Stage 2: SemanticInfiller (LLM-Powered Field Completion) + +**What it does:** +- Receives skeleton with locked statistical fields +- Builds contextual prompt with numeric hints and exemplar examples +- Calls LLM to generate semantic fields (bio, description, etc.) +- Restores locked fields if LLM overwrites them + +**Prompt structure:** +```text +You are a data generator. Complete the following record skeleton. + +Skeleton: {plan: "Pro", role: "Editor", storage: 75} + +Numeric hints: +- storage: mean=50.5, std=30.2, range=[1,500] + +Matching examples: +- {plan: "Pro", role: "Editor", storage: 50, bio: "Freelance designer"} + +Generate: ["bio", "storage"] +Return JSON: {"bio": "...", "storage": 75} +``` + +**Locked fields behavior:** If skeleton has `storage: 75` but `fields_to_generate` includes "storage", the LLM generates it but SemanticInfiller restores the original value to preserve statistical sampling. + +### Stage 3: DuplicateRemover (Similarity Filtering) + +**What it does:** +- Extracts text from comparison fields +- Generates embeddings via embedding model +- Computes cosine similarity with cached embeddings +- Marks records as duplicates if similarity > threshold + +**Output:** +```json +{ + "plan": "Pro", + "role": "Editor", + "storage": 75, + "bio": "Product designer with 5 years experience", + "is_duplicate": false, + "similarity_score": 0.72 +} +``` + +> **Note:** DuplicateRemover gracefully degrades if embedding model is unavailable - marks all records as `is_duplicate: false`. + +## Use Cases + +**Perfect for:** +- Expanding training datasets while maintaining patterns +- Creating realistic test data for applications +- Generating synthetic user profiles with distributions +- Data augmentation for ML training sets +- Privacy-preserving data generation (learn from real, generate synthetic) + +**Not ideal for:** +- Time-series data (no temporal modeling) +- Graph/network data (no relationship modeling) +- Highly correlated numeric fields (limited correlation preservation) + +## Customization + +Modify the template in `lib/templates/data_augmentation.yaml`: + +**Adjust generation count:** +```yaml +blocks: + - type: StructureSampler + config: + target_count: 100 # Generate 100 records +``` + +**Change LLM creativity:** +```yaml + - type: SemanticInfiller + config: + temperature: 0.9 # Higher = more creative (0.7-0.9 recommended) + max_tokens: 300 # Longer outputs +``` + +**Adjust duplicate threshold:** +```yaml + - type: DuplicateRemover + config: + similarity_threshold: 0.9 # Stricter (0.8-0.9 recommended) +``` + +**Add more dependencies:** +```json +{ + "dependencies": { + "role": ["plan"], + "storage": ["plan"] + } +} +``` + +## Filtering Duplicates + +Records marked as `is_duplicate: true` should be filtered post-generation: + +**Via API:** +```python +records = await storage.get_all(job_id=job_id) +unique_records = [r for r in records if not r.output.get("is_duplicate")] +``` + +**Via export (manual filter):** +```bash +# Export all records +curl http://localhost:8000/api/export?job_id=1 > output.jsonl + +# Filter duplicates +jq 'select(.is_duplicate == false)' output.jsonl > unique.jsonl +``` + +> **Note:** Keeping duplicates in the trace allows adjusting the threshold post-generation and analyzing similarity score distribution. + +## Tuning Parameters + +### Quality vs Speed + +**High quality (slower):** +```yaml +target_count: 100 +temperature: 0.9 +max_tokens: 300 +similarity_threshold: 0.9 +``` + +**Fast iteration (lower quality):** +```yaml +target_count: 20 +temperature: 0.7 +max_tokens: 150 +similarity_threshold: 0.75 +``` + +### Diversity vs Fidelity + +**Preserve distributions (higher fidelity):** +- Include all important `categorical_fields` +- Specify `dependencies` accurately +- Include `numeric_fields` with tight ranges + +**Increase diversity (creative generation):** +- Omit some `categorical_fields` (LLM generates freely) +- Higher temperature (0.8-0.9) +- Lower `similarity_threshold` (0.75-0.8) + +## Common Issues + +### Low diversity (many duplicates) + +**Causes:** +- Too few samples (<5) +- Temperature too low (<0.5) +- Fields too restrictive + +**Fixes:** +- Add more diverse samples +- Increase temperature to 0.8-0.9 +- Generate more semantic fields +- Increase similarity_threshold to 0.85-0.9 + +### Unrealistic outputs + +**Causes:** +- Dependencies not specified +- Numeric hints too broad +- Temperature too high (>0.95) + +**Fixes:** +- Add dependencies config +- Provide numeric_fields for constraints +- Reduce temperature to 0.7-0.8 +- Include exemplar samples matching target patterns + +### LLM errors (invalid JSON) + +**Causes:** +- max_tokens too low (truncated JSON) +- Complex nested structures + +**Fixes:** +- Increase max_tokens to 200-300 +- Simplify fields (fewer nested objects) +- SemanticInfiller handles markdown wrappers automatically + +### Missing embeddings + +**Cause:** Embedding model not configured + +**Behavior:** DuplicateRemover marks all as `is_duplicate: false` + +**Fix:** Configure default embedding model in Settings page + +## Example Workflow + +**Goal:** Generate 100 synthetic user profiles + +**Step 1: Prepare samples (6 examples)** +```json +[ + {"plan": "Free", "role": "Viewer", "storage": 1, "bio": "Student learning"}, + {"plan": "Free", "role": "Viewer", "storage": 2, "bio": "Just exploring"}, + {"plan": "Pro", "role": "Editor", "storage": 50, "bio": "Freelance designer"}, + {"plan": "Pro", "role": "Editor", "storage": 75, "bio": "Agency owner"}, + {"plan": "Pro", "role": "Admin", "storage": 100, "bio": "Team lead"}, + {"plan": "Enterprise", "role": "Admin", "storage": 500, "bio": "CTO"} +] +``` + +**Step 2: Create pipeline from template** +```bash +curl -X POST http://localhost:8000/api/pipelines/from_template/data_augmentation \ + -H "Content-Type: application/json" \ + -d '{"name": "User Profile Augmentation"}' +``` + +**Step 3: Start generation** +```bash +curl -X POST http://localhost:8000/api/generate \ + -F "file=@seed_data_augmentation.json" \ + -F "pipeline_id=1" +``` + +**Step 4: Monitor progress** +```bash +# Poll job status +curl http://localhost:8000/api/jobs/1 +``` + +**Step 5: Review and export** +```bash +# Export unique records only +curl http://localhost:8000/api/export?job_id=1 | jq 'select(.is_duplicate == false)' > unique_users.jsonl +``` + +**Result:** 100 synthetic user profiles preserving original distributions + +> **Tip:** For large datasets, start with 20 records to verify quality before scaling up. + +## Related Documentation + +- [Templates Overview](templates) - All available templates +- [How to Use](how_to_use) - Running pipelines with templates +- [Custom Blocks](how_to_create_blocks) - Understanding multiplier blocks +- [StructureSampler Block](how_to_create_blocks#structuresampler) +- [SemanticInfiller Block](how_to_create_blocks#semanticinfiller) +- [DuplicateRemover Block](how_to_create_blocks#duplicateremover) From 36f1827d3b32eae935a99ad28a7256a80ee6e46b Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 21:28:49 +0100 Subject: [PATCH 06/19] wip: setup tests --- Makefile | 6 + pyproject.toml | 5 + scripts/with_server.py | 165 ++++++++++ tests/e2e/README.md | 317 ++++++++++++++++++++ tests/e2e/__init__.py | 0 tests/e2e/fixtures/classification_seed.json | 16 + tests/e2e/fixtures/qa_seed.json | 18 ++ tests/e2e/fixtures/sample_markdown.md | 13 + tests/e2e/fixtures/simple_seed.json | 20 ++ tests/e2e/run_all_tests.sh | 84 ++++++ tests/e2e/test_generator_e2e.py | 272 +++++++++++++++++ tests/e2e/test_helpers.py | 67 +++++ tests/e2e/test_pipelines_e2e.py | 246 +++++++++++++++ tests/e2e/test_review_e2e.py | 311 +++++++++++++++++++ 14 files changed, 1540 insertions(+) create mode 100644 scripts/with_server.py create mode 100644 tests/e2e/README.md create mode 100644 tests/e2e/__init__.py create mode 100644 tests/e2e/fixtures/classification_seed.json create mode 100644 tests/e2e/fixtures/qa_seed.json create mode 100644 tests/e2e/fixtures/sample_markdown.md create mode 100644 tests/e2e/fixtures/simple_seed.json create mode 100644 tests/e2e/run_all_tests.sh create mode 100644 tests/e2e/test_generator_e2e.py create mode 100644 tests/e2e/test_helpers.py create mode 100644 tests/e2e/test_pipelines_e2e.py create mode 100644 tests/e2e/test_review_e2e.py diff --git a/Makefile b/Makefile index 7a875a9..8233cf5 100644 --- a/Makefile +++ b/Makefile @@ -91,6 +91,12 @@ test: test-integration: uv run pytest -m integration -v +test-e2e: + ./tests/e2e/run_all_tests.sh + +test-e2e-ui: + ./tests/e2e/run_all_tests.sh --ui + pre-merge: format-all lint-all typecheck-all test @echo "✅ Pre-merge checks completed successfully. Ready to merge!" diff --git a/pyproject.toml b/pyproject.toml index 700b299..3769038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,8 @@ asyncio_mode = "strict" markers = [ "integration: integration tests requiring external services (ollama, etc) - only run when explicitly called", ] + +[dependency-groups] +dev = [ + "playwright>=1.57.0", +] diff --git a/scripts/with_server.py b/scripts/with_server.py new file mode 100644 index 0000000..46c3c74 --- /dev/null +++ b/scripts/with_server.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Server lifecycle manager for e2e testing. +Starts servers, waits for readiness, runs tests, and cleans up. + +Usage: + python scripts/with_server.py --server "backend command" --port 8000 \\ + --server "frontend command" --port 5173 \\ + -- python test_script.py +""" + +import argparse +import subprocess +import sys +import time +import signal +import urllib.request +import urllib.error +from typing import List, Tuple +import os + + +class ServerManager: + def __init__(self, servers: List[Tuple[str, int]], max_wait: int = 60): + self.servers = servers + self.max_wait = max_wait + self.processes = [] + + def start_servers(self): + """start all servers""" + print("starting servers...") + for cmd, port in self.servers: + print(f" starting: {cmd} (port {port})") + # use shell=True to support commands with cd and && + proc = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid, # create process group for cleanup + ) + self.processes.append((proc, port)) + + def wait_for_ready(self): + """wait for all servers to be ready""" + print("waiting for servers to be ready...") + for proc, port in self.processes: + url = f"http://localhost:{port}" + if port == 8000: + url = f"{url}/health" # backend health endpoint + + start_time = time.time() + while time.time() - start_time < self.max_wait: + try: + with urllib.request.urlopen(url, timeout=2) as response: + if response.status == 200: + print(f" server on port {port} is ready") + break + except (urllib.error.URLError, TimeoutError): + time.sleep(1) + else: + print(f" timeout waiting for server on port {port}") + self.cleanup() + sys.exit(1) + + def cleanup(self): + """stop all servers""" + print("stopping servers...") + for proc, port in self.processes: + try: + # kill process group to clean up child processes + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + proc.wait(timeout=5) + print(f" stopped server on port {port}") + except Exception as e: + print(f" error stopping server on port {port}: {e}") + try: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except: + pass + + def run_command(self, command: List[str]) -> int: + """run test command and return exit code""" + print(f"running: {' '.join(command)}") + try: + result = subprocess.run(command) + return result.returncode + except KeyboardInterrupt: + print("\ninterrupted by user") + return 130 + + +def main(): + parser = argparse.ArgumentParser( + description="Start servers, run command, and cleanup", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # single server + python scripts/with_server.py --server "npm run dev" --port 5173 -- python test.py + + # multiple servers + python scripts/with_server.py \\ + --server "cd backend && python server.py" --port 3000 \\ + --server "cd frontend && npm run dev" --port 5173 \\ + -- python test.py + """, + ) + + parser.add_argument( + "--server", + action="append", + dest="servers", + help="server command (can be specified multiple times)", + ) + parser.add_argument( + "--port", + action="append", + dest="ports", + type=int, + help="server port (must match --server order)", + ) + parser.add_argument( + "--max-wait", + type=int, + default=60, + help="max seconds to wait for servers (default: 60)", + ) + parser.add_argument("command", nargs=argparse.REMAINDER, help="command to run") + + args = parser.parse_args() + + # validate arguments + if not args.servers or not args.ports: + parser.error("at least one --server and --port required") + + if len(args.servers) != len(args.ports): + parser.error("number of --server and --port must match") + + # strip leading '--' from command if present + command = args.command + if command and command[0] == "--": + command = command[1:] + + if not command: + parser.error("command to run is required after --") + + # create server list + servers = list(zip(args.servers, args.ports)) + + # run with server lifecycle management + manager = ServerManager(servers, max_wait=args.max_wait) + + try: + manager.start_servers() + manager.wait_for_ready() + exit_code = manager.run_command(command) + finally: + manager.cleanup() + + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/README.md b/tests/e2e/README.md new file mode 100644 index 0000000..81e0c7e --- /dev/null +++ b/tests/e2e/README.md @@ -0,0 +1,317 @@ +# DataGenFlow E2E Tests + +end-to-end tests for the DataGenFlow application using Playwright. + +## Overview + +these tests verify the full application stack (backend + frontend) by simulating real user interactions in a browser. they cover the main user workflows: + +- **pipelines**: create, edit, delete pipelines +- **generator**: upload seeds, start jobs, monitor progress +- **review**: view records, update status, export data + +## Setup + +### 1. Install dependencies + +```bash +# install dev dependencies (includes playwright) +uv sync --dev + +# install chromium browser for playwright +uv run playwright install chromium +``` + +### 2. Verify servers can start + +make sure both backend and frontend can start: + +```bash +# test backend (port 8000) +uv run uvicorn app:app --reload --host 0.0.0.0 --port 8000 + +# test frontend (port 5173, in another terminal) +cd frontend && yarn dev +``` + +## Running Tests + +### Quick start + +```bash +# using make (recommended) +make test-e2e # run all tests (headless mode) +make test-e2e-ui # run all tests with visible browser UI + +# or directly +./tests/e2e/run_all_tests.sh # headless mode +./tests/e2e/run_all_tests.sh --ui # visible browser UI +``` + +### Using the server helper (recommended) + +the `scripts/with_server.py` helper automatically manages server lifecycle: + +```bash +# run all e2e tests with server management (headless) +python scripts/with_server.py \ + --server "uv run uvicorn app:app --host 0.0.0.0 --port 8000" --port 8000 \ + --server "cd frontend && yarn dev" --port 5173 \ + -- python tests/e2e/test_pipelines_e2e.py + +# run with visible browser UI +E2E_HEADLESS=false python scripts/with_server.py \ + --server "uv run uvicorn app:app --host 0.0.0.0 --port 8000" --port 8000 \ + --server "cd frontend && yarn dev" --port 5173 \ + -- python tests/e2e/test_pipelines_e2e.py +``` + +### Run specific test suites + +```bash +# pipelines tests +python scripts/with_server.py \ + --server "uv run uvicorn app:app --host 0.0.0.0 --port 8000" --port 8000 \ + --server "cd frontend && yarn dev" --port 5173 \ + -- python tests/e2e/test_pipelines_e2e.py + +# generator tests +python scripts/with_server.py \ + --server "uv run uvicorn app:app --host 0.0.0.0 --port 8000" --port 8000 \ + --server "cd frontend && yarn dev" --port 5173 \ + -- python tests/e2e/test_generator_e2e.py + +# review tests +python scripts/with_server.py \ + --server "uv run uvicorn app:app --host 0.0.0.0 --port 8000" --port 8000 \ + --server "cd frontend && yarn dev" --port 5173 \ + -- python tests/e2e/test_review_e2e.py +``` + +### Manual testing (servers already running) + +if you already have servers running, you can run tests directly: + +```bash +# start servers in separate terminals first +# terminal 1 +make dev-backend + +# terminal 2 +make dev-ui + +# terminal 3 - run tests +python tests/e2e/test_pipelines_e2e.py +python tests/e2e/test_generator_e2e.py +python tests/e2e/test_review_e2e.py +``` + +## Test Structure + +``` +tests/e2e/ +├── README.md # this file +├── test_helpers.py # database cleanup utilities +├── fixtures/ # test data +│ ├── simple_seed.json # basic seed file +│ ├── qa_seed.json # qa generation seed +│ ├── classification_seed.json # classification seed +│ └── sample_markdown.md # markdown multiplier test +├── test_pipelines_e2e.py # pipeline workflows (with cleanup) +├── test_generator_e2e.py # generation workflows +└── test_review_e2e.py # review workflows +``` + +## Database Cleanup + +the **pipelines tests** automatically clean the database before and after running to ensure test isolation: + +- **before tests**: deletes all pipelines, jobs, and records +- **after tests**: cleans up any created data + +this ensures each test run starts with a clean state and doesn't interfere with your production data. + +## Test Coverage + +### test_pipelines_e2e.py +- ✓ pipelines page loads +- ✓ view templates +- ✓ create pipeline from template +- ✓ delete pipeline +- ✓ pipeline editor opens + +### test_generator_e2e.py +- ✓ generator page loads +- ✓ select pipeline +- ✓ upload seed file +- ✓ start generation job +- ✓ job progress monitoring + +### test_review_e2e.py +- ✓ review page loads +- ✓ select job +- ✓ view records +- ✓ update record status +- ✓ expand trace +- ✓ delete records +- ✓ export records + +## Debugging + +### Screenshots + +all tests save screenshots to `/tmp/` for debugging: +- `/tmp/pipelines_page.png` +- `/tmp/templates_view.png` +- `/tmp/pipeline_created.png` +- `/tmp/generator_page.png` +- `/tmp/job_started.png` +- etc. + +### Browser visibility + +to see the browser during tests: + +```bash +# using run script +./tests/e2e/run_all_tests.sh --ui + +# using environment variable +E2E_HEADLESS=false python tests/e2e/test_pipelines_e2e.py + +# or export it for the session +export E2E_HEADLESS=false +python tests/e2e/test_pipelines_e2e.py +``` + +the tests will automatically detect the `E2E_HEADLESS` environment variable: +- `E2E_HEADLESS=false` → visible browser (chromium UI) +- `E2E_HEADLESS=true` or unset → headless mode (default) + +### Slow down execution + +add delays to observe actions: + +```python +import time +time.sleep(2) # wait 2 seconds +``` + +## Writing New Tests + +follow the webapp-testing skill patterns: + +1. **wait for networkidle** after page load: +```python +page.goto("http://localhost:5173") +page.wait_for_load_state("networkidle") +``` + +2. **use descriptive selectors**: +```python +# good - semantic selectors +page.get_by_role("button").filter(has_text="Create") +page.get_by_text("Pipeline", exact=False) + +# avoid - fragile css +page.locator("#btn-123") +``` + +3. **take screenshots** for debugging: +```python +page.screenshot(path="/tmp/debug.png", full_page=True) +``` + +4. **add appropriate waits**: +```python +time.sleep(1) # wait for animation +page.wait_for_selector(".record-card") # wait for element +``` + +## Fixtures + +test fixtures are in `fixtures/`: + +- `simple_seed.json`: basic text generation (2 variations) +- `qa_seed.json`: question-answer generation (5 total) +- `classification_seed.json`: text classification (2 samples) +- `sample_markdown.md`: markdown multiplier test + +use fixtures in tests: + +```python +seed_path = "tests/e2e/fixtures/simple_seed.json" +file_input.set_input_files(seed_path) +``` + +## Troubleshooting + +### servers don't start +- check ports 8000 and 5173 are not in use +- verify `uv` and `yarn` are installed +- check backend/frontend dependencies installed + +### tests fail with timeout +- increase `max_wait` in with_server.py +- add longer waits in tests +- check browser console for errors + +### elements not found +- take screenshots to see actual page state +- use browser devtools to find correct selectors +- add wait time for dynamic content + +### cleanup issues +- servers may not stop cleanly, use `killall uvicorn` or `killall node` +- remove test database: `rm data/qa_records.db` + +## CI/CD Integration + +example github actions workflow: + +```yaml +name: E2E Tests + +on: [push, pull_request] + +jobs: + e2e: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Install uv + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Install dependencies + run: | + uv venv && uv sync + cd frontend && yarn install + + - name: Install Playwright + run: | + uv pip install playwright + uv run playwright install chromium + + - name: Run E2E tests + run: | + python scripts/with_server.py \ + --server "uv run uvicorn app:app --host 0.0.0.0 --port 8000" --port 8000 \ + --server "cd frontend && yarn dev" --port 5173 \ + -- python tests/e2e/test_pipelines_e2e.py +``` + +## Best Practices + +1. **keep tests independent**: each test should work standalone +2. **clean up state**: delete created pipelines/jobs after tests +3. **use fixtures**: reuse seed files from `fixtures/` +4. **handle async**: wait for network requests to complete +5. **screenshot failures**: capture state when tests fail +6. **descriptive names**: test names should describe what they verify + +## Resources + +- [Playwright Documentation](https://playwright.dev/python/) +- [webapp-testing skill](/.claude/plugins/cache/anthropic-agent-skills/document-skills/.../webapp-testing/) +- [DataGenFlow API docs](/DEVELOPERS.md) diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/e2e/fixtures/classification_seed.json b/tests/e2e/fixtures/classification_seed.json new file mode 100644 index 0000000..70241dc --- /dev/null +++ b/tests/e2e/fixtures/classification_seed.json @@ -0,0 +1,16 @@ +[ + { + "repetitions": 1, + "metadata": { + "text": "This is a positive review of a great product.", + "categories": ["positive", "negative", "neutral"] + } + }, + { + "repetitions": 1, + "metadata": { + "text": "The service was terrible and disappointing.", + "categories": ["positive", "negative", "neutral"] + } + } +] diff --git a/tests/e2e/fixtures/qa_seed.json b/tests/e2e/fixtures/qa_seed.json new file mode 100644 index 0000000..8768303 --- /dev/null +++ b/tests/e2e/fixtures/qa_seed.json @@ -0,0 +1,18 @@ +[ + { + "repetitions": 3, + "metadata": { + "domain": "science", + "difficulty": "medium", + "question_type": "factual" + } + }, + { + "repetitions": 2, + "metadata": { + "domain": "history", + "difficulty": "easy", + "question_type": "conceptual" + } + } +] diff --git a/tests/e2e/fixtures/sample_markdown.md b/tests/e2e/fixtures/sample_markdown.md new file mode 100644 index 0000000..93f56d4 --- /dev/null +++ b/tests/e2e/fixtures/sample_markdown.md @@ -0,0 +1,13 @@ +# Sample Document + +## Section 1 +This is the first section of the document. +It contains some text that will be processed. + +## Section 2 +This is the second section. +It has different content for variety. + +## Section 3 +Final section with concluding remarks. +Testing markdown multiplier functionality. diff --git a/tests/e2e/fixtures/simple_seed.json b/tests/e2e/fixtures/simple_seed.json new file mode 100644 index 0000000..2469fb4 --- /dev/null +++ b/tests/e2e/fixtures/simple_seed.json @@ -0,0 +1,20 @@ +[ + { + "repetitions": 2, + "metadata": { + "topic": "artificial intelligence", + "role": "teacher", + "system": "You are a {{ role }}.", + "user": "Explain {{ topic }} in simple terms." + } + }, + { + "repetitions": 1, + "metadata": { + "topic": "machine learning", + "role": "expert", + "system": "You are a {{ role }}.", + "user": "Describe {{ topic }} with examples." + } + } +] diff --git a/tests/e2e/run_all_tests.sh b/tests/e2e/run_all_tests.sh new file mode 100644 index 0000000..e783cc6 --- /dev/null +++ b/tests/e2e/run_all_tests.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# run all e2e tests with server management + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +# parse arguments +HEADLESS=true +while [[ $# -gt 0 ]]; do + case $1 in + --ui) + HEADLESS=false + shift + ;; + *) + echo "Unknown option: $1" + echo "Usage: $0 [--ui]" + echo " --ui Run tests with visible browser (chromium UI)" + exit 1 + ;; + esac +done + +# set headless mode +if [ "$HEADLESS" = "false" ]; then + export E2E_HEADLESS=false + echo "🖥️ Running tests with visible browser UI" +else + export E2E_HEADLESS=true + echo "🤖 Running tests in headless mode" +fi + +echo "🧪 Running DataGenFlow E2E Tests" +echo "================================" +echo "" + +# check if playwright is installed +if ! uv run python -c "import playwright" 2>/dev/null; then + echo "❌ Playwright not installed" + echo "Install with: uv pip install playwright && uv run playwright install chromium" + exit 1 +fi + +echo "✓ Playwright installed" +echo "" + +# define server commands +BACKEND_CMD="uv run uvicorn app:app --host 0.0.0.0 --port 8000" +FRONTEND_CMD="cd frontend && yarn dev" + +# run each test suite +echo "📋 Test Suite 1: Pipelines" +echo "-------------------------" +uv run python "$PROJECT_ROOT/scripts/with_server.py" \ + --server "$BACKEND_CMD" --port 8000 \ + --server "$FRONTEND_CMD" --port 5173 \ + -- uv run python "$SCRIPT_DIR/test_pipelines_e2e.py" +echo "" + +echo "📋 Test Suite 2: Generator" +echo "-------------------------" +uv run python "$PROJECT_ROOT/scripts/with_server.py" \ + --server "$BACKEND_CMD" --port 8000 \ + --server "$FRONTEND_CMD" --port 5173 \ + -- uv run python "$SCRIPT_DIR/test_generator_e2e.py" +echo "" + +echo "📋 Test Suite 3: Review" +echo "-------------------------" +uv run python "$PROJECT_ROOT/scripts/with_server.py" \ + --server "$BACKEND_CMD" --port 8000 \ + --server "$FRONTEND_CMD" --port 5173 \ + -- uv run python "$SCRIPT_DIR/test_review_e2e.py" +echo "" + +echo "✅ All E2E tests completed!" +echo "" +echo "📸 Screenshots saved to /tmp/" +echo " - /tmp/pipelines_page.png" +echo " - /tmp/generator_page.png" +echo " - /tmp/review_page.png" +echo " - ... and more" diff --git a/tests/e2e/test_generator_e2e.py b/tests/e2e/test_generator_e2e.py new file mode 100644 index 0000000..9bbca2e --- /dev/null +++ b/tests/e2e/test_generator_e2e.py @@ -0,0 +1,272 @@ +""" +e2e tests for generator page. +tests job creation, file upload, and progress monitoring workflows. +""" + +from playwright.sync_api import sync_playwright, expect +import time +import json +import os +from test_helpers import cleanup_database, wait_for_server, get_headless_mode + + +def test_generator_page_loads(): + """verify generator page loads successfully""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # navigate to generator page (default route) + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # verify we're on generator page by checking heading + heading = page.get_by_role("heading", name="Generate Records") + expect(heading).to_be_visible() + + # take screenshot + page.screenshot(path="/tmp/generator_page.png", full_page=True) + + browser.close() + + +def test_select_pipeline(): + """test selecting a pipeline from dropdown""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # go to generator page (default route) + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + time.sleep(1) + + # find pipeline selector (dropdown or select) + selectors = page.locator('select, [role="combobox"]').all() + + if len(selectors) > 0: + # click first selector + selectors[0].click() + time.sleep(0.5) + + # select first option (if it's a select element) + if selectors[0].evaluate("el => el.tagName") == "SELECT": + options = selectors[0].locator("option").all() + if len(options) > 1: # skip "select pipeline" placeholder + selectors[0].select_option(index=1) + else: + # for custom dropdowns, click first item + items = page.locator('[role="option"]').all() + if len(items) > 0: + items[0].click() + + time.sleep(1) + + # take screenshot + page.screenshot(path="/tmp/pipeline_selected.png", full_page=True) + + browser.close() + + +def test_upload_seed_file(): + """test uploading a seed JSON file""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # create test seed file + seed_data = [ + { + "repetitions": 1, + "metadata": { + "topic": "artificial intelligence", + "role": "teacher", + }, + } + ] + + seed_path = "/tmp/test_seed.json" + with open(seed_path, "w") as f: + json.dump(seed_data, f) + + # go to generator page + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + time.sleep(1) + + # select pipeline + selectors = page.locator("select").all() + if len(selectors) > 0: + selectors[0].select_option(index=1) + time.sleep(1) + + # find file input + file_inputs = page.locator('input[type="file"]').all() + + if len(file_inputs) > 0: + # upload file + file_inputs[0].set_input_files(seed_path) + time.sleep(1) + + # verify file name appears or upload succeeds + page.screenshot(path="/tmp/file_uploaded.png", full_page=True) + + # cleanup + os.remove(seed_path) + + browser.close() + + +def test_start_generation_job(): + """test starting a generation job""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # create test seed file + seed_data = [ + { + "repetitions": 1, + "metadata": { + "topic": "machine learning", + "role": "expert", + }, + } + ] + + seed_path = "/tmp/test_seed_job.json" + with open(seed_path, "w") as f: + json.dump(seed_data, f) + + # go to generator page + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + time.sleep(1) + + # select pipeline + selectors = page.locator("select").all() + if len(selectors) > 0: + selectors[0].select_option(index=1) + time.sleep(1) + + # upload file + file_inputs = page.locator('input[type="file"]').all() + if len(file_inputs) > 0: + file_inputs[0].set_input_files(seed_path) + time.sleep(1) + + # find and click generate/start button + generate_buttons = ( + page.get_by_role("button") + .filter(has_text="Generate") + .or_(page.get_by_role("button").filter(has_text="Start")) + ) + + if generate_buttons.count() > 0: + generate_buttons.first.click() + + # wait for job to start + time.sleep(3) + page.wait_for_load_state("networkidle") + + # verify job progress appears + # look for progress indicators + progress_elements = page.get_by_text("Progress", exact=False).or_( + page.get_by_text("Generated", exact=False) + ) + + # take screenshot + page.screenshot(path="/tmp/job_started.png", full_page=True) + + # cleanup + os.remove(seed_path) + + browser.close() + + +def test_job_progress_monitoring(): + """test that job progress updates are visible""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # look for job progress section (may or may not be running) + # check for progress bar, percentage, or status indicators + progress_indicators = ( + page.locator('[role="progressbar"]') + .or_(page.locator(".progress, .Progress")) + .or_(page.get_by_text("%", exact=False)) + ) + + # if a job is running, progress should be visible + # otherwise, the page should show upload/generate UI + page.screenshot(path="/tmp/job_progress.png", full_page=True) + + browser.close() + + +def setup_test_pipeline(): + """create a pipeline from template for tests""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # navigate to pipelines page + pipelines_link = page.get_by_text("Pipelines", exact=True) + pipelines_link.click() + page.wait_for_load_state("networkidle") + time.sleep(2) + + # create pipeline from first template + create_buttons = page.get_by_role("button").filter(has_text="Use Template") + if create_buttons.count() > 0: + create_buttons.first.click() + time.sleep(2) + page.wait_for_load_state("networkidle") + + browser.close() + + +if __name__ == "__main__": + print("running generator e2e tests...") + + # setup: create a pipeline for generator tests + print("\nsetup: creating test pipeline...") + wait_for_server() + cleanup_database() + setup_test_pipeline() + print("✓ test pipeline created") + + print("\ntest 1: generator page loads") + test_generator_page_loads() + print("✓ passed") + + print("\ntest 2: select pipeline") + test_select_pipeline() + print("✓ passed") + + print("\ntest 3: upload seed file") + test_upload_seed_file() + print("✓ passed") + + print("\ntest 4: start generation job") + test_start_generation_job() + print("✓ passed") + + print("\ntest 5: job progress monitoring") + test_job_progress_monitoring() + print("✓ passed") + + # cleanup after tests + print("\ncleaning up...") + cleanup_database() + print("✓ cleanup complete") + + print("\n✅ all generator e2e tests passed!") diff --git a/tests/e2e/test_helpers.py b/tests/e2e/test_helpers.py new file mode 100644 index 0000000..e3a6666 --- /dev/null +++ b/tests/e2e/test_helpers.py @@ -0,0 +1,67 @@ +""" +helper functions for e2e tests. +handles database cleanup and initialization. +""" + +import httpx +import time +import os + + +def get_headless_mode(): + """get headless mode from environment variable""" + return os.getenv("E2E_HEADLESS", "true").lower() in ("true", "1", "yes") + + +def cleanup_database(): + """delete all pipelines, jobs, and records from the database""" + base_url = "http://localhost:8000" + + try: + # delete all records + httpx.delete(f"{base_url}/api/records", timeout=10.0) + + # get all pipelines + response = httpx.get(f"{base_url}/api/pipelines", timeout=10.0) + if response.status_code == 200: + pipelines = response.json() + + # delete each pipeline + for pipeline in pipelines: + httpx.delete( + f"{base_url}/api/pipelines/{pipeline['id']}", + timeout=10.0 + ) + + time.sleep(0.5) # wait for cleanup to complete + + except Exception as e: + print(f"cleanup warning: {e}") + + +def wait_for_server(url: str = "http://localhost:8000/health", timeout: int = 30): + """wait for server to be ready""" + import urllib.request + import urllib.error + + start_time = time.time() + while time.time() - start_time < timeout: + try: + with urllib.request.urlopen(url, timeout=2) as response: + if response.status == 200: + return True + except (urllib.error.URLError, TimeoutError): + time.sleep(1) + + return False + + +def get_pipeline_count(): + """get number of pipelines in database""" + try: + response = httpx.get("http://localhost:8000/api/pipelines", timeout=10.0) + if response.status_code == 200: + return len(response.json()) + except: + pass + return -1 diff --git a/tests/e2e/test_pipelines_e2e.py b/tests/e2e/test_pipelines_e2e.py new file mode 100644 index 0000000..7b552f1 --- /dev/null +++ b/tests/e2e/test_pipelines_e2e.py @@ -0,0 +1,246 @@ +""" +e2e tests for pipelines page. +tests pipeline creation, editing, and deletion workflows. +""" + +from playwright.sync_api import sync_playwright, expect +import time +from test_helpers import cleanup_database, wait_for_server, get_headless_mode + + +def test_pipelines_page_loads(): + """verify pipelines page loads successfully""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # navigate to pipelines page via sidebar + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # click pipelines in sidebar + pipelines_link = page.get_by_text("Pipelines", exact=True) + pipelines_link.click() + page.wait_for_load_state("networkidle") + time.sleep(1) + + # verify page title + expect(page).to_have_title("DataGenFlow") + + # take screenshot for debugging + page.screenshot(path="/tmp/pipelines_page.png", full_page=True) + + browser.close() + + +def test_view_templates(): + """verify pipeline templates are displayed""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # click pipelines in sidebar + pipelines_link = page.get_by_text("Pipelines", exact=True) + pipelines_link.click() + page.wait_for_load_state("networkidle") + time.sleep(2) + + # check for template-related content or buttons + # look for "Use Template" buttons or template names + use_template_buttons = page.get_by_role("button").filter( + has_text="Use Template" + ).or_(page.get_by_role("button").filter(has_text="Create from Template")) + + # take screenshot first for debugging + page.screenshot(path="/tmp/templates_view.png", full_page=True) + + # if templates exist, there should be use template buttons + # otherwise page should at least load without error + print(f"found {use_template_buttons.count()} template buttons") + + browser.close() + + +def test_create_pipeline_from_template(): + """test creating a pipeline from a template""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # navigate to pipelines page + pipelines_link = page.get_by_text("Pipelines", exact=True) + pipelines_link.click() + page.wait_for_load_state("networkidle") + time.sleep(2) + + # find and click the first template's create button + # look for buttons with text "Use Template" or similar + create_buttons = page.get_by_role("button").filter( + has_text="Use Template" + ).or_(page.get_by_role("button").filter(has_text="Create")) + + if create_buttons.count() > 0: + first_button = create_buttons.first + first_button.click() + + # wait for pipeline to be created (modal or redirect) + time.sleep(2) + page.wait_for_load_state("networkidle") + + # verify success - check for "My Pipelines" heading using role + pipelines_heading = page.get_by_role("heading", name="My Pipelines") + expect(pipelines_heading).to_be_visible() + + # take screenshot + page.screenshot(path="/tmp/pipeline_created.png", full_page=True) + else: + print("no template buttons found, skipping test") + + browser.close() + + +def test_delete_pipeline(): + """test deleting a pipeline""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # navigate to pipelines page + pipelines_link = page.get_by_text("Pipelines", exact=True) + pipelines_link.click() + page.wait_for_load_state("networkidle") + time.sleep(2) + + # first create a pipeline from template + create_buttons = page.get_by_role("button").filter(has_text="Use Template") + if create_buttons.count() > 0: + create_buttons.first.click() + time.sleep(2) + page.wait_for_load_state("networkidle") + + # find delete button (trash icon or delete text) + # might be in a pipeline card or row + delete_buttons = ( + page.get_by_role("button") + .filter(has_text="Delete") + .or_(page.locator('button[aria-label*="Delete"]')) + .or_(page.locator('button:has(svg)')) + ) + + initial_count = delete_buttons.count() + + if initial_count > 0: + # click first delete button + delete_buttons.first.click() + + # handle confirmation dialog if present + time.sleep(0.5) + + # look for confirm button in dialog + confirm_buttons = ( + page.get_by_role("button") + .filter(has_text="Confirm") + .or_(page.get_by_role("button").filter(has_text="Delete")) + ) + + if confirm_buttons.count() > 0: + confirm_buttons.first.click() + + # wait for deletion + time.sleep(1) + page.wait_for_load_state("networkidle") + + # take screenshot + page.screenshot(path="/tmp/pipeline_deleted.png", full_page=True) + + browser.close() + + +def test_pipeline_editor_opens(): + """test that pipeline editor modal opens""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173") + page.wait_for_load_state("networkidle") + + # navigate to pipelines page + pipelines_link = page.get_by_text("Pipelines", exact=True) + pipelines_link.click() + page.wait_for_load_state("networkidle") + time.sleep(2) + + # create a pipeline first + create_buttons = page.get_by_role("button").filter(has_text="Use Template") + if create_buttons.count() > 0: + create_buttons.first.click() + time.sleep(2) + page.wait_for_load_state("networkidle") + + # find edit button (pencil icon, edit text, or gear icon) + edit_buttons = ( + page.get_by_role("button") + .filter(has_text="Edit") + .or_(page.locator('button[aria-label*="Edit"]')) + ) + + if edit_buttons.count() > 0: + edit_buttons.first.click() + time.sleep(1) + + # verify modal/editor opened (reactflow canvas should be visible) + # look for reactflow container or canvas elements + canvas = page.locator(".react-flow, [data-reactflow], canvas").first + expect(canvas).to_be_visible(timeout=5000) + + # take screenshot + page.screenshot(path="/tmp/pipeline_editor.png", full_page=True) + + browser.close() + + +if __name__ == "__main__": + print("running pipelines e2e tests...") + + # clean database before tests + print("\ncleaning database...") + wait_for_server() + cleanup_database() + print("✓ database cleaned") + + print("\ntest 1: pipelines page loads") + test_pipelines_page_loads() + print("✓ passed") + + print("\ntest 2: view templates") + test_view_templates() + print("✓ passed") + + print("\ntest 3: create pipeline from template") + test_create_pipeline_from_template() + print("✓ passed") + + print("\ntest 4: delete pipeline") + test_delete_pipeline() + print("✓ passed") + + print("\ntest 5: pipeline editor opens") + test_pipeline_editor_opens() + print("✓ passed") + + # clean database after tests + print("\ncleaning up...") + cleanup_database() + print("✓ cleanup complete") + + print("\n✅ all pipelines e2e tests passed!") diff --git a/tests/e2e/test_review_e2e.py b/tests/e2e/test_review_e2e.py new file mode 100644 index 0000000..31df757 --- /dev/null +++ b/tests/e2e/test_review_e2e.py @@ -0,0 +1,311 @@ +""" +e2e tests for review page. +tests record viewing, status updates, deletion, and export workflows. +""" + +from playwright.sync_api import sync_playwright, expect +import time +from test_helpers import get_headless_mode, cleanup_database, wait_for_server + + +def test_review_page_loads(): + """verify review page loads successfully""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + # navigate to review page + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + + # verify we're on review page + # look for job selector or records section + elements = page.get_by_text("Select Job", exact=False).or_( + page.get_by_text("Records", exact=False) + ) + + # take screenshot + page.screenshot(path="/tmp/review_page.png", full_page=True) + + browser.close() + + +def test_select_job(): + """test selecting a job from dropdown""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # find job selector (dropdown or select) + selectors = page.locator('select, [role="combobox"]').all() + + if len(selectors) > 0: + # click first selector + selectors[0].click() + time.sleep(0.5) + + # select first option (if options exist) + if selectors[0].evaluate("el => el.tagName") == "SELECT": + options = selectors[0].locator("option").all() + if len(options) > 1: # skip placeholder + selectors[0].select_option(index=1) + time.sleep(1) + + # take screenshot + page.screenshot(path="/tmp/job_selected.png", full_page=True) + + browser.close() + + +def test_view_records(): + """test viewing generated records""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # select a job if selector exists + selectors = page.locator("select").all() + if len(selectors) > 0: + options = selectors[0].locator("option").all() + if len(options) > 1: + selectors[0].select_option(index=1) + time.sleep(2) + + # look for record cards or table rows + records = ( + page.locator(".record-card, [data-record]") + .or_(page.locator(".Box")) + .or_(page.locator("tr")) + ).all() + + # if records exist, verify they're visible + if len(records) > 0: + print(f"found {len(records)} record elements") + + # take screenshot + page.screenshot(path="/tmp/records_view.png", full_page=True) + + browser.close() + + +def test_update_record_status(): + """test updating a record's status""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # select job + selectors = page.locator("select").all() + if len(selectors) > 0: + options = selectors[0].locator("option").all() + if len(options) > 1: + selectors[0].select_option(index=1) + time.sleep(2) + + # find status dropdown in record card + # might be labeled as "pending", "accepted", "rejected" + status_dropdowns = page.locator('select').filter( + has_text="pending" + ).or_(page.locator('[aria-label*="status"]')) + + if status_dropdowns.count() > 0: + # click first status dropdown + status_dropdowns.first.click() + time.sleep(0.5) + + # select "accepted" or another status + status_options = status_dropdowns.first.locator("option").all() + if len(status_options) > 1: + # try to select "accepted" + for option in status_options: + text = option.text_content().lower() + if "accept" in text: + option.click() + break + + time.sleep(1) + + # take screenshot + page.screenshot(path="/tmp/status_updated.png", full_page=True) + + browser.close() + + +def test_expand_trace(): + """test expanding a record's execution trace""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # select job + selectors = page.locator("select").all() + if len(selectors) > 0: + options = selectors[0].locator("option").all() + if len(options) > 1: + selectors[0].select_option(index=1) + time.sleep(2) + + # find trace toggle button (collapsible) + # might say "Show trace", "View details", or have a chevron icon + trace_buttons = ( + page.get_by_role("button") + .filter(has_text="Trace") + .or_(page.get_by_role("button").filter(has_text="Details")) + .or_(page.locator('button[aria-expanded]')) + ) + + if trace_buttons.count() > 0: + # click to expand + trace_buttons.first.click() + time.sleep(1) + + # verify trace content is visible + # look for block type, execution time, or trace data + trace_content = page.get_by_text("block_type", exact=False).or_( + page.get_by_text("execution_time", exact=False) + ) + + # take screenshot + page.screenshot(path="/tmp/trace_expanded.png", full_page=True) + + browser.close() + + +def test_delete_records(): + """test deleting records""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # select job + selectors = page.locator("select").all() + if len(selectors) > 0: + options = selectors[0].locator("option").all() + if len(options) > 1: + selectors[0].select_option(index=1) + time.sleep(2) + + # find delete button (might say "Delete All" or have trash icon) + delete_buttons = ( + page.get_by_role("button") + .filter(has_text="Delete") + .or_(page.locator('button[aria-label*="Delete"]')) + ) + + if delete_buttons.count() > 0: + # click delete + delete_buttons.first.click() + time.sleep(0.5) + + # handle confirmation dialog + confirm_buttons = ( + page.get_by_role("button") + .filter(has_text="Confirm") + .or_(page.get_by_role("button").filter(has_text="Delete")) + ) + + if confirm_buttons.count() > 0: + confirm_buttons.first.click() + time.sleep(1) + + # take screenshot + page.screenshot(path="/tmp/records_deleted.png", full_page=True) + + browser.close() + + +def test_export_records(): + """test exporting records""" + with sync_playwright() as p: + browser = p.chromium.launch(headless=get_headless_mode()) + page = browser.new_page() + + page.goto("http://localhost:5173/review") + page.wait_for_load_state("networkidle") + time.sleep(2) + + # select job + selectors = page.locator("select").all() + if len(selectors) > 0: + options = selectors[0].locator("option").all() + if len(options) > 1: + selectors[0].select_option(index=1) + time.sleep(2) + + # find export button + export_buttons = ( + page.get_by_role("button") + .filter(has_text="Export") + .or_(page.get_by_role("button").filter(has_text="Download")) + ) + + if export_buttons.count() > 0: + # setup download listener + with page.expect_download(timeout=5000) as download_info: + export_buttons.first.click() + + # verify download started (might timeout if no records exist) + try: + download = download_info.value + print(f"download started: {download.suggested_filename}") + except: + print("no download (may be no records)") + + # take screenshot + page.screenshot(path="/tmp/records_export.png", full_page=True) + + browser.close() + + +if __name__ == "__main__": + print("running review e2e tests...") + + print("\ntest 1: review page loads") + test_review_page_loads() + print("✓ passed") + + print("\ntest 2: select job") + test_select_job() + print("✓ passed") + + print("\ntest 3: view records") + test_view_records() + print("✓ passed") + + print("\ntest 4: update record status") + test_update_record_status() + print("✓ passed") + + print("\ntest 5: expand trace") + test_expand_trace() + print("✓ passed") + + print("\ntest 6: delete records") + test_delete_records() + print("✓ passed") + + print("\ntest 7: export records") + test_export_records() + print("✓ passed") + + print("\n✅ all review e2e tests passed!") From 86f46e38e779f8d1c6f072f92c6b663e08dc66ca Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 21:45:55 +0100 Subject: [PATCH 07/19] add: coderabbit instructions --- .coderabbit.yaml | 135 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 .coderabbit.yaml diff --git a/.coderabbit.yaml b/.coderabbit.yaml new file mode 100644 index 0000000..41b7936 --- /dev/null +++ b/.coderabbit.yaml @@ -0,0 +1,135 @@ +language: en-US +early_access: false +enable_free_tier: true + +reviews: + request_changes_workflow: true + + high_level_summary: true + poem: false + review_status: true + collapse_walkthrough: false + + auto_review: + enabled: true + auto_incremental_review: true + ignore_title_keywords: [] + + # Path-based instructions + path_instructions: + # Backend code + - path: "**/*.py" + instructions: | + Apply backend code review checklist from llm/rules-backend.md: + Identify which llm/*.md files need updates: + - New API endpoints → update llm/state-backend.md + - New blocks → update llm/state-backend.md and llm/state-project.md + - Changed patterns → update relevant llm/state-*.md + Identify if the docs needs updates. + Golden rule: if code cannot be explained in one sentence, it's too complex. + + # Frontend code + - path: "frontend/**/*.{ts,tsx,js,jsx}" + instructions: | + Apply frontend code review checklist from llm/rules-frontend.md: + Identify which llm/*.md files need updates: + - New pages/components → update llm/state-frontend.md + - Changed UI flow → update llm/state-frontend.md + - New patterns → update llm/state-frontend.md + Identify if the docs needs updates. + Golden rule: keep components focused and maintainable. + + # Block implementations + - path: "lib/blocks/**/*.py" + instructions: | + Apply block implementation checklist from .claude/skills/implementing-datagenflow-blocks/SKILL.md: + Identify which llm/*.md files need updates: + - New blocks → update llm/state-backend.md and llm/state-project.md + - Changed block behavior → update relevant llm/state-*.md + Identify if the docs needs updates. + Golden rule: blocks should be single-responsibility and reusable. + # Tests + - path: "tests/**/*.py" + instructions: | + Review test quality: + - One behavior per test + - Test names: test___ + - Error cases tested (not just happy path) + - Proper use of fixtures + - Mocks used appropriately + - Tests are focused and maintainable + + # Documentation files + - path: "llm/**/*.md" + instructions: | + Review documentation updates: + - Changes reflect actual code (not aspirational designs) + - Updates are gradual and incremental (not complete rewrites) + - Technical and concise + - Explain what changed and why + - Note any breaking changes + + # Configuration files + - path: "**/*.{yaml,yml,json,toml}" + instructions: | + Review configuration changes: + - No secrets committed + - Valid syntax + - Changes documented if needed + - Backwards compatible or migration documented + +chat: + auto_reply: true + +knowledge_base: + learnings: + scope: "auto" + + opt_out: false + +tone_instructions: | + Be direct, technical, and concise. Focus on: + 1. Blocking issues (anti-patterns, security, broken tests) - must fix + 2. Code quality violations - should fix + 3. Documentation updates needed - identify which llm/*.md files + 4. Improvements - nice to have + + Use this structure: + + ### Anti-patterns Found + [list blocking issues with file:line, violation, why, fix] + + ### Security Issues + [list security vulnerabilities] + + ### Documentation Updates Required + [list llm/*.md files needing updates with specific sections and reasons] + + ### Code Quality Issues + [list with severity: critical|high|medium|low] + + ### Testing Gaps + [list missing tests] + + ### Recommendations + [optional improvements] + + ### Summary + - anti-patterns: ✓ none | ✗ found (count) + - security: ✓ clean | ✗ issues (count) + - documentation: ✓ current | ⚠ updates needed + - testing: ✓ covered | ⚠ gaps exist + - code quality: ✓ good | ⚠ issues exist + + ### Verdict + [block | request changes | approve] + + Reason: [brief explanation] + + Golden rules: + 1. Anti-patterns are blocking - always reject + 2. Security issues are blocking - always reject + 3. Broken tests are blocking - always reject + 4. llm/* updates required for architecture changes + 5. Simplicity wins - if code is complex, it's wrong + 6. Fail loudly - silent failures are never acceptable From 7dda2123d51bec823a84e67d71362dc796a43220 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 21:48:11 +0100 Subject: [PATCH 08/19] add: coderabbit instructions --- .coderabbit.yaml | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 41b7936..39bdc0b 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -94,38 +94,6 @@ tone_instructions: | 3. Documentation updates needed - identify which llm/*.md files 4. Improvements - nice to have - Use this structure: - - ### Anti-patterns Found - [list blocking issues with file:line, violation, why, fix] - - ### Security Issues - [list security vulnerabilities] - - ### Documentation Updates Required - [list llm/*.md files needing updates with specific sections and reasons] - - ### Code Quality Issues - [list with severity: critical|high|medium|low] - - ### Testing Gaps - [list missing tests] - - ### Recommendations - [optional improvements] - - ### Summary - - anti-patterns: ✓ none | ✗ found (count) - - security: ✓ clean | ✗ issues (count) - - documentation: ✓ current | ⚠ updates needed - - testing: ✓ covered | ⚠ gaps exist - - code quality: ✓ good | ⚠ issues exist - - ### Verdict - [block | request changes | approve] - - Reason: [brief explanation] - Golden rules: 1. Anti-patterns are blocking - always reject 2. Security issues are blocking - always reject From b9b743acef2c0c98c526c27259e4b46b1471dffa Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 21:50:24 +0100 Subject: [PATCH 09/19] add: coderabbit instructions --- .coderabbit.yaml | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 39bdc0b..1051cbc 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -88,16 +88,8 @@ knowledge_base: opt_out: false tone_instructions: | - Be direct, technical, and concise. Focus on: + Be direct, technical, and concise: 1. Blocking issues (anti-patterns, security, broken tests) - must fix 2. Code quality violations - should fix 3. Documentation updates needed - identify which llm/*.md files 4. Improvements - nice to have - - Golden rules: - 1. Anti-patterns are blocking - always reject - 2. Security issues are blocking - always reject - 3. Broken tests are blocking - always reject - 4. llm/* updates required for architecture changes - 5. Simplicity wins - if code is complex, it's wrong - 6. Fail loudly - silent failures are never acceptable From 5aa20e7ec3f8264ff0a8aa93ac8d0d3cc4270615 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sat, 10 Jan 2026 21:51:59 +0100 Subject: [PATCH 10/19] add: coderabbit instructions --- .coderabbit.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 1051cbc..1c529bb 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -91,5 +91,5 @@ tone_instructions: | Be direct, technical, and concise: 1. Blocking issues (anti-patterns, security, broken tests) - must fix 2. Code quality violations - should fix - 3. Documentation updates needed - identify which llm/*.md files + 3. Documentation updates needed 4. Improvements - nice to have From e75bae9d7f66cfac9924191c4864862cddd81c2b Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sun, 11 Jan 2026 00:22:40 +0100 Subject: [PATCH 11/19] fix: review --- .github/pull_request_template.md | 2 +- docs/template_data_augmentation.md | 12 +-- .../pipeline-editor/BlockConfigPanel.tsx | 101 +++++++++++++++++- lib/blocks/builtin/duplicate_remover.py | 18 ++-- lib/blocks/builtin/semantic_infiller.py | 27 +++-- lib/blocks/builtin/structure_sampler.py | 32 ++---- llm/state-project.md | 10 +- tests/blocks/test_duplicate_remover.py | 12 +-- tests/blocks/test_semantic_infiller.py | 32 ++---- tests/blocks/test_structure_sampler.py | 52 ++++----- tests/integration/test_data_augmentation.py | 20 ++-- 11 files changed, 191 insertions(+), 127 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 2ed1de9..a7e9c01 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -23,4 +23,4 @@ Title Format: copy the one of the issue keep this format - [ ] `make format` passes - [ ] `make pre-merge` passes - [ ] PR update from develop branch -- [ ] Copilot review run and addressed +- [ ] Ask CodeRabbit review addressed (comment `@coderabbitai review`) diff --git a/docs/template_data_augmentation.md b/docs/template_data_augmentation.md index e12a053..58f78fc 100644 --- a/docs/template_data_augmentation.md +++ b/docs/template_data_augmentation.md @@ -38,7 +38,7 @@ This template creates realistic synthetic records from sample data while maintai ## Pipeline Architecture -``` +```text ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ Structure │──►│ Semantic │──►│ Duplicate │ │ Sampler │ │ Infiller │ │ Remover │ @@ -360,7 +360,7 @@ similarity_threshold: 0.75 **Goal:** Generate 100 synthetic user profiles -**Step 1: Prepare samples (6 examples)** +### Step 1: Prepare samples (6 examples) ```json [ {"plan": "Free", "role": "Viewer", "storage": 1, "bio": "Student learning"}, @@ -372,27 +372,27 @@ similarity_threshold: 0.75 ] ``` -**Step 2: Create pipeline from template** +### Step 2: Create pipeline from template ```bash curl -X POST http://localhost:8000/api/pipelines/from_template/data_augmentation \ -H "Content-Type: application/json" \ -d '{"name": "User Profile Augmentation"}' ``` -**Step 3: Start generation** +### Step 3: Start generation ```bash curl -X POST http://localhost:8000/api/generate \ -F "file=@seed_data_augmentation.json" \ -F "pipeline_id=1" ``` -**Step 4: Monitor progress** +### Step 4: Monitor progress ```bash # Poll job status curl http://localhost:8000/api/jobs/1 ``` -**Step 5: Review and export** +### Step 5: Review and export ```bash # Export unique records only curl http://localhost:8000/api/export?job_id=1 | jq 'select(.is_duplicate == false)' > unique_users.jsonl diff --git a/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx b/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx index 0e8b4b8..f1bb56d 100644 --- a/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx +++ b/frontend/src/components/pipeline-editor/BlockConfigPanel.tsx @@ -32,6 +32,7 @@ export default function BlockConfigPanel({ const [formData, setFormData] = useState>(config || {}); const { resolvedColorScheme } = useTheme(); const [wordWrap, setWordWrap] = useState(false); + const [jsonMode, setJsonMode] = useState>({}); const [errors, setErrors] = useState>({}); const [panelWidth, setPanelWidth] = useState(400); const [isResizing, setIsResizing] = useState(false); @@ -79,28 +80,38 @@ export default function BlockConfigPanel({ // fetch available LLM and embedding models useEffect(() => { + const controller = new AbortController(); + const { signal } = controller; + const fetchModels = async () => { try { const [llmResponse, embeddingResponse] = await Promise.all([ - fetch("/api/llm-models"), - fetch("/api/embedding-models"), + fetch("/api/llm-models", { signal }), + fetch("/api/embedding-models", { signal }), ]); if (llmResponse.ok) { const llmData = await llmResponse.json(); - setLlmModels(llmData.map((m: any) => m.name)); + if (Array.isArray(llmData)) { + setLlmModels(llmData.map((m: any) => m.name).filter(Boolean)); + } } if (embeddingResponse.ok) { const embeddingData = await embeddingResponse.json(); - setEmbeddingModels(embeddingData.map((m: any) => m.name)); + if (Array.isArray(embeddingData)) { + setEmbeddingModels(embeddingData.map((m: any) => m.name).filter(Boolean)); + } } } catch (error) { - console.error("Failed to fetch models:", error); + if ((error as any)?.name !== "AbortError") { + console.error("Failed to fetch models:", error); + } } }; fetchModels(); + return () => controller.abort(); }, []); // handle resize @@ -143,6 +154,12 @@ export default function BlockConfigPanel({ Object.entries(schema).forEach(([key, fieldSchema]: [string, any]) => { const value = processedData[key]; + + // skip json-or-template fields - they stay as strings + if (fieldSchema.format === "json-or-template") { + return; + } + if ( (fieldSchema.type === "array" || fieldSchema.type === "object") && typeof value === "string" @@ -373,6 +390,80 @@ export default function BlockConfigPanel({ ); } + // json-or-template field - use monaco editor with toggle + if (schema.format === "json-or-template") { + const isJsonMode = jsonMode[key] ?? true; // default to JSON mode + const jsonValue = typeof value === "string" ? value : JSON.stringify(value, null, 2); + + return ( + + + setJsonMode((prev) => ({ ...prev, [key]: e.target.checked }))} + id={`jsonmode-${key}`} + sx={{ m: 0 }} + /> + + JSON mode + + + {isJsonMode ? "(JSON syntax)" : "(Jinja2 template)"} + + + + { + // keep as string during editing, will be parsed on save if needed + handleChange(key, newValue || ""); + }} + theme={resolvedColorScheme === "dark" ? "vs-dark" : "light"} + options={{ + minimap: { enabled: false }, + scrollbar: { + vertical: "auto", + horizontal: "auto", + verticalScrollbarSize: 10, + horizontalScrollbarSize: 10, + }, + lineNumbers: "on", + lineNumbersMinChars: 3, + glyphMargin: false, + folding: true, + lineDecorationsWidth: 5, + scrollBeyondLastLine: false, + renderLineHighlight: "none", + overviewRulerLanes: 0, + hideCursorInOverviewRuler: true, + overviewRulerBorder: false, + wordWrap: wordWrap ? "on" : "off", + fontSize: 13, + fontFamily: + "ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace", + tabSize: 2, + padding: { top: 8, bottom: 8 }, + }} + /> + + + ); + } + // object or array field - use monaco editor with JSON if (schema.type === "object" || schema.type === "array") { const jsonValue = typeof value === "string" ? value : JSON.stringify(value, null, 2); diff --git a/lib/blocks/builtin/duplicate_remover.py b/lib/blocks/builtin/duplicate_remover.py index 3c8fb6e..19d7fa9 100644 --- a/lib/blocks/builtin/duplicate_remover.py +++ b/lib/blocks/builtin/duplicate_remover.py @@ -20,7 +20,9 @@ class DuplicateRemover(BaseBlock): _config_descriptions = { "similarity_threshold": "Similarity threshold (0.0-1.0). Above = duplicate.", "comparison_fields": "Fields to compare (leave empty to compare all text fields)", - "embedding_model": "Embedding model to use (leave empty for default). Skips check if no model configured.", + "embedding_model": ( + "Embedding model to use (leave empty for default). Skips check if no model configured." + ), } def __init__( @@ -99,9 +101,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: if trace_id not in self._embeddings_cache: logger.info(f"Building reference embeddings for {len(samples)} samples") - sample_texts = [ - self._extract_text(s, self.comparison_fields) for s in samples - ] + sample_texts = [self._extract_text(s, self.comparison_fields) for s in samples] # filter empty texts sample_texts = [t for t in sample_texts if t] @@ -120,9 +120,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: ) response = await litellm.aembedding(**embedding_params) - self._embeddings_cache[trace_id] = [ - item["embedding"] for item in response.data - ] + self._embeddings_cache[trace_id] = [item["embedding"] for item in response.data] logger.info( f"Initialized {len(self._embeddings_cache[trace_id])} reference embeddings " @@ -145,14 +143,14 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: if is_duplicate: logger.warning( - f"Duplicate detected: similarity={max_similarity:.4f} >= {self.similarity_threshold}" + f"Duplicate detected: similarity={max_similarity:.4f} >= " + f"{self.similarity_threshold}" ) except Exception as e: # no embedding model configured or error - skip check logger.warning( - f"Embedding check failed or no model configured: {e}. " - f"Skipping similarity check." + f"Embedding check failed or no model configured: {e}. Skipping similarity check." ) is_duplicate = False max_similarity = 0.0 diff --git a/lib/blocks/builtin/semantic_infiller.py b/lib/blocks/builtin/semantic_infiller.py index 99e0c14..5583e22 100644 --- a/lib/blocks/builtin/semantic_infiller.py +++ b/lib/blocks/builtin/semantic_infiller.py @@ -49,7 +49,10 @@ def __init__( self.system_prompt = system_prompt def _build_generation_prompt( - self, skeleton: dict[str, Any], hints: dict[str, Any] + self, + fields_to_generate: list[str], + skeleton: dict[str, Any], + hints: dict[str, Any], ) -> str: """ construct LLM prompt with constraints and hints @@ -59,7 +62,7 @@ def _build_generation_prompt( - lock categorical constraints from skeleton - provide numeric hints and exemplars """ - fields_str = ", ".join(f'"{field}"' for field in self.fields_to_generate) + fields_str = ", ".join(f'"{field}"' for field in fields_to_generate) # extract constraints (non-hint fields) constraints = [] @@ -78,11 +81,7 @@ def _build_generation_prompt( hint_lines.append(" - Example records for reference:") for ex in value[: self.MAX_EXEMPLARS_IN_PROMPT]: # only show generated fields from exemplar - ex_fields = { - f: ex.get(f, "") - for f in self.fields_to_generate - if f in ex - } + ex_fields = {f: ex.get(f, "") for f in fields_to_generate if f in ex} hint_lines.append(f" {json.dumps(ex_fields)}") hints_str = "\n".join(hint_lines) if hint_lines else " (none)" @@ -163,14 +162,14 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: except json.JSONDecodeError as e: raise BlockExecutionError( f"fields_to_generate must be valid JSON: {str(e)}", - detail={"template": self.fields_to_generate_template, "rendered": fields_template_rendered}, + detail={ + "template": self.fields_to_generate_template, + "rendered": fields_template_rendered, + }, ) - # temporarily set for prompt building - self.fields_to_generate = fields_to_generate - # build generation prompt - prompt = self._build_generation_prompt(skeleton, hints) + prompt = self._build_generation_prompt(fields_to_generate, skeleton, hints) # prepare system prompt system_content = ( @@ -199,9 +198,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: "tags": ["datagenflow", "semantic-infiller"], } - logger.info( - f"Generating fields {self.fields_to_generate} with model={llm_params.get('model')}" - ) + logger.info(f"Generating fields {fields_to_generate} with model={llm_params.get('model')}") try: response = await litellm.acompletion(**llm_params) diff --git a/lib/blocks/builtin/structure_sampler.py b/lib/blocks/builtin/structure_sampler.py index ac0cada..8a353e4 100644 --- a/lib/blocks/builtin/structure_sampler.py +++ b/lib/blocks/builtin/structure_sampler.py @@ -33,18 +33,16 @@ def __init__( self, target_count: int, categorical_fields: list[str], - numeric_fields: list[str] = [], - dependencies: dict[str, list[str]] = {}, + numeric_fields: list[str] | None = None, + dependencies: dict[str, list[str]] | None = None, seed: int | None = None, ): self.target_count = target_count self.categorical_fields = categorical_fields - self.numeric_fields = numeric_fields - self.dependencies = dependencies + self.numeric_fields = numeric_fields or [] + self.dependencies = dependencies or {} self.seed = seed - - if seed is not None: - random.seed(seed) + self._rng = random.Random(seed) def _validate_samples(self, samples: list[dict[str, Any]]) -> None: """validate samples meet minimum requirements""" @@ -118,9 +116,7 @@ def _compute_numeric_statistics( try: numeric_values.append(float(v)) except (ValueError, TypeError): - logger.warning( - f"Non-numeric value {v} in numeric field {field}, skipping" - ) + logger.warning(f"Non-numeric value {v} in numeric field {field}, skipping") if numeric_values: numeric_stats[field] = { @@ -137,7 +133,7 @@ def _select_exemplars( if max_count is None: max_count = self.MAX_EXEMPLARS num_exemplars = min(max_count, len(samples)) - return random.sample(samples, num_exemplars) + return self._rng.sample(samples, num_exemplars) def _analyze_samples(self, samples: list[dict[str, Any]]) -> dict[str, Any]: """ @@ -202,7 +198,7 @@ def _sample_from_distribution(self, probs: dict[str, float]) -> Any: values = list(probs.keys()) weights = list(probs.values()) - return random.choices(values, weights=weights, k=1)[0] + return self._rng.choices(values, weights=weights, k=1)[0] def _sample_categorical_field( self, field: str, skeleton: dict[str, Any], profile: dict[str, Any] @@ -219,9 +215,7 @@ def _sample_categorical_field( probs = profile["conditional_probs"][key] else: # fallback to marginal distribution - logger.warning( - f"Unseen combination {key}, using marginal distribution for {field}" - ) + logger.warning(f"Unseen combination {key}, using marginal distribution for {field}") probs = profile["categorical_probs"].get(field, {}) else: # independent sampling @@ -229,9 +223,7 @@ def _sample_categorical_field( return self._sample_from_distribution(probs) - def _generate_hints( - self, skeleton: dict[str, Any], profile: dict[str, Any] - ) -> dict[str, Any]: + def _generate_hints(self, skeleton: dict[str, Any], profile: dict[str, Any]) -> dict[str, Any]: """generate hints for numeric fields and matching exemplars""" hints: dict[str, Any] = {} @@ -255,9 +247,7 @@ def _generate_hints( hints["exemplars"] = matching_exemplars return hints - def _generate_skeletons( - self, profile: dict[str, Any], count: int - ) -> list[dict[str, Any]]: + def _generate_skeletons(self, profile: dict[str, Any], count: int) -> list[dict[str, Any]]: """ generate N skeleton records by sampling from learned distributions diff --git a/llm/state-project.md b/llm/state-project.md index f4462d1..ab608d6 100644 --- a/llm/state-project.md +++ b/llm/state-project.md @@ -28,7 +28,7 @@ tools: uv (python), yarn (js) ``` lib/ blocks/ - builtin/ # 12 blocks (generators, multiplier, validators, metrics, seeders, observability) + builtin/ # 14 blocks (generators, multiplier, validators, metrics, seeders, observability, utilities) custom/ # experimental base.py # BaseBlock interface config.py # schema extraction @@ -98,7 +98,7 @@ class BaseBlock: pass ``` -### builtin blocks (12 total) +### builtin blocks (14 total) **seeders:** - StructureSampler: statistical sampler (target_count, categorical_fields, numeric_fields, dependencies, seed) → * (skeletons + hints) @@ -120,6 +120,10 @@ class BaseBlock: - DiversityScore: lexical diversity (field_name) → diversity_score - CoherenceScore: text coherence (field_name) → coherence_score - RougeScore: rouge comparison (generated_field, reference_field, rouge_type) → rouge_score +- RagasMetrics: evaluate QA using RAGAS metrics (question_field, answer_field, etc.) → ragas_scores + +**utilities:** +- FieldMapper: create fields from Jinja2 expressions (mappings) → * (dynamic based on mappings) **observability:** - LangfuseBlock: logging (public_key, secret_key, host, session_id) → langfuse_trace_url @@ -370,7 +374,7 @@ blocks/, integration/, test_api.py, test_workflow.py, test_storage.py, test_cons production-ready full-stack data generation platform ### features -- 12 blocks (seeders, generators, multiplier, validators, metrics, observability) +- 14 blocks (seeders, generators, multiplier, validators, metrics, observability, utilities) - auto-discovery from builtin/custom/user_blocks - reactflow visual editor with drag-drop - jinja2 templates + 4 yaml templates diff --git a/tests/blocks/test_duplicate_remover.py b/tests/blocks/test_duplicate_remover.py index ce7d140..f316582 100644 --- a/tests/blocks/test_duplicate_remover.py +++ b/tests/blocks/test_duplicate_remover.py @@ -118,9 +118,7 @@ class TestDuplicateRemoverWithEmbeddings: @pytest.mark.asyncio @patch("litellm.aembedding") @patch("app.llm_config_manager") - async def test_duplicate_detection_below_threshold( - self, mock_config_manager, mock_embedding - ): + async def test_duplicate_detection_below_threshold(self, mock_config_manager, mock_embedding): # setup mocks mock_config_manager.get_embedding_model = AsyncMock( return_value={"model": "text-embedding-ada-002"} @@ -155,9 +153,7 @@ async def test_duplicate_detection_below_threshold( @pytest.mark.asyncio @patch("litellm.aembedding") @patch("app.llm_config_manager") - async def test_duplicate_detection_above_threshold( - self, mock_config_manager, mock_embedding - ): + async def test_duplicate_detection_above_threshold(self, mock_config_manager, mock_embedding): # setup mocks mock_config_manager.get_embedding_model = AsyncMock( return_value={"model": "text-embedding-ada-002"} @@ -192,9 +188,7 @@ async def test_duplicate_detection_above_threshold( @pytest.mark.asyncio @patch("litellm.aembedding") @patch("app.llm_config_manager") - async def test_embedding_cache_by_trace_id( - self, mock_config_manager, mock_embedding - ): + async def test_embedding_cache_by_trace_id(self, mock_config_manager, mock_embedding): """test that embeddings are cached per trace_id""" mock_config_manager.get_embedding_model = AsyncMock( return_value={"model": "text-embedding-ada-002"} diff --git a/tests/blocks/test_semantic_infiller.py b/tests/blocks/test_semantic_infiller.py index d715933..c6f7c83 100644 --- a/tests/blocks/test_semantic_infiller.py +++ b/tests/blocks/test_semantic_infiller.py @@ -91,7 +91,7 @@ def test_build_prompt_with_exemplars(self): class TestSemanticInfillerJSONParsing: def test_parse_valid_json(self): - block = SemanticInfiller(fields_to_generate=["bio"]) + block = SemanticInfiller(fields_to_generate='["bio"]') content = '{"bio": "Test bio"}' result = block._parse_json_safely(content) @@ -99,7 +99,7 @@ def test_parse_valid_json(self): assert result == {"bio": "Test bio"} def test_parse_json_with_markdown(self): - block = SemanticInfiller(fields_to_generate=["bio"]) + block = SemanticInfiller(fields_to_generate='["bio"]') content = '```json\n{"bio": "Test bio"}\n```' result = block._parse_json_safely(content) @@ -107,7 +107,7 @@ def test_parse_json_with_markdown(self): assert result == {"bio": "Test bio"} def test_parse_json_embedded_in_text(self): - block = SemanticInfiller(fields_to_generate=["bio"]) + block = SemanticInfiller(fields_to_generate='["bio"]') content = 'Here is the result: {"bio": "Test bio"} done' result = block._parse_json_safely(content) @@ -115,7 +115,7 @@ def test_parse_json_embedded_in_text(self): assert result == {"bio": "Test bio"} def test_parse_invalid_json_raises_error(self): - block = SemanticInfiller(fields_to_generate=["bio"]) + block = SemanticInfiller(fields_to_generate='["bio"]') content = "not json at all" @@ -141,9 +141,7 @@ async def test_execute_basic(self, mock_config_manager, mock_completion): return_value={"model": "gpt-4", "messages": []} ) mock_completion.return_value = MagicMock( - choices=[ - MagicMock(message=MagicMock(content='{"bio": "Generated bio"}')) - ], + choices=[MagicMock(message=MagicMock(content='{"bio": "Generated bio"}'))], usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), ) @@ -181,10 +179,7 @@ async def test_execute_with_hints(self, mock_config_manager, mock_completion): ) block = SemanticInfiller(fields_to_generate='["bio", "storage"]') - context = make_context({ - "plan": "Pro", - "_hints": {"storage_range": [10, 100]} - }) + context = make_context({"plan": "Pro", "_hints": {"storage_range": [10, 100]}}) result = await block.execute(context) @@ -211,11 +206,7 @@ async def test_execute_restores_locked_fields(self, mock_config_manager, mock_co ) mock_completion.return_value = MagicMock( choices=[ - MagicMock( - message=MagicMock( - content='{"plan": "Modified", "bio": "Generated bio"}' - ) - ) + MagicMock(message=MagicMock(content='{"plan": "Modified", "bio": "Generated bio"}')) ], usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), ) @@ -269,19 +260,14 @@ async def test_execute_with_template(self, mock_config_manager, mock_completion) return_value={"model": "gpt-4", "messages": []} ) mock_completion.return_value = MagicMock( - choices=[ - MagicMock(message=MagicMock(content='{"bio": "Generated bio"}')) - ], + choices=[MagicMock(message=MagicMock(content='{"bio": "Generated bio"}'))], usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), ) # Use tojson filter to properly serialize the list as JSON block = SemanticInfiller(fields_to_generate="{{ fields_to_generate | tojson }}") # Provide fields_to_generate in the accumulated state (from metadata) - context = make_context({ - "plan": "Free", - "fields_to_generate": ["bio"] - }) + context = make_context({"plan": "Free", "fields_to_generate": ["bio"]}) result = await block.execute(context) diff --git a/tests/blocks/test_structure_sampler.py b/tests/blocks/test_structure_sampler.py index 65eedd1..01540dd 100644 --- a/tests/blocks/test_structure_sampler.py +++ b/tests/blocks/test_structure_sampler.py @@ -117,13 +117,15 @@ async def test_generate_skeletons_basic(self): seed=42, ) - context = make_context({ - "samples": [ - {"plan": "Free"}, - {"plan": "Free"}, - {"plan": "Pro"}, - ] - }) + context = make_context( + { + "samples": [ + {"plan": "Free"}, + {"plan": "Free"}, + {"plan": "Pro"}, + ] + } + ) results = await block.execute(context) @@ -143,13 +145,15 @@ async def test_generate_skeletons_with_dependencies(self): seed=42, ) - context = make_context({ - "samples": [ - {"plan": "Free", "role": "Viewer"}, - {"plan": "Free", "role": "Viewer"}, - {"plan": "Pro", "role": "Editor"}, - ] - }) + context = make_context( + { + "samples": [ + {"plan": "Free", "role": "Viewer"}, + {"plan": "Free", "role": "Viewer"}, + {"plan": "Pro", "role": "Editor"}, + ] + } + ) results = await block.execute(context) @@ -167,13 +171,15 @@ async def test_generate_skeletons_with_hints(self): seed=42, ) - context = make_context({ - "samples": [ - {"plan": "Free", "storage": 1}, - {"plan": "Free", "storage": 2}, - {"plan": "Pro", "storage": 50}, - ] - }) + context = make_context( + { + "samples": [ + {"plan": "Free", "storage": 1}, + {"plan": "Free", "storage": 2}, + {"plan": "Pro", "storage": 50}, + ] + } + ) results = await block.execute(context) @@ -219,9 +225,7 @@ async def test_circular_dependency_detection(self): dependencies={"a": ["b"], "b": ["a"]}, ) - context = make_context({ - "samples": [{"a": "1", "b": "2"}] - }) + context = make_context({"samples": [{"a": "1", "b": "2"}]}) with pytest.raises(ValidationError, match="Circular dependency"): await block.execute(context) diff --git a/tests/integration/test_data_augmentation.py b/tests/integration/test_data_augmentation.py index 2112082..1a7ab13 100644 --- a/tests/integration/test_data_augmentation.py +++ b/tests/integration/test_data_augmentation.py @@ -1,4 +1,5 @@ """integration test for data augmentation pipeline""" + import json from unittest.mock import AsyncMock, MagicMock, patch @@ -30,11 +31,7 @@ async def test_data_augmentation_pipeline(mock_config_manager, mock_completion, # mock LLM response with realistic generated fields mock_completion.return_value = MagicMock( choices=[ - MagicMock( - message=MagicMock( - content='{"bio": "Generated bio text", "storage": 10}' - ) - ) + MagicMock(message=MagicMock(content='{"bio": "Generated bio text", "storage": 10}')) ], usage=MagicMock(prompt_tokens=100, completion_tokens=50, cache_read_input_tokens=0), ) @@ -80,6 +77,7 @@ async def test_data_augmentation_pipeline(mock_config_manager, mock_completion, # save pipeline to database pipeline_id = await storage.save_pipeline("test_augmentation", json.dumps(pipeline_def)) + assert pipeline_id > 0 # create pipeline instance pipeline = Pipeline("test_augmentation", pipeline_def["blocks"]) @@ -142,7 +140,9 @@ async def test_data_augmentation_pipeline(mock_config_manager, mock_completion, assert result["plan"] in ["Free", "Pro"], f"Invalid plan: {result['plan']}" # check role values are valid - assert result["role"] in ["Viewer", "Editor", "Admin"], f"Invalid role: {result['role']}" + assert result["role"] in ["Viewer", "Editor", "Admin"], ( + f"Invalid role: {result['role']}" + ) # check dependencies: Free -> Viewer if result["plan"] == "Free": @@ -166,7 +166,7 @@ async def test_data_augmentation_pipeline(mock_config_manager, mock_completion, # print sample result for inspection sample = results[0].result - print(f"\nSample result:") + print("\nSample result:") print(f" plan: {sample['plan']}") print(f" role: {sample['role']}") print(f" storage: {sample['storage']}") @@ -203,6 +203,7 @@ async def test_structure_sampler_alone(tmp_path): } pipeline_id = await storage.save_pipeline("test_sampler", json.dumps(pipeline_def)) + assert pipeline_id > 0 pipeline = Pipeline("test_sampler", pipeline_def["blocks"]) initial_data = { @@ -265,9 +266,8 @@ async def test_data_augmentation_with_no_embedding_model(tmp_path): ] } - pipeline_id = await storage.save_pipeline( - "test_no_embedding", json.dumps(pipeline_def) - ) + pipeline_id = await storage.save_pipeline("test_no_embedding", json.dumps(pipeline_def)) + assert pipeline_id > 0 pipeline = Pipeline("test_no_embedding", pipeline_def["blocks"]) initial_data = {"samples": [{"plan": "Free"}]} From bbfea64be0cfe1b5cc1c7f6f11ae7322bff32a88 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sun, 11 Jan 2026 00:29:29 +0100 Subject: [PATCH 12/19] wip: fixing skill + add type json-or-template --- .claude/skills/address-pr-review/SKILL.md | 99 +++++++++++++ .../implementing-datagenflow-blocks/SKILL.md | 7 +- lib/blocks/builtin/duplicate_remover.py | 48 ++++++- lib/blocks/builtin/field_mapper.py | 46 ++++++- lib/blocks/builtin/json_validator.py | 51 ++++++- lib/blocks/builtin/ragas_metrics.py | 51 +++++-- lib/blocks/builtin/structure_sampler.py | 130 ++++++++++++++++-- lib/blocks/builtin/structured_generator.py | 40 +++++- lib/blocks/builtin/validator.py | 48 ++++++- lib/blocks/config.py | 6 +- 10 files changed, 472 insertions(+), 54 deletions(-) create mode 100644 .claude/skills/address-pr-review/SKILL.md diff --git a/.claude/skills/address-pr-review/SKILL.md b/.claude/skills/address-pr-review/SKILL.md new file mode 100644 index 0000000..53f0f11 --- /dev/null +++ b/.claude/skills/address-pr-review/SKILL.md @@ -0,0 +1,99 @@ +--- +name: address-pr-review +description: Use when you have PR review comments to address and want to evaluate each comment's validity before deciding to fix, reply, or skip +--- + +# Address PR Review Comments + +## Overview + +Interactive workflow: analyze PR review comment validity, recommend action, let user decide (fix/reply/skip). + +## When to Use + +- PR has review comments needing evaluation before action +- Reviewer feedback might be incorrect or needs discussion +- Comments require varied responses (fix/reply/skip) +- Need to balance code quality with respectful reviewer engagement + +## When NOT to Use + +- All comments are clearly valid and straightforward to fix +- No comments yet or doing pre-review self-review +- Comments only on non-code files without technical analysis needed + +## Workflow Overview + +```dot +digraph pr_review_flow { + "Fetch PR comments" [shape=box]; + "More comments?" [shape=diamond]; + "Show comment + file context" [shape=box]; + "Analyze validity" [shape=box]; + "Recommend action" [shape=box]; + "Ask user: Fix/Reply/Skip/Quit?" [shape=diamond]; + "Make code changes" [shape=box]; + "Draft reply" [shape=box]; + "Track as skipped" [shape=box]; + "Show summary" [shape=box]; + + "Fetch PR comments" -> "More comments?"; + "More comments?" -> "Show comment + file context" [label="yes"]; + "More comments?" -> "Show summary" [label="no"]; + "Show comment + file context" -> "Analyze validity"; + "Analyze validity" -> "Recommend action"; + "Recommend action" -> "Ask user: Fix/Reply/Skip/Quit?"; + "Ask user: Fix/Reply/Skip/Quit?" -> "Make code changes" [label="Fix"]; + "Ask user: Fix/Reply/Skip/Quit?" -> "Draft reply" [label="Reply"]; + "Ask user: Fix/Reply/Skip/Quit?" -> "Track as skipped" [label="Skip"]; + "Ask user: Fix/Reply/Skip/Quit?" -> "Show summary" [label="Quit"]; + "Make code changes" -> "More comments?"; + "Draft reply" -> "More comments?"; + "Track as skipped" -> "More comments?"; +} +``` + +## Quick Reference + +**Critical principle:** Reviewer may be wrong - analyze validity before recommending action. + +| Phase | Actions | +|-------|---------| +| **Fetch** | `gh api repos/{owner}/{repo}/pulls/$PR/comments`
Extract: path, line, body, user.login, id
Exit if no comments | +| **Per Comment** | Show: file:line, author, comment, ±10 lines context
Analyze: Valid/Nitpick/Disagree/Question
Recommend: Fix/Reply/Skip with reasoning | +| **Fix** | Minimal changes per llm/rules-*.md
Offer reply draft: `Fixed: [what]. [why]`
Show: `gh api --method POST repos/{owner}/{repo}/pulls/comments/$ID/replies -f body="..."` | +| **Reply** | Draft based on type: Question/Suggestion/Disagreement
Let user edit
Show gh command (never auto-post) | +| **Summary** | Processed X/N: Fixed Y, Replied Z, Skipped W
List: files modified, reply drafts, next steps | + +## Critical Principles + +| Principle | Violation Pattern | +|-----------|-------------------| +| **Analyze first** | Accepting all feedback as valid without critical analysis | +| **Never auto-post** | Posting replies automatically instead of showing gh command | +| **One at a time** | Batch processing all comments without individual analysis | +| **Show context** | Making changes without displaying ±10 lines around code | +| **Minimal changes** | Large refactors in response to small comments | +| **Follow standards** | Ignoring llm/rules-*.md when fixing | +| **Respectful honesty** | Being defensive/dismissive when reviewer is wrong | +| **User control** | Posting drafts without letting user edit first | + +## Reply Formats + +- Fix: `Fixed: [what]. [why]` +- Update: `Updated: [what]` +- Answer: `[explanation]` +- Acknowledge: `Good catch, [action/reason]` +- Disagree: `[respectful reasoning]` + +## Setup & Usage + +Requires: `gh` CLI authenticated, GitHub remote configured + +```bash +# Start session +"use address-pr-review for PR " + +# Or list PRs first +"use address-pr-review" +``` diff --git a/.claude/skills/implementing-datagenflow-blocks/SKILL.md b/.claude/skills/implementing-datagenflow-blocks/SKILL.md index 9d4e730..f9cbaca 100644 --- a/.claude/skills/implementing-datagenflow-blocks/SKILL.md +++ b/.claude/skills/implementing-datagenflow-blocks/SKILL.md @@ -413,16 +413,17 @@ Blocks that generate multiple items from one input: ```python from lib.blocks.base import BaseMultiplierBlock +from lib.entities.block_execution_context import BlockExecutionContext class StructureSampler(BaseMultiplierBlock): name = "Structure Sampler" - category = "generators" + category = "seeders" async def execute( self, - initial_data: dict[str, Any] + context: BlockExecutionContext ) -> list[dict[str, Any]]: - # return list of records + # read from context and return list of records return [record1, record2, record3] ``` diff --git a/lib/blocks/builtin/duplicate_remover.py b/lib/blocks/builtin/duplicate_remover.py index 19d7fa9..0b935e2 100644 --- a/lib/blocks/builtin/duplicate_remover.py +++ b/lib/blocks/builtin/duplicate_remover.py @@ -1,3 +1,4 @@ +import json import logging from typing import Any @@ -6,6 +7,8 @@ from lib.blocks.base import BaseBlock from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError +from lib.template_renderer import render_template logger = logging.getLogger(__name__) @@ -19,20 +22,27 @@ class DuplicateRemover(BaseBlock): _config_descriptions = { "similarity_threshold": "Similarity threshold (0.0-1.0). Above = duplicate.", - "comparison_fields": "Fields to compare (leave empty to compare all text fields)", + "comparison_fields": ( + 'JSON array or Jinja template. Examples: ["name", "bio"] or ' + '{{ comparison_fields | tojson }} (leave empty to compare all text fields)' + ), "embedding_model": ( "Embedding model to use (leave empty for default). Skips check if no model configured." ), } + _config_formats = { + "comparison_fields": "json-or-template", + } + def __init__( self, similarity_threshold: float = 0.85, - comparison_fields: list[str] | None = None, + comparison_fields: str = "", embedding_model: str | None = None, ): self.similarity_threshold = similarity_threshold - self.comparison_fields = comparison_fields + self.comparison_fields_template = comparison_fields self.embedding_model_name = embedding_model # cache reference embeddings per trace_id (one cache per pipeline execution) @@ -66,6 +76,34 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: current_record.pop("_usage", None) # remove internal fields current_record.pop("_hints", None) + # parse comparison_fields from template + comparison_fields: list[str] | None = None + if self.comparison_fields_template: + fields_rendered = render_template( + self.comparison_fields_template, context.accumulated_state + ) + try: + fields_list = json.loads(fields_rendered) + if not isinstance(fields_list, list): + raise BlockExecutionError( + "comparison_fields must be a JSON array", + detail={"rendered_value": fields_rendered}, + ) + if not all(isinstance(f, str) for f in fields_list): + raise BlockExecutionError( + "All items in comparison_fields must be strings", + detail={"comparison_fields": fields_list}, + ) + comparison_fields = fields_list + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"comparison_fields must be valid JSON: {str(e)}", + detail={ + "template": self.comparison_fields_template, + "rendered": fields_rendered, + }, + ) + # get reference samples from initial state samples = context.get_state("samples", []) @@ -78,7 +116,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: } # extract text for comparison - current_text = self._extract_text(current_record, self.comparison_fields) + current_text = self._extract_text(current_record, comparison_fields) if not current_text: logger.warning("No text found in record for comparison, skipping check") @@ -101,7 +139,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: if trace_id not in self._embeddings_cache: logger.info(f"Building reference embeddings for {len(samples)} samples") - sample_texts = [self._extract_text(s, self.comparison_fields) for s in samples] + sample_texts = [self._extract_text(s, comparison_fields) for s in samples] # filter empty texts sample_texts = [t for t in sample_texts if t] diff --git a/lib/blocks/builtin/field_mapper.py b/lib/blocks/builtin/field_mapper.py index 9f27f38..26f66c9 100644 --- a/lib/blocks/builtin/field_mapper.py +++ b/lib/blocks/builtin/field_mapper.py @@ -4,6 +4,7 @@ from lib.blocks.base import BaseBlock from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError from lib.template_renderer import render_template logger = logging.getLogger(__name__) @@ -20,25 +21,56 @@ class FieldMapper(BaseBlock): _config_descriptions = { "mappings": ( - "Dict mapping new field names to Jinja2 expressions. " - 'Example: {"question": "{{ parsed_json.qa.q }}"}' + 'JSON object or Jinja template mapping field names to Jinja2 expressions. ' + 'Example: {"question": "{{ parsed_json.qa.q }}"} or {{ mappings | tojson }}' ) } - def __init__(self, mappings: dict[str, str] | None = None): + _config_formats = { + "mappings": "json-or-template", + } + + def __init__(self, mappings: str = "{}"): """ Args: - mappings: {"field_name": "{{ jinja2.expression }}"} + mappings: JSON object or template of {"field_name": "{{ jinja2.expression }}"} """ - self.mappings = mappings or {} + self.mappings_template = mappings async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: - if not self.mappings: + # parse mappings from template + if not self.mappings_template or self.mappings_template == "{}": logger.warning("no mappings configured, returning empty result") return {} + mappings_rendered = render_template( + self.mappings_template, context.accumulated_state + ) + try: + mappings = json.loads(mappings_rendered) + if not isinstance(mappings, dict): + raise BlockExecutionError( + "mappings must be a JSON object", + detail={"rendered_value": mappings_rendered}, + ) + # validate all values are strings (Jinja2 templates) + for key, value in mappings.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise BlockExecutionError( + "All mappings keys and values must be strings", + detail={"mappings": mappings}, + ) + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"mappings must be valid JSON: {str(e)}", + detail={ + "template": self.mappings_template, + "rendered": mappings_rendered, + }, + ) + result = {} - for field_name, template in self.mappings.items(): + for field_name, template in mappings.items(): try: rendered = render_template(template, context.accumulated_state) result[field_name] = self._maybe_parse_json(rendered) diff --git a/lib/blocks/builtin/json_validator.py b/lib/blocks/builtin/json_validator.py index 5344ae3..24f974d 100644 --- a/lib/blocks/builtin/json_validator.py +++ b/lib/blocks/builtin/json_validator.py @@ -4,6 +4,8 @@ from lib.blocks.base import BaseBlock from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError +from lib.template_renderer import render_template class JSONValidatorBlock(BaseBlock): @@ -15,10 +17,21 @@ class JSONValidatorBlock(BaseBlock): _field_references = ["field_name"] + _config_descriptions = { + "required_fields": ( + 'JSON array or Jinja template. Examples: ["name", "email"] or ' + '{{ required_fields | tojson }} (leave empty for none)' + ) + } + + _config_formats = { + "required_fields": "json-or-template", + } + def __init__( self, field_name: str = "assistant", - required_fields: list[str] | None = None, + required_fields: str = "", strict: bool = False, ) -> None: """ @@ -26,14 +39,42 @@ def __init__( args: field_name: name of field in accumulated state to validate - required_fields: list of field names that must be present in the JSON + required_fields: JSON array or Jinja template of field names that must be present strict: if true, fail on parse errors; if false, mark as invalid but continue """ self.field_name = field_name - self.required_fields = required_fields or [] + self.required_fields_template = required_fields self.strict = strict async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + # parse required_fields from template (optional) + required_fields: list[str] = [] + if self.required_fields_template: + fields_rendered = render_template( + self.required_fields_template, context.accumulated_state + ) + try: + fields_list = json.loads(fields_rendered) + if not isinstance(fields_list, list): + raise BlockExecutionError( + "required_fields must be a JSON array", + detail={"rendered_value": fields_rendered}, + ) + if not all(isinstance(f, str) for f in fields_list): + raise BlockExecutionError( + "All items in required_fields must be strings", + detail={"required_fields": fields_list}, + ) + required_fields = fields_list + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"required_fields must be valid JSON: {str(e)}", + detail={ + "template": self.required_fields_template, + "rendered": fields_rendered, + }, + ) + field_output = context.get_state(self.field_name, "") # if already parsed (e.g., from StructuredGenerator), use it directly @@ -60,8 +101,8 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: # validate parsed JSON # check if required fields are present - if self.required_fields: - missing_fields = [field for field in self.required_fields if field not in parsed] + if required_fields: + missing_fields = [field for field in required_fields if field not in parsed] if missing_fields: return { "valid": False, diff --git a/lib/blocks/builtin/ragas_metrics.py b/lib/blocks/builtin/ragas_metrics.py index edd064c..c375fdc 100644 --- a/lib/blocks/builtin/ragas_metrics.py +++ b/lib/blocks/builtin/ragas_metrics.py @@ -8,6 +8,8 @@ from lib.blocks.base import BaseBlock from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError +from lib.template_renderer import render_template logger = logging.getLogger(__name__) @@ -36,15 +38,6 @@ class RagasMetrics(BaseBlock): "ground_truth_field", ] - _config_enums = { - "metrics": [ - "answer_relevancy", - "context_precision", - "context_recall", - "faithfulness", - ] - } - _config_descriptions = { "model": "LLM model for evaluation (leave empty for default)", "embedding_model": "Embedding model for answer_relevancy (leave empty for default)", @@ -52,17 +45,24 @@ class RagasMetrics(BaseBlock): "answer_field": "Field containing the answer", "contexts_field": "Field containing contexts (list of strings)", "ground_truth_field": "Field containing expected answer", - "metrics": "RAGAS metrics to calculate", + "metrics": ( + 'JSON array or Jinja template. Available: ["answer_relevancy", "context_precision", ' + '"context_recall", "faithfulness"]. Example: ["faithfulness"] or {{ metrics | tojson }}' + ), "score_threshold": "Minimum score (0.0-1.0) to pass", } + _config_formats = { + "metrics": "json-or-template", + } + def __init__( self, question_field: str = "question", answer_field: str = "answer", contexts_field: str = "contexts", ground_truth_field: str = "ground_truth", - metrics: list[str] | None = None, + metrics: str = '["faithfulness"]', score_threshold: float = 0.5, model: str | None = None, embedding_model: str | None = None, @@ -71,7 +71,7 @@ def __init__( self.answer_field = answer_field self.contexts_field = contexts_field self.ground_truth_field = ground_truth_field - self.metrics = metrics if isinstance(metrics, list) else ["faithfulness"] + self.metrics_template = metrics self.score_threshold = max(0.0, min(1.0, score_threshold)) self.model_name = model self.embedding_model_name = embedding_model @@ -79,6 +79,33 @@ def __init__( async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: from lib.blocks.commons import UsageTracker + # parse metrics from template + metrics_rendered = render_template(self.metrics_template, context.accumulated_state) + try: + metrics_list = json.loads(metrics_rendered) + if not isinstance(metrics_list, list): + raise BlockExecutionError( + "metrics must be a JSON array", + detail={"rendered_value": metrics_rendered}, + ) + if not all(isinstance(m, str) for m in metrics_list): + raise BlockExecutionError( + "All items in metrics must be strings", + detail={"metrics": metrics_list}, + ) + metrics = metrics_list + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"metrics must be valid JSON: {str(e)}", + detail={ + "template": self.metrics_template, + "rendered": metrics_rendered, + }, + ) + + # store parsed metrics for use in other methods + self.metrics = metrics + # 1. collect inputs from configured fields inputs = { "question": context.get_state(self.question_field, ""), diff --git a/lib/blocks/builtin/structure_sampler.py b/lib/blocks/builtin/structure_sampler.py index 8a353e4..06ce389 100644 --- a/lib/blocks/builtin/structure_sampler.py +++ b/lib/blocks/builtin/structure_sampler.py @@ -1,3 +1,4 @@ +import json import logging import random from collections import Counter, defaultdict @@ -5,7 +6,8 @@ from lib.blocks.base import BaseMultiplierBlock from lib.entities.block_execution_context import BlockExecutionContext -from lib.errors import ValidationError +from lib.errors import BlockExecutionError, ValidationError +from lib.template_renderer import render_template logger = logging.getLogger(__name__) @@ -23,24 +25,39 @@ class StructureSampler(BaseMultiplierBlock): _config_descriptions = { "target_count": "Number of skeleton records to generate", - "categorical_fields": "List of categorical field names to sample (e.g., ['plan', 'role'])", - "numeric_fields": "List of numeric field names for hint generation (e.g., ['storage'])", - "dependencies": "Field dependencies as {child: [parent1]} (e.g., {'role': ['plan']})", + "categorical_fields": ( + 'JSON array or Jinja template. Examples: ["plan", "role"] or ' + '{{ categorical_fields | tojson }}' + ), + "numeric_fields": ( + 'JSON array or Jinja template. Examples: ["storage"] or ' + '{{ numeric_fields | tojson }} (leave empty for none)' + ), + "dependencies": ( + 'JSON object or Jinja template. Example: {"role": ["plan"]} or ' + '{{ dependencies | tojson }} (leave empty for none)' + ), "seed": "Random seed for reproducibility (optional)", } + _config_formats = { + "categorical_fields": "json-or-template", + "numeric_fields": "json-or-template", + "dependencies": "json-or-template", + } + def __init__( self, target_count: int, - categorical_fields: list[str], - numeric_fields: list[str] | None = None, - dependencies: dict[str, list[str]] | None = None, + categorical_fields: str, + numeric_fields: str = "", + dependencies: str = "", seed: int | None = None, ): self.target_count = target_count - self.categorical_fields = categorical_fields - self.numeric_fields = numeric_fields or [] - self.dependencies = dependencies or {} + self.categorical_fields_template = categorical_fields + self.numeric_fields_template = numeric_fields + self.dependencies_template = dependencies self.seed = seed self._rng = random.Random(seed) @@ -272,6 +289,99 @@ def _generate_skeletons(self, profile: dict[str, Any], count: int) -> list[dict[ return results async def execute(self, context: BlockExecutionContext) -> list[dict[str, Any]]: # type: ignore[override] + # parse categorical_fields from template + categorical_fields_rendered = render_template( + self.categorical_fields_template, context.accumulated_state + ) + try: + categorical_fields = json.loads(categorical_fields_rendered) + if not isinstance(categorical_fields, list): + raise BlockExecutionError( + "categorical_fields must be a JSON array", + detail={"rendered_value": categorical_fields_rendered}, + ) + if not all(isinstance(f, str) for f in categorical_fields): + raise BlockExecutionError( + "All items in categorical_fields must be strings", + detail={"categorical_fields": categorical_fields}, + ) + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"categorical_fields must be valid JSON: {str(e)}", + detail={ + "template": self.categorical_fields_template, + "rendered": categorical_fields_rendered, + }, + ) + + # parse numeric_fields from template (optional) + numeric_fields: list[str] = [] + if self.numeric_fields_template: + numeric_fields_rendered = render_template( + self.numeric_fields_template, context.accumulated_state + ) + try: + numeric_fields_list = json.loads(numeric_fields_rendered) + if not isinstance(numeric_fields_list, list): + raise BlockExecutionError( + "numeric_fields must be a JSON array", + detail={"rendered_value": numeric_fields_rendered}, + ) + if not all(isinstance(f, str) for f in numeric_fields_list): + raise BlockExecutionError( + "All items in numeric_fields must be strings", + detail={"numeric_fields": numeric_fields_list}, + ) + numeric_fields = numeric_fields_list + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"numeric_fields must be valid JSON: {str(e)}", + detail={ + "template": self.numeric_fields_template, + "rendered": numeric_fields_rendered, + }, + ) + + # parse dependencies from template (optional) + dependencies: dict[str, list[str]] = {} + if self.dependencies_template: + dependencies_rendered = render_template( + self.dependencies_template, context.accumulated_state + ) + try: + dependencies_obj = json.loads(dependencies_rendered) + if not isinstance(dependencies_obj, dict): + raise BlockExecutionError( + "dependencies must be a JSON object", + detail={"rendered_value": dependencies_rendered}, + ) + # validate structure: dict[str, list[str]] + for key, value in dependencies_obj.items(): + if not isinstance(key, str): + raise BlockExecutionError( + "All dependency keys must be strings", + detail={"dependencies": dependencies_obj}, + ) + if not isinstance(value, list) or not all(isinstance(v, str) for v in value): + raise BlockExecutionError( + f"Dependency value for '{key}' must be a list of strings", + detail={"dependencies": dependencies_obj}, + ) + dependencies = dependencies_obj + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"dependencies must be valid JSON: {str(e)}", + detail={ + "template": self.dependencies_template, + "rendered": dependencies_rendered, + }, + ) + + # store parsed values for use in methods + self.categorical_fields = categorical_fields + self.numeric_fields = numeric_fields + self.dependencies = dependencies + # read samples from initial state samples = context.get_state("samples", []) diff --git a/lib/blocks/builtin/structured_generator.py b/lib/blocks/builtin/structured_generator.py index 6af607b..65d3902 100644 --- a/lib/blocks/builtin/structured_generator.py +++ b/lib/blocks/builtin/structured_generator.py @@ -9,6 +9,7 @@ from lib.blocks.base import BaseBlock from lib.entities import pipeline from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError from lib.template_renderer import render_template logger = logging.getLogger(__name__) @@ -27,18 +28,25 @@ class StructuredGenerator(BaseBlock): "Jinja2 template. Reference fields with {{ field_name }} or " "{{ metadata.field_name }}. Example: Generate data for {{ metadata.topic }}" ), - "json_schema": "JSON Schema defining the structure of generated data", + "json_schema": ( + 'JSON object or Jinja template. Example: {"type": "object", "properties": {...}} or ' + '{{ json_schema | tojson }}' + ), + } + + _config_formats = { + "json_schema": "json-or-template", } def __init__( self, - json_schema: dict[str, Any], + json_schema: str, model: str | None = None, temperature: float = 0.7, max_tokens: int = 2048, user_prompt: str = "", ): - self.json_schema = json_schema + self.json_schema_template = json_schema self.model_name = model # model name or None for default self.temperature = temperature self.max_tokens = max_tokens @@ -51,14 +59,14 @@ def _prepare_prompt(self, data: dict[str, Any]) -> str: ) return render_template(prompt_template, data) - def _prepare_response_format(self) -> dict[str, Any]: + def _prepare_response_format(self, json_schema: dict[str, Any]) -> dict[str, Any]: """prepare response format with schema enforcement""" - if self.json_schema: + if json_schema: return { "type": "json_schema", "json_schema": { "name": "response", - "schema": self.json_schema, + "schema": json_schema, "strict": True, }, } @@ -89,9 +97,27 @@ def _parse_json_response(self, content: str) -> dict[str, Any]: async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: from app import llm_config_manager + # parse json_schema from template + schema_rendered = render_template(self.json_schema_template, context.accumulated_state) + try: + json_schema = json.loads(schema_rendered) + if not isinstance(json_schema, dict): + raise BlockExecutionError( + "json_schema must be a JSON object", + detail={"rendered_value": schema_rendered}, + ) + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"json_schema must be valid JSON: {str(e)}", + detail={ + "template": self.json_schema_template, + "rendered": schema_rendered, + }, + ) + user_prompt = self._prepare_prompt(context.accumulated_state) messages = [{"role": "user", "content": user_prompt}] - response_format = self._prepare_response_format() + response_format = self._prepare_response_format(json_schema) llm_config = await llm_config_manager.get_llm_model(self.model_name) llm_params = llm_config_manager.prepare_llm_call( diff --git a/lib/blocks/builtin/validator.py b/lib/blocks/builtin/validator.py index d2e5914..2d588c1 100644 --- a/lib/blocks/builtin/validator.py +++ b/lib/blocks/builtin/validator.py @@ -1,7 +1,10 @@ +import json from typing import Any from lib.blocks.base import BaseBlock from lib.entities.block_execution_context import BlockExecutionContext +from lib.errors import BlockExecutionError +from lib.template_renderer import render_template class ValidatorBlock(BaseBlock): @@ -11,19 +14,56 @@ class ValidatorBlock(BaseBlock): inputs = ["text", "assistant"] outputs = ["text", "valid", "assistant"] - _config_descriptions = {"forbidden_words": "List of words that should not appear in the text"} + _config_descriptions = { + "forbidden_words": ( + 'JSON array or Jinja template. Examples: ["spam", "bad"] or ' + '{{ forbidden_words | tojson }} (leave empty for none)' + ) + } + + _config_formats = { + "forbidden_words": "json-or-template", + } def __init__( self, min_length: int = 0, max_length: int = 100000, - forbidden_words: list[str] | None = None, + forbidden_words: str = "", ) -> None: self.min_length = min_length self.max_length = max_length - self.forbidden_words = forbidden_words or [] + self.forbidden_words_template = forbidden_words async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: + # parse forbidden_words from template (optional) + forbidden_words: list[str] = [] + if self.forbidden_words_template: + words_rendered = render_template( + self.forbidden_words_template, context.accumulated_state + ) + try: + words_list = json.loads(words_rendered) + if not isinstance(words_list, list): + raise BlockExecutionError( + "forbidden_words must be a JSON array", + detail={"rendered_value": words_rendered}, + ) + if not all(isinstance(w, str) for w in words_list): + raise BlockExecutionError( + "All items in forbidden_words must be strings", + detail={"forbidden_words": words_list}, + ) + forbidden_words = words_list + except json.JSONDecodeError as e: + raise BlockExecutionError( + f"forbidden_words must be valid JSON: {str(e)}", + detail={ + "template": self.forbidden_words_template, + "rendered": words_rendered, + }, + ) + # validate either text or assistant field (prefer non-empty) text = context.get_state("text") or context.get_state("assistant", "") @@ -34,7 +74,7 @@ async def execute(self, context: BlockExecutionContext) -> dict[str, Any]: # check forbidden words text_lower = text.lower() valid = True - for word in self.forbidden_words: + for word in forbidden_words: if word.lower() in text_lower: valid = False break diff --git a/lib/blocks/config.py b/lib/blocks/config.py index 40a9c6d..0f6b3d6 100644 --- a/lib/blocks/config.py +++ b/lib/blocks/config.py @@ -11,6 +11,7 @@ def _build_property( enum_values: dict[str, Any], field_refs: list[str], field_descriptions: dict[str, str], + field_formats: dict[str, str], ) -> tuple[dict[str, Any], bool]: """build property definition for a single parameter""" property_def = BlockConfigSchema._get_property_def(param_type) @@ -31,6 +32,8 @@ def _build_property( property_def["isFieldReference"] = True if param_name in field_descriptions: property_def["description"] = field_descriptions[param_name] + if param_name in field_formats: + property_def["format"] = field_formats[param_name] return property_def, is_required @@ -43,6 +46,7 @@ def get_config_schema(block_class: Type[Any]) -> dict[str, Any]: enum_values = getattr(block_class, "_config_enums", {}) field_refs = getattr(block_class, "_field_references", []) field_descriptions = getattr(block_class, "_config_descriptions", {}) + field_formats = getattr(block_class, "_config_formats", {}) properties = {} required = [] @@ -53,7 +57,7 @@ def get_config_schema(block_class: Type[Any]) -> dict[str, Any]: param_type = type_hints.get(param_name, str) property_def, is_required = BlockConfigSchema._build_property( - param_name, param, param_type, enum_values, field_refs, field_descriptions + param_name, param, param_type, enum_values, field_refs, field_descriptions, field_formats ) properties[param_name] = property_def From cf66bbb1892f56912b80a303d45abcdac6e38481 Mon Sep 17 00:00:00 2001 From: nicofretti Date: Sun, 11 Jan 2026 11:30:13 +0100 Subject: [PATCH 13/19] wip: fixing fields in blocks --- app.py | 1 + .../components/pipeline-editor/BlockNode.tsx | 13 +- frontend/src/pages/Generator.tsx | 28 ++-- lib/blocks/builtin/duplicate_remover.py | 8 +- lib/blocks/builtin/field_mapper.py | 8 +- lib/blocks/builtin/json_validator.py | 8 +- lib/blocks/builtin/ragas_metrics.py | 8 +- lib/blocks/builtin/semantic_infiller.py | 12 +- lib/blocks/builtin/structure_sampler.py | 63 +++++++-- lib/blocks/builtin/structured_generator.py | 8 +- lib/blocks/builtin/validator.py | 8 +- lib/template_renderer.py | 15 +- lib/templates/data_augmentation.yaml | 10 +- tests/blocks/test_duplicate_remover.py | 18 +-- tests/blocks/test_structure_sampler.py | 26 ++-- tests/test_template_renderer.py | 132 ++++++++++++++++++ 16 files changed, 305 insertions(+), 61 deletions(-) create mode 100644 tests/test_template_renderer.py diff --git a/app.py b/app.py index bbf4c74..ee8e978 100644 --- a/app.py +++ b/app.py @@ -534,6 +534,7 @@ async def get_pipeline(pipeline_id: int) -> dict[str, Any]: blocks = pipeline.definition.get("blocks", []) pipeline_dict = pipeline.model_dump() pipeline_dict["first_block_is_multiplier"] = is_multiplier_pipeline(blocks) + pipeline_dict["first_block_type"] = blocks[0].get("type") if blocks else None return pipeline_dict diff --git a/frontend/src/components/pipeline-editor/BlockNode.tsx b/frontend/src/components/pipeline-editor/BlockNode.tsx index 6005f2b..82437b4 100644 --- a/frontend/src/components/pipeline-editor/BlockNode.tsx +++ b/frontend/src/components/pipeline-editor/BlockNode.tsx @@ -105,8 +105,19 @@ function getPreviewFields(blockType: string, config: Record): Array if (config[key] !== undefined && config[key] !== null && config[key] !== "") { let displayValue = String(config[key]); + // special handling for fields_to_generate (JSON string) + if (key === "fields_to_generate" && typeof config[key] === "string") { + try { + const parsed = JSON.parse(config[key]); + if (Array.isArray(parsed)) { + displayValue = `[${parsed.length} items]`; + } + } catch { + // if not valid JSON, treat as template string + } + } // special formatting for arrays/objects - if (Array.isArray(config[key])) { + else if (Array.isArray(config[key])) { displayValue = `[${config[key].length} items]`; } else if (typeof config[key] === "object") { displayValue = `{${Object.keys(config[key]).length} keys}`; diff --git a/frontend/src/pages/Generator.tsx b/frontend/src/pages/Generator.tsx index b62c2da..769a5f5 100644 --- a/frontend/src/pages/Generator.tsx +++ b/frontend/src/pages/Generator.tsx @@ -41,6 +41,7 @@ export default function Generator() { const [pipelines, setPipelines] = useState([]); const [selectedPipeline, setSelectedPipeline] = useState(null); const [isMultiplierPipeline, setIsMultiplierPipeline] = useState(false); + const [needsMarkdown, setNeedsMarkdown] = useState(false); const [validationResult, setValidationResult] = useState<{ valid: boolean; errors: string[]; @@ -112,6 +113,7 @@ export default function Generator() { if (!selectedPipeline) { if (mounted) { setIsMultiplierPipeline(false); + setNeedsMarkdown(false); setValidationResult(null); } return; @@ -123,6 +125,8 @@ export default function Generator() { }); const data = await res.json(); const isMultiplier = data.first_block_is_multiplier || false; + const firstBlockType = data.first_block_type || ""; + const needsMd = firstBlockType === "MarkdownMultiplierBlock"; if (!mounted) return; @@ -130,7 +134,7 @@ export default function Generator() { const isMarkdown = file.name.endsWith(".md"); const isJson = file.name.endsWith(".json"); - if ((isMultiplier && isJson) || (!isMultiplier && isMarkdown)) { + if ((needsMd && isJson) || (!needsMd && isMarkdown)) { setFile(null); setValidationResult(null); setValidated(false); @@ -138,10 +142,14 @@ export default function Generator() { } setIsMultiplierPipeline(isMultiplier); + setNeedsMarkdown(needsMd); } catch (err) { if (err instanceof Error && err.name !== "AbortError") { console.error("Failed to load pipeline details:", err); - if (mounted) setIsMultiplierPipeline(false); + if (mounted) { + setIsMultiplierPipeline(false); + setNeedsMarkdown(false); + } } } }; @@ -199,7 +207,7 @@ export default function Generator() { const isJson = droppedFile.type === "application/json" || droppedFile.name.endsWith(".json"); const isMarkdown = droppedFile.name.endsWith(".md"); - const isValidFile = isMultiplierPipeline ? isMarkdown : isJson; + const isValidFile = needsMarkdown ? isMarkdown : isJson; if (isValidFile) { const input = fileInputRef.current; @@ -210,7 +218,7 @@ export default function Generator() { input.dispatchEvent(new Event("change", { bubbles: true })); } } else { - const expected = isMultiplierPipeline ? "Markdown (.md) file" : "JSON (.json) file"; + const expected = needsMarkdown ? "Markdown (.md) file" : "JSON (.json) file"; toast.error(`Please drop a ${expected}`); } } @@ -223,12 +231,12 @@ export default function Generator() { const isMarkdown = selectedFile.name.endsWith(".md"); const isJson = selectedFile.name.endsWith(".json"); - if (isMultiplierPipeline && isJson) { + if (needsMarkdown && isJson) { toast.error("Please upload a Markdown (.md) file for this pipeline."); return; } - if (!isMultiplierPipeline && isMarkdown) { + if (!needsMarkdown && isMarkdown) { toast.error("Please upload a JSON (.json) file for this pipeline."); return; } @@ -650,7 +658,7 @@ export default function Generator() { @@ -663,7 +671,7 @@ export default function Generator() { ? "Select a pipeline first" : file ? file.name - : isMultiplierPipeline + : needsMarkdown ? "Drop Markdown file here or click to browse" : "Drop JSON seed file here or click to browse"} @@ -672,7 +680,7 @@ export default function Generator() { ? "Choose a pipeline from the configuration panel" : file ? `Size: ${(file.size / 1024).toFixed(2)} KB` - : isMultiplierPipeline + : needsMarkdown ? "Markdown (.md) format" : 'Format: {"repetitions": N, "metadata": {...}}'} @@ -725,7 +733,7 @@ export default function Generator() { {/* Verify Seeds Button */} - {file && selectedPipeline && !isMultiplierPipeline && file.name.endsWith(".json") && ( + {file && selectedPipeline && !needsMarkdown && file.name.endsWith(".json") && (