Skip to content

LLM-Powered Operators

docetl.operations.map.MapOperation

Bases: BaseOperation

Source code in docetl/operations/map.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class MapOperation(BaseOperation):
    def syntax_check(self) -> None:
        """
        Checks the configuration of the MapOperation for required keys and valid structure.

        Raises:
            ValueError: If required keys are missing or invalid in the configuration.
            TypeError: If configuration values have incorrect types.
        """
        if "drop_keys" in self.config:
            if not isinstance(self.config["drop_keys"], list):
                raise TypeError(
                    "'drop_keys' in configuration must be a list of strings"
                )
            for key in self.config["drop_keys"]:
                if not isinstance(key, str):
                    raise TypeError("All items in 'drop_keys' must be strings")
        else:
            if "prompt" not in self.config or "output" not in self.config:
                raise ValueError(
                    "If 'drop_keys' is not specified, both 'prompt' and 'output' must be present in the configuration"
                )

        if "prompt" in self.config or "output" in self.config:
            required_keys = ["prompt", "output"]
            for key in required_keys:
                if key not in self.config:
                    raise ValueError(
                        f"Missing required key '{key}' in MapOperation configuration"
                    )

            if "schema" not in self.config["output"]:
                raise ValueError("Missing 'schema' in 'output' configuration")

            if not isinstance(self.config["output"]["schema"], dict):
                raise TypeError(
                    "'schema' in 'output' configuration must be a dictionary"
                )

            if not self.config["output"]["schema"]:
                raise ValueError("'schema' in 'output' configuration cannot be empty")

            # Check if the prompt is a valid Jinja2 template
            try:
                Template(self.config["prompt"])
            except Exception as e:
                raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")

            # Check if the model is specified (optional)
            if "model" in self.config and not isinstance(self.config["model"], str):
                raise TypeError("'model' in configuration must be a string")

            # Check if tools are specified and validate their structure
            if "tools" in self.config:
                if not isinstance(self.config["tools"], list):
                    raise TypeError("'tools' in configuration must be a list")

                for i, tool in enumerate(self.config["tools"]):
                    if not isinstance(tool, dict):
                        raise TypeError(f"Tool {i} in 'tools' must be a dictionary")

                    if "code" not in tool or "function" not in tool:
                        raise ValueError(
                            f"Tool {i} is missing required 'code' or 'function' key"
                        )

                    function = tool.get("function", {})
                    if not isinstance(function, dict):
                        raise TypeError(f"'function' in tool {i} must be a dictionary")

                    required_function_keys = ["name", "description", "parameters"]
                    for key in required_function_keys:
                        if key not in function:
                            raise ValueError(
                                f"Tool {i} is missing required '{key}' in 'function'"
                            )

            self.gleaning_check()

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        """
        Executes the map operation on the provided input data.

        Args:
            input_data (List[Dict]): The input data to process.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

        This method performs the following steps:
        1. If a prompt is specified, it processes each input item using the specified prompt and LLM model
        2. Applies gleaning if configured
        3. Validates the output
        4. If drop_keys is specified, it drops the specified keys from each document
        5. Aggregates results and calculates total cost

        The method uses parallel processing to improve performance.
        """
        # Check if there's no prompt and only drop_keys
        if "prompt" not in self.config and "drop_keys" in self.config:
            # If only drop_keys is specified, simply drop the keys and return
            dropped_results = []
            for item in input_data:
                new_item = {
                    k: v for k, v in item.items() if k not in self.config["drop_keys"]
                }
                dropped_results.append(new_item)
            return dropped_results, 0.0  # Return the modified data with no cost

        def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]:
            prompt_template = Template(self.config["prompt"])
            prompt = prompt_template.render(input=item)

            def validation_fn(response: Dict[str, Any]):
                output = parse_llm_response(
                    response, tools=self.config.get("tools", None)
                )[0]
                for key, value in item.items():
                    if key not in self.config["output"]["schema"]:
                        output[key] = value
                if validate_output(self.config, output, self.console):
                    return output, True
                return output, False

            if "gleaning" in self.config:
                output, cost, success = call_llm_with_validation(
                    [{"role": "user", "content": prompt}],
                    llm_call_fn=lambda messages: call_llm_with_gleaning(
                        self.config.get("model", self.default_model),
                        "map",
                        messages,
                        self.config["output"]["schema"],
                        self.config["gleaning"]["validation_prompt"],
                        self.config["gleaning"]["num_rounds"],
                        self.console,
                    ),
                    validation_fn=validation_fn,
                    val_rule=self.config.get("validate", []),
                    num_retries=self.num_retries_on_validate_failure,
                    console=self.console,
                )
            else:
                output, cost, success = call_llm_with_validation(
                    [{"role": "user", "content": prompt}],
                    llm_call_fn=lambda messages: call_llm(
                        self.config.get("model", self.default_model),
                        "map",
                        messages,
                        self.config["output"]["schema"],
                        tools=self.config.get("tools", None),
                        console=self.console,
                    ),
                    validation_fn=validation_fn,
                    val_rule=self.config.get("validate", []),
                    num_retries=self.num_retries_on_validate_failure,
                    console=self.console,
                )

            if success:
                return output, cost

            return None, cost

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            futures = [executor.submit(_process_map_item, item) for item in input_data]
            results = []
            total_cost = 0
            pbar = RichLoopBar(
                range(len(futures)),
                desc="Processing map items",
                console=self.console,
            )
            for i in pbar:
                result, item_cost = futures[i].result()
                if result is not None:
                    if "drop_keys" in self.config:
                        result = {
                            k: v
                            for k, v in result.items()
                            if k not in self.config["drop_keys"]
                        }
                    results.append(result)
                total_cost += item_cost
                pbar.update(i)

        return results, total_cost

    def validate_output(self, output: Dict) -> bool:
        """
        Validates the output of a single map operation against the specified schema.

        Args:
            output (Dict): The output to validate.

        Returns:
            bool: True if the output is valid, False otherwise.
        """
        schema = self.config["output"]["schema"]
        for key in schema:
            if key not in output:
                self.console.log(f"[red]Error: Missing key '{key}' in output[/red]")
                return False
        return True

execute(input_data)

Executes the map operation on the provided input data.

Parameters:

Name Type Description Default
input_data List[Dict]

The input data to process.

required

Returns:

Type Description
Tuple[List[Dict], float]

Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

This method performs the following steps: 1. If a prompt is specified, it processes each input item using the specified prompt and LLM model 2. Applies gleaning if configured 3. Validates the output 4. If drop_keys is specified, it drops the specified keys from each document 5. Aggregates results and calculates total cost

The method uses parallel processing to improve performance.

Source code in docetl/operations/map.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
    """
    Executes the map operation on the provided input data.

    Args:
        input_data (List[Dict]): The input data to process.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

    This method performs the following steps:
    1. If a prompt is specified, it processes each input item using the specified prompt and LLM model
    2. Applies gleaning if configured
    3. Validates the output
    4. If drop_keys is specified, it drops the specified keys from each document
    5. Aggregates results and calculates total cost

    The method uses parallel processing to improve performance.
    """
    # Check if there's no prompt and only drop_keys
    if "prompt" not in self.config and "drop_keys" in self.config:
        # If only drop_keys is specified, simply drop the keys and return
        dropped_results = []
        for item in input_data:
            new_item = {
                k: v for k, v in item.items() if k not in self.config["drop_keys"]
            }
            dropped_results.append(new_item)
        return dropped_results, 0.0  # Return the modified data with no cost

    def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]:
        prompt_template = Template(self.config["prompt"])
        prompt = prompt_template.render(input=item)

        def validation_fn(response: Dict[str, Any]):
            output = parse_llm_response(
                response, tools=self.config.get("tools", None)
            )[0]
            for key, value in item.items():
                if key not in self.config["output"]["schema"]:
                    output[key] = value
            if validate_output(self.config, output, self.console):
                return output, True
            return output, False

        if "gleaning" in self.config:
            output, cost, success = call_llm_with_validation(
                [{"role": "user", "content": prompt}],
                llm_call_fn=lambda messages: call_llm_with_gleaning(
                    self.config.get("model", self.default_model),
                    "map",
                    messages,
                    self.config["output"]["schema"],
                    self.config["gleaning"]["validation_prompt"],
                    self.config["gleaning"]["num_rounds"],
                    self.console,
                ),
                validation_fn=validation_fn,
                val_rule=self.config.get("validate", []),
                num_retries=self.num_retries_on_validate_failure,
                console=self.console,
            )
        else:
            output, cost, success = call_llm_with_validation(
                [{"role": "user", "content": prompt}],
                llm_call_fn=lambda messages: call_llm(
                    self.config.get("model", self.default_model),
                    "map",
                    messages,
                    self.config["output"]["schema"],
                    tools=self.config.get("tools", None),
                    console=self.console,
                ),
                validation_fn=validation_fn,
                val_rule=self.config.get("validate", []),
                num_retries=self.num_retries_on_validate_failure,
                console=self.console,
            )

        if success:
            return output, cost

        return None, cost

    with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
        futures = [executor.submit(_process_map_item, item) for item in input_data]
        results = []
        total_cost = 0
        pbar = RichLoopBar(
            range(len(futures)),
            desc="Processing map items",
            console=self.console,
        )
        for i in pbar:
            result, item_cost = futures[i].result()
            if result is not None:
                if "drop_keys" in self.config:
                    result = {
                        k: v
                        for k, v in result.items()
                        if k not in self.config["drop_keys"]
                    }
                results.append(result)
            total_cost += item_cost
            pbar.update(i)

    return results, total_cost

syntax_check()

Checks the configuration of the MapOperation for required keys and valid structure.

Raises:

Type Description
ValueError

If required keys are missing or invalid in the configuration.

TypeError

If configuration values have incorrect types.

Source code in docetl/operations/map.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def syntax_check(self) -> None:
    """
    Checks the configuration of the MapOperation for required keys and valid structure.

    Raises:
        ValueError: If required keys are missing or invalid in the configuration.
        TypeError: If configuration values have incorrect types.
    """
    if "drop_keys" in self.config:
        if not isinstance(self.config["drop_keys"], list):
            raise TypeError(
                "'drop_keys' in configuration must be a list of strings"
            )
        for key in self.config["drop_keys"]:
            if not isinstance(key, str):
                raise TypeError("All items in 'drop_keys' must be strings")
    else:
        if "prompt" not in self.config or "output" not in self.config:
            raise ValueError(
                "If 'drop_keys' is not specified, both 'prompt' and 'output' must be present in the configuration"
            )

    if "prompt" in self.config or "output" in self.config:
        required_keys = ["prompt", "output"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in MapOperation configuration"
                )

        if "schema" not in self.config["output"]:
            raise ValueError("Missing 'schema' in 'output' configuration")

        if not isinstance(self.config["output"]["schema"], dict):
            raise TypeError(
                "'schema' in 'output' configuration must be a dictionary"
            )

        if not self.config["output"]["schema"]:
            raise ValueError("'schema' in 'output' configuration cannot be empty")

        # Check if the prompt is a valid Jinja2 template
        try:
            Template(self.config["prompt"])
        except Exception as e:
            raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")

        # Check if the model is specified (optional)
        if "model" in self.config and not isinstance(self.config["model"], str):
            raise TypeError("'model' in configuration must be a string")

        # Check if tools are specified and validate their structure
        if "tools" in self.config:
            if not isinstance(self.config["tools"], list):
                raise TypeError("'tools' in configuration must be a list")

            for i, tool in enumerate(self.config["tools"]):
                if not isinstance(tool, dict):
                    raise TypeError(f"Tool {i} in 'tools' must be a dictionary")

                if "code" not in tool or "function" not in tool:
                    raise ValueError(
                        f"Tool {i} is missing required 'code' or 'function' key"
                    )

                function = tool.get("function", {})
                if not isinstance(function, dict):
                    raise TypeError(f"'function' in tool {i} must be a dictionary")

                required_function_keys = ["name", "description", "parameters"]
                for key in required_function_keys:
                    if key not in function:
                        raise ValueError(
                            f"Tool {i} is missing required '{key}' in 'function'"
                        )

        self.gleaning_check()

validate_output(output)

Validates the output of a single map operation against the specified schema.

Parameters:

Name Type Description Default
output Dict

The output to validate.

required

Returns:

Name Type Description
bool bool

True if the output is valid, False otherwise.

Source code in docetl/operations/map.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def validate_output(self, output: Dict) -> bool:
    """
    Validates the output of a single map operation against the specified schema.

    Args:
        output (Dict): The output to validate.

    Returns:
        bool: True if the output is valid, False otherwise.
    """
    schema = self.config["output"]["schema"]
    for key in schema:
        if key not in output:
            self.console.log(f"[red]Error: Missing key '{key}' in output[/red]")
            return False
    return True

docetl.operations.resolve.ResolveOperation

Bases: BaseOperation

Source code in docetl/operations/resolve.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class ResolveOperation(BaseOperation):
    def syntax_check(self) -> None:
        """
        Checks the configuration of the ResolveOperation for required keys and valid structure.

        This method performs the following checks:
        1. Verifies the presence of required keys: 'comparison_prompt' and 'output'.
        2. Ensures 'output' contains a 'schema' key.
        3. Validates that 'schema' in 'output' is a non-empty dictionary.
        4. Checks if 'comparison_prompt' is a valid Jinja2 template with 'input1' and 'input2' variables.
        5. If 'resolution_prompt' is present, verifies it as a valid Jinja2 template with 'inputs' variable.
        6. Optionally checks if 'model' is a string (if present).
        7. Optionally checks 'blocking_keys' (if present, further checks are performed).

        Raises:
            ValueError: If required keys are missing, if templates are invalid or missing required variables,
                        or if any other configuration aspect is incorrect or inconsistent.
            TypeError: If the types of configuration values are incorrect, such as 'schema' not being a dict
                       or 'model' not being a string.
        """
        required_keys = ["comparison_prompt", "output"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in ResolveOperation configuration"
                )

        if "schema" not in self.config["output"]:
            raise ValueError("Missing 'schema' in 'output' configuration")

        if not isinstance(self.config["output"]["schema"], dict):
            raise TypeError("'schema' in 'output' configuration must be a dictionary")

        if not self.config["output"]["schema"]:
            raise ValueError("'schema' in 'output' configuration cannot be empty")

        # Check if the comparison_prompt is a valid Jinja2 template
        try:
            comparison_template = Template(self.config["comparison_prompt"])
            comparison_vars = comparison_template.environment.parse(
                self.config["comparison_prompt"]
            ).find_all(jinja2.nodes.Name)
            comparison_var_names = {var.name for var in comparison_vars}
            if (
                "input1" not in comparison_var_names
                or "input2" not in comparison_var_names
            ):
                raise ValueError(
                    "'comparison_prompt' must contain both 'input1' and 'input2' variables"
                )

            if "resolution_prompt" in self.config:
                reduction_template = Template(self.config["resolution_prompt"])
                reduction_vars = reduction_template.environment.parse(
                    self.config["resolution_prompt"]
                ).find_all(jinja2.nodes.Name)
                reduction_var_names = {var.name for var in reduction_vars}
                if "inputs" not in reduction_var_names:
                    raise ValueError(
                        "'resolution_prompt' must contain 'inputs' variable"
                    )
        except Exception as e:
            raise ValueError(f"Invalid Jinja2 template: {str(e)}")

        # Check if the model is specified (optional)
        if "model" in self.config and not isinstance(self.config["model"], str):
            raise TypeError("'model' in configuration must be a string")

        # Check blocking_keys (optional)
        if "blocking_keys" in self.config:
            if not isinstance(self.config["blocking_keys"], list):
                raise TypeError("'blocking_keys' must be a list")
            if not all(isinstance(key, str) for key in self.config["blocking_keys"]):
                raise TypeError("All items in 'blocking_keys' must be strings")

        # Check blocking_threshold (optional)
        if "blocking_threshold" in self.config:
            if not isinstance(self.config["blocking_threshold"], (int, float)):
                raise TypeError("'blocking_threshold' must be a number")
            if not 0 <= self.config["blocking_threshold"] <= 1:
                raise ValueError("'blocking_threshold' must be between 0 and 1")

        # Check blocking_conditions (optional)
        if "blocking_conditions" in self.config:
            if not isinstance(self.config["blocking_conditions"], list):
                raise TypeError("'blocking_conditions' must be a list")
            if not all(
                isinstance(cond, str) for cond in self.config["blocking_conditions"]
            ):
                raise TypeError("All items in 'blocking_conditions' must be strings")

        # Check if input schema is provided and valid (optional)
        if "input" in self.config:
            if "schema" not in self.config["input"]:
                raise ValueError("Missing 'schema' in 'input' configuration")
            if not isinstance(self.config["input"]["schema"], dict):
                raise TypeError(
                    "'schema' in 'input' configuration must be a dictionary"
                )

        # Check limit_comparisons (optional)
        if "limit_comparisons" in self.config:
            if not isinstance(self.config["limit_comparisons"], int):
                raise TypeError("'limit_comparisons' must be an integer")
            if self.config["limit_comparisons"] <= 0:
                raise ValueError("'limit_comparisons' must be a positive integer")

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        """
        Executes the resolve operation on the provided dataset.

        Args:
            input_data (List[Dict]): The dataset to resolve.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the resolved results and the total cost of the operation.

        This method performs the following steps:
        1. Initial blocking based on specified conditions and/or embedding similarity
        2. Pairwise comparison of potentially matching entries using LLM
        3. Clustering of matched entries
        4. Resolution of each cluster into a single entry (if applicable)
        5. Result aggregation and validation

        The method also calculates and logs statistics such as comparisons saved by blocking and self-join selectivity.
        """
        if len(input_data) == 0:
            return [], 0

        blocking_keys = self.config.get("blocking_keys", [])
        blocking_threshold = self.config.get("blocking_threshold")
        blocking_conditions = self.config.get("blocking_conditions", [])

        if not blocking_threshold and not blocking_conditions:
            # Prompt the user for confirmation
            if self.status:
                self.status.stop()
            if not Confirm.ask(
                f"[yellow]Warning: No blocking keys or conditions specified. "
                f"This may result in a large number of comparisons. "
                f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
                f"Do you want to continue without blocking?[/yellow]",
            ):
                raise ValueError("Operation cancelled by user.")

            if self.status:
                self.status.start()

        input_schema = self.config.get("input", {}).get("schema", {})
        if not blocking_keys:
            # Set them to all keys in the input data
            blocking_keys = list(input_data[0].keys())
        limit_comparisons = self.config.get("limit_comparisons")
        total_cost = 0

        def is_match(item1: Dict[str, Any], item2: Dict[str, Any]) -> bool:
            return any(
                eval(condition, {"input1": item1, "input2": item2})
                for condition in blocking_conditions
            )

        # Calculate embeddings if blocking_threshold is set
        embeddings = None
        if blocking_threshold is not None:
            embedding_model = self.config.get("embedding_model", self.default_model)

            def get_embeddings_batch(
                items: List[Dict[str, Any]]
            ) -> List[Tuple[List[float], float]]:
                texts = [
                    " ".join(str(item[key]) for key in blocking_keys if key in item)
                    for item in items
                ]
                response = gen_embedding(model=embedding_model, input=texts)
                return [
                    (data["embedding"], completion_cost(response))
                    for data in response["data"]
                ]

            embeddings = []
            costs = []
            with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
                for i in range(
                    0, len(input_data), self.config.get("embedding_batch_size", 1000)
                ):
                    batch = input_data[
                        i : i + self.config.get("embedding_batch_size", 1000)
                    ]
                    batch_results = list(executor.map(get_embeddings_batch, [batch]))

                    for result in batch_results:
                        embeddings.extend([r[0] for r in result])
                        costs.extend([r[1] for r in result])

                total_cost += sum(costs)

        # Initialize clusters
        clusters = [{i} for i in range(len(input_data))]
        cluster_map = {i: i for i in range(len(input_data))}

        def find_cluster(item):
            while item != cluster_map[item]:
                cluster_map[item] = cluster_map[cluster_map[item]]
                item = cluster_map[item]
            return item

        def merge_clusters(item1, item2):
            root1, root2 = find_cluster(item1), find_cluster(item2)
            if root1 != root2:
                if len(clusters[root1]) < len(clusters[root2]):
                    root1, root2 = root2, root1
                clusters[root1] |= clusters[root2]
                cluster_map[root2] = root1
                clusters[root2] = set()

        # Generate all pairs to compare
        # TODO: virtualize this if possible
        all_pairs = [
            (i, j)
            for i in range(len(input_data))
            for j in range(i + 1, len(input_data))
        ]

        # Filter pairs based on blocking conditions
        def meets_blocking_conditions(pair):
            i, j = pair
            return (
                is_match(input_data[i], input_data[j]) if blocking_conditions else False
            )

        blocked_pairs = list(filter(meets_blocking_conditions, all_pairs))

        # Apply limit_comparisons to blocked pairs
        if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
            self.console.log(
                f"Randomly sampling {limit_comparisons} pairs out of {len(blocked_pairs)} blocked pairs."
            )
            blocked_pairs = random.sample(blocked_pairs, limit_comparisons)

        # If there are remaining comparisons, fill with highest cosine similarities
        remaining_comparisons = (
            limit_comparisons - len(blocked_pairs)
            if limit_comparisons is not None
            else float("inf")
        )
        if remaining_comparisons > 0 and blocking_threshold is not None:
            # Compute cosine similarity for all pairs efficiently
            similarity_matrix = cosine_similarity(embeddings)

            cosine_pairs = []
            for i, j in all_pairs:
                if (i, j) not in blocked_pairs and find_cluster(i) != find_cluster(j):
                    similarity = similarity_matrix[i, j]
                    if similarity >= blocking_threshold:
                        cosine_pairs.append((i, j, similarity))

            if remaining_comparisons != float("inf"):
                cosine_pairs.sort(key=lambda x: x[2], reverse=True)
                additional_pairs = [
                    (i, j) for i, j, _ in cosine_pairs[: int(remaining_comparisons)]
                ]
                blocked_pairs.extend(additional_pairs)
            else:
                blocked_pairs.extend((i, j) for i, j, _ in cosine_pairs)

        filtered_pairs = blocked_pairs

        # Calculate and print statistics
        total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
        comparisons_made = len(filtered_pairs)
        comparisons_saved = total_possible_comparisons - comparisons_made
        self.console.log(
            f"[green]Comparisons saved by blocking: {comparisons_saved} "
            f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
        )

        # Compare pairs and update clusters in real-time
        batch_size = self.config.get("compare_batch_size", 100)
        pair_costs = 0

        pbar = RichLoopBar(
            range(0, len(filtered_pairs), batch_size),
            desc=f"Processing batches of {batch_size} LLM comparisons",
            console=self.console,
        )
        for i in pbar:
            batch = filtered_pairs[i : i + batch_size]

            with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
                future_to_pair = {
                    executor.submit(
                        compare_pair,
                        self.config["comparison_prompt"],
                        self.config.get("comparison_model", self.default_model),
                        input_data[pair[0]],
                        input_data[pair[1]],
                        blocking_keys,
                    ): pair
                    for pair in batch
                }

                for future in as_completed(future_to_pair):
                    pair = future_to_pair[future]
                    is_match_result, cost = future.result()
                    pair_costs += cost
                    if is_match_result:
                        merge_clusters(pair[0], pair[1])

                    pbar.update(i)

        total_cost += pair_costs

        # Collect final clusters
        final_clusters = [cluster for cluster in clusters if cluster]

        # Process each cluster
        results = []

        def process_cluster(cluster):
            if len(cluster) > 1:
                cluster_items = [input_data[i] for i in cluster]
                reduction_template = Template(self.config["resolution_prompt"])
                if input_schema:
                    cluster_items = [
                        {k: item[k] for k in input_schema.keys() if k in item}
                        for item in cluster_items
                    ]

                resolution_prompt = reduction_template.render(inputs=cluster_items)
                reduction_response = call_llm(
                    self.config.get("resolution_model", self.default_model),
                    "reduce",
                    [{"role": "user", "content": resolution_prompt}],
                    self.config["output"]["schema"],
                    console=self.console,
                )
                reduction_output = parse_llm_response(reduction_response)[0]
                reduction_cost = completion_cost(reduction_response)

                if validate_output(self.config, reduction_output, self.console):
                    return (
                        [
                            {
                                **item,
                                **{
                                    k: reduction_output[k]
                                    for k in self.config["output"]["schema"]
                                },
                            }
                            for item in [input_data[i] for i in cluster]
                        ],
                        reduction_cost,
                    )
                return [], reduction_cost
            else:
                return [input_data[list(cluster)[0]]], 0

        # Calculate the number of records before and clusters after
        num_records_before = len(input_data)
        num_clusters_after = len(final_clusters)
        self.console.log(f"Number of keys before resolution: {num_records_before}")
        self.console.log(
            f"Number of distinct keys after resolution: {num_clusters_after}"
        )

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            futures = [
                executor.submit(process_cluster, cluster) for cluster in final_clusters
            ]
            for future in rich_as_completed(
                futures,
                total=len(futures),
                desc="Determining resolved key for each group of equivalent keys",
                console=self.console,
            ):
                cluster_results, cluster_cost = future.result()
                results.extend(cluster_results)
                total_cost += cluster_cost

        total_pairs = len(input_data) * (len(input_data) - 1) // 2
        true_match_count = sum(
            len(cluster) * (len(cluster) - 1) // 2
            for cluster in final_clusters
            if len(cluster) > 1
        )
        true_match_selectivity = (
            true_match_count / total_pairs if total_pairs > 0 else 0
        )
        self.console.log(f"Self-join selectivity: {true_match_selectivity:.4f}")

        return results, total_cost

execute(input_data)

Executes the resolve operation on the provided dataset.

Parameters:

Name Type Description Default
input_data List[Dict]

The dataset to resolve.

required

Returns:

Type Description
Tuple[List[Dict], float]

Tuple[List[Dict], float]: A tuple containing the resolved results and the total cost of the operation.

This method performs the following steps: 1. Initial blocking based on specified conditions and/or embedding similarity 2. Pairwise comparison of potentially matching entries using LLM 3. Clustering of matched entries 4. Resolution of each cluster into a single entry (if applicable) 5. Result aggregation and validation

The method also calculates and logs statistics such as comparisons saved by blocking and self-join selectivity.

Source code in docetl/operations/resolve.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
    """
    Executes the resolve operation on the provided dataset.

    Args:
        input_data (List[Dict]): The dataset to resolve.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the resolved results and the total cost of the operation.

    This method performs the following steps:
    1. Initial blocking based on specified conditions and/or embedding similarity
    2. Pairwise comparison of potentially matching entries using LLM
    3. Clustering of matched entries
    4. Resolution of each cluster into a single entry (if applicable)
    5. Result aggregation and validation

    The method also calculates and logs statistics such as comparisons saved by blocking and self-join selectivity.
    """
    if len(input_data) == 0:
        return [], 0

    blocking_keys = self.config.get("blocking_keys", [])
    blocking_threshold = self.config.get("blocking_threshold")
    blocking_conditions = self.config.get("blocking_conditions", [])

    if not blocking_threshold and not blocking_conditions:
        # Prompt the user for confirmation
        if self.status:
            self.status.stop()
        if not Confirm.ask(
            f"[yellow]Warning: No blocking keys or conditions specified. "
            f"This may result in a large number of comparisons. "
            f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
            f"Do you want to continue without blocking?[/yellow]",
        ):
            raise ValueError("Operation cancelled by user.")

        if self.status:
            self.status.start()

    input_schema = self.config.get("input", {}).get("schema", {})
    if not blocking_keys:
        # Set them to all keys in the input data
        blocking_keys = list(input_data[0].keys())
    limit_comparisons = self.config.get("limit_comparisons")
    total_cost = 0

    def is_match(item1: Dict[str, Any], item2: Dict[str, Any]) -> bool:
        return any(
            eval(condition, {"input1": item1, "input2": item2})
            for condition in blocking_conditions
        )

    # Calculate embeddings if blocking_threshold is set
    embeddings = None
    if blocking_threshold is not None:
        embedding_model = self.config.get("embedding_model", self.default_model)

        def get_embeddings_batch(
            items: List[Dict[str, Any]]
        ) -> List[Tuple[List[float], float]]:
            texts = [
                " ".join(str(item[key]) for key in blocking_keys if key in item)
                for item in items
            ]
            response = gen_embedding(model=embedding_model, input=texts)
            return [
                (data["embedding"], completion_cost(response))
                for data in response["data"]
            ]

        embeddings = []
        costs = []
        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            for i in range(
                0, len(input_data), self.config.get("embedding_batch_size", 1000)
            ):
                batch = input_data[
                    i : i + self.config.get("embedding_batch_size", 1000)
                ]
                batch_results = list(executor.map(get_embeddings_batch, [batch]))

                for result in batch_results:
                    embeddings.extend([r[0] for r in result])
                    costs.extend([r[1] for r in result])

            total_cost += sum(costs)

    # Initialize clusters
    clusters = [{i} for i in range(len(input_data))]
    cluster_map = {i: i for i in range(len(input_data))}

    def find_cluster(item):
        while item != cluster_map[item]:
            cluster_map[item] = cluster_map[cluster_map[item]]
            item = cluster_map[item]
        return item

    def merge_clusters(item1, item2):
        root1, root2 = find_cluster(item1), find_cluster(item2)
        if root1 != root2:
            if len(clusters[root1]) < len(clusters[root2]):
                root1, root2 = root2, root1
            clusters[root1] |= clusters[root2]
            cluster_map[root2] = root1
            clusters[root2] = set()

    # Generate all pairs to compare
    # TODO: virtualize this if possible
    all_pairs = [
        (i, j)
        for i in range(len(input_data))
        for j in range(i + 1, len(input_data))
    ]

    # Filter pairs based on blocking conditions
    def meets_blocking_conditions(pair):
        i, j = pair
        return (
            is_match(input_data[i], input_data[j]) if blocking_conditions else False
        )

    blocked_pairs = list(filter(meets_blocking_conditions, all_pairs))

    # Apply limit_comparisons to blocked pairs
    if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
        self.console.log(
            f"Randomly sampling {limit_comparisons} pairs out of {len(blocked_pairs)} blocked pairs."
        )
        blocked_pairs = random.sample(blocked_pairs, limit_comparisons)

    # If there are remaining comparisons, fill with highest cosine similarities
    remaining_comparisons = (
        limit_comparisons - len(blocked_pairs)
        if limit_comparisons is not None
        else float("inf")
    )
    if remaining_comparisons > 0 and blocking_threshold is not None:
        # Compute cosine similarity for all pairs efficiently
        similarity_matrix = cosine_similarity(embeddings)

        cosine_pairs = []
        for i, j in all_pairs:
            if (i, j) not in blocked_pairs and find_cluster(i) != find_cluster(j):
                similarity = similarity_matrix[i, j]
                if similarity >= blocking_threshold:
                    cosine_pairs.append((i, j, similarity))

        if remaining_comparisons != float("inf"):
            cosine_pairs.sort(key=lambda x: x[2], reverse=True)
            additional_pairs = [
                (i, j) for i, j, _ in cosine_pairs[: int(remaining_comparisons)]
            ]
            blocked_pairs.extend(additional_pairs)
        else:
            blocked_pairs.extend((i, j) for i, j, _ in cosine_pairs)

    filtered_pairs = blocked_pairs

    # Calculate and print statistics
    total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
    comparisons_made = len(filtered_pairs)
    comparisons_saved = total_possible_comparisons - comparisons_made
    self.console.log(
        f"[green]Comparisons saved by blocking: {comparisons_saved} "
        f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
    )

    # Compare pairs and update clusters in real-time
    batch_size = self.config.get("compare_batch_size", 100)
    pair_costs = 0

    pbar = RichLoopBar(
        range(0, len(filtered_pairs), batch_size),
        desc=f"Processing batches of {batch_size} LLM comparisons",
        console=self.console,
    )
    for i in pbar:
        batch = filtered_pairs[i : i + batch_size]

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            future_to_pair = {
                executor.submit(
                    compare_pair,
                    self.config["comparison_prompt"],
                    self.config.get("comparison_model", self.default_model),
                    input_data[pair[0]],
                    input_data[pair[1]],
                    blocking_keys,
                ): pair
                for pair in batch
            }

            for future in as_completed(future_to_pair):
                pair = future_to_pair[future]
                is_match_result, cost = future.result()
                pair_costs += cost
                if is_match_result:
                    merge_clusters(pair[0], pair[1])

                pbar.update(i)

    total_cost += pair_costs

    # Collect final clusters
    final_clusters = [cluster for cluster in clusters if cluster]

    # Process each cluster
    results = []

    def process_cluster(cluster):
        if len(cluster) > 1:
            cluster_items = [input_data[i] for i in cluster]
            reduction_template = Template(self.config["resolution_prompt"])
            if input_schema:
                cluster_items = [
                    {k: item[k] for k in input_schema.keys() if k in item}
                    for item in cluster_items
                ]

            resolution_prompt = reduction_template.render(inputs=cluster_items)
            reduction_response = call_llm(
                self.config.get("resolution_model", self.default_model),
                "reduce",
                [{"role": "user", "content": resolution_prompt}],
                self.config["output"]["schema"],
                console=self.console,
            )
            reduction_output = parse_llm_response(reduction_response)[0]
            reduction_cost = completion_cost(reduction_response)

            if validate_output(self.config, reduction_output, self.console):
                return (
                    [
                        {
                            **item,
                            **{
                                k: reduction_output[k]
                                for k in self.config["output"]["schema"]
                            },
                        }
                        for item in [input_data[i] for i in cluster]
                    ],
                    reduction_cost,
                )
            return [], reduction_cost
        else:
            return [input_data[list(cluster)[0]]], 0

    # Calculate the number of records before and clusters after
    num_records_before = len(input_data)
    num_clusters_after = len(final_clusters)
    self.console.log(f"Number of keys before resolution: {num_records_before}")
    self.console.log(
        f"Number of distinct keys after resolution: {num_clusters_after}"
    )

    with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
        futures = [
            executor.submit(process_cluster, cluster) for cluster in final_clusters
        ]
        for future in rich_as_completed(
            futures,
            total=len(futures),
            desc="Determining resolved key for each group of equivalent keys",
            console=self.console,
        ):
            cluster_results, cluster_cost = future.result()
            results.extend(cluster_results)
            total_cost += cluster_cost

    total_pairs = len(input_data) * (len(input_data) - 1) // 2
    true_match_count = sum(
        len(cluster) * (len(cluster) - 1) // 2
        for cluster in final_clusters
        if len(cluster) > 1
    )
    true_match_selectivity = (
        true_match_count / total_pairs if total_pairs > 0 else 0
    )
    self.console.log(f"Self-join selectivity: {true_match_selectivity:.4f}")

    return results, total_cost

syntax_check()

Checks the configuration of the ResolveOperation for required keys and valid structure.

This method performs the following checks: 1. Verifies the presence of required keys: 'comparison_prompt' and 'output'. 2. Ensures 'output' contains a 'schema' key. 3. Validates that 'schema' in 'output' is a non-empty dictionary. 4. Checks if 'comparison_prompt' is a valid Jinja2 template with 'input1' and 'input2' variables. 5. If 'resolution_prompt' is present, verifies it as a valid Jinja2 template with 'inputs' variable. 6. Optionally checks if 'model' is a string (if present). 7. Optionally checks 'blocking_keys' (if present, further checks are performed).

Raises:

Type Description
ValueError

If required keys are missing, if templates are invalid or missing required variables, or if any other configuration aspect is incorrect or inconsistent.

TypeError

If the types of configuration values are incorrect, such as 'schema' not being a dict or 'model' not being a string.

Source code in docetl/operations/resolve.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def syntax_check(self) -> None:
    """
    Checks the configuration of the ResolveOperation for required keys and valid structure.

    This method performs the following checks:
    1. Verifies the presence of required keys: 'comparison_prompt' and 'output'.
    2. Ensures 'output' contains a 'schema' key.
    3. Validates that 'schema' in 'output' is a non-empty dictionary.
    4. Checks if 'comparison_prompt' is a valid Jinja2 template with 'input1' and 'input2' variables.
    5. If 'resolution_prompt' is present, verifies it as a valid Jinja2 template with 'inputs' variable.
    6. Optionally checks if 'model' is a string (if present).
    7. Optionally checks 'blocking_keys' (if present, further checks are performed).

    Raises:
        ValueError: If required keys are missing, if templates are invalid or missing required variables,
                    or if any other configuration aspect is incorrect or inconsistent.
        TypeError: If the types of configuration values are incorrect, such as 'schema' not being a dict
                   or 'model' not being a string.
    """
    required_keys = ["comparison_prompt", "output"]
    for key in required_keys:
        if key not in self.config:
            raise ValueError(
                f"Missing required key '{key}' in ResolveOperation configuration"
            )

    if "schema" not in self.config["output"]:
        raise ValueError("Missing 'schema' in 'output' configuration")

    if not isinstance(self.config["output"]["schema"], dict):
        raise TypeError("'schema' in 'output' configuration must be a dictionary")

    if not self.config["output"]["schema"]:
        raise ValueError("'schema' in 'output' configuration cannot be empty")

    # Check if the comparison_prompt is a valid Jinja2 template
    try:
        comparison_template = Template(self.config["comparison_prompt"])
        comparison_vars = comparison_template.environment.parse(
            self.config["comparison_prompt"]
        ).find_all(jinja2.nodes.Name)
        comparison_var_names = {var.name for var in comparison_vars}
        if (
            "input1" not in comparison_var_names
            or "input2" not in comparison_var_names
        ):
            raise ValueError(
                "'comparison_prompt' must contain both 'input1' and 'input2' variables"
            )

        if "resolution_prompt" in self.config:
            reduction_template = Template(self.config["resolution_prompt"])
            reduction_vars = reduction_template.environment.parse(
                self.config["resolution_prompt"]
            ).find_all(jinja2.nodes.Name)
            reduction_var_names = {var.name for var in reduction_vars}
            if "inputs" not in reduction_var_names:
                raise ValueError(
                    "'resolution_prompt' must contain 'inputs' variable"
                )
    except Exception as e:
        raise ValueError(f"Invalid Jinja2 template: {str(e)}")

    # Check if the model is specified (optional)
    if "model" in self.config and not isinstance(self.config["model"], str):
        raise TypeError("'model' in configuration must be a string")

    # Check blocking_keys (optional)
    if "blocking_keys" in self.config:
        if not isinstance(self.config["blocking_keys"], list):
            raise TypeError("'blocking_keys' must be a list")
        if not all(isinstance(key, str) for key in self.config["blocking_keys"]):
            raise TypeError("All items in 'blocking_keys' must be strings")

    # Check blocking_threshold (optional)
    if "blocking_threshold" in self.config:
        if not isinstance(self.config["blocking_threshold"], (int, float)):
            raise TypeError("'blocking_threshold' must be a number")
        if not 0 <= self.config["blocking_threshold"] <= 1:
            raise ValueError("'blocking_threshold' must be between 0 and 1")

    # Check blocking_conditions (optional)
    if "blocking_conditions" in self.config:
        if not isinstance(self.config["blocking_conditions"], list):
            raise TypeError("'blocking_conditions' must be a list")
        if not all(
            isinstance(cond, str) for cond in self.config["blocking_conditions"]
        ):
            raise TypeError("All items in 'blocking_conditions' must be strings")

    # Check if input schema is provided and valid (optional)
    if "input" in self.config:
        if "schema" not in self.config["input"]:
            raise ValueError("Missing 'schema' in 'input' configuration")
        if not isinstance(self.config["input"]["schema"], dict):
            raise TypeError(
                "'schema' in 'input' configuration must be a dictionary"
            )

    # Check limit_comparisons (optional)
    if "limit_comparisons" in self.config:
        if not isinstance(self.config["limit_comparisons"], int):
            raise TypeError("'limit_comparisons' must be an integer")
        if self.config["limit_comparisons"] <= 0:
            raise ValueError("'limit_comparisons' must be a positive integer")

docetl.operations.reduce.ReduceOperation

Bases: BaseOperation

A class that implements a reduce operation on input data using language models.

This class extends BaseOperation to provide functionality for reducing grouped data using various strategies including batch reduce, incremental reduce, and parallel fold and merge.

Source code in docetl/operations/reduce.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
class ReduceOperation(BaseOperation):
    """
    A class that implements a reduce operation on input data using language models.

    This class extends BaseOperation to provide functionality for reducing grouped data
    using various strategies including batch reduce, incremental reduce, and parallel fold and merge.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize the ReduceOperation.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.min_samples = 5
        self.max_samples = 1000
        self.fold_times = deque(maxlen=self.max_samples)
        self.merge_times = deque(maxlen=self.max_samples)
        self.lock = Lock()
        self.config["reduce_key"] = (
            [self.config["reduce_key"]]
            if isinstance(self.config["reduce_key"], str)
            else self.config["reduce_key"]
        )
        self.intermediates = {}

    def syntax_check(self) -> None:
        """
        Perform comprehensive syntax checks on the configuration of the ReduceOperation.

        This method validates the presence and correctness of all required configuration keys, Jinja2 templates, and ensures the correct
        structure and types of the entire configuration.

        The method performs the following checks:
        1. Verifies the presence of all required keys in the configuration.
        2. Validates the structure and content of the 'output' configuration, including its 'schema'.
        3. Checks if the main 'prompt' is a valid Jinja2 template and contains the required 'inputs' variable.
        4. If 'merge_prompt' is specified, ensures that 'fold_prompt' is also present.
        5. If 'fold_prompt' is present, verifies the existence of 'fold_batch_size'.
        6. Validates the 'fold_prompt' as a Jinja2 template with required variables 'inputs' and 'output'.
        7. If present, checks 'merge_prompt' as a valid Jinja2 template with required 'outputs' variable.
        8. Verifies types of various configuration inputs (e.g., 'fold_batch_size' as int).
        9. Checks for the presence and validity of optional configurations like 'model'.

        Raises:
            ValueError: If any required configuration is missing, if templates are invalid or missing required
                        variables, or if any other configuration aspect is incorrect or inconsistent.
            TypeError: If any configuration value has an incorrect type, such as 'schema' not being a dict
                       or 'fold_batch_size' not being an integer.
        """
        required_keys = ["reduce_key", "prompt", "output"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in ReduceOperation configuration"
                )

        if "schema" not in self.config["output"]:
            raise ValueError("Missing 'schema' in 'output' configuration")

        if not isinstance(self.config["output"]["schema"], dict):
            raise TypeError("'schema' in 'output' configuration must be a dictionary")

        if not self.config["output"]["schema"]:
            raise ValueError("'schema' in 'output' configuration cannot be empty")

        # Check if the prompt is a valid Jinja2 template
        try:
            template = Template(self.config["prompt"])
            template_vars = template.environment.parse(self.config["prompt"]).find_all(
                jinja2.nodes.Name
            )
            template_var_names = {var.name for var in template_vars}
            if "inputs" not in template_var_names:
                raise ValueError("Template must include the 'inputs' variable")
        except Exception as e:
            raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")

        # Check if fold_prompt is a valid Jinja2 template (now required if merge exists)
        if "merge_prompt" in self.config:
            if "fold_prompt" not in self.config:
                raise ValueError(
                    "'fold_prompt' is required when 'merge_prompt' is specified"
                )

        if "fold_prompt" in self.config:
            if "fold_batch_size" not in self.config:
                raise ValueError(
                    "'fold_batch_size' is required when 'fold_prompt' is specified"
                )

            try:
                fold_template = Template(self.config["fold_prompt"])
                fold_template_vars = fold_template.environment.parse(
                    self.config["fold_prompt"]
                ).find_all(jinja2.nodes.Name)
                fold_template_var_names = {var.name for var in fold_template_vars}
                required_vars = {"inputs", "output"}
                if not required_vars.issubset(fold_template_var_names):
                    raise ValueError(
                        f"Fold template must include variables: {required_vars}. Current template includes: {fold_template_var_names}"
                    )
            except Exception as e:
                raise ValueError(f"Invalid Jinja2 template in 'fold_prompt': {str(e)}")

        # Check merge_prompt and merge_batch_size
        if "merge_prompt" in self.config:
            if "merge_batch_size" not in self.config:
                raise ValueError(
                    "'merge_batch_size' is required when 'merge_prompt' is specified"
                )

            try:
                merge_template = Template(self.config["merge_prompt"])
                merge_template_vars = merge_template.environment.parse(
                    self.config["merge_prompt"]
                ).find_all(jinja2.nodes.Name)
                merge_template_var_names = {var.name for var in merge_template_vars}
                if "outputs" not in merge_template_var_names:
                    raise ValueError(
                        "Merge template must include the 'outputs' variable"
                    )
            except Exception as e:
                raise ValueError(f"Invalid Jinja2 template in 'merge_prompt': {str(e)}")

        # Check if the model is specified (optional)
        if "model" in self.config and not isinstance(self.config["model"], str):
            raise TypeError("'model' in configuration must be a string")

        # Check if reduce_key is a string or a list of strings
        if not isinstance(self.config["reduce_key"], (str, list)):
            raise TypeError("'reduce_key' must be a string or a list of strings")
        if isinstance(self.config["reduce_key"], list):
            if not all(isinstance(key, str) for key in self.config["reduce_key"]):
                raise TypeError("All elements in 'reduce_key' list must be strings")

        # Check if input schema is provided and valid (optional)
        if "input" in self.config:
            if "schema" not in self.config["input"]:
                raise ValueError("Missing 'schema' in 'input' configuration")
            if not isinstance(self.config["input"]["schema"], dict):
                raise TypeError(
                    "'schema' in 'input' configuration must be a dictionary"
                )

        # Check if fold_batch_size and merge_batch_size are positive integers
        for key in ["fold_batch_size", "merge_batch_size"]:
            if key in self.config:
                if not isinstance(self.config[key], int) or self.config[key] <= 0:
                    raise ValueError(f"'{key}' must be a positive integer")

        if "value_sampling" in self.config:
            sampling = self.config["value_sampling"]
            if not isinstance(sampling, dict):
                raise TypeError("'value_sampling' must be a dictionary")

            if "enabled" not in sampling:
                raise ValueError(
                    "'enabled' is required in 'value_sampling' configuration"
                )
            if not isinstance(sampling["enabled"], bool):
                raise TypeError("'enabled' in 'value_sampling' must be a boolean")

            if sampling["enabled"]:
                if "sample_size" not in sampling:
                    raise ValueError(
                        "'sample_size' is required when value_sampling is enabled"
                    )
                if (
                    not isinstance(sampling["sample_size"], int)
                    or sampling["sample_size"] <= 0
                ):
                    raise ValueError("'sample_size' must be a positive integer")

                if "method" not in sampling:
                    raise ValueError(
                        "'method' is required when value_sampling is enabled"
                    )
                if sampling["method"] not in [
                    "random",
                    "first_n",
                    "cluster",
                    "sem_sim",
                ]:
                    raise ValueError(
                        "Invalid 'method'. Must be 'random', 'first_n', or 'embedding'"
                    )

                if sampling["method"] == "embedding":
                    if "embedding_model" not in sampling:
                        raise ValueError(
                            "'embedding_model' is required when using embedding-based sampling"
                        )
                    if "embedding_keys" not in sampling:
                        raise ValueError(
                            "'embedding_keys' is required when using embedding-based sampling"
                        )

        self.gleaning_check()

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        """
        Execute the reduce operation on the provided input data.

        This method sorts and groups the input data by the reduce key(s), then processes each group
        using either parallel fold and merge, incremental reduce, or batch reduce strategies.

        Args:
            input_data (List[Dict]): The input data to process.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.
        """
        reduce_keys = self.config["reduce_key"]
        if isinstance(reduce_keys, str):
            reduce_keys = [reduce_keys]
        input_schema = self.config.get("input", {}).get("schema", {})

        # Check if we need to group everything into one group
        if reduce_keys == ["_all"] or reduce_keys == "_all":
            grouped_data = [("_all", input_data)]
        else:
            # Group the input data by the reduce key(s) while maintaining original order
            def get_group_key(item):
                return tuple(item[key] for key in reduce_keys)

            grouped_data = {}
            for item in input_data:
                key = get_group_key(item)
                if key not in grouped_data:
                    grouped_data[key] = []
                grouped_data[key].append(item)

            # Convert the grouped data to a list of tuples
            grouped_data = list(grouped_data.items())

        def process_group(
            key: Tuple, group_elems: List[Dict]
        ) -> Tuple[Optional[Dict], float]:
            if input_schema:
                group_list = [
                    {k: item[k] for k in input_schema.keys() if k in item}
                    for item in group_elems
                ]
            else:
                group_list = group_elems

            total_cost = 0.0

            # Apply value sampling if enabled
            value_sampling = self.config.get("value_sampling", {})
            if value_sampling.get("enabled", False):
                sample_size = min(value_sampling["sample_size"], len(group_list))
                method = value_sampling["method"]

                if method == "random":
                    group_sample = random.sample(group_list, sample_size)
                    group_sample.sort(key=lambda x: group_list.index(x))
                elif method == "first_n":
                    group_sample = group_list[:sample_size]
                elif method == "cluster":
                    group_sample, embedding_cost = self._cluster_based_sampling(
                        group_list, value_sampling, sample_size
                    )
                    group_sample.sort(key=lambda x: group_list.index(x))
                    total_cost += embedding_cost
                elif method == "sem_sim":
                    group_sample, embedding_cost = self._semantic_similarity_sampling(
                        key, group_list, value_sampling, sample_size
                    )
                    group_sample.sort(key=lambda x: group_list.index(x))
                    total_cost += embedding_cost

                group_list = group_sample

            # Only execute merge-based plans if associative = True
            if "merge_prompt" in self.config and self.config.get("associative", True):
                result, cost = self._parallel_fold_and_merge(key, group_list)
            elif "fold_prompt" in self.config:
                result, cost = self._incremental_reduce(key, group_list)
            else:
                result, cost = self._batch_reduce(key, group_list)

            total_cost += cost

            # Apply pass-through at the group level
            if (
                result is not None
                and self.config.get("pass_through", False)
                and group_elems
            ):
                for k, v in group_elems[0].items():
                    if k not in self.config["output"]["schema"] and k not in result:
                        result[k] = v

            return result, total_cost

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            futures = [
                executor.submit(process_group, key, group)
                for key, group in grouped_data
            ]
            results = []
            total_cost = 0
            for future in rich_as_completed(
                futures,
                total=len(futures),
                desc="Processing reduce items",
                leave=True,
                console=self.console,
            ):
                output, item_cost = future.result()
                total_cost += item_cost
                if output is not None:
                    results.append(output)

        if self.config.get("persist_intermediates", False):
            for result in results:
                key = tuple(result[k] for k in self.config["reduce_key"])
                if key in self.intermediates:
                    result[f"_{self.config['name']}_intermediates"] = (
                        self.intermediates[key]
                    )

        return results, total_cost

    def _get_embeddings(
        self, items: List[Dict], value_sampling: Dict
    ) -> Tuple[List[List[float]], float]:
        embedding_model = value_sampling["embedding_model"]
        embedding_keys = value_sampling["embedding_keys"]
        if not embedding_keys:
            embedding_keys = list(items[0].keys())
        embeddings = []
        cost = 0
        batch_size = 1000

        for i in range(0, len(items), batch_size):
            batch = items[i : i + batch_size]
            texts = [
                " ".join(str(item[key]) for key in embedding_keys if key in item)[
                    :10000
                ]
                for item in batch
            ]
            response = gen_embedding(embedding_model, texts)
            embeddings.extend([data["embedding"] for data in response["data"]])
            cost += completion_cost(response)

        return embeddings, cost

    def _cluster_based_sampling(
        self, group_list: List[Dict], value_sampling: Dict, sample_size: int
    ) -> Tuple[List[Dict], float]:
        embeddings, cost = self._get_embeddings(group_list, value_sampling)

        kmeans = KMeans(n_clusters=sample_size, random_state=42)
        cluster_labels = kmeans.fit_predict(embeddings)

        sampled_items = []
        for i in range(sample_size):
            cluster_items = [
                item for item, label in zip(group_list, cluster_labels) if label == i
            ]
            if cluster_items:
                sampled_items.append(random.choice(cluster_items))

        return sampled_items, cost

    def _semantic_similarity_sampling(
        self, key: Tuple, group_list: List[Dict], value_sampling: Dict, sample_size: int
    ) -> Tuple[List[Dict], float]:
        embedding_model = value_sampling["embedding_model"]
        query_text_template = Template(value_sampling["query_text"])
        query_text = query_text_template.render(
            reduce_key=dict(zip(self.config["reduce_key"], key))
        )

        embeddings, cost = self._get_embeddings(group_list, value_sampling)

        query_response = gen_embedding(embedding_model, [query_text])
        query_embedding = query_response["data"][0]["embedding"]
        cost += completion_cost(query_response)

        similarities = cosine_similarity([query_embedding], embeddings)[0]

        top_k_indices = np.argsort(similarities)[-sample_size:]

        return [group_list[i] for i in top_k_indices], cost

    def _parallel_fold_and_merge(
        self, key: Tuple, group_list: List[Dict]
    ) -> Tuple[Optional[Dict], float]:
        """
        Perform parallel folding and merging on a group of items.

        This method implements a strategy that combines parallel folding of input items
        and merging of intermediate results to efficiently process large groups. It works as follows:
        1. The input group is initially divided into smaller batches for efficient processing.
        2. The method performs an initial round of folding operations on these batches.
        3. After the first round of folds, a few merges are performed to estimate the merge runtime.
        4. Based on the estimated merge runtime and observed fold runtime, it calculates the optimal number of parallel folds. Subsequent rounds of folding are then performed concurrently, with the number of parallel folds determined by the runtime estimates.
        5. The folding process repeats in rounds, progressively reducing the number of items to be processed.
        6. Once all folding operations are complete, the method recursively performs final merges on the fold results to combine them into a final result.
        7. Throughout this process, the method may adjust the number of parallel folds based on updated performance metrics (i.e., fold and merge runtimes) to maintain efficiency.

        Args:
            key (Tuple): The reduce key tuple for the group.
            group_list (List[Dict]): The list of items in the group to be processed.

        Returns:
            Tuple[Optional[Dict], float]: A tuple containing the final merged result (or None if processing failed)
            and the total cost of the operation.
        """
        fold_batch_size = self.config["fold_batch_size"]
        merge_batch_size = self.config["merge_batch_size"]
        total_cost = 0

        def calculate_num_parallel_folds():
            fold_time, fold_default = self.get_fold_time()
            merge_time, merge_default = self.get_merge_time()
            num_group_items = len(group_list)
            return (
                max(
                    1,
                    int(
                        (fold_time * num_group_items * math.log(merge_batch_size))
                        / (fold_batch_size * merge_time)
                    ),
                ),
                fold_default or merge_default,
            )

        num_parallel_folds, used_default_times = calculate_num_parallel_folds()
        fold_results = []
        remaining_items = group_list

        if self.config.get("persist_intermediates", False):
            self.intermediates[key] = []
            iter_count = 0

        # Parallel folding and merging
        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            while remaining_items:
                # Folding phase
                fold_futures = []
                for i in range(min(num_parallel_folds, len(remaining_items))):
                    batch = remaining_items[:fold_batch_size]
                    remaining_items = remaining_items[fold_batch_size:]
                    current_output = fold_results[i] if i < len(fold_results) else None
                    fold_futures.append(
                        executor.submit(
                            self._increment_fold, key, batch, current_output
                        )
                    )

                new_fold_results = []
                for future in as_completed(fold_futures):
                    result, cost = future.result()
                    total_cost += cost
                    if result is not None:
                        new_fold_results.append(result)
                        if self.config.get("persist_intermediates", False):
                            self.intermediates[key].append(
                                {
                                    "iter": iter_count,
                                    "intermediate": result,
                                    "scratchpad": result["updated_scratchpad"],
                                }
                            )
                            iter_count += 1

                # Update fold_results with new results
                fold_results = new_fold_results + fold_results[len(new_fold_results) :]

                # Single pass merging phase
                if (
                    len(self.merge_times) < self.min_samples
                    and len(fold_results) >= merge_batch_size
                ):
                    merge_futures = []
                    for i in range(0, len(fold_results), merge_batch_size):
                        batch = fold_results[i : i + merge_batch_size]
                        merge_futures.append(
                            executor.submit(self._merge_results, key, batch)
                        )

                    new_results = []
                    for future in as_completed(merge_futures):
                        result, cost = future.result()
                        total_cost += cost
                        if result is not None:
                            new_results.append(result)
                            if self.config.get("persist_intermediates", False):
                                self.intermediates[key].append(
                                    {
                                        "iter": iter_count,
                                        "intermediate": result,
                                        "scratchpad": None,
                                    }
                                )
                                iter_count += 1

                    fold_results = new_results

                # Recalculate num_parallel_folds if we used default times
                if used_default_times:
                    new_num_parallel_folds, used_default_times = (
                        calculate_num_parallel_folds()
                    )
                    if not used_default_times:
                        self.console.log(
                            f"Recalculated num_parallel_folds from {num_parallel_folds} to {new_num_parallel_folds}"
                        )
                        num_parallel_folds = new_num_parallel_folds

        # Final merging if needed
        while len(fold_results) > 1:
            self.console.log(f"Finished folding! Merging {len(fold_results)} items.")
            with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
                merge_futures = []
                for i in range(0, len(fold_results), merge_batch_size):
                    batch = fold_results[i : i + merge_batch_size]
                    merge_futures.append(
                        executor.submit(self._merge_results, key, batch)
                    )

                new_results = []
                for future in as_completed(merge_futures):
                    result, cost = future.result()
                    total_cost += cost
                    if result is not None:
                        new_results.append(result)
                        if self.config.get("persist_intermediates", False):
                            self.intermediates[key].append(
                                {
                                    "iter": iter_count,
                                    "intermediate": result,
                                    "scratchpad": None,
                                }
                            )
                            iter_count += 1

                fold_results = new_results

        return (fold_results[0], total_cost) if fold_results else (None, total_cost)

    def _incremental_reduce(
        self, key: Tuple, group_list: List[Dict]
    ) -> Tuple[Optional[Dict], float]:
        """
        Perform an incremental reduce operation on a group of items.

        This method processes the group in batches, incrementally folding the results.

        Args:
            key (Tuple): The reduce key tuple for the group.
            group_list (List[Dict]): The list of items in the group to be processed.

        Returns:
            Tuple[Optional[Dict], float]: A tuple containing the final reduced result (or None if processing failed)
            and the total cost of the operation.
        """
        fold_batch_size = self.config["fold_batch_size"]
        total_cost = 0
        current_output = None

        # Calculate and log the number of folds to be performed
        num_folds = (len(group_list) + fold_batch_size - 1) // fold_batch_size

        scratchpad = ""
        if self.config.get("persist_intermediates", False):
            self.intermediates[key] = []
            iter_count = 0

        for i in range(0, len(group_list), fold_batch_size):
            # Log the current iteration and total number of folds
            current_fold = i // fold_batch_size + 1
            if self.config.get("verbose", False):
                self.console.log(
                    f"Processing fold {current_fold} of {num_folds} for group with key {key}"
                )
            batch = group_list[i : i + fold_batch_size]

            folded_output, fold_cost = self._increment_fold(
                key, batch, current_output, scratchpad
            )
            total_cost += fold_cost

            if folded_output is None:
                continue

            if self.config.get("persist_intermediates", False):
                self.intermediates[key].append(
                    {
                        "iter": iter_count,
                        "intermediate": folded_output,
                        "scratchpad": folded_output["updated_scratchpad"],
                    }
                )
                iter_count += 1

            # Pop off updated_scratchpad
            if "updated_scratchpad" in folded_output:
                scratchpad = folded_output["updated_scratchpad"]
                if self.config.get("verbose", False):
                    self.console.log(
                        f"Updated scratchpad for fold {current_fold}: {scratchpad}"
                    )
                del folded_output["updated_scratchpad"]

            current_output = folded_output

        return current_output, total_cost

    def _increment_fold(
        self,
        key: Tuple,
        batch: List[Dict],
        current_output: Optional[Dict],
        scratchpad: Optional[str] = None,
    ) -> Tuple[Optional[Dict], float]:
        """
        Perform an incremental fold operation on a batch of items.

        This method folds a batch of items into the current output using the fold prompt.

        Args:
            key (Tuple): The reduce key tuple for the group.
            batch (List[Dict]): The batch of items to be folded.
            current_output (Optional[Dict]): The current accumulated output, if any.
            scratchpad (Optional[str]): The scratchpad to use for the fold operation.
        Returns:
            Tuple[Optional[Dict], float]: A tuple containing the folded output (or None if processing failed)
            and the cost of the fold operation.
        """
        if current_output is None:
            return self._batch_reduce(key, batch, scratchpad)

        start_time = time.time()
        fold_prompt_template = Template(self.config["fold_prompt"])
        fold_prompt = fold_prompt_template.render(
            inputs=batch,
            output=current_output,
            reduce_key=dict(zip(self.config["reduce_key"], key)),
        )
        response = call_llm(
            self.config.get("model", self.default_model),
            "reduce",
            [{"role": "user", "content": fold_prompt}],
            self.config["output"]["schema"],
            scratchpad=scratchpad,
            console=self.console,
        )
        folded_output = parse_llm_response(response)[0]

        folded_output.update(dict(zip(self.config["reduce_key"], key)))
        fold_cost = completion_cost(response)
        end_time = time.time()
        self._update_fold_time(end_time - start_time)

        if validate_output(self.config, folded_output, self.console):
            return folded_output, fold_cost
        return None, fold_cost

    def _merge_results(
        self, key: Tuple, outputs: List[Dict]
    ) -> Tuple[Optional[Dict], float]:
        """
        Merge multiple outputs into a single result.

        This method merges a list of outputs using the merge prompt.

        Args:
            key (Tuple): The reduce key tuple for the group.
            outputs (List[Dict]): The list of outputs to be merged.

        Returns:
            Tuple[Optional[Dict], float]: A tuple containing the merged output (or None if processing failed)
            and the cost of the merge operation.
        """
        start_time = time.time()
        merge_prompt_template = Template(self.config["merge_prompt"])
        merge_prompt = merge_prompt_template.render(
            outputs=outputs, reduce_key=dict(zip(self.config["reduce_key"], key))
        )
        response = call_llm(
            self.config.get("model", self.default_model),
            "merge",
            [{"role": "user", "content": merge_prompt}],
            self.config["output"]["schema"],
            console=self.console,
        )
        merged_output = parse_llm_response(response)[0]
        merged_output.update(dict(zip(self.config["reduce_key"], key)))
        merge_cost = completion_cost(response)
        end_time = time.time()
        self._update_merge_time(end_time - start_time)

        if validate_output(self.config, merged_output, self.console):
            return merged_output, merge_cost
        return None, merge_cost

    def get_fold_time(self) -> Tuple[float, bool]:
        """
        Get the average fold time or a default value.

        Returns:
            Tuple[float, bool]: A tuple containing the average fold time (or default) and a boolean
            indicating whether the default value was used.
        """
        if "fold_time" in self.config:
            return self.config["fold_time"], False
        with self.lock:
            if len(self.fold_times) >= self.min_samples:
                return sum(self.fold_times) / len(self.fold_times), False
        return 1.0, True  # Default to 1 second if no data is available

    def get_merge_time(self) -> Tuple[float, bool]:
        """
        Get the average merge time or a default value.

        Returns:
            Tuple[float, bool]: A tuple containing the average merge time (or default) and a boolean
            indicating whether the default value was used.
        """
        if "merge_time" in self.config:
            return self.config["merge_time"], False
        with self.lock:
            if len(self.merge_times) >= self.min_samples:
                return sum(self.merge_times) / len(self.merge_times), False
        return 1.0, True  # Default to 1 second if no data is available

    def _update_fold_time(self, time: float) -> None:
        """
        Update the fold time statistics.

        Args:
            time (float): The time taken for a fold operation.
        """
        with self.lock:
            self.fold_times.append(time)

    def _update_merge_time(self, time: float) -> None:
        """
        Update the merge time statistics.

        Args:
            time (float): The time taken for a merge operation.
        """
        with self.lock:
            self.merge_times.append(time)

    def _batch_reduce(
        self, key: Tuple, group_list: List[Dict], scratchpad: Optional[str] = None
    ) -> Tuple[Optional[Dict], float]:
        """
        Perform a batch reduce operation on a group of items.

        This method reduces a group of items into a single output using the reduce prompt.

        Args:
            key (Tuple): The reduce key tuple for the group.
            group_list (List[Dict]): The list of items to be reduced.
            scratchpad (Optional[str]): The scratchpad to use for the reduce operation.
        Returns:
            Tuple[Optional[Dict], float]: A tuple containing the reduced output (or None if processing failed)
            and the cost of the reduce operation.
        """
        prompt_template = Template(self.config["prompt"])
        prompt = prompt_template.render(
            reduce_key=dict(zip(self.config["reduce_key"], key)), inputs=group_list
        )
        item_cost = 0

        if "gleaning" in self.config:
            response, gleaning_cost = call_llm_with_gleaning(
                self.config.get("model", self.default_model),
                "reduce",
                [{"role": "user", "content": prompt}],
                self.config["output"]["schema"],
                self.config["gleaning"]["validation_prompt"],
                self.config["gleaning"]["num_rounds"],
                console=self.console,
            )
            item_cost += gleaning_cost
        else:
            response = call_llm(
                self.config.get("model", self.default_model),
                "reduce",
                [{"role": "user", "content": prompt}],
                self.config["output"]["schema"],
                console=self.console,
                scratchpad=scratchpad,
            )

        item_cost += completion_cost(response)

        output = parse_llm_response(response)[0]
        output.update(dict(zip(self.config["reduce_key"], key)))

        if validate_output(self.config, output, self.console):
            return output, item_cost
        return None, item_cost

__init__(*args, **kwargs)

Initialize the ReduceOperation.

Parameters:

Name Type Description Default
*args

Variable length argument list.

()
**kwargs

Arbitrary keyword arguments.

{}
Source code in docetl/operations/reduce.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(self, *args, **kwargs):
    """
    Initialize the ReduceOperation.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self.min_samples = 5
    self.max_samples = 1000
    self.fold_times = deque(maxlen=self.max_samples)
    self.merge_times = deque(maxlen=self.max_samples)
    self.lock = Lock()
    self.config["reduce_key"] = (
        [self.config["reduce_key"]]
        if isinstance(self.config["reduce_key"], str)
        else self.config["reduce_key"]
    )
    self.intermediates = {}

execute(input_data)

Execute the reduce operation on the provided input data.

This method sorts and groups the input data by the reduce key(s), then processes each group using either parallel fold and merge, incremental reduce, or batch reduce strategies.

Parameters:

Name Type Description Default
input_data List[Dict]

The input data to process.

required

Returns:

Type Description
Tuple[List[Dict], float]

Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

Source code in docetl/operations/reduce.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
    """
    Execute the reduce operation on the provided input data.

    This method sorts and groups the input data by the reduce key(s), then processes each group
    using either parallel fold and merge, incremental reduce, or batch reduce strategies.

    Args:
        input_data (List[Dict]): The input data to process.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.
    """
    reduce_keys = self.config["reduce_key"]
    if isinstance(reduce_keys, str):
        reduce_keys = [reduce_keys]
    input_schema = self.config.get("input", {}).get("schema", {})

    # Check if we need to group everything into one group
    if reduce_keys == ["_all"] or reduce_keys == "_all":
        grouped_data = [("_all", input_data)]
    else:
        # Group the input data by the reduce key(s) while maintaining original order
        def get_group_key(item):
            return tuple(item[key] for key in reduce_keys)

        grouped_data = {}
        for item in input_data:
            key = get_group_key(item)
            if key not in grouped_data:
                grouped_data[key] = []
            grouped_data[key].append(item)

        # Convert the grouped data to a list of tuples
        grouped_data = list(grouped_data.items())

    def process_group(
        key: Tuple, group_elems: List[Dict]
    ) -> Tuple[Optional[Dict], float]:
        if input_schema:
            group_list = [
                {k: item[k] for k in input_schema.keys() if k in item}
                for item in group_elems
            ]
        else:
            group_list = group_elems

        total_cost = 0.0

        # Apply value sampling if enabled
        value_sampling = self.config.get("value_sampling", {})
        if value_sampling.get("enabled", False):
            sample_size = min(value_sampling["sample_size"], len(group_list))
            method = value_sampling["method"]

            if method == "random":
                group_sample = random.sample(group_list, sample_size)
                group_sample.sort(key=lambda x: group_list.index(x))
            elif method == "first_n":
                group_sample = group_list[:sample_size]
            elif method == "cluster":
                group_sample, embedding_cost = self._cluster_based_sampling(
                    group_list, value_sampling, sample_size
                )
                group_sample.sort(key=lambda x: group_list.index(x))
                total_cost += embedding_cost
            elif method == "sem_sim":
                group_sample, embedding_cost = self._semantic_similarity_sampling(
                    key, group_list, value_sampling, sample_size
                )
                group_sample.sort(key=lambda x: group_list.index(x))
                total_cost += embedding_cost

            group_list = group_sample

        # Only execute merge-based plans if associative = True
        if "merge_prompt" in self.config and self.config.get("associative", True):
            result, cost = self._parallel_fold_and_merge(key, group_list)
        elif "fold_prompt" in self.config:
            result, cost = self._incremental_reduce(key, group_list)
        else:
            result, cost = self._batch_reduce(key, group_list)

        total_cost += cost

        # Apply pass-through at the group level
        if (
            result is not None
            and self.config.get("pass_through", False)
            and group_elems
        ):
            for k, v in group_elems[0].items():
                if k not in self.config["output"]["schema"] and k not in result:
                    result[k] = v

        return result, total_cost

    with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
        futures = [
            executor.submit(process_group, key, group)
            for key, group in grouped_data
        ]
        results = []
        total_cost = 0
        for future in rich_as_completed(
            futures,
            total=len(futures),
            desc="Processing reduce items",
            leave=True,
            console=self.console,
        ):
            output, item_cost = future.result()
            total_cost += item_cost
            if output is not None:
                results.append(output)

    if self.config.get("persist_intermediates", False):
        for result in results:
            key = tuple(result[k] for k in self.config["reduce_key"])
            if key in self.intermediates:
                result[f"_{self.config['name']}_intermediates"] = (
                    self.intermediates[key]
                )

    return results, total_cost

get_fold_time()

Get the average fold time or a default value.

Returns:

Type Description
float

Tuple[float, bool]: A tuple containing the average fold time (or default) and a boolean

bool

indicating whether the default value was used.

Source code in docetl/operations/reduce.py
742
743
744
745
746
747
748
749
750
751
752
753
754
755
def get_fold_time(self) -> Tuple[float, bool]:
    """
    Get the average fold time or a default value.

    Returns:
        Tuple[float, bool]: A tuple containing the average fold time (or default) and a boolean
        indicating whether the default value was used.
    """
    if "fold_time" in self.config:
        return self.config["fold_time"], False
    with self.lock:
        if len(self.fold_times) >= self.min_samples:
            return sum(self.fold_times) / len(self.fold_times), False
    return 1.0, True  # Default to 1 second if no data is available

get_merge_time()

Get the average merge time or a default value.

Returns:

Type Description
float

Tuple[float, bool]: A tuple containing the average merge time (or default) and a boolean

bool

indicating whether the default value was used.

Source code in docetl/operations/reduce.py
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def get_merge_time(self) -> Tuple[float, bool]:
    """
    Get the average merge time or a default value.

    Returns:
        Tuple[float, bool]: A tuple containing the average merge time (or default) and a boolean
        indicating whether the default value was used.
    """
    if "merge_time" in self.config:
        return self.config["merge_time"], False
    with self.lock:
        if len(self.merge_times) >= self.min_samples:
            return sum(self.merge_times) / len(self.merge_times), False
    return 1.0, True  # Default to 1 second if no data is available

syntax_check()

Perform comprehensive syntax checks on the configuration of the ReduceOperation.

This method validates the presence and correctness of all required configuration keys, Jinja2 templates, and ensures the correct structure and types of the entire configuration.

The method performs the following checks: 1. Verifies the presence of all required keys in the configuration. 2. Validates the structure and content of the 'output' configuration, including its 'schema'. 3. Checks if the main 'prompt' is a valid Jinja2 template and contains the required 'inputs' variable. 4. If 'merge_prompt' is specified, ensures that 'fold_prompt' is also present. 5. If 'fold_prompt' is present, verifies the existence of 'fold_batch_size'. 6. Validates the 'fold_prompt' as a Jinja2 template with required variables 'inputs' and 'output'. 7. If present, checks 'merge_prompt' as a valid Jinja2 template with required 'outputs' variable. 8. Verifies types of various configuration inputs (e.g., 'fold_batch_size' as int). 9. Checks for the presence and validity of optional configurations like 'model'.

Raises:

Type Description
ValueError

If any required configuration is missing, if templates are invalid or missing required variables, or if any other configuration aspect is incorrect or inconsistent.

TypeError

If any configuration value has an incorrect type, such as 'schema' not being a dict or 'fold_batch_size' not being an integer.

Source code in docetl/operations/reduce.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def syntax_check(self) -> None:
    """
    Perform comprehensive syntax checks on the configuration of the ReduceOperation.

    This method validates the presence and correctness of all required configuration keys, Jinja2 templates, and ensures the correct
    structure and types of the entire configuration.

    The method performs the following checks:
    1. Verifies the presence of all required keys in the configuration.
    2. Validates the structure and content of the 'output' configuration, including its 'schema'.
    3. Checks if the main 'prompt' is a valid Jinja2 template and contains the required 'inputs' variable.
    4. If 'merge_prompt' is specified, ensures that 'fold_prompt' is also present.
    5. If 'fold_prompt' is present, verifies the existence of 'fold_batch_size'.
    6. Validates the 'fold_prompt' as a Jinja2 template with required variables 'inputs' and 'output'.
    7. If present, checks 'merge_prompt' as a valid Jinja2 template with required 'outputs' variable.
    8. Verifies types of various configuration inputs (e.g., 'fold_batch_size' as int).
    9. Checks for the presence and validity of optional configurations like 'model'.

    Raises:
        ValueError: If any required configuration is missing, if templates are invalid or missing required
                    variables, or if any other configuration aspect is incorrect or inconsistent.
        TypeError: If any configuration value has an incorrect type, such as 'schema' not being a dict
                   or 'fold_batch_size' not being an integer.
    """
    required_keys = ["reduce_key", "prompt", "output"]
    for key in required_keys:
        if key not in self.config:
            raise ValueError(
                f"Missing required key '{key}' in ReduceOperation configuration"
            )

    if "schema" not in self.config["output"]:
        raise ValueError("Missing 'schema' in 'output' configuration")

    if not isinstance(self.config["output"]["schema"], dict):
        raise TypeError("'schema' in 'output' configuration must be a dictionary")

    if not self.config["output"]["schema"]:
        raise ValueError("'schema' in 'output' configuration cannot be empty")

    # Check if the prompt is a valid Jinja2 template
    try:
        template = Template(self.config["prompt"])
        template_vars = template.environment.parse(self.config["prompt"]).find_all(
            jinja2.nodes.Name
        )
        template_var_names = {var.name for var in template_vars}
        if "inputs" not in template_var_names:
            raise ValueError("Template must include the 'inputs' variable")
    except Exception as e:
        raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")

    # Check if fold_prompt is a valid Jinja2 template (now required if merge exists)
    if "merge_prompt" in self.config:
        if "fold_prompt" not in self.config:
            raise ValueError(
                "'fold_prompt' is required when 'merge_prompt' is specified"
            )

    if "fold_prompt" in self.config:
        if "fold_batch_size" not in self.config:
            raise ValueError(
                "'fold_batch_size' is required when 'fold_prompt' is specified"
            )

        try:
            fold_template = Template(self.config["fold_prompt"])
            fold_template_vars = fold_template.environment.parse(
                self.config["fold_prompt"]
            ).find_all(jinja2.nodes.Name)
            fold_template_var_names = {var.name for var in fold_template_vars}
            required_vars = {"inputs", "output"}
            if not required_vars.issubset(fold_template_var_names):
                raise ValueError(
                    f"Fold template must include variables: {required_vars}. Current template includes: {fold_template_var_names}"
                )
        except Exception as e:
            raise ValueError(f"Invalid Jinja2 template in 'fold_prompt': {str(e)}")

    # Check merge_prompt and merge_batch_size
    if "merge_prompt" in self.config:
        if "merge_batch_size" not in self.config:
            raise ValueError(
                "'merge_batch_size' is required when 'merge_prompt' is specified"
            )

        try:
            merge_template = Template(self.config["merge_prompt"])
            merge_template_vars = merge_template.environment.parse(
                self.config["merge_prompt"]
            ).find_all(jinja2.nodes.Name)
            merge_template_var_names = {var.name for var in merge_template_vars}
            if "outputs" not in merge_template_var_names:
                raise ValueError(
                    "Merge template must include the 'outputs' variable"
                )
        except Exception as e:
            raise ValueError(f"Invalid Jinja2 template in 'merge_prompt': {str(e)}")

    # Check if the model is specified (optional)
    if "model" in self.config and not isinstance(self.config["model"], str):
        raise TypeError("'model' in configuration must be a string")

    # Check if reduce_key is a string or a list of strings
    if not isinstance(self.config["reduce_key"], (str, list)):
        raise TypeError("'reduce_key' must be a string or a list of strings")
    if isinstance(self.config["reduce_key"], list):
        if not all(isinstance(key, str) for key in self.config["reduce_key"]):
            raise TypeError("All elements in 'reduce_key' list must be strings")

    # Check if input schema is provided and valid (optional)
    if "input" in self.config:
        if "schema" not in self.config["input"]:
            raise ValueError("Missing 'schema' in 'input' configuration")
        if not isinstance(self.config["input"]["schema"], dict):
            raise TypeError(
                "'schema' in 'input' configuration must be a dictionary"
            )

    # Check if fold_batch_size and merge_batch_size are positive integers
    for key in ["fold_batch_size", "merge_batch_size"]:
        if key in self.config:
            if not isinstance(self.config[key], int) or self.config[key] <= 0:
                raise ValueError(f"'{key}' must be a positive integer")

    if "value_sampling" in self.config:
        sampling = self.config["value_sampling"]
        if not isinstance(sampling, dict):
            raise TypeError("'value_sampling' must be a dictionary")

        if "enabled" not in sampling:
            raise ValueError(
                "'enabled' is required in 'value_sampling' configuration"
            )
        if not isinstance(sampling["enabled"], bool):
            raise TypeError("'enabled' in 'value_sampling' must be a boolean")

        if sampling["enabled"]:
            if "sample_size" not in sampling:
                raise ValueError(
                    "'sample_size' is required when value_sampling is enabled"
                )
            if (
                not isinstance(sampling["sample_size"], int)
                or sampling["sample_size"] <= 0
            ):
                raise ValueError("'sample_size' must be a positive integer")

            if "method" not in sampling:
                raise ValueError(
                    "'method' is required when value_sampling is enabled"
                )
            if sampling["method"] not in [
                "random",
                "first_n",
                "cluster",
                "sem_sim",
            ]:
                raise ValueError(
                    "Invalid 'method'. Must be 'random', 'first_n', or 'embedding'"
                )

            if sampling["method"] == "embedding":
                if "embedding_model" not in sampling:
                    raise ValueError(
                        "'embedding_model' is required when using embedding-based sampling"
                    )
                if "embedding_keys" not in sampling:
                    raise ValueError(
                        "'embedding_keys' is required when using embedding-based sampling"
                    )

    self.gleaning_check()

docetl.operations.map.ParallelMapOperation

Bases: BaseOperation

Source code in docetl/operations/map.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class ParallelMapOperation(BaseOperation):
    def syntax_check(self) -> None:
        """
        Checks the configuration of the ParallelMapOperation for required keys and valid structure.

        Raises:
            ValueError: If required keys are missing or if the configuration structure is invalid.
            TypeError: If the configuration values have incorrect types.
        """
        if "drop_keys" in self.config:
            if not isinstance(self.config["drop_keys"], list):
                raise TypeError(
                    "'drop_keys' in configuration must be a list of strings"
                )
            for key in self.config["drop_keys"]:
                if not isinstance(key, str):
                    raise TypeError("All items in 'drop_keys' must be strings")
        else:
            if "prompts" not in self.config:
                raise ValueError(
                    "If 'drop_keys' is not specified, 'prompts' must be present in the configuration"
                )

        if "prompts" in self.config:
            if not isinstance(self.config["prompts"], list):
                raise ValueError(
                    "ParallelMapOperation requires a 'prompts' list in the configuration"
                )

            if not self.config["prompts"]:
                raise ValueError("The 'prompts' list cannot be empty")

            for i, prompt_config in enumerate(self.config["prompts"]):
                if not isinstance(prompt_config, dict):
                    raise TypeError(f"Prompt configuration {i} must be a dictionary")

                required_keys = ["name", "prompt", "output_keys"]
                for key in required_keys:
                    if key not in prompt_config:
                        raise ValueError(
                            f"Missing required key '{key}' in prompt configuration {i}"
                        )

                if not isinstance(prompt_config["name"], str):
                    raise TypeError(
                        f"'name' in prompt configuration {i} must be a string"
                    )

                if not isinstance(prompt_config["prompt"], str):
                    raise TypeError(
                        f"'prompt' in prompt configuration {i} must be a string"
                    )

                if not isinstance(prompt_config["output_keys"], list):
                    raise TypeError(
                        f"'output_keys' in prompt configuration {i} must be a list"
                    )

                if not prompt_config["output_keys"]:
                    raise ValueError(
                        f"'output_keys' list in prompt configuration {i} cannot be empty"
                    )

                # Check if the prompt is a valid Jinja2 template
                try:
                    Template(prompt_config["prompt"])
                except Exception as e:
                    raise ValueError(
                        f"Invalid Jinja2 template in prompt configuration {i}: {str(e)}"
                    )

                # Check if the model is specified (optional)
                if "model" in prompt_config and not isinstance(
                    prompt_config["model"], str
                ):
                    raise TypeError(
                        f"'model' in prompt configuration {i} must be a string"
                    )

            # Check if all output schema keys are covered by the prompts
            output_schema = self.config["output"]["schema"]
            output_keys_covered = set()
            for prompt_config in self.config["prompts"]:
                output_keys_covered.update(prompt_config["output_keys"])

            missing_keys = set(output_schema.keys()) - output_keys_covered
            if missing_keys:
                raise ValueError(
                    f"The following output schema keys are not covered by any prompt: {missing_keys}"
                )

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        """
        Executes the parallel map operation on the provided input data.

        Args:
            input_data (List[Dict]): The input data to process.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

        This method performs the following steps:
        1. If prompts are specified, it processes each input item using multiple prompts in parallel
        2. Aggregates results from different prompts for each input item
        3. Validates the combined output for each item
        4. If drop_keys is specified, it drops the specified keys from each document
        5. Calculates total cost of the operation
        """
        results = {}
        total_cost = 0
        output_schema = self.config.get("output", {}).get("schema", {})

        # Check if there's no prompt and only drop_keys
        if "prompts" not in self.config and "drop_keys" in self.config:
            # If only drop_keys is specified, simply drop the keys and return
            dropped_results = []
            for item in input_data:
                new_item = {
                    k: v for k, v in item.items() if k not in self.config["drop_keys"]
                }
                dropped_results.append(new_item)
            return dropped_results, 0.0  # Return the modified data with no cost

        def process_prompt(item, prompt_config):
            prompt_template = Template(prompt_config["prompt"])
            prompt = prompt_template.render(input=item)
            local_output_schema = {
                key: output_schema[key] for key in prompt_config["output_keys"]
            }

            # If there are tools, we need to pass in the tools
            response = call_llm(
                prompt_config.get("model", self.default_model),
                "parallel_map",
                [{"role": "user", "content": prompt}],
                local_output_schema,
                tools=prompt_config.get("tools", None),
                console=self.console,
            )
            output = parse_llm_response(
                response, tools=prompt_config.get("tools", None)
            )[0]
            return output, completion_cost(response)

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            if "prompts" in self.config:
                # Create all futures at once
                all_futures = [
                    executor.submit(process_prompt, item, prompt_config)
                    for item in input_data
                    for prompt_config in self.config["prompts"]
                ]

                # Process results in order
                pbar = RichLoopBar(
                    range(len(all_futures)),
                    desc="Processing parallel map items",
                    console=self.console,
                )
                for i in pbar:
                    future = all_futures[i]
                    output, cost = future.result()
                    total_cost += cost

                    # Determine which item this future corresponds to
                    item_index = i // len(self.config["prompts"])
                    prompt_index = i % len(self.config["prompts"])

                    # Initialize or update the item_result
                    if prompt_index == 0:
                        item_result = input_data[item_index].copy()
                        results[item_index] = item_result

                    # Fetch the item_result
                    item_result = results[item_index]

                    # Update the item_result with the output
                    item_result.update(output)

                    pbar.update(i)
            else:
                results = {i: item.copy() for i, item in enumerate(input_data)}

        # Apply drop_keys if specified
        if "drop_keys" in self.config:
            drop_keys = self.config["drop_keys"]
            for item in results.values():
                for key in drop_keys:
                    item.pop(key, None)

        # Return the results in order
        return [results[i] for i in range(len(input_data)) if i in results], total_cost

execute(input_data)

Executes the parallel map operation on the provided input data.

Parameters:

Name Type Description Default
input_data List[Dict]

The input data to process.

required

Returns:

Type Description
Tuple[List[Dict], float]

Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

This method performs the following steps: 1. If prompts are specified, it processes each input item using multiple prompts in parallel 2. Aggregates results from different prompts for each input item 3. Validates the combined output for each item 4. If drop_keys is specified, it drops the specified keys from each document 5. Calculates total cost of the operation

Source code in docetl/operations/map.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
    """
    Executes the parallel map operation on the provided input data.

    Args:
        input_data (List[Dict]): The input data to process.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.

    This method performs the following steps:
    1. If prompts are specified, it processes each input item using multiple prompts in parallel
    2. Aggregates results from different prompts for each input item
    3. Validates the combined output for each item
    4. If drop_keys is specified, it drops the specified keys from each document
    5. Calculates total cost of the operation
    """
    results = {}
    total_cost = 0
    output_schema = self.config.get("output", {}).get("schema", {})

    # Check if there's no prompt and only drop_keys
    if "prompts" not in self.config and "drop_keys" in self.config:
        # If only drop_keys is specified, simply drop the keys and return
        dropped_results = []
        for item in input_data:
            new_item = {
                k: v for k, v in item.items() if k not in self.config["drop_keys"]
            }
            dropped_results.append(new_item)
        return dropped_results, 0.0  # Return the modified data with no cost

    def process_prompt(item, prompt_config):
        prompt_template = Template(prompt_config["prompt"])
        prompt = prompt_template.render(input=item)
        local_output_schema = {
            key: output_schema[key] for key in prompt_config["output_keys"]
        }

        # If there are tools, we need to pass in the tools
        response = call_llm(
            prompt_config.get("model", self.default_model),
            "parallel_map",
            [{"role": "user", "content": prompt}],
            local_output_schema,
            tools=prompt_config.get("tools", None),
            console=self.console,
        )
        output = parse_llm_response(
            response, tools=prompt_config.get("tools", None)
        )[0]
        return output, completion_cost(response)

    with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
        if "prompts" in self.config:
            # Create all futures at once
            all_futures = [
                executor.submit(process_prompt, item, prompt_config)
                for item in input_data
                for prompt_config in self.config["prompts"]
            ]

            # Process results in order
            pbar = RichLoopBar(
                range(len(all_futures)),
                desc="Processing parallel map items",
                console=self.console,
            )
            for i in pbar:
                future = all_futures[i]
                output, cost = future.result()
                total_cost += cost

                # Determine which item this future corresponds to
                item_index = i // len(self.config["prompts"])
                prompt_index = i % len(self.config["prompts"])

                # Initialize or update the item_result
                if prompt_index == 0:
                    item_result = input_data[item_index].copy()
                    results[item_index] = item_result

                # Fetch the item_result
                item_result = results[item_index]

                # Update the item_result with the output
                item_result.update(output)

                pbar.update(i)
        else:
            results = {i: item.copy() for i, item in enumerate(input_data)}

    # Apply drop_keys if specified
    if "drop_keys" in self.config:
        drop_keys = self.config["drop_keys"]
        for item in results.values():
            for key in drop_keys:
                item.pop(key, None)

    # Return the results in order
    return [results[i] for i in range(len(input_data)) if i in results], total_cost

syntax_check()

Checks the configuration of the ParallelMapOperation for required keys and valid structure.

Raises:

Type Description
ValueError

If required keys are missing or if the configuration structure is invalid.

TypeError

If the configuration values have incorrect types.

Source code in docetl/operations/map.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def syntax_check(self) -> None:
    """
    Checks the configuration of the ParallelMapOperation for required keys and valid structure.

    Raises:
        ValueError: If required keys are missing or if the configuration structure is invalid.
        TypeError: If the configuration values have incorrect types.
    """
    if "drop_keys" in self.config:
        if not isinstance(self.config["drop_keys"], list):
            raise TypeError(
                "'drop_keys' in configuration must be a list of strings"
            )
        for key in self.config["drop_keys"]:
            if not isinstance(key, str):
                raise TypeError("All items in 'drop_keys' must be strings")
    else:
        if "prompts" not in self.config:
            raise ValueError(
                "If 'drop_keys' is not specified, 'prompts' must be present in the configuration"
            )

    if "prompts" in self.config:
        if not isinstance(self.config["prompts"], list):
            raise ValueError(
                "ParallelMapOperation requires a 'prompts' list in the configuration"
            )

        if not self.config["prompts"]:
            raise ValueError("The 'prompts' list cannot be empty")

        for i, prompt_config in enumerate(self.config["prompts"]):
            if not isinstance(prompt_config, dict):
                raise TypeError(f"Prompt configuration {i} must be a dictionary")

            required_keys = ["name", "prompt", "output_keys"]
            for key in required_keys:
                if key not in prompt_config:
                    raise ValueError(
                        f"Missing required key '{key}' in prompt configuration {i}"
                    )

            if not isinstance(prompt_config["name"], str):
                raise TypeError(
                    f"'name' in prompt configuration {i} must be a string"
                )

            if not isinstance(prompt_config["prompt"], str):
                raise TypeError(
                    f"'prompt' in prompt configuration {i} must be a string"
                )

            if not isinstance(prompt_config["output_keys"], list):
                raise TypeError(
                    f"'output_keys' in prompt configuration {i} must be a list"
                )

            if not prompt_config["output_keys"]:
                raise ValueError(
                    f"'output_keys' list in prompt configuration {i} cannot be empty"
                )

            # Check if the prompt is a valid Jinja2 template
            try:
                Template(prompt_config["prompt"])
            except Exception as e:
                raise ValueError(
                    f"Invalid Jinja2 template in prompt configuration {i}: {str(e)}"
                )

            # Check if the model is specified (optional)
            if "model" in prompt_config and not isinstance(
                prompt_config["model"], str
            ):
                raise TypeError(
                    f"'model' in prompt configuration {i} must be a string"
                )

        # Check if all output schema keys are covered by the prompts
        output_schema = self.config["output"]["schema"]
        output_keys_covered = set()
        for prompt_config in self.config["prompts"]:
            output_keys_covered.update(prompt_config["output_keys"])

        missing_keys = set(output_schema.keys()) - output_keys_covered
        if missing_keys:
            raise ValueError(
                f"The following output schema keys are not covered by any prompt: {missing_keys}"
            )

docetl.operations.filter.FilterOperation

Bases: BaseOperation

Source code in docetl/operations/filter.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class FilterOperation(BaseOperation):
    def syntax_check(self) -> None:
        """
        Checks the configuration of the FilterOperation for required keys and valid structure.

        Raises:
            ValueError: If required keys are missing or if the output schema structure is invalid.
            TypeError: If the schema in the output configuration is not a dictionary or if the schema value is not of type bool.

        This method checks for the following:
        - Presence of required keys: 'prompt' and 'output'
        - Presence of 'schema' in the 'output' configuration
        - The 'schema' is a non-empty dictionary with exactly one key-value pair
        - The value in the schema is of type bool
        """
        required_keys = ["prompt", "output"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in FilterOperation configuration"
                )

        if "schema" not in self.config["output"]:
            raise ValueError("Missing 'schema' in 'output' configuration")

        if not isinstance(self.config["output"]["schema"], dict):
            raise TypeError("'schema' in 'output' configuration must be a dictionary")

        if not self.config["output"]["schema"]:
            raise ValueError("'schema' in 'output' configuration cannot be empty")

        schema = self.config["output"]["schema"]
        if "_short_explanation" in schema:
            schema = {k: v for k, v in schema.items() if k != "_short_explanation"}
        if len(schema) != 1:
            raise ValueError(
                "The 'schema' in 'output' configuration must have exactly one key-value pair that maps to a boolean value"
            )

        key, value = next(iter(schema.items()))
        if value not in ["bool", "boolean"]:
            raise TypeError(
                f"The value in the 'schema' must be of type bool, got {value}"
            )

    def execute(
        self, input_data: List[Dict], is_build: bool = False
    ) -> Tuple[List[Dict], float]:
        """
        Executes the filter operation on the input data.

        Args:
            input_data (List[Dict]): A list of dictionaries to process.
            is_build (bool): Whether the operation is being executed in the build phase. Defaults to False.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the filtered list of dictionaries
            and the total cost of the operation.

        This method performs the following steps:
        1. Processes each input item using an LLM model
        2. Validates the output
        3. Filters the results based on the specified filter key
        4. Calculates the total cost of the operation

        The method uses multi-threading to process items in parallel, improving performance
        for large datasets.

        Usage:
        ```python
        from docetl.operations import FilterOperation

        config = {
            "prompt": "Determine if the following item is important: {{input}}",
            "output": {
                "schema": {"is_important": "bool"}
            },
            "model": "gpt-3.5-turbo"
        }
        filter_op = FilterOperation(config)
        input_data = [
            {"id": 1, "text": "Critical update"},
            {"id": 2, "text": "Regular maintenance"}
        ]
        results, cost = filter_op.execute(input_data)
        print(f"Filtered results: {results}")
        print(f"Total cost: {cost}")
        ```
        """
        filter_key = next(
            iter(
                [
                    k
                    for k in self.config["output"]["schema"].keys()
                    if k != "_short_explanation"
                ]
            )
        )

        def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]:
            prompt_template = Template(self.config["prompt"])
            prompt = prompt_template.render(input=item)

            def validation_fn(response: Dict[str, Any]):
                output = parse_llm_response(response)[0]
                for key, value in item.items():
                    if key not in self.config["output"]["schema"]:
                        output[key] = value
                if validate_output(self.config, output, self.console):
                    return output, True
                return output, False

            output, cost, is_valid = call_llm_with_validation(
                [{"role": "user", "content": prompt}],
                llm_call_fn=lambda messages: call_llm(
                    self.config.get("model", self.default_model),
                    "filter",
                    messages,
                    self.config["output"]["schema"],
                    console=self.console,
                ),
                validation_fn=validation_fn,
                val_rule=self.config.get("validate", []),
                num_retries=self.num_retries_on_validate_failure,
                console=self.console,
            )

            if is_valid:
                return output, cost

            return None, cost

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            futures = [
                executor.submit(_process_filter_item, item) for item in input_data
            ]
            results = []
            total_cost = 0
            pbar = RichLoopBar(
                range(len(futures)),
                desc="Processing filter items",
                console=self.console,
            )
            for i in pbar:
                future = futures[i]
                result, item_cost = future.result()
                total_cost += item_cost
                if result is not None:
                    if is_build:
                        results.append(result)
                    else:
                        if result.get(filter_key, False):
                            results.append(result)
                pbar.update(1)

        return results, total_cost

execute(input_data, is_build=False)

Executes the filter operation on the input data.

Parameters:

Name Type Description Default
input_data List[Dict]

A list of dictionaries to process.

required
is_build bool

Whether the operation is being executed in the build phase. Defaults to False.

False

Returns:

Type Description
List[Dict]

Tuple[List[Dict], float]: A tuple containing the filtered list of dictionaries

float

and the total cost of the operation.

This method performs the following steps: 1. Processes each input item using an LLM model 2. Validates the output 3. Filters the results based on the specified filter key 4. Calculates the total cost of the operation

The method uses multi-threading to process items in parallel, improving performance for large datasets.

Usage:

from docetl.operations import FilterOperation

config = {
    "prompt": "Determine if the following item is important: {{input}}",
    "output": {
        "schema": {"is_important": "bool"}
    },
    "model": "gpt-3.5-turbo"
}
filter_op = FilterOperation(config)
input_data = [
    {"id": 1, "text": "Critical update"},
    {"id": 2, "text": "Regular maintenance"}
]
results, cost = filter_op.execute(input_data)
print(f"Filtered results: {results}")
print(f"Total cost: {cost}")

Source code in docetl/operations/filter.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def execute(
    self, input_data: List[Dict], is_build: bool = False
) -> Tuple[List[Dict], float]:
    """
    Executes the filter operation on the input data.

    Args:
        input_data (List[Dict]): A list of dictionaries to process.
        is_build (bool): Whether the operation is being executed in the build phase. Defaults to False.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the filtered list of dictionaries
        and the total cost of the operation.

    This method performs the following steps:
    1. Processes each input item using an LLM model
    2. Validates the output
    3. Filters the results based on the specified filter key
    4. Calculates the total cost of the operation

    The method uses multi-threading to process items in parallel, improving performance
    for large datasets.

    Usage:
    ```python
    from docetl.operations import FilterOperation

    config = {
        "prompt": "Determine if the following item is important: {{input}}",
        "output": {
            "schema": {"is_important": "bool"}
        },
        "model": "gpt-3.5-turbo"
    }
    filter_op = FilterOperation(config)
    input_data = [
        {"id": 1, "text": "Critical update"},
        {"id": 2, "text": "Regular maintenance"}
    ]
    results, cost = filter_op.execute(input_data)
    print(f"Filtered results: {results}")
    print(f"Total cost: {cost}")
    ```
    """
    filter_key = next(
        iter(
            [
                k
                for k in self.config["output"]["schema"].keys()
                if k != "_short_explanation"
            ]
        )
    )

    def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]:
        prompt_template = Template(self.config["prompt"])
        prompt = prompt_template.render(input=item)

        def validation_fn(response: Dict[str, Any]):
            output = parse_llm_response(response)[0]
            for key, value in item.items():
                if key not in self.config["output"]["schema"]:
                    output[key] = value
            if validate_output(self.config, output, self.console):
                return output, True
            return output, False

        output, cost, is_valid = call_llm_with_validation(
            [{"role": "user", "content": prompt}],
            llm_call_fn=lambda messages: call_llm(
                self.config.get("model", self.default_model),
                "filter",
                messages,
                self.config["output"]["schema"],
                console=self.console,
            ),
            validation_fn=validation_fn,
            val_rule=self.config.get("validate", []),
            num_retries=self.num_retries_on_validate_failure,
            console=self.console,
        )

        if is_valid:
            return output, cost

        return None, cost

    with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
        futures = [
            executor.submit(_process_filter_item, item) for item in input_data
        ]
        results = []
        total_cost = 0
        pbar = RichLoopBar(
            range(len(futures)),
            desc="Processing filter items",
            console=self.console,
        )
        for i in pbar:
            future = futures[i]
            result, item_cost = future.result()
            total_cost += item_cost
            if result is not None:
                if is_build:
                    results.append(result)
                else:
                    if result.get(filter_key, False):
                        results.append(result)
            pbar.update(1)

    return results, total_cost

syntax_check()

Checks the configuration of the FilterOperation for required keys and valid structure.

Raises:

Type Description
ValueError

If required keys are missing or if the output schema structure is invalid.

TypeError

If the schema in the output configuration is not a dictionary or if the schema value is not of type bool.

This method checks for the following: - Presence of required keys: 'prompt' and 'output' - Presence of 'schema' in the 'output' configuration - The 'schema' is a non-empty dictionary with exactly one key-value pair - The value in the schema is of type bool

Source code in docetl/operations/filter.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def syntax_check(self) -> None:
    """
    Checks the configuration of the FilterOperation for required keys and valid structure.

    Raises:
        ValueError: If required keys are missing or if the output schema structure is invalid.
        TypeError: If the schema in the output configuration is not a dictionary or if the schema value is not of type bool.

    This method checks for the following:
    - Presence of required keys: 'prompt' and 'output'
    - Presence of 'schema' in the 'output' configuration
    - The 'schema' is a non-empty dictionary with exactly one key-value pair
    - The value in the schema is of type bool
    """
    required_keys = ["prompt", "output"]
    for key in required_keys:
        if key not in self.config:
            raise ValueError(
                f"Missing required key '{key}' in FilterOperation configuration"
            )

    if "schema" not in self.config["output"]:
        raise ValueError("Missing 'schema' in 'output' configuration")

    if not isinstance(self.config["output"]["schema"], dict):
        raise TypeError("'schema' in 'output' configuration must be a dictionary")

    if not self.config["output"]["schema"]:
        raise ValueError("'schema' in 'output' configuration cannot be empty")

    schema = self.config["output"]["schema"]
    if "_short_explanation" in schema:
        schema = {k: v for k, v in schema.items() if k != "_short_explanation"}
    if len(schema) != 1:
        raise ValueError(
            "The 'schema' in 'output' configuration must have exactly one key-value pair that maps to a boolean value"
        )

    key, value = next(iter(schema.items()))
    if value not in ["bool", "boolean"]:
        raise TypeError(
            f"The value in the 'schema' must be of type bool, got {value}"
        )

docetl.operations.equijoin.EquijoinOperation

Bases: BaseOperation

Source code in docetl/operations/equijoin.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
class EquijoinOperation(BaseOperation):
    def syntax_check(self) -> None:
        """
        Checks the configuration of the EquijoinOperation for required keys and valid structure.

        Raises:
            ValueError: If required keys are missing or if the blocking_keys structure is invalid.
            Specifically:
            - Raises if 'comparison_prompt' is missing from the config.
            - Raises if 'left' or 'right' are missing from the 'blocking_keys' structure (if present).
            - Raises if 'left' or 'right' are missing from the 'limits' structure (if present).
        """
        if "comparison_prompt" not in self.config:
            raise ValueError(
                "Missing required key 'comparison_prompt' in EquijoinOperation configuration"
            )

        if "blocking_keys" in self.config:
            if (
                "left" not in self.config["blocking_keys"]
                or "right" not in self.config["blocking_keys"]
            ):
                raise ValueError(
                    "Both 'left' and 'right' must be specified in 'blocking_keys'"
                )

        if "limits" in self.config:
            if (
                "left" not in self.config["limits"]
                or "right" not in self.config["limits"]
            ):
                raise ValueError(
                    "Both 'left' and 'right' must be specified in 'limits'"
                )

        if "limit_comparisons" in self.config:
            if not isinstance(self.config["limit_comparisons"], int):
                raise ValueError("limit_comparisons must be an integer")

    def execute(
        self, left_data: List[Dict], right_data: List[Dict]
    ) -> Tuple[List[Dict], float]:
        """
        Executes the equijoin operation on the provided datasets.

        Args:
            left_data (List[Dict]): The left dataset to join.
            right_data (List[Dict]): The right dataset to join.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the joined results and the total cost of the operation.

        Usage:
        ```python
        from docetl.operations import EquijoinOperation

        config = {
            "blocking_keys": {
                "left": ["id"],
                "right": ["user_id"]
            },
            "limits": {
                "left": 1,
                "right": 1
            },
            "comparison_prompt": "Compare {{left}} and {{right}} and determine if they match.",
            "blocking_threshold": 0.8,
            "blocking_conditions": ["left['id'] == right['user_id']"],
            "limit_comparisons": 1000
        }
        equijoin_op = EquijoinOperation(config)
        left_data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
        right_data = [{"user_id": 1, "age": 30}, {"user_id": 2, "age": 25}]
        results, cost = equijoin_op.execute(left_data, right_data)
        print(f"Joined results: {results}")
        print(f"Total cost: {cost}")
        ```

        This method performs the following steps:
        1. Initial blocking based on specified conditions (if any)
        2. Embedding-based blocking (if threshold is provided)
        3. LLM-based comparison for blocked pairs
        4. Result aggregation and validation

        The method also calculates and logs statistics such as comparisons saved by blocking and join selectivity.
        """

        blocking_keys = self.config.get("blocking_keys", {})
        left_keys = blocking_keys.get(
            "left", list(left_data[0].keys()) if left_data else []
        )
        right_keys = blocking_keys.get(
            "right", list(right_data[0].keys()) if right_data else []
        )
        limits = self.config.get(
            "limits", {"left": float("inf"), "right": float("inf")}
        )
        left_limit = limits["left"]
        right_limit = limits["right"]
        blocking_threshold = self.config.get("blocking_threshold")
        blocking_conditions = self.config.get("blocking_conditions", [])
        limit_comparisons = self.config.get("limit_comparisons")
        total_cost = 0

        # LLM-based comparison for blocked pairs
        def get_hashable_key(item: Dict) -> str:
            return json.dumps(item, sort_keys=True)

        if len(left_data) == 0 or len(right_data) == 0:
            return [], 0

        # Initial blocking using multiprocessing
        num_processes = min(cpu_count(), len(left_data))

        self.console.log(
            f"Starting to run code-based blocking rules for {len(left_data)} left and {len(right_data)} right rows ({len(left_data) * len(right_data)} total pairs) with {num_processes} processes..."
        )

        with Pool(
            processes=num_processes,
            initializer=init_worker,
            initargs=(right_data, blocking_conditions),
        ) as pool:
            blocked_pairs_nested = pool.map(process_left_item, left_data)

        # Flatten the nested list of blocked pairs
        blocked_pairs = [pair for sublist in blocked_pairs_nested for pair in sublist]

        # Check if we have exceeded the pairwise comparison limit
        if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
            # Sample pairs randomly
            sampled_pairs = random.sample(blocked_pairs, limit_comparisons)

            # Calculate number of dropped pairs
            dropped_pairs = len(blocked_pairs) - limit_comparisons

            # Prompt the user for confirmation
            if self.status:
                self.status.stop()
            if not Confirm.ask(
                f"[yellow]Warning: {dropped_pairs} pairs will be dropped due to the comparison limit. "
                f"Proceeding with {limit_comparisons} randomly sampled pairs. "
                f"Do you want to continue?[/yellow]",
            ):
                raise ValueError("Operation cancelled by user due to pair limit.")

            if self.status:
                self.status.start()

            blocked_pairs = sampled_pairs

        self.console.log(
            f"Number of blocked pairs after initial blocking: {len(blocked_pairs)}"
        )

        if blocking_threshold is not None:
            embedding_model = self.config.get("embedding_model", self.default_model)
            model_input_context_length = model_cost.get(embedding_model, {}).get(
                "max_input_tokens", 8192
            )

            def get_embeddings(
                input_data: List[Dict[str, Any]], keys: List[str], name: str
            ) -> Tuple[List[List[float]], float]:
                texts = [
                    " ".join(str(item[key]) for key in keys if key in item)[
                        : model_input_context_length * 4
                    ]
                    for item in input_data
                ]

                embeddings = []
                total_cost = 0
                batch_size = 2000
                for i in range(0, len(texts), batch_size):
                    batch = texts[i : i + batch_size]
                    self.console.log(
                        f"On iteration {i} for creating embeddings for {name} data"
                    )
                    response = gen_embedding(
                        model=embedding_model,
                        input=batch,
                    )
                    embeddings.extend([data["embedding"] for data in response["data"]])
                    total_cost += completion_cost(response)
                return embeddings, total_cost

            left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
            right_embeddings, right_cost = get_embeddings(
                right_data, right_keys, "right"
            )
            total_cost += left_cost + right_cost
            self.console.log(
                f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
            )

            # Compute all cosine similarities in one call
            similarities = cosine_similarity(left_embeddings, right_embeddings)

            # Additional blocking based on embeddings
            # Find indices where similarity is above threshold
            above_threshold = np.argwhere(similarities >= blocking_threshold)
            self.console.log(
                f"There are {above_threshold.shape[0]} pairs above the threshold."
            )
            block_pair_set = set(
                (get_hashable_key(left_item), get_hashable_key(right_item))
                for left_item, right_item in blocked_pairs
            )

            # If limit_comparisons is set, take only the top pairs
            if limit_comparisons is not None:
                # First, get all pairs above threshold
                above_threshold_pairs = [(int(i), int(j)) for i, j in above_threshold]

                # Sort these pairs by their similarity scores
                sorted_pairs = sorted(
                    above_threshold_pairs,
                    key=lambda pair: similarities[pair[0], pair[1]],
                    reverse=True,
                )

                # Take the top 'limit_comparisons' pairs
                top_pairs = sorted_pairs[:limit_comparisons]

                # Create new blocked_pairs based on top similarities and existing blocked pairs
                new_blocked_pairs = []
                remaining_limit = limit_comparisons - len(blocked_pairs)

                # First, include all existing blocked pairs
                final_blocked_pairs = blocked_pairs.copy()

                # Then, add new pairs from top similarities until we reach the limit
                for i, j in top_pairs:
                    if remaining_limit <= 0:
                        break
                    left_item, right_item = left_data[i], right_data[j]
                    left_key = get_hashable_key(left_item)
                    right_key = get_hashable_key(right_item)
                    if (left_key, right_key) not in block_pair_set:
                        new_blocked_pairs.append((left_item, right_item))
                        block_pair_set.add((left_key, right_key))
                        remaining_limit -= 1

                final_blocked_pairs.extend(new_blocked_pairs)
                blocked_pairs = final_blocked_pairs

                self.console.log(
                    f"Limited comparisons to top {limit_comparisons} pairs, including {len(blocked_pairs) - len(new_blocked_pairs)} from code-based blocking and {len(new_blocked_pairs)} based on cosine similarity. Lowest cosine similarity included: {similarities[top_pairs[-1]]:.4f}"
                )
            else:
                # Add new pairs to blocked_pairs
                for i, j in above_threshold:
                    left_item, right_item = left_data[i], right_data[j]
                    left_key = get_hashable_key(left_item)
                    right_key = get_hashable_key(right_item)
                    if (left_key, right_key) not in block_pair_set:
                        blocked_pairs.append((left_item, right_item))
                        block_pair_set.add((left_key, right_key))

        # If there are no blocking conditions or embedding threshold, use all pairs
        if not blocking_conditions and blocking_threshold is None:
            blocked_pairs = [
                (left_item, right_item)
                for left_item in left_data
                for right_item in right_data
            ]

        # If there's a limit on the number of comparisons, randomly sample pairs
        if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
            self.console.log(
                f"Randomly sampling {limit_comparisons} pairs out of {len(blocked_pairs)} blocked pairs."
            )
            blocked_pairs = random.sample(blocked_pairs, limit_comparisons)

        self.console.log(
            f"Total pairs to compare after blocking and sampling: {len(blocked_pairs)}"
        )

        # Calculate and print statistics
        total_possible_comparisons = len(left_data) * len(right_data)
        comparisons_made = len(blocked_pairs)
        comparisons_saved = total_possible_comparisons - comparisons_made
        self.console.log(
            f"[green]Comparisons saved by blocking: {comparisons_saved} "
            f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
        )

        left_match_counts = defaultdict(int)
        right_match_counts = defaultdict(int)
        results = []
        comparison_costs = 0

        with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
            future_to_pair = {
                executor.submit(
                    compare_pair,
                    self.config["comparison_prompt"],
                    self.config.get("comparison_model", self.default_model),
                    left,
                    right,
                ): (left, right)
                for left, right in blocked_pairs
            }

            for future in rich_as_completed(
                future_to_pair,
                total=len(future_to_pair),
                desc="Comparing pairs",
                console=self.console,
            ):
                pair = future_to_pair[future]
                is_match, cost = future.result()
                comparison_costs += cost

                if is_match:
                    joined_item = {}
                    left_item, right_item = pair
                    left_key_hash = get_hashable_key(left_item)
                    right_key_hash = get_hashable_key(right_item)
                    if (
                        left_match_counts[left_key_hash] >= left_limit
                        or right_match_counts[right_key_hash] >= right_limit
                    ):
                        continue

                    for key, value in left_item.items():
                        joined_item[f"{key}_left" if key in right_item else key] = value
                    for key, value in right_item.items():
                        joined_item[f"{key}_right" if key in left_item else key] = value
                    if validate_output(self.config, joined_item, self.console):
                        results.append(joined_item)
                        left_match_counts[left_key_hash] += 1
                        right_match_counts[right_key_hash] += 1

                    # TODO: support retry in validation failure

        total_cost += comparison_costs

        # Calculate and print the join selectivity
        join_selectivity = (
            len(results) / (len(left_data) * len(right_data))
            if len(left_data) * len(right_data) > 0
            else 0
        )
        self.console.log(f"Equijoin selectivity: {join_selectivity:.4f}")

        return results, total_cost

execute(left_data, right_data)

Executes the equijoin operation on the provided datasets.

Parameters:

Name Type Description Default
left_data List[Dict]

The left dataset to join.

required
right_data List[Dict]

The right dataset to join.

required

Returns:

Type Description
Tuple[List[Dict], float]

Tuple[List[Dict], float]: A tuple containing the joined results and the total cost of the operation.

Usage:

from docetl.operations import EquijoinOperation

config = {
    "blocking_keys": {
        "left": ["id"],
        "right": ["user_id"]
    },
    "limits": {
        "left": 1,
        "right": 1
    },
    "comparison_prompt": "Compare {{left}} and {{right}} and determine if they match.",
    "blocking_threshold": 0.8,
    "blocking_conditions": ["left['id'] == right['user_id']"],
    "limit_comparisons": 1000
}
equijoin_op = EquijoinOperation(config)
left_data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
right_data = [{"user_id": 1, "age": 30}, {"user_id": 2, "age": 25}]
results, cost = equijoin_op.execute(left_data, right_data)
print(f"Joined results: {results}")
print(f"Total cost: {cost}")

This method performs the following steps: 1. Initial blocking based on specified conditions (if any) 2. Embedding-based blocking (if threshold is provided) 3. LLM-based comparison for blocked pairs 4. Result aggregation and validation

The method also calculates and logs statistics such as comparisons saved by blocking and join selectivity.

Source code in docetl/operations/equijoin.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def execute(
    self, left_data: List[Dict], right_data: List[Dict]
) -> Tuple[List[Dict], float]:
    """
    Executes the equijoin operation on the provided datasets.

    Args:
        left_data (List[Dict]): The left dataset to join.
        right_data (List[Dict]): The right dataset to join.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the joined results and the total cost of the operation.

    Usage:
    ```python
    from docetl.operations import EquijoinOperation

    config = {
        "blocking_keys": {
            "left": ["id"],
            "right": ["user_id"]
        },
        "limits": {
            "left": 1,
            "right": 1
        },
        "comparison_prompt": "Compare {{left}} and {{right}} and determine if they match.",
        "blocking_threshold": 0.8,
        "blocking_conditions": ["left['id'] == right['user_id']"],
        "limit_comparisons": 1000
    }
    equijoin_op = EquijoinOperation(config)
    left_data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
    right_data = [{"user_id": 1, "age": 30}, {"user_id": 2, "age": 25}]
    results, cost = equijoin_op.execute(left_data, right_data)
    print(f"Joined results: {results}")
    print(f"Total cost: {cost}")
    ```

    This method performs the following steps:
    1. Initial blocking based on specified conditions (if any)
    2. Embedding-based blocking (if threshold is provided)
    3. LLM-based comparison for blocked pairs
    4. Result aggregation and validation

    The method also calculates and logs statistics such as comparisons saved by blocking and join selectivity.
    """

    blocking_keys = self.config.get("blocking_keys", {})
    left_keys = blocking_keys.get(
        "left", list(left_data[0].keys()) if left_data else []
    )
    right_keys = blocking_keys.get(
        "right", list(right_data[0].keys()) if right_data else []
    )
    limits = self.config.get(
        "limits", {"left": float("inf"), "right": float("inf")}
    )
    left_limit = limits["left"]
    right_limit = limits["right"]
    blocking_threshold = self.config.get("blocking_threshold")
    blocking_conditions = self.config.get("blocking_conditions", [])
    limit_comparisons = self.config.get("limit_comparisons")
    total_cost = 0

    # LLM-based comparison for blocked pairs
    def get_hashable_key(item: Dict) -> str:
        return json.dumps(item, sort_keys=True)

    if len(left_data) == 0 or len(right_data) == 0:
        return [], 0

    # Initial blocking using multiprocessing
    num_processes = min(cpu_count(), len(left_data))

    self.console.log(
        f"Starting to run code-based blocking rules for {len(left_data)} left and {len(right_data)} right rows ({len(left_data) * len(right_data)} total pairs) with {num_processes} processes..."
    )

    with Pool(
        processes=num_processes,
        initializer=init_worker,
        initargs=(right_data, blocking_conditions),
    ) as pool:
        blocked_pairs_nested = pool.map(process_left_item, left_data)

    # Flatten the nested list of blocked pairs
    blocked_pairs = [pair for sublist in blocked_pairs_nested for pair in sublist]

    # Check if we have exceeded the pairwise comparison limit
    if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
        # Sample pairs randomly
        sampled_pairs = random.sample(blocked_pairs, limit_comparisons)

        # Calculate number of dropped pairs
        dropped_pairs = len(blocked_pairs) - limit_comparisons

        # Prompt the user for confirmation
        if self.status:
            self.status.stop()
        if not Confirm.ask(
            f"[yellow]Warning: {dropped_pairs} pairs will be dropped due to the comparison limit. "
            f"Proceeding with {limit_comparisons} randomly sampled pairs. "
            f"Do you want to continue?[/yellow]",
        ):
            raise ValueError("Operation cancelled by user due to pair limit.")

        if self.status:
            self.status.start()

        blocked_pairs = sampled_pairs

    self.console.log(
        f"Number of blocked pairs after initial blocking: {len(blocked_pairs)}"
    )

    if blocking_threshold is not None:
        embedding_model = self.config.get("embedding_model", self.default_model)
        model_input_context_length = model_cost.get(embedding_model, {}).get(
            "max_input_tokens", 8192
        )

        def get_embeddings(
            input_data: List[Dict[str, Any]], keys: List[str], name: str
        ) -> Tuple[List[List[float]], float]:
            texts = [
                " ".join(str(item[key]) for key in keys if key in item)[
                    : model_input_context_length * 4
                ]
                for item in input_data
            ]

            embeddings = []
            total_cost = 0
            batch_size = 2000
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                self.console.log(
                    f"On iteration {i} for creating embeddings for {name} data"
                )
                response = gen_embedding(
                    model=embedding_model,
                    input=batch,
                )
                embeddings.extend([data["embedding"] for data in response["data"]])
                total_cost += completion_cost(response)
            return embeddings, total_cost

        left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
        right_embeddings, right_cost = get_embeddings(
            right_data, right_keys, "right"
        )
        total_cost += left_cost + right_cost
        self.console.log(
            f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
        )

        # Compute all cosine similarities in one call
        similarities = cosine_similarity(left_embeddings, right_embeddings)

        # Additional blocking based on embeddings
        # Find indices where similarity is above threshold
        above_threshold = np.argwhere(similarities >= blocking_threshold)
        self.console.log(
            f"There are {above_threshold.shape[0]} pairs above the threshold."
        )
        block_pair_set = set(
            (get_hashable_key(left_item), get_hashable_key(right_item))
            for left_item, right_item in blocked_pairs
        )

        # If limit_comparisons is set, take only the top pairs
        if limit_comparisons is not None:
            # First, get all pairs above threshold
            above_threshold_pairs = [(int(i), int(j)) for i, j in above_threshold]

            # Sort these pairs by their similarity scores
            sorted_pairs = sorted(
                above_threshold_pairs,
                key=lambda pair: similarities[pair[0], pair[1]],
                reverse=True,
            )

            # Take the top 'limit_comparisons' pairs
            top_pairs = sorted_pairs[:limit_comparisons]

            # Create new blocked_pairs based on top similarities and existing blocked pairs
            new_blocked_pairs = []
            remaining_limit = limit_comparisons - len(blocked_pairs)

            # First, include all existing blocked pairs
            final_blocked_pairs = blocked_pairs.copy()

            # Then, add new pairs from top similarities until we reach the limit
            for i, j in top_pairs:
                if remaining_limit <= 0:
                    break
                left_item, right_item = left_data[i], right_data[j]
                left_key = get_hashable_key(left_item)
                right_key = get_hashable_key(right_item)
                if (left_key, right_key) not in block_pair_set:
                    new_blocked_pairs.append((left_item, right_item))
                    block_pair_set.add((left_key, right_key))
                    remaining_limit -= 1

            final_blocked_pairs.extend(new_blocked_pairs)
            blocked_pairs = final_blocked_pairs

            self.console.log(
                f"Limited comparisons to top {limit_comparisons} pairs, including {len(blocked_pairs) - len(new_blocked_pairs)} from code-based blocking and {len(new_blocked_pairs)} based on cosine similarity. Lowest cosine similarity included: {similarities[top_pairs[-1]]:.4f}"
            )
        else:
            # Add new pairs to blocked_pairs
            for i, j in above_threshold:
                left_item, right_item = left_data[i], right_data[j]
                left_key = get_hashable_key(left_item)
                right_key = get_hashable_key(right_item)
                if (left_key, right_key) not in block_pair_set:
                    blocked_pairs.append((left_item, right_item))
                    block_pair_set.add((left_key, right_key))

    # If there are no blocking conditions or embedding threshold, use all pairs
    if not blocking_conditions and blocking_threshold is None:
        blocked_pairs = [
            (left_item, right_item)
            for left_item in left_data
            for right_item in right_data
        ]

    # If there's a limit on the number of comparisons, randomly sample pairs
    if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
        self.console.log(
            f"Randomly sampling {limit_comparisons} pairs out of {len(blocked_pairs)} blocked pairs."
        )
        blocked_pairs = random.sample(blocked_pairs, limit_comparisons)

    self.console.log(
        f"Total pairs to compare after blocking and sampling: {len(blocked_pairs)}"
    )

    # Calculate and print statistics
    total_possible_comparisons = len(left_data) * len(right_data)
    comparisons_made = len(blocked_pairs)
    comparisons_saved = total_possible_comparisons - comparisons_made
    self.console.log(
        f"[green]Comparisons saved by blocking: {comparisons_saved} "
        f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
    )

    left_match_counts = defaultdict(int)
    right_match_counts = defaultdict(int)
    results = []
    comparison_costs = 0

    with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
        future_to_pair = {
            executor.submit(
                compare_pair,
                self.config["comparison_prompt"],
                self.config.get("comparison_model", self.default_model),
                left,
                right,
            ): (left, right)
            for left, right in blocked_pairs
        }

        for future in rich_as_completed(
            future_to_pair,
            total=len(future_to_pair),
            desc="Comparing pairs",
            console=self.console,
        ):
            pair = future_to_pair[future]
            is_match, cost = future.result()
            comparison_costs += cost

            if is_match:
                joined_item = {}
                left_item, right_item = pair
                left_key_hash = get_hashable_key(left_item)
                right_key_hash = get_hashable_key(right_item)
                if (
                    left_match_counts[left_key_hash] >= left_limit
                    or right_match_counts[right_key_hash] >= right_limit
                ):
                    continue

                for key, value in left_item.items():
                    joined_item[f"{key}_left" if key in right_item else key] = value
                for key, value in right_item.items():
                    joined_item[f"{key}_right" if key in left_item else key] = value
                if validate_output(self.config, joined_item, self.console):
                    results.append(joined_item)
                    left_match_counts[left_key_hash] += 1
                    right_match_counts[right_key_hash] += 1

                # TODO: support retry in validation failure

    total_cost += comparison_costs

    # Calculate and print the join selectivity
    join_selectivity = (
        len(results) / (len(left_data) * len(right_data))
        if len(left_data) * len(right_data) > 0
        else 0
    )
    self.console.log(f"Equijoin selectivity: {join_selectivity:.4f}")

    return results, total_cost

syntax_check()

Checks the configuration of the EquijoinOperation for required keys and valid structure.

Raises:

Type Description
ValueError

If required keys are missing or if the blocking_keys structure is invalid.

Specifically
Source code in docetl/operations/equijoin.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def syntax_check(self) -> None:
    """
    Checks the configuration of the EquijoinOperation for required keys and valid structure.

    Raises:
        ValueError: If required keys are missing or if the blocking_keys structure is invalid.
        Specifically:
        - Raises if 'comparison_prompt' is missing from the config.
        - Raises if 'left' or 'right' are missing from the 'blocking_keys' structure (if present).
        - Raises if 'left' or 'right' are missing from the 'limits' structure (if present).
    """
    if "comparison_prompt" not in self.config:
        raise ValueError(
            "Missing required key 'comparison_prompt' in EquijoinOperation configuration"
        )

    if "blocking_keys" in self.config:
        if (
            "left" not in self.config["blocking_keys"]
            or "right" not in self.config["blocking_keys"]
        ):
            raise ValueError(
                "Both 'left' and 'right' must be specified in 'blocking_keys'"
            )

    if "limits" in self.config:
        if (
            "left" not in self.config["limits"]
            or "right" not in self.config["limits"]
        ):
            raise ValueError(
                "Both 'left' and 'right' must be specified in 'limits'"
            )

    if "limit_comparisons" in self.config:
        if not isinstance(self.config["limit_comparisons"], int):
            raise ValueError("limit_comparisons must be an integer")

Auxiliary Operators

docetl.operations.split.SplitOperation

Bases: BaseOperation

A class that implements a split operation on input data, dividing it into manageable chunks.

This class extends BaseOperation to: 1. Split input data into chunks of specified size based on the 'split_key' and 'token_count' configuration. 2. Assign unique identifiers to each original document and number chunks sequentially. 3. Return results containing: - {split_key}_chunk: The content of the split chunk. - {name}_id: A unique identifier for each original document. - {name}_chunk_num: The sequential number of the chunk within its original document.

Source code in docetl/operations/split.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class SplitOperation(BaseOperation):
    """
    A class that implements a split operation on input data, dividing it into manageable chunks.

    This class extends BaseOperation to:
    1. Split input data into chunks of specified size based on the 'split_key' and 'token_count' configuration.
    2. Assign unique identifiers to each original document and number chunks sequentially.
    3. Return results containing:
       - {split_key}_chunk: The content of the split chunk.
       - {name}_id: A unique identifier for each original document.
       - {name}_chunk_num: The sequential number of the chunk within its original document.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.name = self.config["name"]

    def syntax_check(self) -> None:
        required_keys = ["split_key", "method", "method_kwargs"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in SplitOperation configuration"
                )

        if not isinstance(self.config["split_key"], str):
            raise TypeError("'split_key' must be a string")

        if self.config["method"] not in ["token_count", "delimiter"]:
            raise ValueError(f"Invalid method '{self.config['method']}'")

        if self.config["method"] == "token_count":
            if (
                not isinstance(self.config["method_kwargs"]["num_tokens"], int)
                or self.config["method_kwargs"]["num_tokens"] <= 0
            ):
                raise ValueError("'num_tokens' must be a positive integer")
        elif self.config["method"] == "delimiter":
            if not isinstance(self.config["method_kwargs"]["delimiter"], str):
                raise ValueError("'delimiter' must be a string")

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        split_key = self.config["split_key"]
        method = self.config["method"]
        method_kwargs = self.config["method_kwargs"]
        encoder = tiktoken.encoding_for_model(
            self.config["method_kwargs"].get("model", self.default_model)
        )
        results = []
        cost = 0.0

        for item in input_data:
            if split_key not in item:
                raise KeyError(f"Split key '{split_key}' not found in item")

            content = item[split_key]
            doc_id = str(uuid.uuid4())

            if method == "token_count":
                token_count = method_kwargs["num_tokens"]
                tokens = encoder.encode(content)

                for chunk_num, i in enumerate(
                    range(0, len(tokens), token_count), start=1
                ):
                    chunk_tokens = tokens[i : i + token_count]
                    chunk = encoder.decode(chunk_tokens)

                    result = item.copy()
                    result.update(
                        {
                            f"{split_key}_chunk": chunk,
                            f"{self.name}_id": doc_id,
                            f"{self.name}_chunk_num": chunk_num,
                        }
                    )
                    results.append(result)

            elif method == "delimiter":
                delimiter = method_kwargs["delimiter"]
                num_splits_to_group = method_kwargs.get("num_splits_to_group", 1)
                chunks = content.split(delimiter)

                # Get rid of empty chunks
                chunks = [chunk for chunk in chunks if chunk.strip()]

                for chunk_num, i in enumerate(
                    range(0, len(chunks), num_splits_to_group), start=1
                ):
                    grouped_chunks = chunks[i : i + num_splits_to_group]
                    joined_chunk = delimiter.join(grouped_chunks).strip()

                    result = item.copy()
                    result.update(
                        {
                            f"{split_key}_chunk": joined_chunk,
                            f"{self.name}_id": doc_id,
                            f"{self.name}_chunk_num": chunk_num,
                        }
                    )
                    results.append(result)

        return results, cost

docetl.operations.gather.GatherOperation

Bases: BaseOperation

A class that implements a gather operation on input data, adding contextual information from surrounding chunks.

This class extends BaseOperation to: 1. Group chunks by their document ID. 2. Order chunks within each group. 3. Add peripheral context to each chunk based on the configuration. 4. Include headers for each chunk and its upward hierarchy. 5. Return results containing the rendered chunks with added context, including information about skipped characters and headers.

Source code in docetl/operations/gather.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class GatherOperation(BaseOperation):
    """
    A class that implements a gather operation on input data, adding contextual information from surrounding chunks.

    This class extends BaseOperation to:
    1. Group chunks by their document ID.
    2. Order chunks within each group.
    3. Add peripheral context to each chunk based on the configuration.
    4. Include headers for each chunk and its upward hierarchy.
    5. Return results containing the rendered chunks with added context, including information about skipped characters and headers.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """
        Initialize the GatherOperation.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)

    def syntax_check(self) -> None:
        """
        Perform a syntax check on the operation configuration.

        Raises:
            ValueError: If required keys are missing or if there are configuration errors.
            TypeError: If main_chunk_start or main_chunk_end are not strings.
        """
        required_keys = ["content_key", "doc_id_key", "order_key"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in GatherOperation configuration"
                )

        if "peripheral_chunks" not in self.config:
            raise ValueError(
                "Missing 'peripheral_chunks' configuration in GatherOperation"
            )

        peripheral_config = self.config["peripheral_chunks"]
        for direction in ["previous", "next"]:
            if direction not in peripheral_config:
                continue
            for section in ["head", "middle", "tail"]:
                if section in peripheral_config[direction]:
                    section_config = peripheral_config[direction][section]
                    if section != "middle" and "count" not in section_config:
                        raise ValueError(
                            f"Missing 'count' in {direction}.{section} configuration"
                        )

        if "main_chunk_start" in self.config and not isinstance(
            self.config["main_chunk_start"], str
        ):
            raise TypeError("'main_chunk_start' must be a string")
        if "main_chunk_end" in self.config and not isinstance(
            self.config["main_chunk_end"], str
        ):
            raise TypeError("'main_chunk_end' must be a string")

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        """
        Execute the gather operation on the input data.

        Args:
            input_data (List[Dict]): The input data to process.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the processed results and the cost of the operation.
        """
        content_key = self.config["content_key"]
        doc_id_key = self.config["doc_id_key"]
        order_key = self.config["order_key"]
        peripheral_config = self.config["peripheral_chunks"]
        main_chunk_start = self.config.get(
            "main_chunk_start", "--- Begin Main Chunk ---"
        )
        main_chunk_end = self.config.get("main_chunk_end", "--- End Main Chunk ---")
        doc_header_key = self.config.get("doc_header_key", None)
        results = []
        cost = 0.0

        # Group chunks by document ID
        grouped_chunks = {}
        for item in input_data:
            doc_id = item[doc_id_key]
            if doc_id not in grouped_chunks:
                grouped_chunks[doc_id] = []
            grouped_chunks[doc_id].append(item)

        # Process each group of chunks
        for doc_id, chunks in grouped_chunks.items():
            # Sort chunks by their order within the document
            chunks.sort(key=lambda x: x[order_key])

            # Process each chunk with its peripheral context and headers
            for i, chunk in enumerate(chunks):
                rendered_chunk = self.render_chunk_with_context(
                    chunks,
                    i,
                    peripheral_config,
                    content_key,
                    order_key,
                    main_chunk_start,
                    main_chunk_end,
                    doc_header_key,
                )

                result = chunk.copy()
                result[f"{content_key}_rendered"] = rendered_chunk
                results.append(result)

        return results, cost

    def render_chunk_with_context(
        self,
        chunks: List[Dict],
        current_index: int,
        peripheral_config: Dict,
        content_key: str,
        order_key: str,
        main_chunk_start: str,
        main_chunk_end: str,
        doc_header_key: str,
    ) -> str:
        """
        Render a chunk with its peripheral context and headers.

        Args:
            chunks (List[Dict]): List of all chunks in the document.
            current_index (int): Index of the current chunk being processed.
            peripheral_config (Dict): Configuration for peripheral chunks.
            content_key (str): Key for the content in each chunk.
            order_key (str): Key for the order of each chunk.
            main_chunk_start (str): String to mark the start of the main chunk.
            main_chunk_end (str): String to mark the end of the main chunk.
            doc_header_key (str): The key for the headers in the current chunk.

        Returns:
            str: Renderted chunk with context and headers.
        """
        combined_parts = []

        # Process previous chunks
        combined_parts.append("--- Previous Context ---")
        combined_parts.extend(
            self.process_peripheral_chunks(
                chunks[:current_index],
                peripheral_config.get("previous", {}),
                content_key,
                order_key,
            )
        )
        combined_parts.append("--- End Previous Context ---\n")

        # Process main chunk
        main_chunk = chunks[current_index]
        headers = self.render_hierarchy_headers(
            main_chunk, chunks[: current_index + 1], doc_header_key
        )
        if headers:
            combined_parts.append(headers)
        combined_parts.append(f"{main_chunk_start}")
        combined_parts.append(f"{main_chunk[content_key]}")
        combined_parts.append(f"{main_chunk_end}")

        # Process next chunks
        combined_parts.append("\n--- Next Context ---")
        combined_parts.extend(
            self.process_peripheral_chunks(
                chunks[current_index + 1 :],
                peripheral_config.get("next", {}),
                content_key,
                order_key,
            )
        )
        combined_parts.append("--- End Next Context ---")

        return "\n".join(combined_parts)

    def process_peripheral_chunks(
        self,
        chunks: List[Dict],
        config: Dict,
        content_key: str,
        order_key: str,
        reverse: bool = False,
    ) -> List[str]:
        """
        Process peripheral chunks according to the configuration.

        Args:
            chunks (List[Dict]): List of chunks to process.
            config (Dict): Configuration for processing peripheral chunks.
            content_key (str): Key for the content in each chunk.
            order_key (str): Key for the order of each chunk.
            reverse (bool, optional): Whether to process chunks in reverse order. Defaults to False.

        Returns:
            List[str]: List of processed chunk strings.
        """
        if reverse:
            chunks = list(reversed(chunks))

        processed_parts = []
        included_chunks = []
        total_chunks = len(chunks)

        head_config = config.get("head", {})
        tail_config = config.get("tail", {})

        head_count = int(head_config.get("count", 0))
        tail_count = int(tail_config.get("count", 0))
        in_skip = False
        skip_char_count = 0

        for i, chunk in enumerate(chunks):
            if i < head_count:
                section = "head"
            elif i >= total_chunks - tail_count:
                section = "tail"
            elif "middle" in config:
                section = "middle"
            else:
                # Show number of characters skipped
                skipped_chars = len(chunk[content_key])
                if not in_skip:
                    skip_char_count = skipped_chars
                    in_skip = True
                else:
                    skip_char_count += skipped_chars

                continue

            if in_skip:
                processed_parts.append(
                    f"[... {skip_char_count} characters skipped ...]"
                )
                in_skip = False
                skip_char_count = 0

            section_config = config.get(section, {})
            section_content_key = section_config.get("content_key", content_key)

            is_summary = section_content_key != content_key
            summary_suffix = " (Summary)" if is_summary else ""

            chunk_prefix = f"[Chunk {chunk[order_key]}{summary_suffix}]"
            processed_parts.append(chunk_prefix)
            processed_parts.append(f"{chunk[section_content_key]}")
            included_chunks.append(chunk)

        if in_skip:
            processed_parts.append(f"[... {skip_char_count} characters skipped ...]")

        if reverse:
            processed_parts = list(reversed(processed_parts))

        return processed_parts

    def render_hierarchy_headers(
        self,
        current_chunk: Dict,
        chunks: List[Dict],
        doc_header_key: str,
    ) -> str:
        """
        Render headers for the current chunk's hierarchy.

        Args:
            current_chunk (Dict): The current chunk being processed.
            chunks (List[Dict]): List of chunks up to and including the current chunk.
            doc_header_key (str): The key for the headers in the current chunk.
        Returns:
            str: Renderted headers in the current chunk's hierarchy.
        """
        rendered_headers = []
        current_hierarchy = {}

        if doc_header_key is None:
            return ""

        # Find the largest/highest level in the current chunk
        current_chunk_headers = current_chunk.get(doc_header_key, [])
        highest_level = float("inf")  # Initialize with positive infinity
        for header_info in current_chunk_headers:
            level = header_info.get("level")
            if level is not None and level < highest_level:
                highest_level = level

        # If no headers found in the current chunk, set highest_level to None
        if highest_level == float("inf"):
            highest_level = None

        for chunk in chunks:
            for header_info in chunk.get(doc_header_key, []):
                header = header_info["header"]
                level = header_info["level"]
                if header and level:
                    current_hierarchy[level] = header
                    # Clear lower levels when a higher level header is found
                    for lower_level in range(level + 1, len(current_hierarchy) + 1):
                        if lower_level in current_hierarchy:
                            current_hierarchy[lower_level] = None

        # Render the headers in the current hierarchy, everything above the highest level in the current chunk (if the highest level in the current chunk is None, render everything)
        for level, header in sorted(current_hierarchy.items()):
            if header is not None and (highest_level is None or level < highest_level):
                rendered_headers.append(f"{'#' * level} {header}")

        rendered_headers = " > ".join(rendered_headers)
        return f"_Current Section:_ {rendered_headers}" if rendered_headers else ""

__init__(*args, **kwargs)

Initialize the GatherOperation.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in docetl/operations/gather.py
18
19
20
21
22
23
24
25
26
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """
    Initialize the GatherOperation.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)

execute(input_data)

Execute the gather operation on the input data.

Parameters:

Name Type Description Default
input_data List[Dict]

The input data to process.

required

Returns:

Type Description
Tuple[List[Dict], float]

Tuple[List[Dict], float]: A tuple containing the processed results and the cost of the operation.

Source code in docetl/operations/gather.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
    """
    Execute the gather operation on the input data.

    Args:
        input_data (List[Dict]): The input data to process.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the processed results and the cost of the operation.
    """
    content_key = self.config["content_key"]
    doc_id_key = self.config["doc_id_key"]
    order_key = self.config["order_key"]
    peripheral_config = self.config["peripheral_chunks"]
    main_chunk_start = self.config.get(
        "main_chunk_start", "--- Begin Main Chunk ---"
    )
    main_chunk_end = self.config.get("main_chunk_end", "--- End Main Chunk ---")
    doc_header_key = self.config.get("doc_header_key", None)
    results = []
    cost = 0.0

    # Group chunks by document ID
    grouped_chunks = {}
    for item in input_data:
        doc_id = item[doc_id_key]
        if doc_id not in grouped_chunks:
            grouped_chunks[doc_id] = []
        grouped_chunks[doc_id].append(item)

    # Process each group of chunks
    for doc_id, chunks in grouped_chunks.items():
        # Sort chunks by their order within the document
        chunks.sort(key=lambda x: x[order_key])

        # Process each chunk with its peripheral context and headers
        for i, chunk in enumerate(chunks):
            rendered_chunk = self.render_chunk_with_context(
                chunks,
                i,
                peripheral_config,
                content_key,
                order_key,
                main_chunk_start,
                main_chunk_end,
                doc_header_key,
            )

            result = chunk.copy()
            result[f"{content_key}_rendered"] = rendered_chunk
            results.append(result)

    return results, cost

process_peripheral_chunks(chunks, config, content_key, order_key, reverse=False)

Process peripheral chunks according to the configuration.

Parameters:

Name Type Description Default
chunks List[Dict]

List of chunks to process.

required
config Dict

Configuration for processing peripheral chunks.

required
content_key str

Key for the content in each chunk.

required
order_key str

Key for the order of each chunk.

required
reverse bool

Whether to process chunks in reverse order. Defaults to False.

False

Returns:

Type Description
List[str]

List[str]: List of processed chunk strings.

Source code in docetl/operations/gather.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def process_peripheral_chunks(
    self,
    chunks: List[Dict],
    config: Dict,
    content_key: str,
    order_key: str,
    reverse: bool = False,
) -> List[str]:
    """
    Process peripheral chunks according to the configuration.

    Args:
        chunks (List[Dict]): List of chunks to process.
        config (Dict): Configuration for processing peripheral chunks.
        content_key (str): Key for the content in each chunk.
        order_key (str): Key for the order of each chunk.
        reverse (bool, optional): Whether to process chunks in reverse order. Defaults to False.

    Returns:
        List[str]: List of processed chunk strings.
    """
    if reverse:
        chunks = list(reversed(chunks))

    processed_parts = []
    included_chunks = []
    total_chunks = len(chunks)

    head_config = config.get("head", {})
    tail_config = config.get("tail", {})

    head_count = int(head_config.get("count", 0))
    tail_count = int(tail_config.get("count", 0))
    in_skip = False
    skip_char_count = 0

    for i, chunk in enumerate(chunks):
        if i < head_count:
            section = "head"
        elif i >= total_chunks - tail_count:
            section = "tail"
        elif "middle" in config:
            section = "middle"
        else:
            # Show number of characters skipped
            skipped_chars = len(chunk[content_key])
            if not in_skip:
                skip_char_count = skipped_chars
                in_skip = True
            else:
                skip_char_count += skipped_chars

            continue

        if in_skip:
            processed_parts.append(
                f"[... {skip_char_count} characters skipped ...]"
            )
            in_skip = False
            skip_char_count = 0

        section_config = config.get(section, {})
        section_content_key = section_config.get("content_key", content_key)

        is_summary = section_content_key != content_key
        summary_suffix = " (Summary)" if is_summary else ""

        chunk_prefix = f"[Chunk {chunk[order_key]}{summary_suffix}]"
        processed_parts.append(chunk_prefix)
        processed_parts.append(f"{chunk[section_content_key]}")
        included_chunks.append(chunk)

    if in_skip:
        processed_parts.append(f"[... {skip_char_count} characters skipped ...]")

    if reverse:
        processed_parts = list(reversed(processed_parts))

    return processed_parts

render_chunk_with_context(chunks, current_index, peripheral_config, content_key, order_key, main_chunk_start, main_chunk_end, doc_header_key)

Render a chunk with its peripheral context and headers.

Parameters:

Name Type Description Default
chunks List[Dict]

List of all chunks in the document.

required
current_index int

Index of the current chunk being processed.

required
peripheral_config Dict

Configuration for peripheral chunks.

required
content_key str

Key for the content in each chunk.

required
order_key str

Key for the order of each chunk.

required
main_chunk_start str

String to mark the start of the main chunk.

required
main_chunk_end str

String to mark the end of the main chunk.

required
doc_header_key str

The key for the headers in the current chunk.

required

Returns:

Name Type Description
str str

Renderted chunk with context and headers.

Source code in docetl/operations/gather.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def render_chunk_with_context(
    self,
    chunks: List[Dict],
    current_index: int,
    peripheral_config: Dict,
    content_key: str,
    order_key: str,
    main_chunk_start: str,
    main_chunk_end: str,
    doc_header_key: str,
) -> str:
    """
    Render a chunk with its peripheral context and headers.

    Args:
        chunks (List[Dict]): List of all chunks in the document.
        current_index (int): Index of the current chunk being processed.
        peripheral_config (Dict): Configuration for peripheral chunks.
        content_key (str): Key for the content in each chunk.
        order_key (str): Key for the order of each chunk.
        main_chunk_start (str): String to mark the start of the main chunk.
        main_chunk_end (str): String to mark the end of the main chunk.
        doc_header_key (str): The key for the headers in the current chunk.

    Returns:
        str: Renderted chunk with context and headers.
    """
    combined_parts = []

    # Process previous chunks
    combined_parts.append("--- Previous Context ---")
    combined_parts.extend(
        self.process_peripheral_chunks(
            chunks[:current_index],
            peripheral_config.get("previous", {}),
            content_key,
            order_key,
        )
    )
    combined_parts.append("--- End Previous Context ---\n")

    # Process main chunk
    main_chunk = chunks[current_index]
    headers = self.render_hierarchy_headers(
        main_chunk, chunks[: current_index + 1], doc_header_key
    )
    if headers:
        combined_parts.append(headers)
    combined_parts.append(f"{main_chunk_start}")
    combined_parts.append(f"{main_chunk[content_key]}")
    combined_parts.append(f"{main_chunk_end}")

    # Process next chunks
    combined_parts.append("\n--- Next Context ---")
    combined_parts.extend(
        self.process_peripheral_chunks(
            chunks[current_index + 1 :],
            peripheral_config.get("next", {}),
            content_key,
            order_key,
        )
    )
    combined_parts.append("--- End Next Context ---")

    return "\n".join(combined_parts)

render_hierarchy_headers(current_chunk, chunks, doc_header_key)

Render headers for the current chunk's hierarchy.

Parameters:

Name Type Description Default
current_chunk Dict

The current chunk being processed.

required
chunks List[Dict]

List of chunks up to and including the current chunk.

required
doc_header_key str

The key for the headers in the current chunk.

required

Returns: str: Renderted headers in the current chunk's hierarchy.

Source code in docetl/operations/gather.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def render_hierarchy_headers(
    self,
    current_chunk: Dict,
    chunks: List[Dict],
    doc_header_key: str,
) -> str:
    """
    Render headers for the current chunk's hierarchy.

    Args:
        current_chunk (Dict): The current chunk being processed.
        chunks (List[Dict]): List of chunks up to and including the current chunk.
        doc_header_key (str): The key for the headers in the current chunk.
    Returns:
        str: Renderted headers in the current chunk's hierarchy.
    """
    rendered_headers = []
    current_hierarchy = {}

    if doc_header_key is None:
        return ""

    # Find the largest/highest level in the current chunk
    current_chunk_headers = current_chunk.get(doc_header_key, [])
    highest_level = float("inf")  # Initialize with positive infinity
    for header_info in current_chunk_headers:
        level = header_info.get("level")
        if level is not None and level < highest_level:
            highest_level = level

    # If no headers found in the current chunk, set highest_level to None
    if highest_level == float("inf"):
        highest_level = None

    for chunk in chunks:
        for header_info in chunk.get(doc_header_key, []):
            header = header_info["header"]
            level = header_info["level"]
            if header and level:
                current_hierarchy[level] = header
                # Clear lower levels when a higher level header is found
                for lower_level in range(level + 1, len(current_hierarchy) + 1):
                    if lower_level in current_hierarchy:
                        current_hierarchy[lower_level] = None

    # Render the headers in the current hierarchy, everything above the highest level in the current chunk (if the highest level in the current chunk is None, render everything)
    for level, header in sorted(current_hierarchy.items()):
        if header is not None and (highest_level is None or level < highest_level):
            rendered_headers.append(f"{'#' * level} {header}")

    rendered_headers = " > ".join(rendered_headers)
    return f"_Current Section:_ {rendered_headers}" if rendered_headers else ""

syntax_check()

Perform a syntax check on the operation configuration.

Raises:

Type Description
ValueError

If required keys are missing or if there are configuration errors.

TypeError

If main_chunk_start or main_chunk_end are not strings.

Source code in docetl/operations/gather.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def syntax_check(self) -> None:
    """
    Perform a syntax check on the operation configuration.

    Raises:
        ValueError: If required keys are missing or if there are configuration errors.
        TypeError: If main_chunk_start or main_chunk_end are not strings.
    """
    required_keys = ["content_key", "doc_id_key", "order_key"]
    for key in required_keys:
        if key not in self.config:
            raise ValueError(
                f"Missing required key '{key}' in GatherOperation configuration"
            )

    if "peripheral_chunks" not in self.config:
        raise ValueError(
            "Missing 'peripheral_chunks' configuration in GatherOperation"
        )

    peripheral_config = self.config["peripheral_chunks"]
    for direction in ["previous", "next"]:
        if direction not in peripheral_config:
            continue
        for section in ["head", "middle", "tail"]:
            if section in peripheral_config[direction]:
                section_config = peripheral_config[direction][section]
                if section != "middle" and "count" not in section_config:
                    raise ValueError(
                        f"Missing 'count' in {direction}.{section} configuration"
                    )

    if "main_chunk_start" in self.config and not isinstance(
        self.config["main_chunk_start"], str
    ):
        raise TypeError("'main_chunk_start' must be a string")
    if "main_chunk_end" in self.config and not isinstance(
        self.config["main_chunk_end"], str
    ):
        raise TypeError("'main_chunk_end' must be a string")

docetl.operations.unnest.UnnestOperation

Bases: BaseOperation

A class that represents an operation to unnest a list-like or dictionary value in a dictionary into multiple dictionaries.

This operation takes a list of dictionaries and a specified key, and creates new dictionaries based on the value type: - For list-like values: Creates a new dictionary for each element in the list, copying all other key-value pairs. - For dictionary values: Expands specified fields from the nested dictionary into the parent dictionary.

Inherits from

BaseOperation

Usage:

from docetl.operations import UnnestOperation

# Unnesting a list
config_list = {"unnest_key": "tags"}
input_data_list = [
    {"id": 1, "tags": ["a", "b", "c"]},
    {"id": 2, "tags": ["d", "e"]}
]

unnest_op_list = UnnestOperation(config_list)
result_list, _ = unnest_op_list.execute(input_data_list)

# Result will be:
# [
#     {"id": 1, "tags": "a"},
#     {"id": 1, "tags": "b"},
#     {"id": 1, "tags": "c"},
#     {"id": 2, "tags": "d"},
#     {"id": 2, "tags": "e"}
# ]

# Unnesting a dictionary
config_dict = {"unnest_key": "user", "expand_fields": ["name", "age"]}
input_data_dict = [
    {"id": 1, "user": {"name": "Alice", "age": 30, "email": "alice@example.com"}},
    {"id": 2, "user": {"name": "Bob", "age": 25, "email": "bob@example.com"}}
]

unnest_op_dict = UnnestOperation(config_dict)
result_dict, _ = unnest_op_dict.execute(input_data_dict)

# Result will be:
# [
#     {"id": 1, "name": "Alice", "age": 30, "user": {"name": "Alice", "age": 30, "email": "alice@example.com"}},
#     {"id": 2, "name": "Bob", "age": 25, "user": {"name": "Bob", "age": 25, "email": "bob@example.com"}}
# ]

Source code in docetl/operations/unnest.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class UnnestOperation(BaseOperation):
    """
    A class that represents an operation to unnest a list-like or dictionary value in a dictionary into multiple dictionaries.

    This operation takes a list of dictionaries and a specified key, and creates new dictionaries based on the value type:
    - For list-like values: Creates a new dictionary for each element in the list, copying all other key-value pairs.
    - For dictionary values: Expands specified fields from the nested dictionary into the parent dictionary.

    Inherits from:
        BaseOperation

    Usage:
    ```python
    from docetl.operations import UnnestOperation

    # Unnesting a list
    config_list = {"unnest_key": "tags"}
    input_data_list = [
        {"id": 1, "tags": ["a", "b", "c"]},
        {"id": 2, "tags": ["d", "e"]}
    ]

    unnest_op_list = UnnestOperation(config_list)
    result_list, _ = unnest_op_list.execute(input_data_list)

    # Result will be:
    # [
    #     {"id": 1, "tags": "a"},
    #     {"id": 1, "tags": "b"},
    #     {"id": 1, "tags": "c"},
    #     {"id": 2, "tags": "d"},
    #     {"id": 2, "tags": "e"}
    # ]

    # Unnesting a dictionary
    config_dict = {"unnest_key": "user", "expand_fields": ["name", "age"]}
    input_data_dict = [
        {"id": 1, "user": {"name": "Alice", "age": 30, "email": "alice@example.com"}},
        {"id": 2, "user": {"name": "Bob", "age": 25, "email": "bob@example.com"}}
    ]

    unnest_op_dict = UnnestOperation(config_dict)
    result_dict, _ = unnest_op_dict.execute(input_data_dict)

    # Result will be:
    # [
    #     {"id": 1, "name": "Alice", "age": 30, "user": {"name": "Alice", "age": 30, "email": "alice@example.com"}},
    #     {"id": 2, "name": "Bob", "age": 25, "user": {"name": "Bob", "age": 25, "email": "bob@example.com"}}
    # ]
    ```
    """

    def syntax_check(self) -> None:
        """
        Checks if the required configuration key is present in the operation's config.

        Raises:
            ValueError: If the required 'unnest_key' is missing from the configuration.
        """

        required_keys = ["unnest_key"]
        for key in required_keys:
            if key not in self.config:
                raise ValueError(
                    f"Missing required key '{key}' in UnnestOperation configuration"
                )

    def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
        """
        Executes the unnest operation on the input data.

        Args:
            input_data (List[Dict]): A list of dictionaries to process.

        Returns:
            Tuple[List[Dict], float]: A tuple containing the processed list of dictionaries
            and a float value (always 0 in this implementation).

        Raises:
            KeyError: If the specified unnest_key is not found in an input dictionary.
            TypeError: If the value of the unnest_key is not iterable (list, tuple, set, or dict).
            ValueError: If unnesting a dictionary and 'expand_fields' is not provided in the config.

        The operation supports unnesting of both list-like values and dictionary values:

        1. For list-like values (list, tuple, set):
           Each element in the list becomes a separate dictionary in the output.

        2. For dictionary values:
           The operation expands specified fields from the nested dictionary into the parent dictionary.
           The 'expand_fields' config parameter must be provided to specify which fields to expand.

        Examples:
        ```python
        # Unnesting a list
        unnest_op = UnnestOperation({"unnest_key": "colors"})
        input_data = [
            {"id": 1, "colors": ["red", "blue"]},
            {"id": 2, "colors": ["green"]}
        ]
        result, _ = unnest_op.execute(input_data)
        # Result will be:
        # [
        #     {"id": 1, "colors": "red"},
        #     {"id": 1, "colors": "blue"},
        #     {"id": 2, "colors": "green"}
        # ]

        # Unnesting a dictionary
        unnest_op = UnnestOperation({"unnest_key": "details", "expand_fields": ["color", "size"]})
        input_data = [
            {"id": 1, "details": {"color": "red", "size": "large", "stock": 5}},
            {"id": 2, "details": {"color": "blue", "size": "medium", "stock": 3}}
        ]
        result, _ = unnest_op.execute(input_data)
        # Result will be:
        # [
        #     {"id": 1, "details": {"color": "red", "size": "large", "stock": 5}, "color": "red", "size": "large"},
        #     {"id": 2, "details": {"color": "blue", "size": "medium", "stock": 3}, "color": "blue", "size": "medium"}
        # ]
        ```

        Note: When unnesting dictionaries, the original nested dictionary is preserved in the output,
        and the specified fields are expanded into the parent dictionary.
        """

        unnest_key = self.config["unnest_key"]
        recursive = self.config.get("recursive", False)
        depth = self.config.get("depth", None)
        if not depth:
            depth = 1 if not recursive else float("inf")
        results = []

        def unnest_recursive(item, key, level=0):
            if level == 0 and not isinstance(item[key], (list, tuple, set, dict)):
                raise TypeError(f"Value of unnest key '{key}' is not iterable")

            if level > 0 and not isinstance(item[key], (list, tuple, set, dict)):
                return [item]

            if level >= depth:
                return [item]

            if isinstance(item[key], dict):
                expand_fields = self.config.get("expand_fields")
                if expand_fields is None:
                    expand_fields = item[key].keys()
                new_item = copy.deepcopy(item)
                for field in expand_fields:
                    if field in new_item[key]:
                        new_item[field] = new_item[key][field]
                    else:
                        new_item[field] = None
                return [new_item]
            else:
                nested_results = []
                for value in item[key]:
                    new_item = copy.deepcopy(item)
                    new_item[key] = value
                    if recursive and isinstance(value, (list, tuple, set, dict)):
                        nested_results.extend(
                            unnest_recursive(new_item, key, level + 1)
                        )
                    else:
                        nested_results.append(new_item)
                return nested_results

        for item in input_data:
            if unnest_key not in item:
                raise KeyError(
                    f"Unnest key '{unnest_key}' not found in item. Other keys are {item.keys()}"
                )

            results.extend(unnest_recursive(item, unnest_key))

            if not item[unnest_key] and self.config.get("keep_empty", False):
                expand_fields = self.config.get("expand_fields")
                new_item = copy.deepcopy(item)
                if isinstance(item[unnest_key], dict):
                    if expand_fields is None:
                        expand_fields = item[unnest_key].keys()
                    for field in expand_fields:
                        new_item[field] = None
                else:
                    new_item[unnest_key] = None
                results.append(new_item)

        # Assert that no keys are missing after the operation
        if results:
            original_keys = set(input_data[0].keys())
            assert original_keys.issubset(
                set(results[0].keys())
            ), "Keys lost during unnest operation"

        return results, 0

execute(input_data)

Executes the unnest operation on the input data.

Parameters:

Name Type Description Default
input_data List[Dict]

A list of dictionaries to process.

required

Returns:

Type Description
List[Dict]

Tuple[List[Dict], float]: A tuple containing the processed list of dictionaries

float

and a float value (always 0 in this implementation).

Raises:

Type Description
KeyError

If the specified unnest_key is not found in an input dictionary.

TypeError

If the value of the unnest_key is not iterable (list, tuple, set, or dict).

ValueError

If unnesting a dictionary and 'expand_fields' is not provided in the config.

The operation supports unnesting of both list-like values and dictionary values:

  1. For list-like values (list, tuple, set): Each element in the list becomes a separate dictionary in the output.

  2. For dictionary values: The operation expands specified fields from the nested dictionary into the parent dictionary. The 'expand_fields' config parameter must be provided to specify which fields to expand.

Examples:

# Unnesting a list
unnest_op = UnnestOperation({"unnest_key": "colors"})
input_data = [
    {"id": 1, "colors": ["red", "blue"]},
    {"id": 2, "colors": ["green"]}
]
result, _ = unnest_op.execute(input_data)
# Result will be:
# [
#     {"id": 1, "colors": "red"},
#     {"id": 1, "colors": "blue"},
#     {"id": 2, "colors": "green"}
# ]

# Unnesting a dictionary
unnest_op = UnnestOperation({"unnest_key": "details", "expand_fields": ["color", "size"]})
input_data = [
    {"id": 1, "details": {"color": "red", "size": "large", "stock": 5}},
    {"id": 2, "details": {"color": "blue", "size": "medium", "stock": 3}}
]
result, _ = unnest_op.execute(input_data)
# Result will be:
# [
#     {"id": 1, "details": {"color": "red", "size": "large", "stock": 5}, "color": "red", "size": "large"},
#     {"id": 2, "details": {"color": "blue", "size": "medium", "stock": 3}, "color": "blue", "size": "medium"}
# ]

Note: When unnesting dictionaries, the original nested dictionary is preserved in the output, and the specified fields are expanded into the parent dictionary.

Source code in docetl/operations/unnest.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
    """
    Executes the unnest operation on the input data.

    Args:
        input_data (List[Dict]): A list of dictionaries to process.

    Returns:
        Tuple[List[Dict], float]: A tuple containing the processed list of dictionaries
        and a float value (always 0 in this implementation).

    Raises:
        KeyError: If the specified unnest_key is not found in an input dictionary.
        TypeError: If the value of the unnest_key is not iterable (list, tuple, set, or dict).
        ValueError: If unnesting a dictionary and 'expand_fields' is not provided in the config.

    The operation supports unnesting of both list-like values and dictionary values:

    1. For list-like values (list, tuple, set):
       Each element in the list becomes a separate dictionary in the output.

    2. For dictionary values:
       The operation expands specified fields from the nested dictionary into the parent dictionary.
       The 'expand_fields' config parameter must be provided to specify which fields to expand.

    Examples:
    ```python
    # Unnesting a list
    unnest_op = UnnestOperation({"unnest_key": "colors"})
    input_data = [
        {"id": 1, "colors": ["red", "blue"]},
        {"id": 2, "colors": ["green"]}
    ]
    result, _ = unnest_op.execute(input_data)
    # Result will be:
    # [
    #     {"id": 1, "colors": "red"},
    #     {"id": 1, "colors": "blue"},
    #     {"id": 2, "colors": "green"}
    # ]

    # Unnesting a dictionary
    unnest_op = UnnestOperation({"unnest_key": "details", "expand_fields": ["color", "size"]})
    input_data = [
        {"id": 1, "details": {"color": "red", "size": "large", "stock": 5}},
        {"id": 2, "details": {"color": "blue", "size": "medium", "stock": 3}}
    ]
    result, _ = unnest_op.execute(input_data)
    # Result will be:
    # [
    #     {"id": 1, "details": {"color": "red", "size": "large", "stock": 5}, "color": "red", "size": "large"},
    #     {"id": 2, "details": {"color": "blue", "size": "medium", "stock": 3}, "color": "blue", "size": "medium"}
    # ]
    ```

    Note: When unnesting dictionaries, the original nested dictionary is preserved in the output,
    and the specified fields are expanded into the parent dictionary.
    """

    unnest_key = self.config["unnest_key"]
    recursive = self.config.get("recursive", False)
    depth = self.config.get("depth", None)
    if not depth:
        depth = 1 if not recursive else float("inf")
    results = []

    def unnest_recursive(item, key, level=0):
        if level == 0 and not isinstance(item[key], (list, tuple, set, dict)):
            raise TypeError(f"Value of unnest key '{key}' is not iterable")

        if level > 0 and not isinstance(item[key], (list, tuple, set, dict)):
            return [item]

        if level >= depth:
            return [item]

        if isinstance(item[key], dict):
            expand_fields = self.config.get("expand_fields")
            if expand_fields is None:
                expand_fields = item[key].keys()
            new_item = copy.deepcopy(item)
            for field in expand_fields:
                if field in new_item[key]:
                    new_item[field] = new_item[key][field]
                else:
                    new_item[field] = None
            return [new_item]
        else:
            nested_results = []
            for value in item[key]:
                new_item = copy.deepcopy(item)
                new_item[key] = value
                if recursive and isinstance(value, (list, tuple, set, dict)):
                    nested_results.extend(
                        unnest_recursive(new_item, key, level + 1)
                    )
                else:
                    nested_results.append(new_item)
            return nested_results

    for item in input_data:
        if unnest_key not in item:
            raise KeyError(
                f"Unnest key '{unnest_key}' not found in item. Other keys are {item.keys()}"
            )

        results.extend(unnest_recursive(item, unnest_key))

        if not item[unnest_key] and self.config.get("keep_empty", False):
            expand_fields = self.config.get("expand_fields")
            new_item = copy.deepcopy(item)
            if isinstance(item[unnest_key], dict):
                if expand_fields is None:
                    expand_fields = item[unnest_key].keys()
                for field in expand_fields:
                    new_item[field] = None
            else:
                new_item[unnest_key] = None
            results.append(new_item)

    # Assert that no keys are missing after the operation
    if results:
        original_keys = set(input_data[0].keys())
        assert original_keys.issubset(
            set(results[0].keys())
        ), "Keys lost during unnest operation"

    return results, 0

syntax_check()

Checks if the required configuration key is present in the operation's config.

Raises:

Type Description
ValueError

If the required 'unnest_key' is missing from the configuration.

Source code in docetl/operations/unnest.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def syntax_check(self) -> None:
    """
    Checks if the required configuration key is present in the operation's config.

    Raises:
        ValueError: If the required 'unnest_key' is missing from the configuration.
    """

    required_keys = ["unnest_key"]
    for key in required_keys:
        if key not in self.config:
            raise ValueError(
                f"Missing required key '{key}' in UnnestOperation configuration"
            )