[FLINK-36525][transform] Support for AI Model Integration for Data Processing ()

pull/3360/head^2
Kunni committed by GitHub
parent c969957f02
commit 06154e9674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -356,6 +356,77 @@ transform:
filter: inc(id) < 100
```
## Embedding AI Model
Embedding AI Model can be used in transform rules.
To use Embedding AI Model, you need to download the jar of build-in model, and then add `--jar {$BUILT_IN_MODEL_PATH}` to your flink-cdc.sh command.
How to define a Embedding AI Model:
```yaml
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apiKky: abcd1234
```
Note:
* `model-name` is a common required parameter for all support models, which represent the function name called in `projection` or `filter`.
* `class-name` is a common required parameter for all support models, available values can be found in [All Support models](#all-support-models).
* `openai.model`, `openai.host`, `openai.apiKey` and `openai.chat.prompt` is option parameters that defined in specific model.
How to use a Embedding AI Model:
```yaml
transform:
- source-table: db.\.*
projection: "*, inc(inc(inc(id))) as inc_id, GET_EMBEDDING(page) as emb, CHAT(page) as summary"
filter: inc(id) < 100
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: gpt-4o-mini
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
```
Here, GET_EMBEDDING is defined though `model-name` in `pipeline`.
### All Support models
The following built-in models are provided:
#### OpenAIChatModel
| parameter | type | optional/required | meaning |
|--------------------|--------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "gpt-4o-mini", Available options are "gpt-4o-mini", "gpt-4o", "gpt-4-32k", "gpt-3.5-turbo". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
| openai.chat.prompt | STRING | optional | Prompt for chatting with OpenAI, for example: "Please summary this ". |
#### OpenAIEmbeddingModel
| parameter | type | optional/required | meaning |
|---------------|--------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "text-embedding-3-small", Available options are "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
# Known limitations
* Currently, transform doesn't work with route rules. It will be supported in future versions.
* Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions.

@ -356,6 +356,78 @@ transform:
filter: inc(id) < 100
```
## Embedding AI Model
Embedding AI Model can be used in transform rules.
To use Embedding AI Model, you need to download the jar of build-in model, and then add `--jar {$BUILT_IN_MODEL_PATH}` to your flink-cdc.sh command.
How to define a Embedding AI Model:
```yaml
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234
```
Note:
* `model-name` is a common required parameter for all support models, which represent the function name called in `projection` or `filter`.
* `class-name` is a common required parameter for all support models, available values can be found in [All Support models](#all-support-models).
* `openai.model`, `openai.host`, `openai.apiKey` and `openai.chat.prompt` is option parameters that defined in specific model.
How to use a Embedding AI Model:
```yaml
transform:
- source-table: db.\.*
projection: "*, inc(inc(inc(id))) as inc_id, GET_EMBEDDING(page) as emb, CHAT(page) as summary"
filter: inc(id) < 100
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: gpt-4o-mini
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
```
Here, GET_EMBEDDING is defined though `model-name` in `pipeline`.
### All Support models
The following built-in models are provided:
#### OpenAIChatModel
| parameter | type | optional/required | meaning |
|--------------------|--------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "gpt-4o-mini", Available options are "gpt-4o-mini", "gpt-4o", "gpt-4-32k", "gpt-3.5-turbo". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
| openai.chat.prompt | STRING | optional | Prompt for chatting with OpenAI, for example: "Please summary this ". |
#### OpenAIEmbeddingModel
| parameter | type | optional/required | meaning |
|---------------|--------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "text-embedding-3-small", Available options are "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
# Known limitations
* Currently, transform doesn't work with route rules. It will be supported in future versions.
* Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions.

@ -21,7 +21,9 @@ import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.event.SchemaChangeEventType;
import org.apache.flink.cdc.common.event.SchemaChangeEventTypeFamily;
import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior;
import org.apache.flink.cdc.common.utils.Preconditions;
import org.apache.flink.cdc.common.utils.StringUtils;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.PipelineDef;
import org.apache.flink.cdc.composer.definition.RouteDef;
import org.apache.flink.cdc.composer.definition.SinkDef;
@ -57,6 +59,7 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
private static final String ROUTE_KEY = "route";
private static final String TRANSFORM_KEY = "transform";
private static final String PIPELINE_KEY = "pipeline";
private static final String MODEL_KEY = "model";
// Source / sink keys
private static final String TYPE_KEY = "type";
@ -81,6 +84,11 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
private static final String UDF_FUNCTION_NAME_KEY = "name";
private static final String UDF_CLASSPATH_KEY = "classpath";
// Model related keys
private static final String MODEL_NAME_KEY = "model-name";
private static final String MODEL_CLASS_NAME_KEY = "class-name";
public static final String TRANSFORM_PRIMARY_KEY_KEY = "primary-keys";
public static final String TRANSFORM_PARTITION_KEY_KEY = "partition-keys";
@ -108,10 +116,15 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
// UDFs are optional. We parse UDF first and remove it from the pipelineDefJsonNode since
// it's not of plain data types and must be removed before calling toPipelineConfig.
List<UdfDef> udfDefs = new ArrayList<>();
final List<ModelDef> modelDefs = new ArrayList<>();
if (pipelineDefJsonNode.get(PIPELINE_KEY) != null) {
Optional.ofNullable(
((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(UDF_KEY))
.ifPresent(node -> node.forEach(udf -> udfDefs.add(toUdfDef(udf))));
Optional.ofNullable(
((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(MODEL_KEY))
.ifPresent(node -> modelDefs.addAll(parseModels(node)));
}
// Pipeline configs are optional
@ -156,7 +169,7 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
pipelineConfig.addAll(userPipelineConfig);
return new PipelineDef(
sourceDef, sinkDef, routeDefs, transformDefs, udfDefs, pipelineConfig);
sourceDef, sinkDef, routeDefs, transformDefs, udfDefs, modelDefs, pipelineConfig);
}
private SourceDef toSourceDef(JsonNode sourceNode) {
@ -323,4 +336,34 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
pipelineConfigNode, new TypeReference<Map<String, String>>() {});
return Configuration.fromMap(pipelineConfigMap);
}
private List<ModelDef> parseModels(JsonNode modelsNode) {
List<ModelDef> modelDefs = new ArrayList<>();
Preconditions.checkNotNull(modelsNode, "`model` in `pipeline` should not be empty.");
if (modelsNode.isArray()) {
for (JsonNode modelNode : modelsNode) {
modelDefs.add(convertJsonNodeToModelDef(modelNode));
}
} else {
modelDefs.add(convertJsonNodeToModelDef(modelsNode));
}
return modelDefs;
}
private ModelDef convertJsonNodeToModelDef(JsonNode modelNode) {
String name =
checkNotNull(
modelNode.get(MODEL_NAME_KEY),
"Missing required field \"%s\" in `model`",
MODEL_NAME_KEY)
.asText();
String model =
checkNotNull(
modelNode.get(MODEL_CLASS_NAME_KEY),
"Missing required field \"%s\" in `model`",
MODEL_CLASS_NAME_KEY)
.asText();
Map<String, String> properties = mapper.convertValue(modelNode, Map.class);
return new ModelDef(name, model, properties);
}
}

@ -20,6 +20,7 @@ package org.apache.flink.cdc.cli.parser;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.event.SchemaChangeEventType;
import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.PipelineDef;
import org.apache.flink.cdc.composer.definition.RouteDef;
import org.apache.flink.cdc.composer.definition.SinkDef;
@ -39,6 +40,7 @@ import java.time.Duration;
import java.time.ZoneId;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Set;
import static org.apache.flink.cdc.common.event.SchemaChangeEventType.ADD_COLUMN;
@ -344,6 +346,18 @@ class YamlPipelineDefinitionParserTest {
null,
"add new uniq_id for each row")),
Collections.emptyList(),
Collections.singletonList(
new ModelDef(
"GET_EMBEDDING",
"OpenAIEmbeddingModel",
new LinkedHashMap<>(
ImmutableMap.<String, String>builder()
.put("model-name", "GET_EMBEDDING")
.put("class-name", "OpenAIEmbeddingModel")
.put("openai.model", "text-embedding-3-small")
.put("openai.host", "https://xxxx")
.put("openai.apikey", "abcd1234")
.build()))),
Configuration.fromMap(
ImmutableMap.<String, String>builder()
.put("name", "source-database-sync-pipe")
@ -397,7 +411,13 @@ class YamlPipelineDefinitionParserTest {
+ " name: source-database-sync-pipe\n"
+ " parallelism: 4\n"
+ " schema.change.behavior: evolve\n"
+ " schema-operator.rpc-timeout: 1 h";
+ " schema-operator.rpc-timeout: 1 h\n"
+ " model:\n"
+ " - model-name: GET_EMBEDDING\n"
+ " class-name: OpenAIEmbeddingModel\n"
+ " openai.model: text-embedding-3-small\n"
+ " openai.host: https://xxxx\n"
+ " openai.apikey: abcd1234";
YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser();
PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration());
assertThat(pipelineDef).isEqualTo(fullDef);
@ -459,6 +479,18 @@ class YamlPipelineDefinitionParserTest {
null,
"add new uniq_id for each row")),
Collections.emptyList(),
Collections.singletonList(
new ModelDef(
"GET_EMBEDDING",
"OpenAIEmbeddingModel",
new LinkedHashMap<>(
ImmutableMap.<String, String>builder()
.put("model-name", "GET_EMBEDDING")
.put("class-name", "OpenAIEmbeddingModel")
.put("openai.model", "text-embedding-3-small")
.put("openai.host", "https://xxxx")
.put("openai.apikey", "abcd1234")
.build()))),
Configuration.fromMap(
ImmutableMap.<String, String>builder()
.put("name", "source-database-sync-pipe")

@ -57,3 +57,9 @@ pipeline:
parallelism: 4
schema.change.behavior: evolve
schema-operator.rpc-timeout: 1 h
model:
model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234

@ -30,9 +30,20 @@ public interface UserDefinedFunction {
return null;
}
/** This will be invoked every time when a UDF got created. */
/**
* This will be invoked every time when a UDF got created.
*
* <p>this method is {@link Deprecated}, please use {@link #open(UserDefinedFunctionContext)}
* instead.
*/
@Deprecated
default void open() throws Exception {}
/** This will be invoked every time when a UDF got created. */
default void open(UserDefinedFunctionContext context) throws Exception {
open();
}
/** This will be invoked before a UDF got destroyed. */
default void close() throws Exception {}
}

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.common.udf;
import org.apache.flink.cdc.common.configuration.Configuration;
/** Context for initialization of {@link UserDefinedFunction}. */
public interface UserDefinedFunctionContext {
Configuration configuration();
}

@ -67,6 +67,12 @@ limitations under the License.
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-cdc-pipeline-model</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<!-- This is for testing Scala UDF.-->
<dependency>

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.composer.definition;
import java.util.Map;
import java.util.Objects;
/**
* Common properties of model.
*
* <p>A transformation definition contains:
*
* <ul>
* <li>modelName: The name of function.
* <li>className: The model to transform data.
* <li>parameters: The parameters that used to configure the model.
* </ul>
*/
public class ModelDef {
private final String modelName;
private final String className;
private final Map<String, String> parameters;
public ModelDef(String modelName, String className, Map<String, String> parameters) {
this.modelName = modelName;
this.className = className;
this.parameters = parameters;
}
public String getModelName() {
return modelName;
}
public String getClassName() {
return className;
}
public Map<String, String> getParameters() {
return parameters;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ModelDef modelDef = (ModelDef) o;
return Objects.equals(modelName, modelDef.modelName)
&& Objects.equals(className, modelDef.className)
&& Objects.equals(parameters, modelDef.parameters);
}
@Override
public int hashCode() {
return Objects.hash(modelName, className, parameters);
}
@Override
public String toString() {
return "ModelDef{"
+ "name='"
+ modelName
+ '\''
+ ", model='"
+ className
+ '\''
+ ", parameters="
+ parameters
+ '}';
}
}

@ -24,6 +24,7 @@ import org.apache.flink.cdc.composer.PipelineComposer;
import org.apache.flink.cdc.composer.PipelineExecution;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.TimeZone;
@ -55,6 +56,7 @@ public class PipelineDef {
private final List<RouteDef> routes;
private final List<TransformDef> transforms;
private final List<UdfDef> udfs;
private final List<ModelDef> models;
private final Configuration config;
public PipelineDef(
@ -63,15 +65,27 @@ public class PipelineDef {
List<RouteDef> routes,
List<TransformDef> transforms,
List<UdfDef> udfs,
List<ModelDef> models,
Configuration config) {
this.source = source;
this.sink = sink;
this.routes = routes;
this.transforms = transforms;
this.udfs = udfs;
this.models = models;
this.config = evaluatePipelineTimeZone(config);
}
public PipelineDef(
SourceDef source,
SinkDef sink,
List<RouteDef> routes,
List<TransformDef> transforms,
List<UdfDef> udfs,
Configuration config) {
this(source, sink, routes, transforms, udfs, new ArrayList<>(), config);
}
public SourceDef getSource() {
return source;
}
@ -92,6 +106,10 @@ public class PipelineDef {
return udfs;
}
public List<ModelDef> getModels() {
return models;
}
public Configuration getConfig() {
return config;
}
@ -109,6 +127,8 @@ public class PipelineDef {
+ transforms
+ ", udfs="
+ udfs
+ ", models="
+ models
+ ", config="
+ config
+ '}';
@ -128,12 +148,13 @@ public class PipelineDef {
&& Objects.equals(routes, that.routes)
&& Objects.equals(transforms, that.transforms)
&& Objects.equals(udfs, that.udfs)
&& Objects.equals(models, that.models)
&& Objects.equals(config, that.config);
}
@Override
public int hashCode() {
return Objects.hash(source, sink, routes, transforms, udfs, config);
return Objects.hash(source, sink, routes, transforms, udfs, models, config);
}
// ------------------------------------------------------------------------

@ -107,7 +107,10 @@ public class FlinkPipelineComposer implements PipelineComposer {
TransformTranslator transformTranslator = new TransformTranslator();
stream =
transformTranslator.translatePreTransform(
stream, pipelineDef.getTransforms(), pipelineDef.getUdfs());
stream,
pipelineDef.getTransforms(),
pipelineDef.getUdfs(),
pipelineDef.getModels());
// Schema operator
SchemaOperatorTranslator schemaOperatorTranslator =
@ -124,8 +127,9 @@ public class FlinkPipelineComposer implements PipelineComposer {
transformTranslator.translatePostTransform(
stream,
pipelineDef.getTransforms(),
pipelineDefConfig.get(PipelineOptions.PIPELINE_LOCAL_TIME_ZONE),
pipelineDef.getUdfs());
pipelineDef.getConfig().get(PipelineOptions.PIPELINE_LOCAL_TIME_ZONE),
pipelineDef.getUdfs(),
pipelineDef.getModels());
// Build DataSink in advance as schema operator requires MetadataApplier
DataSinkTranslator sinkTranslator = new DataSinkTranslator();

@ -17,8 +17,9 @@
package org.apache.flink.cdc.composer.flink.translator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.cdc.common.event.Event;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.TransformDef;
import org.apache.flink.cdc.composer.definition.UdfDef;
import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperator;
@ -26,7 +27,9 @@ import org.apache.flink.cdc.runtime.operators.transform.PreTransformOperator;
import org.apache.flink.cdc.runtime.typeutils.EventTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
@ -35,8 +38,15 @@ import java.util.stream.Collectors;
*/
public class TransformTranslator {
/** Package of built-in model. */
public static final String PREFIX_CLASSPATH_BUILT_IN_MODEL =
"org.apache.flink.cdc.runtime.model.";
public DataStream<Event> translatePreTransform(
DataStream<Event> input, List<TransformDef> transforms, List<UdfDef> udfFunctions) {
DataStream<Event> input,
List<TransformDef> transforms,
List<UdfDef> udfFunctions,
List<ModelDef> models) {
if (transforms.isEmpty()) {
return input;
}
@ -52,10 +62,11 @@ public class TransformTranslator {
transform.getPartitionKeys(),
transform.getTableOptions());
}
preTransformFunctionBuilder.addUdfFunctions(
udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList()));
preTransformFunctionBuilder.addUdfFunctions(
udfFunctions.stream()
.map(udf -> Tuple2.of(udf.getName(), udf.getClasspath()))
.collect(Collectors.toList()));
models.stream().map(this::modelToUDFTuple).collect(Collectors.toList()));
return input.transform(
"Transform:Schema", new EventTypeInfo(), preTransformFunctionBuilder.build());
}
@ -64,7 +75,8 @@ public class TransformTranslator {
DataStream<Event> input,
List<TransformDef> transforms,
String timezone,
List<UdfDef> udfFunctions) {
List<UdfDef> udfFunctions,
List<ModelDef> models) {
if (transforms.isEmpty()) {
return input;
}
@ -84,10 +96,21 @@ public class TransformTranslator {
}
postTransformFunctionBuilder.addTimezone(timezone);
postTransformFunctionBuilder.addUdfFunctions(
udfFunctions.stream()
.map(udf -> Tuple2.of(udf.getName(), udf.getClasspath()))
.collect(Collectors.toList()));
udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList()));
postTransformFunctionBuilder.addUdfFunctions(
models.stream().map(this::modelToUDFTuple).collect(Collectors.toList()));
return input.transform(
"Transform:Data", new EventTypeInfo(), postTransformFunctionBuilder.build());
}
private Tuple3<String, String, Map<String, String>> modelToUDFTuple(ModelDef model) {
return Tuple3.of(
model.getModelName(),
PREFIX_CLASSPATH_BUILT_IN_MODEL + model.getClassName(),
model.getParameters());
}
private Tuple3<String, String, Map<String, String>> udfDefToUDFTuple(UdfDef udf) {
return Tuple3.of(udf.getName(), udf.getClasspath(), new HashMap<>());
}
}

@ -21,6 +21,7 @@ import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior;
import org.apache.flink.cdc.composer.PipelineExecution;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.PipelineDef;
import org.apache.flink.cdc.composer.definition.SinkDef;
import org.apache.flink.cdc.composer.definition.SourceDef;
@ -35,6 +36,8 @@ import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceOptions;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.test.junit5.MiniClusterExtension;
import org.apache.flink.shaded.guava31.com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
@ -44,8 +47,10 @@ import org.junit.jupiter.params.provider.MethodSource;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.stream.Stream;
import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL;
@ -822,6 +827,75 @@ public class FlinkPipelineUdfITCase {
"DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, , 4, Integer: 42, 2-42], after=[2, x, 4, Integer: 42, 2-42], op=UPDATE, meta=()}");
}
@ParameterizedTest
@MethodSource("testParams")
void testTransformWithModel(ValuesDataSink.SinkApi sinkApi) throws Exception {
FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster();
// Setup value source
Configuration sourceConfig = new Configuration();
sourceConfig.set(
ValuesDataSourceOptions.EVENT_SET_ID,
ValuesDataSourceHelper.EventSetId.TRANSFORM_TABLE);
SourceDef sourceDef =
new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig);
// Setup value sink
Configuration sinkConfig = new Configuration();
sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true);
sinkConfig.set(ValuesDataSinkOptions.SINK_API, sinkApi);
SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig);
// Setup transform
TransformDef transformDef =
new TransformDef(
"default_namespace.default_schema.table1",
"*, CHAT(col1) AS emb",
null,
"col1",
null,
"key1=value1",
"");
// Setup pipeline
Configuration pipelineConfig = new Configuration();
pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1);
pipelineConfig.set(
PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE);
PipelineDef pipelineDef =
new PipelineDef(
sourceDef,
sinkDef,
Collections.emptyList(),
Collections.singletonList(transformDef),
new ArrayList<>(),
Arrays.asList(
new ModelDef(
"CHAT",
"OpenAIChatModel",
new LinkedHashMap<>(
ImmutableMap.<String, String>builder()
.put("openai.model", "gpt-4o-mini")
.put(
"openai.host",
"http://langchain4j.dev/demo/openai/v1")
.put("openai.apikey", "demo")
.build()))),
pipelineConfig);
// Execute the pipeline
PipelineExecution execution = composer.compose(pipelineDef);
execution.execute();
// Check the order and content of all received events
String[] outputEvents = outCaptor.toString().trim().split("\n");
assertThat(outputEvents)
.contains(
"CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING,`col2` STRING,`emb` STRING}, primaryKeys=col1, options=({key1=value1})}")
// The result of transform by model is not fixed.
.hasSize(9);
}
private static Stream<Arguments> testParams() {
return Stream.of(
arguments(ValuesDataSink.SinkApi.SINK_FUNCTION, "java"),

@ -0,0 +1,81 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>flink-cdc-parent</artifactId>
<groupId>org.apache.flink</groupId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>flink-cdc-pipeline-model</artifactId>
<properties>
<langchain4j.version>0.23.0</langchain4j.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-cdc-common</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-test-utils-junit</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>0.12.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<executions>
<execution>
<id>test-jar</id>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.ConfigOption;
import org.apache.flink.cdc.common.configuration.ConfigOptions;
/** Options of built-in model. */
public class ModelOptions {
// Options for Open AI Model.
public static final ConfigOption<String> OPENAI_MODEL_NAME =
ConfigOptions.key("openai.model")
.stringType()
.noDefaultValue()
.withDescription("Name of model to be called.");
public static final ConfigOption<String> OPENAI_HOST =
ConfigOptions.key("openai.host")
.stringType()
.noDefaultValue()
.withDescription("Host of the Model server to be connected.");
public static final ConfigOption<String> OPENAI_API_KEY =
ConfigOptions.key("openai.apikey")
.stringType()
.noDefaultValue()
.withDescription("Api Key for verification of the Model server.");
public static final ConfigOption<String> OPENAI_CHAT_PROMPT =
ConfigOptions.key("openai.chat.prompt")
.stringType()
.noDefaultValue()
.withDescription("Prompt for chat using OpenAI.");
}

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.apache.flink.cdc.common.utils.Preconditions;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collections;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_CHAT_PROMPT;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME;
/**
* A {@link UserDefinedFunction} that use Model defined by OpenAI to generate text, refer to <a
* href="https://docs.langchain4j.dev/integrations/language-models/open-ai/">docs</a>}.
*/
public class OpenAIChatModel implements UserDefinedFunction {
private static final Logger LOG = LoggerFactory.getLogger(OpenAIChatModel.class);
private OpenAiChatModel chatModel;
private String modelName;
private String host;
private String prompt;
public String eval(String input) {
return chat(input);
}
private String chat(String input) {
if (input == null || input.trim().isEmpty()) {
LOG.warn("Empty or null input provided for embedding.");
return "";
}
if (prompt != null) {
input = prompt + ": " + input;
}
return chatModel
.generate(Collections.singletonList(new UserMessage(input)))
.content()
.text();
}
@Override
public DataType getReturnType() {
return DataTypes.STRING();
}
@Override
public void open(UserDefinedFunctionContext userDefinedFunctionContext) {
Configuration modelOptions = userDefinedFunctionContext.configuration();
this.modelName = modelOptions.get(OPENAI_MODEL_NAME);
Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty.");
this.host = modelOptions.get(OPENAI_HOST);
Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty.");
String apiKey = modelOptions.get(OPENAI_API_KEY);
Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty.");
this.prompt = modelOptions.get(OPENAI_CHAT_PROMPT);
LOG.info("Opening OpenAIChatModel " + modelName + " " + host);
this.chatModel =
OpenAiChatModel.builder().apiKey(apiKey).baseUrl(host).modelName(modelName).build();
}
@Override
public void close() {
LOG.info("Closed OpenAIChatModel " + modelName + " " + host);
}
}

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.data.ArrayData;
import org.apache.flink.cdc.common.data.GenericArrayData;
import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.apache.flink.cdc.common.utils.Preconditions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collections;
import java.util.List;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME;
/**
* A {@link UserDefinedFunction} that use Model defined by OpenAI to generate vector data, refer to
* <a href="https://docs.langchain4j.dev/integrations/language-models/open-ai/">docs</a>}.
*/
public class OpenAIEmbeddingModel implements UserDefinedFunction {
private static final Logger LOG = LoggerFactory.getLogger(OpenAIEmbeddingModel.class);
private String modelName;
private String host;
private OpenAiEmbeddingModel embeddingModel;
public ArrayData eval(String input) {
return getEmbedding(input);
}
private ArrayData getEmbedding(String input) {
if (input == null || input.trim().isEmpty()) {
LOG.debug("Empty or null input provided for embedding.");
return new GenericArrayData(new Float[0]);
}
TextSegment textSegment = new TextSegment(input, new Metadata());
List<Embedding> embeddings =
embeddingModel.embedAll(Collections.singletonList(textSegment)).content();
if (embeddings != null && !embeddings.isEmpty()) {
List<Float> embeddingList = embeddings.get(0).vectorAsList();
Float[] embeddingArray = embeddingList.toArray(new Float[0]);
return new GenericArrayData(embeddingArray);
} else {
LOG.warn("No embedding results returned for input: {}", input);
return new GenericArrayData(new Float[0]);
}
}
@Override
public DataType getReturnType() {
return DataTypes.ARRAY(DataTypes.FLOAT());
}
@Override
public void open(UserDefinedFunctionContext userDefinedFunctionContext) {
Configuration modelOptions = userDefinedFunctionContext.configuration();
this.modelName = modelOptions.get(OPENAI_MODEL_NAME);
Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty.");
this.host = modelOptions.get(OPENAI_HOST);
Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty.");
String apiKey = modelOptions.get(OPENAI_API_KEY);
Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty.");
LOG.info("Opening OpenAIEmbeddingModel " + modelName + " " + host);
this.embeddingModel =
OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(host)
.modelName(modelName)
.build();
}
@Override
public void close() {
LOG.info("Closed OpenAIEmbeddingModel " + modelName + " " + host);
}
}

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/** A test for {@link OpenAIChatModel}. */
public class TestOpenAIChatModel {
@Test
public void testEval() {
OpenAIChatModel openAIChatModel = new OpenAIChatModel();
Configuration configuration = new Configuration();
configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1");
configuration.set(ModelOptions.OPENAI_API_KEY, "demo");
configuration.set(ModelOptions.OPENAI_MODEL_NAME, "gpt-4o-mini");
UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration;
openAIChatModel.open(userDefinedFunctionContext);
String response = openAIChatModel.eval("Who invented the electric light?");
Assertions.assertFalse(response.isEmpty());
}
}

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.data.ArrayData;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/** A test for {@link OpenAIEmbeddingModel}. */
public class TestOpenAIEmbeddingModel {
@Test
public void testEval() {
OpenAIEmbeddingModel openAIEmbeddingModel = new OpenAIEmbeddingModel();
Configuration configuration = new Configuration();
configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1");
configuration.set(ModelOptions.OPENAI_API_KEY, "demo");
configuration.set(ModelOptions.OPENAI_MODEL_NAME, "text-embedding-3-small");
UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration;
openAIEmbeddingModel.open(userDefinedFunctionContext);
ArrayData arrayData =
openAIEmbeddingModel.eval("Flink CDC is a streaming data integration tool");
Assertions.assertNotNull(arrayData);
}
}

@ -89,6 +89,12 @@ limitations under the License.
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-cdc-pipeline-model</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

@ -18,6 +18,8 @@
package org.apache.flink.cdc.runtime.operators.transform;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.data.RecordData;
import org.apache.flink.cdc.common.data.binary.BinaryRecordData;
import org.apache.flink.cdc.common.event.CreateTableEvent;
@ -29,6 +31,7 @@ import org.apache.flink.cdc.common.event.TableId;
import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.common.schema.Schema;
import org.apache.flink.cdc.common.schema.Selectors;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.apache.flink.cdc.common.utils.SchemaUtils;
import org.apache.flink.cdc.runtime.parser.TransformParser;
import org.apache.flink.streaming.api.graph.StreamConfig;
@ -70,7 +73,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
/** keep the relationship of TableId and table information. */
private final Map<TableId, PostTransformChangeInfo> postTransformChangeInfoMap;
private final List<Tuple2<String, String>> udfFunctions;
private final List<Tuple3<String, String, Map<String, String>>> udfFunctions;
private List<UserDefinedFunctionDescriptor> udfDescriptors;
private transient Map<String, Object> udfFunctionInstances;
@ -89,7 +92,8 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
public static class Builder {
private final List<TransformRule> transformRules = new ArrayList<>();
private String timezone;
private final List<Tuple2<String, String>> udfFunctions = new ArrayList<>();
private final List<Tuple3<String, String, Map<String, String>>> udfFunctions =
new ArrayList<>();
public PostTransformOperator.Builder addTransform(
String tableInclusions,
@ -125,7 +129,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
}
public PostTransformOperator.Builder addUdfFunctions(
List<Tuple2<String, String>> udfFunctions) {
List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.udfFunctions.addAll(udfFunctions);
return this;
}
@ -138,7 +142,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
private PostTransformOperator(
List<TransformRule> transformRules,
String timezone,
List<Tuple2<String, String>> udfFunctions) {
List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.transformRules = transformRules;
this.timezone = timezone;
this.postTransformChangeInfoMap = new ConcurrentHashMap<>();
@ -158,10 +162,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
super.setup(containingTask, config, output);
udfDescriptors =
udfFunctions.stream()
.map(
udf -> {
return new UserDefinedFunctionDescriptor(udf.f0, udf.f1);
})
.map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1, udf.f2))
.collect(Collectors.toList());
}
@ -539,7 +540,12 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
// into UserDefinedFunction interface, thus the provided UDF classes
// might not be compatible with the interface definition in CDC common.
Object udfInstance = udfFunctionInstances.get(udf.getName());
udfInstance.getClass().getMethod("open").invoke(udfInstance);
UserDefinedFunctionContext userDefinedFunctionContext =
() -> Configuration.fromMap(udf.getParameters());
udfInstance
.getClass()
.getMethod("open", UserDefinedFunctionContext.class)
.invoke(udfInstance, userDefinedFunctionContext);
} else {
// Do nothing, Flink-style UDF lifecycle hooks are not supported
}

@ -21,6 +21,7 @@ import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.cdc.common.data.binary.BinaryRecordData;
import org.apache.flink.cdc.common.event.CreateTableEvent;
import org.apache.flink.cdc.common.event.DataChangeEvent;
@ -69,7 +70,7 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
private final Map<TableId, PreTransformChangeInfo> preTransformChangeInfoMap;
private final List<Tuple2<Selectors, SchemaMetadataTransform>> schemaMetadataTransformers;
private transient ListState<byte[]> state;
private final List<Tuple2<String, String>> udfFunctions;
private final List<Tuple3<String, String, Map<String, String>>> udfFunctions;
private List<UserDefinedFunctionDescriptor> udfDescriptors;
private Map<TableId, PreTransformProcessor> preTransformProcessorMap;
private Map<TableId, Boolean> hasAsteriskMap;
@ -82,7 +83,8 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
public static class Builder {
private final List<TransformRule> transformRules = new ArrayList<>();
private final List<Tuple2<String, String>> udfFunctions = new ArrayList<>();
private final List<Tuple3<String, String, Map<String, String>>> udfFunctions =
new ArrayList<>();
public PreTransformOperator.Builder addTransform(
String tableInclusions, @Nullable String projection, @Nullable String filter) {
@ -109,7 +111,7 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
}
public PreTransformOperator.Builder addUdfFunctions(
List<Tuple2<String, String>> udfFunctions) {
List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.udfFunctions.addAll(udfFunctions);
return this;
}
@ -120,7 +122,8 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
}
private PreTransformOperator(
List<TransformRule> transformRules, List<Tuple2<String, String>> udfFunctions) {
List<TransformRule> transformRules,
List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.transformRules = transformRules;
this.preTransformChangeInfoMap = new ConcurrentHashMap<>();
this.preTransformProcessorMap = new ConcurrentHashMap<>();
@ -137,7 +140,7 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
super.setup(containingTask, config, output);
this.udfDescriptors =
this.udfFunctions.stream()
.map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1))
.map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1, udf.f2))
.collect(Collectors.toList());
// Initialize data fields in advance because they might be accessed in

@ -24,6 +24,8 @@ import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@ -38,9 +40,16 @@ public class UserDefinedFunctionDescriptor implements Serializable {
private final String className;
private final DataType returnTypeHint;
private final boolean isCdcPipelineUdf;
private final Map<String, String> parameters;
public UserDefinedFunctionDescriptor(String name, String classpath) {
this(name, classpath, new HashMap<>());
}
public UserDefinedFunctionDescriptor(
String name, String classpath, Map<String, String> parameters) {
this.name = name;
this.parameters = parameters;
this.classpath = classpath;
this.className = classpath.substring(classpath.lastIndexOf('.') + 1);
try {
@ -107,6 +116,10 @@ public class UserDefinedFunctionDescriptor implements Serializable {
return className;
}
public Map<String, String> getParameters() {
return parameters;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -115,13 +128,19 @@ public class UserDefinedFunctionDescriptor implements Serializable {
if (o == null || getClass() != o.getClass()) {
return false;
}
UserDefinedFunctionDescriptor context = (UserDefinedFunctionDescriptor) o;
return Objects.equals(name, context.name) && Objects.equals(classpath, context.classpath);
UserDefinedFunctionDescriptor that = (UserDefinedFunctionDescriptor) o;
return isCdcPipelineUdf == that.isCdcPipelineUdf
&& Objects.equals(name, that.name)
&& Objects.equals(classpath, that.classpath)
&& Objects.equals(className, that.className)
&& Objects.equals(returnTypeHint, that.returnTypeHint)
&& Objects.equals(parameters, that.parameters);
}
@Override
public int hashCode() {
return Objects.hash(name, classpath);
return Objects.hash(
name, classpath, className, returnTypeHint, isCdcPipelineUdf, parameters);
}
@Override
@ -133,6 +152,15 @@ public class UserDefinedFunctionDescriptor implements Serializable {
+ ", classpath='"
+ classpath
+ '\''
+ ", className='"
+ className
+ '\''
+ ", returnTypeHint="
+ returnTypeHint
+ ", isCdcPipelineUdf="
+ isCdcPipelineUdf
+ ", parameters="
+ parameters
+ '}';
}
}

@ -36,6 +36,7 @@ import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
@ -123,10 +124,11 @@ public class TransformParser {
if (udf.getReturnTypeHint() != null) {
// This UDF has return type hint annotation
returnTypeInference =
o ->
o.getTypeFactory()
.createSqlType(
convertCalciteType(udf.getReturnTypeHint()));
o -> {
RelDataTypeFactory typeFactory = o.getTypeFactory();
DataType returnTypeHint = udf.getReturnTypeHint();
return convertCalciteType(typeFactory, returnTypeHint);
};
} else {
// Infer it from eval method return type
returnTypeInference = o -> function.getReturnType(o.getTypeFactory());

@ -310,4 +310,36 @@ public class TransformSqlOperatorTable extends ReflectiveSqlOperatorTable {
// Cast Functions
// --------------
public static final SqlFunction CAST = SqlStdOperatorTable.CAST;
public static final SqlFunction AI_CHAT_PREDICT =
new SqlFunction(
"AI_CHAT_PREDICT",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.VARCHAR),
null,
OperandTypes.family(
SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
// Define the AI_EMBEDDING function
public static final SqlFunction GET_EMBEDDING =
new SqlFunction(
"GET_EMBEDDING",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.VARCHAR),
null,
OperandTypes.family(
SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
// Define the AI_LANGCHAIN_PREDICT function
public static final SqlFunction AI_LANGCHAIN_PREDICT =
new SqlFunction(
"AI_LANGCHAIN_PREDICT",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.VARCHAR),
null,
OperandTypes.family(
SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
}

@ -51,8 +51,7 @@ public class TransformTable extends AbstractTable {
for (Column column : columns) {
names.add(column.getName());
RelDataType sqlType =
relDataTypeFactory.createSqlType(
DataTypeConverter.convertCalciteType(column.getType()));
DataTypeConverter.convertCalciteType(relDataTypeFactory, column.getType());
types.add(sqlType);
}
return relDataTypeFactory.createStructType(Pair.zip(names, types));

@ -17,18 +17,27 @@
package org.apache.flink.cdc.runtime.typeutils;
import org.apache.flink.cdc.common.data.ArrayData;
import org.apache.flink.cdc.common.data.DecimalData;
import org.apache.flink.cdc.common.data.GenericArrayData;
import org.apache.flink.cdc.common.data.GenericMapData;
import org.apache.flink.cdc.common.data.LocalZonedTimestampData;
import org.apache.flink.cdc.common.data.MapData;
import org.apache.flink.cdc.common.data.TimestampData;
import org.apache.flink.cdc.common.data.binary.BinaryStringData;
import org.apache.flink.cdc.common.schema.Column;
import org.apache.flink.cdc.common.types.ArrayType;
import org.apache.flink.cdc.common.types.BinaryType;
import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.types.DecimalType;
import org.apache.flink.cdc.common.types.MapType;
import org.apache.flink.cdc.common.types.RowType;
import org.apache.flink.cdc.common.types.TimestampType;
import org.apache.flink.cdc.common.types.VarBinaryType;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.SqlTypeName;
import java.math.BigDecimal;
@ -39,8 +48,11 @@ import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/** A data type converter. */
public class DataTypeConverter {
@ -90,50 +102,67 @@ public class DataTypeConverter {
case ROW:
return Object.class;
case ARRAY:
return ArrayData.class;
case MAP:
return MapData.class;
default:
throw new UnsupportedOperationException("Unsupported type: " + dataType);
}
}
public static SqlTypeName convertCalciteType(DataType dataType) {
public static RelDataType convertCalciteType(
RelDataTypeFactory typeFactory, DataType dataType) {
switch (dataType.getTypeRoot()) {
case BOOLEAN:
return SqlTypeName.BOOLEAN;
return typeFactory.createSqlType(SqlTypeName.BOOLEAN);
case TINYINT:
return SqlTypeName.TINYINT;
return typeFactory.createSqlType(SqlTypeName.TINYINT);
case SMALLINT:
return SqlTypeName.SMALLINT;
return typeFactory.createSqlType(SqlTypeName.SMALLINT);
case INTEGER:
return SqlTypeName.INTEGER;
return typeFactory.createSqlType(SqlTypeName.INTEGER);
case BIGINT:
return SqlTypeName.BIGINT;
return typeFactory.createSqlType(SqlTypeName.BIGINT);
case DATE:
return SqlTypeName.DATE;
return typeFactory.createSqlType(SqlTypeName.DATE);
case TIME_WITHOUT_TIME_ZONE:
return SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE;
return typeFactory.createSqlType(SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE);
case TIMESTAMP_WITHOUT_TIME_ZONE:
return SqlTypeName.TIMESTAMP;
return typeFactory.createSqlType(SqlTypeName.TIMESTAMP);
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
return typeFactory.createSqlType(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE);
case FLOAT:
return SqlTypeName.FLOAT;
return typeFactory.createSqlType(SqlTypeName.FLOAT);
case DOUBLE:
return SqlTypeName.DOUBLE;
return typeFactory.createSqlType(SqlTypeName.DOUBLE);
case CHAR:
return SqlTypeName.CHAR;
return typeFactory.createSqlType(SqlTypeName.CHAR);
case VARCHAR:
return SqlTypeName.VARCHAR;
return typeFactory.createSqlType(SqlTypeName.VARCHAR);
case BINARY:
return SqlTypeName.BINARY;
return typeFactory.createSqlType(SqlTypeName.BINARY);
case VARBINARY:
return SqlTypeName.VARBINARY;
return typeFactory.createSqlType(SqlTypeName.VARBINARY);
case DECIMAL:
return SqlTypeName.DECIMAL;
return typeFactory.createSqlType(SqlTypeName.DECIMAL);
case ROW:
return SqlTypeName.ROW;
List<RelDataType> dataTypes =
((RowType) dataType)
.getFieldTypes().stream()
.map((type) -> convertCalciteType(typeFactory, type))
.collect(Collectors.toList());
return typeFactory.createStructType(
dataTypes, ((RowType) dataType).getFieldNames());
case ARRAY:
DataType elementType = ((ArrayType) dataType).getElementType();
return typeFactory.createArrayType(
convertCalciteType(typeFactory, elementType), -1);
case MAP:
RelDataType keyType =
convertCalciteType(typeFactory, ((MapType) dataType).getKeyType());
RelDataType valueType =
convertCalciteType(typeFactory, ((MapType) dataType).getValueType());
return typeFactory.createMapType(keyType, valueType);
default:
throw new UnsupportedOperationException("Unsupported type: " + dataType);
}
@ -173,9 +202,16 @@ public class DataTypeConverter {
return DataTypes.VARBINARY(VarBinaryType.MAX_LENGTH);
case DECIMAL:
return DataTypes.DECIMAL(relDataType.getPrecision(), relDataType.getScale());
case ROW:
case ARRAY:
RelDataType componentType = relDataType.getComponentType();
return DataTypes.ARRAY(convertCalciteRelDataTypeToDataType(componentType));
case MAP:
RelDataType keyType = relDataType.getKeyType();
RelDataType valueType = relDataType.getValueType();
return DataTypes.MAP(
convertCalciteRelDataTypeToDataType(keyType),
convertCalciteRelDataTypeToDataType(valueType));
case ROW:
default:
throw new UnsupportedOperationException(
"Unsupported type: " + relDataType.getSqlTypeName());
@ -220,7 +256,9 @@ public class DataTypeConverter {
case ROW:
return value;
case ARRAY:
return convertToArray(value, (ArrayType) dataType);
case MAP:
return convertToMap(value, (MapType) dataType);
default:
throw new UnsupportedOperationException("Unsupported type: " + dataType);
}
@ -264,7 +302,9 @@ public class DataTypeConverter {
case ROW:
return value;
case ARRAY:
return convertToArrayOriginal(value, (ArrayType) dataType);
case MAP:
return convertToMapOriginal(value, (MapType) dataType);
default:
throw new UnsupportedOperationException("Unsupported type: " + dataType);
}
@ -378,6 +418,101 @@ public class DataTypeConverter {
return toLocalTime(obj).toSecondOfDay() * 1000;
}
private static Object convertToArray(Object obj, ArrayType arrayType) {
if (obj instanceof ArrayData) {
return obj;
}
if (obj instanceof List) {
List<?> list = (List<?>) obj;
GenericArrayData arrayData = new GenericArrayData(list.toArray());
return arrayData;
}
if (obj.getClass().isArray()) {
return new GenericArrayData((Object[]) obj);
}
throw new IllegalArgumentException("Unable to convert to ArrayData: " + obj);
}
private static Object convertToArrayOriginal(Object obj, ArrayType arrayType) {
if (obj instanceof ArrayData) {
ArrayData arrayData = (ArrayData) obj;
Object[] result = new Object[arrayData.size()];
for (int i = 0; i < arrayData.size(); i++) {
result[i] = getArrayElement(arrayData, i, arrayType.getElementType());
}
return result;
}
return obj;
}
private static Object getArrayElement(ArrayData arrayData, int pos, DataType elementType) {
switch (elementType.getTypeRoot()) {
case BOOLEAN:
return arrayData.getBoolean(pos);
case TINYINT:
return arrayData.getByte(pos);
case SMALLINT:
return arrayData.getShort(pos);
case INTEGER:
return arrayData.getInt(pos);
case BIGINT:
return arrayData.getLong(pos);
case FLOAT:
return arrayData.getFloat(pos);
case DOUBLE:
return arrayData.getDouble(pos);
case CHAR:
case VARCHAR:
return arrayData.getString(pos);
case DECIMAL:
return arrayData.getDecimal(
pos,
((DecimalType) elementType).getPrecision(),
((DecimalType) elementType).getScale());
case DATE:
return arrayData.getInt(pos);
case TIME_WITHOUT_TIME_ZONE:
return arrayData.getInt(pos);
case TIMESTAMP_WITHOUT_TIME_ZONE:
return arrayData.getTimestamp(pos, ((TimestampType) elementType).getPrecision());
case ARRAY:
return convertToArrayOriginal(arrayData.getArray(pos), (ArrayType) elementType);
case MAP:
return convertToMapOriginal(arrayData.getMap(pos), (MapType) elementType);
default:
throw new UnsupportedOperationException(
"Unsupported array element type: " + elementType);
}
}
private static Object convertToMap(Object obj, MapType mapType) {
if (obj instanceof MapData) {
return obj;
}
if (obj instanceof Map) {
Map<?, ?> javaMap = (Map<?, ?>) obj;
GenericMapData mapData = new GenericMapData(javaMap);
return mapData;
}
throw new IllegalArgumentException("Unable to convert to MapData: " + obj);
}
private static Object convertToMapOriginal(Object obj, MapType mapType) {
if (obj instanceof MapData) {
MapData mapData = (MapData) obj;
Map<Object, Object> result = new HashMap<>();
ArrayData keyArray = mapData.keyArray();
ArrayData valueArray = mapData.valueArray();
for (int i = 0; i < mapData.size(); i++) {
Object key = getArrayElement(keyArray, i, mapType.getKeyType());
Object value = getArrayElement(valueArray, i, mapType.getValueType());
result.put(key, value);
}
return result;
}
return obj;
}
private static LocalTime toLocalTime(Object obj) {
if (obj == null) {
return null;

@ -20,8 +20,10 @@ package org.apache.flink.cdc.runtime.operators.transform;
import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel;
import org.apache.flink.table.functions.ScalarFunction;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
@ -48,7 +50,7 @@ public class UserDefinedFunctionDescriptorTest {
public static class NotUDF {}
@Test
void testUserDefinedFunctionDescriptor() {
void testUserDefinedFunctionDescriptor() throws JsonProcessingException {
assertThat(new UserDefinedFunctionDescriptor("cdc_udf", CdcUdf.class.getName()))
.extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf")
@ -93,5 +95,14 @@ public class UserDefinedFunctionDescriptorTest {
"not_even_exist", "not.a.valid.class.path"))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("Failed to instantiate UDF not_even_exist@not.a.valid.class.path");
String name = "GET_EMBEDDING";
assertThat(new UserDefinedFunctionDescriptor(name, OpenAIEmbeddingModel.class.getName()))
.extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf")
.containsExactly(
"GET_EMBEDDING",
"OpenAIEmbeddingModel",
"org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel",
DataTypes.ARRAY(DataTypes.FLOAT()),
true);
}
}

@ -41,6 +41,7 @@ limitations under the License.
<module>flink-cdc-runtime</module>
<module>flink-cdc-e2e-tests</module>
<module>flink-cdc-pipeline-udf-examples</module>
<module>flink-cdc-pipeline-model</module>
</modules>
<licenses>

Loading…
Cancel
Save