[FLINK-36525][transform] Support for AI Model Integration for Data Processing (#3753)

pull/3360/head^2
Kunni 3 months ago committed by GitHub
parent c969957f02
commit 06154e9674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -356,6 +356,77 @@ transform:
filter: inc(id) < 100 filter: inc(id) < 100
``` ```
## Embedding AI Model
Embedding AI Model can be used in transform rules.
To use Embedding AI Model, you need to download the jar of build-in model, and then add `--jar {$BUILT_IN_MODEL_PATH}` to your flink-cdc.sh command.
How to define a Embedding AI Model:
```yaml
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apiKky: abcd1234
```
Note:
* `model-name` is a common required parameter for all support models, which represent the function name called in `projection` or `filter`.
* `class-name` is a common required parameter for all support models, available values can be found in [All Support models](#all-support-models).
* `openai.model`, `openai.host`, `openai.apiKey` and `openai.chat.prompt` is option parameters that defined in specific model.
How to use a Embedding AI Model:
```yaml
transform:
- source-table: db.\.*
projection: "*, inc(inc(inc(id))) as inc_id, GET_EMBEDDING(page) as emb, CHAT(page) as summary"
filter: inc(id) < 100
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: gpt-4o-mini
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
```
Here, GET_EMBEDDING is defined though `model-name` in `pipeline`.
### All Support models
The following built-in models are provided:
#### OpenAIChatModel
| parameter | type | optional/required | meaning |
|--------------------|--------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "gpt-4o-mini", Available options are "gpt-4o-mini", "gpt-4o", "gpt-4-32k", "gpt-3.5-turbo". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
| openai.chat.prompt | STRING | optional | Prompt for chatting with OpenAI, for example: "Please summary this ". |
#### OpenAIEmbeddingModel
| parameter | type | optional/required | meaning |
|---------------|--------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "text-embedding-3-small", Available options are "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
# Known limitations # Known limitations
* Currently, transform doesn't work with route rules. It will be supported in future versions. * Currently, transform doesn't work with route rules. It will be supported in future versions.
* Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions. * Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions.

@ -356,6 +356,78 @@ transform:
filter: inc(id) < 100 filter: inc(id) < 100
``` ```
## Embedding AI Model
Embedding AI Model can be used in transform rules.
To use Embedding AI Model, you need to download the jar of build-in model, and then add `--jar {$BUILT_IN_MODEL_PATH}` to your flink-cdc.sh command.
How to define a Embedding AI Model:
```yaml
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234
```
Note:
* `model-name` is a common required parameter for all support models, which represent the function name called in `projection` or `filter`.
* `class-name` is a common required parameter for all support models, available values can be found in [All Support models](#all-support-models).
* `openai.model`, `openai.host`, `openai.apiKey` and `openai.chat.prompt` is option parameters that defined in specific model.
How to use a Embedding AI Model:
```yaml
transform:
- source-table: db.\.*
projection: "*, inc(inc(inc(id))) as inc_id, GET_EMBEDDING(page) as emb, CHAT(page) as summary"
filter: inc(id) < 100
pipeline:
model:
- model-name: CHAT
class-name: OpenAIChatModel
openai.model: gpt-4o-mini
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
openai.chat.prompt: please summary this
- model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: http://langchain4j.dev/demo/openai/v1
openai.apikey: demo
```
Here, GET_EMBEDDING is defined though `model-name` in `pipeline`.
### All Support models
The following built-in models are provided:
#### OpenAIChatModel
| parameter | type | optional/required | meaning |
|--------------------|--------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "gpt-4o-mini", Available options are "gpt-4o-mini", "gpt-4o", "gpt-4-32k", "gpt-3.5-turbo". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
| openai.chat.prompt | STRING | optional | Prompt for chatting with OpenAI, for example: "Please summary this ". |
#### OpenAIEmbeddingModel
| parameter | type | optional/required | meaning |
|---------------|--------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| openai.model | STRING | required | Name of model to be called, for example: "text-embedding-3-small", Available options are "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". |
| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. |
| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". |
# Known limitations # Known limitations
* Currently, transform doesn't work with route rules. It will be supported in future versions. * Currently, transform doesn't work with route rules. It will be supported in future versions.
* Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions. * Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions.

@ -21,7 +21,9 @@ import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.event.SchemaChangeEventType; import org.apache.flink.cdc.common.event.SchemaChangeEventType;
import org.apache.flink.cdc.common.event.SchemaChangeEventTypeFamily; import org.apache.flink.cdc.common.event.SchemaChangeEventTypeFamily;
import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior;
import org.apache.flink.cdc.common.utils.Preconditions;
import org.apache.flink.cdc.common.utils.StringUtils; import org.apache.flink.cdc.common.utils.StringUtils;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.PipelineDef;
import org.apache.flink.cdc.composer.definition.RouteDef; import org.apache.flink.cdc.composer.definition.RouteDef;
import org.apache.flink.cdc.composer.definition.SinkDef; import org.apache.flink.cdc.composer.definition.SinkDef;
@ -57,6 +59,7 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
private static final String ROUTE_KEY = "route"; private static final String ROUTE_KEY = "route";
private static final String TRANSFORM_KEY = "transform"; private static final String TRANSFORM_KEY = "transform";
private static final String PIPELINE_KEY = "pipeline"; private static final String PIPELINE_KEY = "pipeline";
private static final String MODEL_KEY = "model";
// Source / sink keys // Source / sink keys
private static final String TYPE_KEY = "type"; private static final String TYPE_KEY = "type";
@ -81,6 +84,11 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
private static final String UDF_FUNCTION_NAME_KEY = "name"; private static final String UDF_FUNCTION_NAME_KEY = "name";
private static final String UDF_CLASSPATH_KEY = "classpath"; private static final String UDF_CLASSPATH_KEY = "classpath";
// Model related keys
private static final String MODEL_NAME_KEY = "model-name";
private static final String MODEL_CLASS_NAME_KEY = "class-name";
public static final String TRANSFORM_PRIMARY_KEY_KEY = "primary-keys"; public static final String TRANSFORM_PRIMARY_KEY_KEY = "primary-keys";
public static final String TRANSFORM_PARTITION_KEY_KEY = "partition-keys"; public static final String TRANSFORM_PARTITION_KEY_KEY = "partition-keys";
@ -108,10 +116,15 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
// UDFs are optional. We parse UDF first and remove it from the pipelineDefJsonNode since // UDFs are optional. We parse UDF first and remove it from the pipelineDefJsonNode since
// it's not of plain data types and must be removed before calling toPipelineConfig. // it's not of plain data types and must be removed before calling toPipelineConfig.
List<UdfDef> udfDefs = new ArrayList<>(); List<UdfDef> udfDefs = new ArrayList<>();
final List<ModelDef> modelDefs = new ArrayList<>();
if (pipelineDefJsonNode.get(PIPELINE_KEY) != null) { if (pipelineDefJsonNode.get(PIPELINE_KEY) != null) {
Optional.ofNullable( Optional.ofNullable(
((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(UDF_KEY)) ((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(UDF_KEY))
.ifPresent(node -> node.forEach(udf -> udfDefs.add(toUdfDef(udf)))); .ifPresent(node -> node.forEach(udf -> udfDefs.add(toUdfDef(udf))));
Optional.ofNullable(
((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(MODEL_KEY))
.ifPresent(node -> modelDefs.addAll(parseModels(node)));
} }
// Pipeline configs are optional // Pipeline configs are optional
@ -156,7 +169,7 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
pipelineConfig.addAll(userPipelineConfig); pipelineConfig.addAll(userPipelineConfig);
return new PipelineDef( return new PipelineDef(
sourceDef, sinkDef, routeDefs, transformDefs, udfDefs, pipelineConfig); sourceDef, sinkDef, routeDefs, transformDefs, udfDefs, modelDefs, pipelineConfig);
} }
private SourceDef toSourceDef(JsonNode sourceNode) { private SourceDef toSourceDef(JsonNode sourceNode) {
@ -323,4 +336,34 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser {
pipelineConfigNode, new TypeReference<Map<String, String>>() {}); pipelineConfigNode, new TypeReference<Map<String, String>>() {});
return Configuration.fromMap(pipelineConfigMap); return Configuration.fromMap(pipelineConfigMap);
} }
private List<ModelDef> parseModels(JsonNode modelsNode) {
List<ModelDef> modelDefs = new ArrayList<>();
Preconditions.checkNotNull(modelsNode, "`model` in `pipeline` should not be empty.");
if (modelsNode.isArray()) {
for (JsonNode modelNode : modelsNode) {
modelDefs.add(convertJsonNodeToModelDef(modelNode));
}
} else {
modelDefs.add(convertJsonNodeToModelDef(modelsNode));
}
return modelDefs;
}
private ModelDef convertJsonNodeToModelDef(JsonNode modelNode) {
String name =
checkNotNull(
modelNode.get(MODEL_NAME_KEY),
"Missing required field \"%s\" in `model`",
MODEL_NAME_KEY)
.asText();
String model =
checkNotNull(
modelNode.get(MODEL_CLASS_NAME_KEY),
"Missing required field \"%s\" in `model`",
MODEL_CLASS_NAME_KEY)
.asText();
Map<String, String> properties = mapper.convertValue(modelNode, Map.class);
return new ModelDef(name, model, properties);
}
} }

@ -20,6 +20,7 @@ package org.apache.flink.cdc.cli.parser;
import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.event.SchemaChangeEventType; import org.apache.flink.cdc.common.event.SchemaChangeEventType;
import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.PipelineDef;
import org.apache.flink.cdc.composer.definition.RouteDef; import org.apache.flink.cdc.composer.definition.RouteDef;
import org.apache.flink.cdc.composer.definition.SinkDef; import org.apache.flink.cdc.composer.definition.SinkDef;
@ -39,6 +40,7 @@ import java.time.Duration;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Set; import java.util.Set;
import static org.apache.flink.cdc.common.event.SchemaChangeEventType.ADD_COLUMN; import static org.apache.flink.cdc.common.event.SchemaChangeEventType.ADD_COLUMN;
@ -344,6 +346,18 @@ class YamlPipelineDefinitionParserTest {
null, null,
"add new uniq_id for each row")), "add new uniq_id for each row")),
Collections.emptyList(), Collections.emptyList(),
Collections.singletonList(
new ModelDef(
"GET_EMBEDDING",
"OpenAIEmbeddingModel",
new LinkedHashMap<>(
ImmutableMap.<String, String>builder()
.put("model-name", "GET_EMBEDDING")
.put("class-name", "OpenAIEmbeddingModel")
.put("openai.model", "text-embedding-3-small")
.put("openai.host", "https://xxxx")
.put("openai.apikey", "abcd1234")
.build()))),
Configuration.fromMap( Configuration.fromMap(
ImmutableMap.<String, String>builder() ImmutableMap.<String, String>builder()
.put("name", "source-database-sync-pipe") .put("name", "source-database-sync-pipe")
@ -397,7 +411,13 @@ class YamlPipelineDefinitionParserTest {
+ " name: source-database-sync-pipe\n" + " name: source-database-sync-pipe\n"
+ " parallelism: 4\n" + " parallelism: 4\n"
+ " schema.change.behavior: evolve\n" + " schema.change.behavior: evolve\n"
+ " schema-operator.rpc-timeout: 1 h"; + " schema-operator.rpc-timeout: 1 h\n"
+ " model:\n"
+ " - model-name: GET_EMBEDDING\n"
+ " class-name: OpenAIEmbeddingModel\n"
+ " openai.model: text-embedding-3-small\n"
+ " openai.host: https://xxxx\n"
+ " openai.apikey: abcd1234";
YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser();
PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration());
assertThat(pipelineDef).isEqualTo(fullDef); assertThat(pipelineDef).isEqualTo(fullDef);
@ -459,6 +479,18 @@ class YamlPipelineDefinitionParserTest {
null, null,
"add new uniq_id for each row")), "add new uniq_id for each row")),
Collections.emptyList(), Collections.emptyList(),
Collections.singletonList(
new ModelDef(
"GET_EMBEDDING",
"OpenAIEmbeddingModel",
new LinkedHashMap<>(
ImmutableMap.<String, String>builder()
.put("model-name", "GET_EMBEDDING")
.put("class-name", "OpenAIEmbeddingModel")
.put("openai.model", "text-embedding-3-small")
.put("openai.host", "https://xxxx")
.put("openai.apikey", "abcd1234")
.build()))),
Configuration.fromMap( Configuration.fromMap(
ImmutableMap.<String, String>builder() ImmutableMap.<String, String>builder()
.put("name", "source-database-sync-pipe") .put("name", "source-database-sync-pipe")

@ -57,3 +57,9 @@ pipeline:
parallelism: 4 parallelism: 4
schema.change.behavior: evolve schema.change.behavior: evolve
schema-operator.rpc-timeout: 1 h schema-operator.rpc-timeout: 1 h
model:
model-name: GET_EMBEDDING
class-name: OpenAIEmbeddingModel
openai.model: text-embedding-3-small
openai.host: https://xxxx
openai.apikey: abcd1234

@ -30,9 +30,20 @@ public interface UserDefinedFunction {
return null; return null;
} }
/** This will be invoked every time when a UDF got created. */ /**
* This will be invoked every time when a UDF got created.
*
* <p>this method is {@link Deprecated}, please use {@link #open(UserDefinedFunctionContext)}
* instead.
*/
@Deprecated
default void open() throws Exception {} default void open() throws Exception {}
/** This will be invoked every time when a UDF got created. */
default void open(UserDefinedFunctionContext context) throws Exception {
open();
}
/** This will be invoked before a UDF got destroyed. */ /** This will be invoked before a UDF got destroyed. */
default void close() throws Exception {} default void close() throws Exception {}
} }

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.common.udf;
import org.apache.flink.cdc.common.configuration.Configuration;
/** Context for initialization of {@link UserDefinedFunction}. */
public interface UserDefinedFunctionContext {
Configuration configuration();
}

@ -67,6 +67,12 @@ limitations under the License.
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-cdc-pipeline-model</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<!-- This is for testing Scala UDF.--> <!-- This is for testing Scala UDF.-->
<dependency> <dependency>

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.composer.definition;
import java.util.Map;
import java.util.Objects;
/**
* Common properties of model.
*
* <p>A transformation definition contains:
*
* <ul>
* <li>modelName: The name of function.
* <li>className: The model to transform data.
* <li>parameters: The parameters that used to configure the model.
* </ul>
*/
public class ModelDef {
private final String modelName;
private final String className;
private final Map<String, String> parameters;
public ModelDef(String modelName, String className, Map<String, String> parameters) {
this.modelName = modelName;
this.className = className;
this.parameters = parameters;
}
public String getModelName() {
return modelName;
}
public String getClassName() {
return className;
}
public Map<String, String> getParameters() {
return parameters;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ModelDef modelDef = (ModelDef) o;
return Objects.equals(modelName, modelDef.modelName)
&& Objects.equals(className, modelDef.className)
&& Objects.equals(parameters, modelDef.parameters);
}
@Override
public int hashCode() {
return Objects.hash(modelName, className, parameters);
}
@Override
public String toString() {
return "ModelDef{"
+ "name='"
+ modelName
+ '\''
+ ", model='"
+ className
+ '\''
+ ", parameters="
+ parameters
+ '}';
}
}

@ -24,6 +24,7 @@ import org.apache.flink.cdc.composer.PipelineComposer;
import org.apache.flink.cdc.composer.PipelineExecution; import org.apache.flink.cdc.composer.PipelineExecution;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.TimeZone; import java.util.TimeZone;
@ -55,6 +56,7 @@ public class PipelineDef {
private final List<RouteDef> routes; private final List<RouteDef> routes;
private final List<TransformDef> transforms; private final List<TransformDef> transforms;
private final List<UdfDef> udfs; private final List<UdfDef> udfs;
private final List<ModelDef> models;
private final Configuration config; private final Configuration config;
public PipelineDef( public PipelineDef(
@ -63,15 +65,27 @@ public class PipelineDef {
List<RouteDef> routes, List<RouteDef> routes,
List<TransformDef> transforms, List<TransformDef> transforms,
List<UdfDef> udfs, List<UdfDef> udfs,
List<ModelDef> models,
Configuration config) { Configuration config) {
this.source = source; this.source = source;
this.sink = sink; this.sink = sink;
this.routes = routes; this.routes = routes;
this.transforms = transforms; this.transforms = transforms;
this.udfs = udfs; this.udfs = udfs;
this.models = models;
this.config = evaluatePipelineTimeZone(config); this.config = evaluatePipelineTimeZone(config);
} }
public PipelineDef(
SourceDef source,
SinkDef sink,
List<RouteDef> routes,
List<TransformDef> transforms,
List<UdfDef> udfs,
Configuration config) {
this(source, sink, routes, transforms, udfs, new ArrayList<>(), config);
}
public SourceDef getSource() { public SourceDef getSource() {
return source; return source;
} }
@ -92,6 +106,10 @@ public class PipelineDef {
return udfs; return udfs;
} }
public List<ModelDef> getModels() {
return models;
}
public Configuration getConfig() { public Configuration getConfig() {
return config; return config;
} }
@ -109,6 +127,8 @@ public class PipelineDef {
+ transforms + transforms
+ ", udfs=" + ", udfs="
+ udfs + udfs
+ ", models="
+ models
+ ", config=" + ", config="
+ config + config
+ '}'; + '}';
@ -128,12 +148,13 @@ public class PipelineDef {
&& Objects.equals(routes, that.routes) && Objects.equals(routes, that.routes)
&& Objects.equals(transforms, that.transforms) && Objects.equals(transforms, that.transforms)
&& Objects.equals(udfs, that.udfs) && Objects.equals(udfs, that.udfs)
&& Objects.equals(models, that.models)
&& Objects.equals(config, that.config); && Objects.equals(config, that.config);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(source, sink, routes, transforms, udfs, config); return Objects.hash(source, sink, routes, transforms, udfs, models, config);
} }
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------

@ -107,7 +107,10 @@ public class FlinkPipelineComposer implements PipelineComposer {
TransformTranslator transformTranslator = new TransformTranslator(); TransformTranslator transformTranslator = new TransformTranslator();
stream = stream =
transformTranslator.translatePreTransform( transformTranslator.translatePreTransform(
stream, pipelineDef.getTransforms(), pipelineDef.getUdfs()); stream,
pipelineDef.getTransforms(),
pipelineDef.getUdfs(),
pipelineDef.getModels());
// Schema operator // Schema operator
SchemaOperatorTranslator schemaOperatorTranslator = SchemaOperatorTranslator schemaOperatorTranslator =
@ -124,8 +127,9 @@ public class FlinkPipelineComposer implements PipelineComposer {
transformTranslator.translatePostTransform( transformTranslator.translatePostTransform(
stream, stream,
pipelineDef.getTransforms(), pipelineDef.getTransforms(),
pipelineDefConfig.get(PipelineOptions.PIPELINE_LOCAL_TIME_ZONE), pipelineDef.getConfig().get(PipelineOptions.PIPELINE_LOCAL_TIME_ZONE),
pipelineDef.getUdfs()); pipelineDef.getUdfs(),
pipelineDef.getModels());
// Build DataSink in advance as schema operator requires MetadataApplier // Build DataSink in advance as schema operator requires MetadataApplier
DataSinkTranslator sinkTranslator = new DataSinkTranslator(); DataSinkTranslator sinkTranslator = new DataSinkTranslator();

@ -17,8 +17,9 @@
package org.apache.flink.cdc.composer.flink.translator; package org.apache.flink.cdc.composer.flink.translator;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.cdc.common.event.Event; import org.apache.flink.cdc.common.event.Event;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.TransformDef; import org.apache.flink.cdc.composer.definition.TransformDef;
import org.apache.flink.cdc.composer.definition.UdfDef; import org.apache.flink.cdc.composer.definition.UdfDef;
import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperator; import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperator;
@ -26,7 +27,9 @@ import org.apache.flink.cdc.runtime.operators.transform.PreTransformOperator;
import org.apache.flink.cdc.runtime.typeutils.EventTypeInfo; import org.apache.flink.cdc.runtime.typeutils.EventTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStream;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@ -35,8 +38,15 @@ import java.util.stream.Collectors;
*/ */
public class TransformTranslator { public class TransformTranslator {
/** Package of built-in model. */
public static final String PREFIX_CLASSPATH_BUILT_IN_MODEL =
"org.apache.flink.cdc.runtime.model.";
public DataStream<Event> translatePreTransform( public DataStream<Event> translatePreTransform(
DataStream<Event> input, List<TransformDef> transforms, List<UdfDef> udfFunctions) { DataStream<Event> input,
List<TransformDef> transforms,
List<UdfDef> udfFunctions,
List<ModelDef> models) {
if (transforms.isEmpty()) { if (transforms.isEmpty()) {
return input; return input;
} }
@ -52,10 +62,11 @@ public class TransformTranslator {
transform.getPartitionKeys(), transform.getPartitionKeys(),
transform.getTableOptions()); transform.getTableOptions());
} }
preTransformFunctionBuilder.addUdfFunctions(
udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList()));
preTransformFunctionBuilder.addUdfFunctions( preTransformFunctionBuilder.addUdfFunctions(
udfFunctions.stream() models.stream().map(this::modelToUDFTuple).collect(Collectors.toList()));
.map(udf -> Tuple2.of(udf.getName(), udf.getClasspath()))
.collect(Collectors.toList()));
return input.transform( return input.transform(
"Transform:Schema", new EventTypeInfo(), preTransformFunctionBuilder.build()); "Transform:Schema", new EventTypeInfo(), preTransformFunctionBuilder.build());
} }
@ -64,7 +75,8 @@ public class TransformTranslator {
DataStream<Event> input, DataStream<Event> input,
List<TransformDef> transforms, List<TransformDef> transforms,
String timezone, String timezone,
List<UdfDef> udfFunctions) { List<UdfDef> udfFunctions,
List<ModelDef> models) {
if (transforms.isEmpty()) { if (transforms.isEmpty()) {
return input; return input;
} }
@ -84,10 +96,21 @@ public class TransformTranslator {
} }
postTransformFunctionBuilder.addTimezone(timezone); postTransformFunctionBuilder.addTimezone(timezone);
postTransformFunctionBuilder.addUdfFunctions( postTransformFunctionBuilder.addUdfFunctions(
udfFunctions.stream() udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList()));
.map(udf -> Tuple2.of(udf.getName(), udf.getClasspath())) postTransformFunctionBuilder.addUdfFunctions(
.collect(Collectors.toList())); models.stream().map(this::modelToUDFTuple).collect(Collectors.toList()));
return input.transform( return input.transform(
"Transform:Data", new EventTypeInfo(), postTransformFunctionBuilder.build()); "Transform:Data", new EventTypeInfo(), postTransformFunctionBuilder.build());
} }
private Tuple3<String, String, Map<String, String>> modelToUDFTuple(ModelDef model) {
return Tuple3.of(
model.getModelName(),
PREFIX_CLASSPATH_BUILT_IN_MODEL + model.getClassName(),
model.getParameters());
}
private Tuple3<String, String, Map<String, String>> udfDefToUDFTuple(UdfDef udf) {
return Tuple3.of(udf.getName(), udf.getClasspath(), new HashMap<>());
}
} }

@ -21,6 +21,7 @@ import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior;
import org.apache.flink.cdc.composer.PipelineExecution; import org.apache.flink.cdc.composer.PipelineExecution;
import org.apache.flink.cdc.composer.definition.ModelDef;
import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.PipelineDef;
import org.apache.flink.cdc.composer.definition.SinkDef; import org.apache.flink.cdc.composer.definition.SinkDef;
import org.apache.flink.cdc.composer.definition.SourceDef; import org.apache.flink.cdc.composer.definition.SourceDef;
@ -35,6 +36,8 @@ import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceOptions;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.test.junit5.MiniClusterExtension; import org.apache.flink.test.junit5.MiniClusterExtension;
import org.apache.flink.shaded.guava31.com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
@ -44,8 +47,10 @@ import org.junit.jupiter.params.provider.MethodSource;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.PrintStream; import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL; import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL;
@ -822,6 +827,75 @@ public class FlinkPipelineUdfITCase {
"DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, , 4, Integer: 42, 2-42], after=[2, x, 4, Integer: 42, 2-42], op=UPDATE, meta=()}"); "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, , 4, Integer: 42, 2-42], after=[2, x, 4, Integer: 42, 2-42], op=UPDATE, meta=()}");
} }
@ParameterizedTest
@MethodSource("testParams")
void testTransformWithModel(ValuesDataSink.SinkApi sinkApi) throws Exception {
FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster();
// Setup value source
Configuration sourceConfig = new Configuration();
sourceConfig.set(
ValuesDataSourceOptions.EVENT_SET_ID,
ValuesDataSourceHelper.EventSetId.TRANSFORM_TABLE);
SourceDef sourceDef =
new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig);
// Setup value sink
Configuration sinkConfig = new Configuration();
sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true);
sinkConfig.set(ValuesDataSinkOptions.SINK_API, sinkApi);
SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig);
// Setup transform
TransformDef transformDef =
new TransformDef(
"default_namespace.default_schema.table1",
"*, CHAT(col1) AS emb",
null,
"col1",
null,
"key1=value1",
"");
// Setup pipeline
Configuration pipelineConfig = new Configuration();
pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1);
pipelineConfig.set(
PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE);
PipelineDef pipelineDef =
new PipelineDef(
sourceDef,
sinkDef,
Collections.emptyList(),
Collections.singletonList(transformDef),
new ArrayList<>(),
Arrays.asList(
new ModelDef(
"CHAT",
"OpenAIChatModel",
new LinkedHashMap<>(
ImmutableMap.<String, String>builder()
.put("openai.model", "gpt-4o-mini")
.put(
"openai.host",
"http://langchain4j.dev/demo/openai/v1")
.put("openai.apikey", "demo")
.build()))),
pipelineConfig);
// Execute the pipeline
PipelineExecution execution = composer.compose(pipelineDef);
execution.execute();
// Check the order and content of all received events
String[] outputEvents = outCaptor.toString().trim().split("\n");
assertThat(outputEvents)
.contains(
"CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING,`col2` STRING,`emb` STRING}, primaryKeys=col1, options=({key1=value1})}")
// The result of transform by model is not fixed.
.hasSize(9);
}
private static Stream<Arguments> testParams() { private static Stream<Arguments> testParams() {
return Stream.of( return Stream.of(
arguments(ValuesDataSink.SinkApi.SINK_FUNCTION, "java"), arguments(ValuesDataSink.SinkApi.SINK_FUNCTION, "java"),

@ -0,0 +1,81 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>flink-cdc-parent</artifactId>
<groupId>org.apache.flink</groupId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>flink-cdc-pipeline-model</artifactId>
<properties>
<langchain4j.version>0.23.0</langchain4j.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-cdc-common</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-test-utils-junit</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>0.12.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<executions>
<execution>
<id>test-jar</id>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.ConfigOption;
import org.apache.flink.cdc.common.configuration.ConfigOptions;
/** Options of built-in model. */
public class ModelOptions {
// Options for Open AI Model.
public static final ConfigOption<String> OPENAI_MODEL_NAME =
ConfigOptions.key("openai.model")
.stringType()
.noDefaultValue()
.withDescription("Name of model to be called.");
public static final ConfigOption<String> OPENAI_HOST =
ConfigOptions.key("openai.host")
.stringType()
.noDefaultValue()
.withDescription("Host of the Model server to be connected.");
public static final ConfigOption<String> OPENAI_API_KEY =
ConfigOptions.key("openai.apikey")
.stringType()
.noDefaultValue()
.withDescription("Api Key for verification of the Model server.");
public static final ConfigOption<String> OPENAI_CHAT_PROMPT =
ConfigOptions.key("openai.chat.prompt")
.stringType()
.noDefaultValue()
.withDescription("Prompt for chat using OpenAI.");
}

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.apache.flink.cdc.common.utils.Preconditions;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collections;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_CHAT_PROMPT;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME;
/**
* A {@link UserDefinedFunction} that use Model defined by OpenAI to generate text, refer to <a
* href="https://docs.langchain4j.dev/integrations/language-models/open-ai/">docs</a>}.
*/
public class OpenAIChatModel implements UserDefinedFunction {
private static final Logger LOG = LoggerFactory.getLogger(OpenAIChatModel.class);
private OpenAiChatModel chatModel;
private String modelName;
private String host;
private String prompt;
public String eval(String input) {
return chat(input);
}
private String chat(String input) {
if (input == null || input.trim().isEmpty()) {
LOG.warn("Empty or null input provided for embedding.");
return "";
}
if (prompt != null) {
input = prompt + ": " + input;
}
return chatModel
.generate(Collections.singletonList(new UserMessage(input)))
.content()
.text();
}
@Override
public DataType getReturnType() {
return DataTypes.STRING();
}
@Override
public void open(UserDefinedFunctionContext userDefinedFunctionContext) {
Configuration modelOptions = userDefinedFunctionContext.configuration();
this.modelName = modelOptions.get(OPENAI_MODEL_NAME);
Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty.");
this.host = modelOptions.get(OPENAI_HOST);
Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty.");
String apiKey = modelOptions.get(OPENAI_API_KEY);
Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty.");
this.prompt = modelOptions.get(OPENAI_CHAT_PROMPT);
LOG.info("Opening OpenAIChatModel " + modelName + " " + host);
this.chatModel =
OpenAiChatModel.builder().apiKey(apiKey).baseUrl(host).modelName(modelName).build();
}
@Override
public void close() {
LOG.info("Closed OpenAIChatModel " + modelName + " " + host);
}
}

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.data.ArrayData;
import org.apache.flink.cdc.common.data.GenericArrayData;
import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.apache.flink.cdc.common.utils.Preconditions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collections;
import java.util.List;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST;
import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME;
/**
* A {@link UserDefinedFunction} that use Model defined by OpenAI to generate vector data, refer to
* <a href="https://docs.langchain4j.dev/integrations/language-models/open-ai/">docs</a>}.
*/
public class OpenAIEmbeddingModel implements UserDefinedFunction {
private static final Logger LOG = LoggerFactory.getLogger(OpenAIEmbeddingModel.class);
private String modelName;
private String host;
private OpenAiEmbeddingModel embeddingModel;
public ArrayData eval(String input) {
return getEmbedding(input);
}
private ArrayData getEmbedding(String input) {
if (input == null || input.trim().isEmpty()) {
LOG.debug("Empty or null input provided for embedding.");
return new GenericArrayData(new Float[0]);
}
TextSegment textSegment = new TextSegment(input, new Metadata());
List<Embedding> embeddings =
embeddingModel.embedAll(Collections.singletonList(textSegment)).content();
if (embeddings != null && !embeddings.isEmpty()) {
List<Float> embeddingList = embeddings.get(0).vectorAsList();
Float[] embeddingArray = embeddingList.toArray(new Float[0]);
return new GenericArrayData(embeddingArray);
} else {
LOG.warn("No embedding results returned for input: {}", input);
return new GenericArrayData(new Float[0]);
}
}
@Override
public DataType getReturnType() {
return DataTypes.ARRAY(DataTypes.FLOAT());
}
@Override
public void open(UserDefinedFunctionContext userDefinedFunctionContext) {
Configuration modelOptions = userDefinedFunctionContext.configuration();
this.modelName = modelOptions.get(OPENAI_MODEL_NAME);
Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty.");
this.host = modelOptions.get(OPENAI_HOST);
Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty.");
String apiKey = modelOptions.get(OPENAI_API_KEY);
Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty.");
LOG.info("Opening OpenAIEmbeddingModel " + modelName + " " + host);
this.embeddingModel =
OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(host)
.modelName(modelName)
.build();
}
@Override
public void close() {
LOG.info("Closed OpenAIEmbeddingModel " + modelName + " " + host);
}
}

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/** A test for {@link OpenAIChatModel}. */
public class TestOpenAIChatModel {
@Test
public void testEval() {
OpenAIChatModel openAIChatModel = new OpenAIChatModel();
Configuration configuration = new Configuration();
configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1");
configuration.set(ModelOptions.OPENAI_API_KEY, "demo");
configuration.set(ModelOptions.OPENAI_MODEL_NAME, "gpt-4o-mini");
UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration;
openAIChatModel.open(userDefinedFunctionContext);
String response = openAIChatModel.eval("Who invented the electric light?");
Assertions.assertFalse(response.isEmpty());
}
}

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cdc.runtime.model;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.data.ArrayData;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
/** A test for {@link OpenAIEmbeddingModel}. */
public class TestOpenAIEmbeddingModel {
@Test
public void testEval() {
OpenAIEmbeddingModel openAIEmbeddingModel = new OpenAIEmbeddingModel();
Configuration configuration = new Configuration();
configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1");
configuration.set(ModelOptions.OPENAI_API_KEY, "demo");
configuration.set(ModelOptions.OPENAI_MODEL_NAME, "text-embedding-3-small");
UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration;
openAIEmbeddingModel.open(userDefinedFunctionContext);
ArrayData arrayData =
openAIEmbeddingModel.eval("Flink CDC is a streaming data integration tool");
Assertions.assertNotNull(arrayData);
}
}

@ -89,6 +89,12 @@ limitations under the License.
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-cdc-pipeline-model</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

@ -18,6 +18,8 @@
package org.apache.flink.cdc.runtime.operators.transform; package org.apache.flink.cdc.runtime.operators.transform;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.cdc.common.configuration.Configuration;
import org.apache.flink.cdc.common.data.RecordData; import org.apache.flink.cdc.common.data.RecordData;
import org.apache.flink.cdc.common.data.binary.BinaryRecordData; import org.apache.flink.cdc.common.data.binary.BinaryRecordData;
import org.apache.flink.cdc.common.event.CreateTableEvent; import org.apache.flink.cdc.common.event.CreateTableEvent;
@ -29,6 +31,7 @@ import org.apache.flink.cdc.common.event.TableId;
import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.pipeline.PipelineOptions;
import org.apache.flink.cdc.common.schema.Schema; import org.apache.flink.cdc.common.schema.Schema;
import org.apache.flink.cdc.common.schema.Selectors; import org.apache.flink.cdc.common.schema.Selectors;
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
import org.apache.flink.cdc.common.utils.SchemaUtils; import org.apache.flink.cdc.common.utils.SchemaUtils;
import org.apache.flink.cdc.runtime.parser.TransformParser; import org.apache.flink.cdc.runtime.parser.TransformParser;
import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamConfig;
@ -70,7 +73,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
/** keep the relationship of TableId and table information. */ /** keep the relationship of TableId and table information. */
private final Map<TableId, PostTransformChangeInfo> postTransformChangeInfoMap; private final Map<TableId, PostTransformChangeInfo> postTransformChangeInfoMap;
private final List<Tuple2<String, String>> udfFunctions; private final List<Tuple3<String, String, Map<String, String>>> udfFunctions;
private List<UserDefinedFunctionDescriptor> udfDescriptors; private List<UserDefinedFunctionDescriptor> udfDescriptors;
private transient Map<String, Object> udfFunctionInstances; private transient Map<String, Object> udfFunctionInstances;
@ -89,7 +92,8 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
public static class Builder { public static class Builder {
private final List<TransformRule> transformRules = new ArrayList<>(); private final List<TransformRule> transformRules = new ArrayList<>();
private String timezone; private String timezone;
private final List<Tuple2<String, String>> udfFunctions = new ArrayList<>(); private final List<Tuple3<String, String, Map<String, String>>> udfFunctions =
new ArrayList<>();
public PostTransformOperator.Builder addTransform( public PostTransformOperator.Builder addTransform(
String tableInclusions, String tableInclusions,
@ -125,7 +129,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
} }
public PostTransformOperator.Builder addUdfFunctions( public PostTransformOperator.Builder addUdfFunctions(
List<Tuple2<String, String>> udfFunctions) { List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.udfFunctions.addAll(udfFunctions); this.udfFunctions.addAll(udfFunctions);
return this; return this;
} }
@ -138,7 +142,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
private PostTransformOperator( private PostTransformOperator(
List<TransformRule> transformRules, List<TransformRule> transformRules,
String timezone, String timezone,
List<Tuple2<String, String>> udfFunctions) { List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.transformRules = transformRules; this.transformRules = transformRules;
this.timezone = timezone; this.timezone = timezone;
this.postTransformChangeInfoMap = new ConcurrentHashMap<>(); this.postTransformChangeInfoMap = new ConcurrentHashMap<>();
@ -158,10 +162,7 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
super.setup(containingTask, config, output); super.setup(containingTask, config, output);
udfDescriptors = udfDescriptors =
udfFunctions.stream() udfFunctions.stream()
.map( .map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1, udf.f2))
udf -> {
return new UserDefinedFunctionDescriptor(udf.f0, udf.f1);
})
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@ -539,7 +540,12 @@ public class PostTransformOperator extends AbstractStreamOperator<Event>
// into UserDefinedFunction interface, thus the provided UDF classes // into UserDefinedFunction interface, thus the provided UDF classes
// might not be compatible with the interface definition in CDC common. // might not be compatible with the interface definition in CDC common.
Object udfInstance = udfFunctionInstances.get(udf.getName()); Object udfInstance = udfFunctionInstances.get(udf.getName());
udfInstance.getClass().getMethod("open").invoke(udfInstance); UserDefinedFunctionContext userDefinedFunctionContext =
() -> Configuration.fromMap(udf.getParameters());
udfInstance
.getClass()
.getMethod("open", UserDefinedFunctionContext.class)
.invoke(udfInstance, userDefinedFunctionContext);
} else { } else {
// Do nothing, Flink-style UDF lifecycle hooks are not supported // Do nothing, Flink-style UDF lifecycle hooks are not supported
} }

@ -21,6 +21,7 @@ import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.cdc.common.data.binary.BinaryRecordData; import org.apache.flink.cdc.common.data.binary.BinaryRecordData;
import org.apache.flink.cdc.common.event.CreateTableEvent; import org.apache.flink.cdc.common.event.CreateTableEvent;
import org.apache.flink.cdc.common.event.DataChangeEvent; import org.apache.flink.cdc.common.event.DataChangeEvent;
@ -69,7 +70,7 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
private final Map<TableId, PreTransformChangeInfo> preTransformChangeInfoMap; private final Map<TableId, PreTransformChangeInfo> preTransformChangeInfoMap;
private final List<Tuple2<Selectors, SchemaMetadataTransform>> schemaMetadataTransformers; private final List<Tuple2<Selectors, SchemaMetadataTransform>> schemaMetadataTransformers;
private transient ListState<byte[]> state; private transient ListState<byte[]> state;
private final List<Tuple2<String, String>> udfFunctions; private final List<Tuple3<String, String, Map<String, String>>> udfFunctions;
private List<UserDefinedFunctionDescriptor> udfDescriptors; private List<UserDefinedFunctionDescriptor> udfDescriptors;
private Map<TableId, PreTransformProcessor> preTransformProcessorMap; private Map<TableId, PreTransformProcessor> preTransformProcessorMap;
private Map<TableId, Boolean> hasAsteriskMap; private Map<TableId, Boolean> hasAsteriskMap;
@ -82,7 +83,8 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
public static class Builder { public static class Builder {
private final List<TransformRule> transformRules = new ArrayList<>(); private final List<TransformRule> transformRules = new ArrayList<>();
private final List<Tuple2<String, String>> udfFunctions = new ArrayList<>(); private final List<Tuple3<String, String, Map<String, String>>> udfFunctions =
new ArrayList<>();
public PreTransformOperator.Builder addTransform( public PreTransformOperator.Builder addTransform(
String tableInclusions, @Nullable String projection, @Nullable String filter) { String tableInclusions, @Nullable String projection, @Nullable String filter) {
@ -109,7 +111,7 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
} }
public PreTransformOperator.Builder addUdfFunctions( public PreTransformOperator.Builder addUdfFunctions(
List<Tuple2<String, String>> udfFunctions) { List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.udfFunctions.addAll(udfFunctions); this.udfFunctions.addAll(udfFunctions);
return this; return this;
} }
@ -120,7 +122,8 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
} }
private PreTransformOperator( private PreTransformOperator(
List<TransformRule> transformRules, List<Tuple2<String, String>> udfFunctions) { List<TransformRule> transformRules,
List<Tuple3<String, String, Map<String, String>>> udfFunctions) {
this.transformRules = transformRules; this.transformRules = transformRules;
this.preTransformChangeInfoMap = new ConcurrentHashMap<>(); this.preTransformChangeInfoMap = new ConcurrentHashMap<>();
this.preTransformProcessorMap = new ConcurrentHashMap<>(); this.preTransformProcessorMap = new ConcurrentHashMap<>();
@ -137,7 +140,7 @@ public class PreTransformOperator extends AbstractStreamOperator<Event>
super.setup(containingTask, config, output); super.setup(containingTask, config, output);
this.udfDescriptors = this.udfDescriptors =
this.udfFunctions.stream() this.udfFunctions.stream()
.map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1)) .map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1, udf.f2))
.collect(Collectors.toList()); .collect(Collectors.toList());
// Initialize data fields in advance because they might be accessed in // Initialize data fields in advance because they might be accessed in

@ -24,6 +24,8 @@ import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -38,9 +40,16 @@ public class UserDefinedFunctionDescriptor implements Serializable {
private final String className; private final String className;
private final DataType returnTypeHint; private final DataType returnTypeHint;
private final boolean isCdcPipelineUdf; private final boolean isCdcPipelineUdf;
private final Map<String, String> parameters;
public UserDefinedFunctionDescriptor(String name, String classpath) { public UserDefinedFunctionDescriptor(String name, String classpath) {
this(name, classpath, new HashMap<>());
}
public UserDefinedFunctionDescriptor(
String name, String classpath, Map<String, String> parameters) {
this.name = name; this.name = name;
this.parameters = parameters;
this.classpath = classpath; this.classpath = classpath;
this.className = classpath.substring(classpath.lastIndexOf('.') + 1); this.className = classpath.substring(classpath.lastIndexOf('.') + 1);
try { try {
@ -107,6 +116,10 @@ public class UserDefinedFunctionDescriptor implements Serializable {
return className; return className;
} }
public Map<String, String> getParameters() {
return parameters;
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) { if (this == o) {
@ -115,13 +128,19 @@ public class UserDefinedFunctionDescriptor implements Serializable {
if (o == null || getClass() != o.getClass()) { if (o == null || getClass() != o.getClass()) {
return false; return false;
} }
UserDefinedFunctionDescriptor context = (UserDefinedFunctionDescriptor) o; UserDefinedFunctionDescriptor that = (UserDefinedFunctionDescriptor) o;
return Objects.equals(name, context.name) && Objects.equals(classpath, context.classpath); return isCdcPipelineUdf == that.isCdcPipelineUdf
&& Objects.equals(name, that.name)
&& Objects.equals(classpath, that.classpath)
&& Objects.equals(className, that.className)
&& Objects.equals(returnTypeHint, that.returnTypeHint)
&& Objects.equals(parameters, that.parameters);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(name, classpath); return Objects.hash(
name, classpath, className, returnTypeHint, isCdcPipelineUdf, parameters);
} }
@Override @Override
@ -133,6 +152,15 @@ public class UserDefinedFunctionDescriptor implements Serializable {
+ ", classpath='" + ", classpath='"
+ classpath + classpath
+ '\'' + '\''
+ ", className='"
+ className
+ '\''
+ ", returnTypeHint="
+ returnTypeHint
+ ", isCdcPipelineUdf="
+ isCdcPipelineUdf
+ ", parameters="
+ parameters
+ '}'; + '}';
} }
} }

@ -36,6 +36,7 @@ import org.apache.calcite.prepare.CalciteCatalogReader;
import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexBuilder;
@ -123,10 +124,11 @@ public class TransformParser {
if (udf.getReturnTypeHint() != null) { if (udf.getReturnTypeHint() != null) {
// This UDF has return type hint annotation // This UDF has return type hint annotation
returnTypeInference = returnTypeInference =
o -> o -> {
o.getTypeFactory() RelDataTypeFactory typeFactory = o.getTypeFactory();
.createSqlType( DataType returnTypeHint = udf.getReturnTypeHint();
convertCalciteType(udf.getReturnTypeHint())); return convertCalciteType(typeFactory, returnTypeHint);
};
} else { } else {
// Infer it from eval method return type // Infer it from eval method return type
returnTypeInference = o -> function.getReturnType(o.getTypeFactory()); returnTypeInference = o -> function.getReturnType(o.getTypeFactory());

@ -310,4 +310,36 @@ public class TransformSqlOperatorTable extends ReflectiveSqlOperatorTable {
// Cast Functions // Cast Functions
// -------------- // --------------
public static final SqlFunction CAST = SqlStdOperatorTable.CAST; public static final SqlFunction CAST = SqlStdOperatorTable.CAST;
public static final SqlFunction AI_CHAT_PREDICT =
new SqlFunction(
"AI_CHAT_PREDICT",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.VARCHAR),
null,
OperandTypes.family(
SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
// Define the AI_EMBEDDING function
public static final SqlFunction GET_EMBEDDING =
new SqlFunction(
"GET_EMBEDDING",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.VARCHAR),
null,
OperandTypes.family(
SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
// Define the AI_LANGCHAIN_PREDICT function
public static final SqlFunction AI_LANGCHAIN_PREDICT =
new SqlFunction(
"AI_LANGCHAIN_PREDICT",
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.VARCHAR),
null,
OperandTypes.family(
SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING),
SqlFunctionCategory.USER_DEFINED_FUNCTION);
} }

@ -51,8 +51,7 @@ public class TransformTable extends AbstractTable {
for (Column column : columns) { for (Column column : columns) {
names.add(column.getName()); names.add(column.getName());
RelDataType sqlType = RelDataType sqlType =
relDataTypeFactory.createSqlType( DataTypeConverter.convertCalciteType(relDataTypeFactory, column.getType());
DataTypeConverter.convertCalciteType(column.getType()));
types.add(sqlType); types.add(sqlType);
} }
return relDataTypeFactory.createStructType(Pair.zip(names, types)); return relDataTypeFactory.createStructType(Pair.zip(names, types));

@ -17,18 +17,27 @@
package org.apache.flink.cdc.runtime.typeutils; package org.apache.flink.cdc.runtime.typeutils;
import org.apache.flink.cdc.common.data.ArrayData;
import org.apache.flink.cdc.common.data.DecimalData; import org.apache.flink.cdc.common.data.DecimalData;
import org.apache.flink.cdc.common.data.GenericArrayData;
import org.apache.flink.cdc.common.data.GenericMapData;
import org.apache.flink.cdc.common.data.LocalZonedTimestampData; import org.apache.flink.cdc.common.data.LocalZonedTimestampData;
import org.apache.flink.cdc.common.data.MapData;
import org.apache.flink.cdc.common.data.TimestampData; import org.apache.flink.cdc.common.data.TimestampData;
import org.apache.flink.cdc.common.data.binary.BinaryStringData; import org.apache.flink.cdc.common.data.binary.BinaryStringData;
import org.apache.flink.cdc.common.schema.Column; import org.apache.flink.cdc.common.schema.Column;
import org.apache.flink.cdc.common.types.ArrayType;
import org.apache.flink.cdc.common.types.BinaryType; import org.apache.flink.cdc.common.types.BinaryType;
import org.apache.flink.cdc.common.types.DataType; import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes; import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.types.DecimalType;
import org.apache.flink.cdc.common.types.MapType;
import org.apache.flink.cdc.common.types.RowType; import org.apache.flink.cdc.common.types.RowType;
import org.apache.flink.cdc.common.types.TimestampType;
import org.apache.flink.cdc.common.types.VarBinaryType; import org.apache.flink.cdc.common.types.VarBinaryType;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeName;
import java.math.BigDecimal; import java.math.BigDecimal;
@ -39,8 +48,11 @@ import java.time.Instant;
import java.time.LocalDate; import java.time.LocalDate;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.LocalTime; import java.time.LocalTime;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/** A data type converter. */ /** A data type converter. */
public class DataTypeConverter { public class DataTypeConverter {
@ -90,50 +102,67 @@ public class DataTypeConverter {
case ROW: case ROW:
return Object.class; return Object.class;
case ARRAY: case ARRAY:
return ArrayData.class;
case MAP: case MAP:
return MapData.class;
default: default:
throw new UnsupportedOperationException("Unsupported type: " + dataType); throw new UnsupportedOperationException("Unsupported type: " + dataType);
} }
} }
public static SqlTypeName convertCalciteType(DataType dataType) { public static RelDataType convertCalciteType(
RelDataTypeFactory typeFactory, DataType dataType) {
switch (dataType.getTypeRoot()) { switch (dataType.getTypeRoot()) {
case BOOLEAN: case BOOLEAN:
return SqlTypeName.BOOLEAN; return typeFactory.createSqlType(SqlTypeName.BOOLEAN);
case TINYINT: case TINYINT:
return SqlTypeName.TINYINT; return typeFactory.createSqlType(SqlTypeName.TINYINT);
case SMALLINT: case SMALLINT:
return SqlTypeName.SMALLINT; return typeFactory.createSqlType(SqlTypeName.SMALLINT);
case INTEGER: case INTEGER:
return SqlTypeName.INTEGER; return typeFactory.createSqlType(SqlTypeName.INTEGER);
case BIGINT: case BIGINT:
return SqlTypeName.BIGINT; return typeFactory.createSqlType(SqlTypeName.BIGINT);
case DATE: case DATE:
return SqlTypeName.DATE; return typeFactory.createSqlType(SqlTypeName.DATE);
case TIME_WITHOUT_TIME_ZONE: case TIME_WITHOUT_TIME_ZONE:
return SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE; return typeFactory.createSqlType(SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE);
case TIMESTAMP_WITHOUT_TIME_ZONE: case TIMESTAMP_WITHOUT_TIME_ZONE:
return SqlTypeName.TIMESTAMP; return typeFactory.createSqlType(SqlTypeName.TIMESTAMP);
case TIMESTAMP_WITH_LOCAL_TIME_ZONE: case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; return typeFactory.createSqlType(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE);
case FLOAT: case FLOAT:
return SqlTypeName.FLOAT; return typeFactory.createSqlType(SqlTypeName.FLOAT);
case DOUBLE: case DOUBLE:
return SqlTypeName.DOUBLE; return typeFactory.createSqlType(SqlTypeName.DOUBLE);
case CHAR: case CHAR:
return SqlTypeName.CHAR; return typeFactory.createSqlType(SqlTypeName.CHAR);
case VARCHAR: case VARCHAR:
return SqlTypeName.VARCHAR; return typeFactory.createSqlType(SqlTypeName.VARCHAR);
case BINARY: case BINARY:
return SqlTypeName.BINARY; return typeFactory.createSqlType(SqlTypeName.BINARY);
case VARBINARY: case VARBINARY:
return SqlTypeName.VARBINARY; return typeFactory.createSqlType(SqlTypeName.VARBINARY);
case DECIMAL: case DECIMAL:
return SqlTypeName.DECIMAL; return typeFactory.createSqlType(SqlTypeName.DECIMAL);
case ROW: case ROW:
return SqlTypeName.ROW; List<RelDataType> dataTypes =
((RowType) dataType)
.getFieldTypes().stream()
.map((type) -> convertCalciteType(typeFactory, type))
.collect(Collectors.toList());
return typeFactory.createStructType(
dataTypes, ((RowType) dataType).getFieldNames());
case ARRAY: case ARRAY:
DataType elementType = ((ArrayType) dataType).getElementType();
return typeFactory.createArrayType(
convertCalciteType(typeFactory, elementType), -1);
case MAP: case MAP:
RelDataType keyType =
convertCalciteType(typeFactory, ((MapType) dataType).getKeyType());
RelDataType valueType =
convertCalciteType(typeFactory, ((MapType) dataType).getValueType());
return typeFactory.createMapType(keyType, valueType);
default: default:
throw new UnsupportedOperationException("Unsupported type: " + dataType); throw new UnsupportedOperationException("Unsupported type: " + dataType);
} }
@ -173,9 +202,16 @@ public class DataTypeConverter {
return DataTypes.VARBINARY(VarBinaryType.MAX_LENGTH); return DataTypes.VARBINARY(VarBinaryType.MAX_LENGTH);
case DECIMAL: case DECIMAL:
return DataTypes.DECIMAL(relDataType.getPrecision(), relDataType.getScale()); return DataTypes.DECIMAL(relDataType.getPrecision(), relDataType.getScale());
case ROW:
case ARRAY: case ARRAY:
RelDataType componentType = relDataType.getComponentType();
return DataTypes.ARRAY(convertCalciteRelDataTypeToDataType(componentType));
case MAP: case MAP:
RelDataType keyType = relDataType.getKeyType();
RelDataType valueType = relDataType.getValueType();
return DataTypes.MAP(
convertCalciteRelDataTypeToDataType(keyType),
convertCalciteRelDataTypeToDataType(valueType));
case ROW:
default: default:
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Unsupported type: " + relDataType.getSqlTypeName()); "Unsupported type: " + relDataType.getSqlTypeName());
@ -220,7 +256,9 @@ public class DataTypeConverter {
case ROW: case ROW:
return value; return value;
case ARRAY: case ARRAY:
return convertToArray(value, (ArrayType) dataType);
case MAP: case MAP:
return convertToMap(value, (MapType) dataType);
default: default:
throw new UnsupportedOperationException("Unsupported type: " + dataType); throw new UnsupportedOperationException("Unsupported type: " + dataType);
} }
@ -264,7 +302,9 @@ public class DataTypeConverter {
case ROW: case ROW:
return value; return value;
case ARRAY: case ARRAY:
return convertToArrayOriginal(value, (ArrayType) dataType);
case MAP: case MAP:
return convertToMapOriginal(value, (MapType) dataType);
default: default:
throw new UnsupportedOperationException("Unsupported type: " + dataType); throw new UnsupportedOperationException("Unsupported type: " + dataType);
} }
@ -378,6 +418,101 @@ public class DataTypeConverter {
return toLocalTime(obj).toSecondOfDay() * 1000; return toLocalTime(obj).toSecondOfDay() * 1000;
} }
private static Object convertToArray(Object obj, ArrayType arrayType) {
if (obj instanceof ArrayData) {
return obj;
}
if (obj instanceof List) {
List<?> list = (List<?>) obj;
GenericArrayData arrayData = new GenericArrayData(list.toArray());
return arrayData;
}
if (obj.getClass().isArray()) {
return new GenericArrayData((Object[]) obj);
}
throw new IllegalArgumentException("Unable to convert to ArrayData: " + obj);
}
private static Object convertToArrayOriginal(Object obj, ArrayType arrayType) {
if (obj instanceof ArrayData) {
ArrayData arrayData = (ArrayData) obj;
Object[] result = new Object[arrayData.size()];
for (int i = 0; i < arrayData.size(); i++) {
result[i] = getArrayElement(arrayData, i, arrayType.getElementType());
}
return result;
}
return obj;
}
private static Object getArrayElement(ArrayData arrayData, int pos, DataType elementType) {
switch (elementType.getTypeRoot()) {
case BOOLEAN:
return arrayData.getBoolean(pos);
case TINYINT:
return arrayData.getByte(pos);
case SMALLINT:
return arrayData.getShort(pos);
case INTEGER:
return arrayData.getInt(pos);
case BIGINT:
return arrayData.getLong(pos);
case FLOAT:
return arrayData.getFloat(pos);
case DOUBLE:
return arrayData.getDouble(pos);
case CHAR:
case VARCHAR:
return arrayData.getString(pos);
case DECIMAL:
return arrayData.getDecimal(
pos,
((DecimalType) elementType).getPrecision(),
((DecimalType) elementType).getScale());
case DATE:
return arrayData.getInt(pos);
case TIME_WITHOUT_TIME_ZONE:
return arrayData.getInt(pos);
case TIMESTAMP_WITHOUT_TIME_ZONE:
return arrayData.getTimestamp(pos, ((TimestampType) elementType).getPrecision());
case ARRAY:
return convertToArrayOriginal(arrayData.getArray(pos), (ArrayType) elementType);
case MAP:
return convertToMapOriginal(arrayData.getMap(pos), (MapType) elementType);
default:
throw new UnsupportedOperationException(
"Unsupported array element type: " + elementType);
}
}
private static Object convertToMap(Object obj, MapType mapType) {
if (obj instanceof MapData) {
return obj;
}
if (obj instanceof Map) {
Map<?, ?> javaMap = (Map<?, ?>) obj;
GenericMapData mapData = new GenericMapData(javaMap);
return mapData;
}
throw new IllegalArgumentException("Unable to convert to MapData: " + obj);
}
private static Object convertToMapOriginal(Object obj, MapType mapType) {
if (obj instanceof MapData) {
MapData mapData = (MapData) obj;
Map<Object, Object> result = new HashMap<>();
ArrayData keyArray = mapData.keyArray();
ArrayData valueArray = mapData.valueArray();
for (int i = 0; i < mapData.size(); i++) {
Object key = getArrayElement(keyArray, i, mapType.getKeyType());
Object value = getArrayElement(valueArray, i, mapType.getValueType());
result.put(key, value);
}
return result;
}
return obj;
}
private static LocalTime toLocalTime(Object obj) { private static LocalTime toLocalTime(Object obj) {
if (obj == null) { if (obj == null) {
return null; return null;

@ -20,8 +20,10 @@ package org.apache.flink.cdc.runtime.operators.transform;
import org.apache.flink.cdc.common.types.DataType; import org.apache.flink.cdc.common.types.DataType;
import org.apache.flink.cdc.common.types.DataTypes; import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.common.udf.UserDefinedFunction; import org.apache.flink.cdc.common.udf.UserDefinedFunction;
import org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel;
import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.ScalarFunction;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -48,7 +50,7 @@ public class UserDefinedFunctionDescriptorTest {
public static class NotUDF {} public static class NotUDF {}
@Test @Test
void testUserDefinedFunctionDescriptor() { void testUserDefinedFunctionDescriptor() throws JsonProcessingException {
assertThat(new UserDefinedFunctionDescriptor("cdc_udf", CdcUdf.class.getName())) assertThat(new UserDefinedFunctionDescriptor("cdc_udf", CdcUdf.class.getName()))
.extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf")
@ -93,5 +95,14 @@ public class UserDefinedFunctionDescriptorTest {
"not_even_exist", "not.a.valid.class.path")) "not_even_exist", "not.a.valid.class.path"))
.isExactlyInstanceOf(IllegalArgumentException.class) .isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessage("Failed to instantiate UDF not_even_exist@not.a.valid.class.path"); .hasMessage("Failed to instantiate UDF not_even_exist@not.a.valid.class.path");
String name = "GET_EMBEDDING";
assertThat(new UserDefinedFunctionDescriptor(name, OpenAIEmbeddingModel.class.getName()))
.extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf")
.containsExactly(
"GET_EMBEDDING",
"OpenAIEmbeddingModel",
"org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel",
DataTypes.ARRAY(DataTypes.FLOAT()),
true);
} }
} }

@ -41,6 +41,7 @@ limitations under the License.
<module>flink-cdc-runtime</module> <module>flink-cdc-runtime</module>
<module>flink-cdc-e2e-tests</module> <module>flink-cdc-e2e-tests</module>
<module>flink-cdc-pipeline-udf-examples</module> <module>flink-cdc-pipeline-udf-examples</module>
<module>flink-cdc-pipeline-model</module>
</modules> </modules>
<licenses> <licenses>

Loading…
Cancel
Save