From 06154e96743a7217538f96ccc53f2f9d15bba865 Mon Sep 17 00:00:00 2001 From: Kunni Date: Tue, 26 Nov 2024 16:40:45 +0800 Subject: [PATCH] [FLINK-36525][transform] Support for AI Model Integration for Data Processing (#3753) --- .../content.zh/docs/core-concept/transform.md | 71 +++++++ docs/content/docs/core-concept/transform.md | 72 ++++++++ .../parser/YamlPipelineDefinitionParser.java | 45 ++++- .../YamlPipelineDefinitionParserTest.java | 34 +++- .../definitions/pipeline-definition-full.yaml | 6 + .../cdc/common/udf/UserDefinedFunction.java | 13 +- .../udf/UserDefinedFunctionContext.java | 26 +++ flink-cdc-composer/pom.xml | 6 + .../cdc/composer/definition/ModelDef.java | 92 ++++++++++ .../cdc/composer/definition/PipelineDef.java | 23 ++- .../composer/flink/FlinkPipelineComposer.java | 10 +- .../flink/translator/TransformTranslator.java | 41 ++++- .../flink/FlinkPipelineUdfITCase.java | 74 ++++++++ flink-cdc-pipeline-model/pom.xml | 81 ++++++++ .../flink/cdc/runtime/model/ModelOptions.java | 50 +++++ .../cdc/runtime/model/OpenAIChatModel.java | 97 ++++++++++ .../runtime/model/OpenAIEmbeddingModel.java | 109 +++++++++++ .../runtime/model/TestOpenAIChatModel.java | 40 ++++ .../model/TestOpenAIEmbeddingModel.java | 43 +++++ flink-cdc-runtime/pom.xml | 6 + .../transform/PostTransformOperator.java | 24 ++- .../transform/PreTransformOperator.java | 13 +- .../UserDefinedFunctionDescriptor.java | 34 +++- .../cdc/runtime/parser/TransformParser.java | 10 +- .../metadata/TransformSqlOperatorTable.java | 32 ++++ .../parser/metadata/TransformTable.java | 3 +- .../runtime/typeutils/DataTypeConverter.java | 173 ++++++++++++++++-- .../UserDefinedFunctionDescriptorTest.java | 13 +- pom.xml | 1 + 29 files changed, 1183 insertions(+), 59 deletions(-) create mode 100644 flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunctionContext.java create mode 100644 flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java create mode 100644 flink-cdc-pipeline-model/pom.xml create mode 100644 flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java create mode 100644 flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java create mode 100644 flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java create mode 100644 flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java create mode 100644 flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java diff --git a/docs/content.zh/docs/core-concept/transform.md b/docs/content.zh/docs/core-concept/transform.md index da2a6a4ee..adff80fa3 100644 --- a/docs/content.zh/docs/core-concept/transform.md +++ b/docs/content.zh/docs/core-concept/transform.md @@ -356,6 +356,77 @@ transform: filter: inc(id) < 100 ``` +## Embedding AI Model + +Embedding AI Model can be used in transform rules. +To use Embedding AI Model, you need to download the jar of build-in model, and then add `--jar {$BUILT_IN_MODEL_PATH}` to your flink-cdc.sh command. + +How to define a Embedding AI Model: + +```yaml +pipeline: + model: + - model-name: CHAT + class-name: OpenAIChatModel + openai.model: text-embedding-3-small + openai.host: https://xxxx + openai.apikey: abcd1234 + openai.chat.prompt: please summary this + - model-name: GET_EMBEDDING + class-name: OpenAIEmbeddingModel + openai.model: text-embedding-3-small + openai.host: https://xxxx + openai.apiKky: abcd1234 +``` +Note: +* `model-name` is a common required parameter for all support models, which represent the function name called in `projection` or `filter`. +* `class-name` is a common required parameter for all support models, available values can be found in [All Support models](#all-support-models). +* `openai.model`, `openai.host`, `openai.apiKey` and `openai.chat.prompt` is option parameters that defined in specific model. + +How to use a Embedding AI Model: + +```yaml +transform: + - source-table: db.\.* + projection: "*, inc(inc(inc(id))) as inc_id, GET_EMBEDDING(page) as emb, CHAT(page) as summary" + filter: inc(id) < 100 +pipeline: + model: + - model-name: CHAT + class-name: OpenAIChatModel + openai.model: gpt-4o-mini + openai.host: http://langchain4j.dev/demo/openai/v1 + openai.apikey: demo + openai.chat.prompt: please summary this + - model-name: GET_EMBEDDING + class-name: OpenAIEmbeddingModel + openai.model: text-embedding-3-small + openai.host: http://langchain4j.dev/demo/openai/v1 + openai.apikey: demo +``` +Here, GET_EMBEDDING is defined though `model-name` in `pipeline`. + +### All Support models + +The following built-in models are provided: + +#### OpenAIChatModel + +| parameter | type | optional/required | meaning | +|--------------------|--------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------| +| openai.model | STRING | required | Name of model to be called, for example: "gpt-4o-mini", Available options are "gpt-4o-mini", "gpt-4o", "gpt-4-32k", "gpt-3.5-turbo". | +| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. | +| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". | +| openai.chat.prompt | STRING | optional | Prompt for chatting with OpenAI, for example: "Please summary this ". | + +#### OpenAIEmbeddingModel + +| parameter | type | optional/required | meaning | +|---------------|--------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| openai.model | STRING | required | Name of model to be called, for example: "text-embedding-3-small", Available options are "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". | +| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. | +| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". | + # Known limitations * Currently, transform doesn't work with route rules. It will be supported in future versions. * Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions. diff --git a/docs/content/docs/core-concept/transform.md b/docs/content/docs/core-concept/transform.md index e8d58f254..d1d725f92 100644 --- a/docs/content/docs/core-concept/transform.md +++ b/docs/content/docs/core-concept/transform.md @@ -356,6 +356,78 @@ transform: filter: inc(id) < 100 ``` +## Embedding AI Model + +Embedding AI Model can be used in transform rules. +To use Embedding AI Model, you need to download the jar of build-in model, and then add `--jar {$BUILT_IN_MODEL_PATH}` to your flink-cdc.sh command. + +How to define a Embedding AI Model: + +```yaml +pipeline: + model: + - model-name: CHAT + class-name: OpenAIChatModel + openai.model: text-embedding-3-small + openai.host: https://xxxx + openai.apikey: abcd1234 + openai.chat.prompt: please summary this + - model-name: GET_EMBEDDING + class-name: OpenAIEmbeddingModel + openai.model: text-embedding-3-small + openai.host: https://xxxx + openai.apikey: abcd1234 +``` +Note: +* `model-name` is a common required parameter for all support models, which represent the function name called in `projection` or `filter`. +* `class-name` is a common required parameter for all support models, available values can be found in [All Support models](#all-support-models). +* `openai.model`, `openai.host`, `openai.apiKey` and `openai.chat.prompt` is option parameters that defined in specific model. + +How to use a Embedding AI Model: + +```yaml +transform: + - source-table: db.\.* + projection: "*, inc(inc(inc(id))) as inc_id, GET_EMBEDDING(page) as emb, CHAT(page) as summary" + filter: inc(id) < 100 +pipeline: + model: + - model-name: CHAT + class-name: OpenAIChatModel + openai.model: gpt-4o-mini + openai.host: http://langchain4j.dev/demo/openai/v1 + openai.apikey: demo + openai.chat.prompt: please summary this + - model-name: GET_EMBEDDING + class-name: OpenAIEmbeddingModel + openai.model: text-embedding-3-small + openai.host: http://langchain4j.dev/demo/openai/v1 + openai.apikey: demo +``` +Here, GET_EMBEDDING is defined though `model-name` in `pipeline`. + +### All Support models + +The following built-in models are provided: + +#### OpenAIChatModel + +| parameter | type | optional/required | meaning | +|--------------------|--------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------| +| openai.model | STRING | required | Name of model to be called, for example: "gpt-4o-mini", Available options are "gpt-4o-mini", "gpt-4o", "gpt-4-32k", "gpt-3.5-turbo". | +| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. | +| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". | +| openai.chat.prompt | STRING | optional | Prompt for chatting with OpenAI, for example: "Please summary this ". | + +#### OpenAIEmbeddingModel + +| parameter | type | optional/required | meaning | +|---------------|--------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| openai.model | STRING | required | Name of model to be called, for example: "text-embedding-3-small", Available options are "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". | +| openai.host | STRING | required | Host of the Model server to be connected, for example: `http://langchain4j.dev/demo/openai/v1`. | +| openai.apikey | STRING | required | Api Key for verification of the Model server, for example, "demo". | + + # Known limitations * Currently, transform doesn't work with route rules. It will be supported in future versions. * Computed columns cannot reference trimmed columns that do not present in final projection results. This will be fixed in future versions. diff --git a/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java b/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java index 3e65191e2..2f471db21 100644 --- a/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java +++ b/flink-cdc-cli/src/main/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParser.java @@ -21,7 +21,9 @@ import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.event.SchemaChangeEventType; import org.apache.flink.cdc.common.event.SchemaChangeEventTypeFamily; import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; +import org.apache.flink.cdc.common.utils.Preconditions; import org.apache.flink.cdc.common.utils.StringUtils; +import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.RouteDef; import org.apache.flink.cdc.composer.definition.SinkDef; @@ -57,6 +59,7 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { private static final String ROUTE_KEY = "route"; private static final String TRANSFORM_KEY = "transform"; private static final String PIPELINE_KEY = "pipeline"; + private static final String MODEL_KEY = "model"; // Source / sink keys private static final String TYPE_KEY = "type"; @@ -81,6 +84,11 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { private static final String UDF_FUNCTION_NAME_KEY = "name"; private static final String UDF_CLASSPATH_KEY = "classpath"; + // Model related keys + private static final String MODEL_NAME_KEY = "model-name"; + + private static final String MODEL_CLASS_NAME_KEY = "class-name"; + public static final String TRANSFORM_PRIMARY_KEY_KEY = "primary-keys"; public static final String TRANSFORM_PARTITION_KEY_KEY = "partition-keys"; @@ -108,10 +116,15 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { // UDFs are optional. We parse UDF first and remove it from the pipelineDefJsonNode since // it's not of plain data types and must be removed before calling toPipelineConfig. List udfDefs = new ArrayList<>(); + final List modelDefs = new ArrayList<>(); if (pipelineDefJsonNode.get(PIPELINE_KEY) != null) { Optional.ofNullable( ((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(UDF_KEY)) .ifPresent(node -> node.forEach(udf -> udfDefs.add(toUdfDef(udf)))); + + Optional.ofNullable( + ((ObjectNode) pipelineDefJsonNode.get(PIPELINE_KEY)).remove(MODEL_KEY)) + .ifPresent(node -> modelDefs.addAll(parseModels(node))); } // Pipeline configs are optional @@ -156,7 +169,7 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { pipelineConfig.addAll(userPipelineConfig); return new PipelineDef( - sourceDef, sinkDef, routeDefs, transformDefs, udfDefs, pipelineConfig); + sourceDef, sinkDef, routeDefs, transformDefs, udfDefs, modelDefs, pipelineConfig); } private SourceDef toSourceDef(JsonNode sourceNode) { @@ -323,4 +336,34 @@ public class YamlPipelineDefinitionParser implements PipelineDefinitionParser { pipelineConfigNode, new TypeReference>() {}); return Configuration.fromMap(pipelineConfigMap); } + + private List parseModels(JsonNode modelsNode) { + List modelDefs = new ArrayList<>(); + Preconditions.checkNotNull(modelsNode, "`model` in `pipeline` should not be empty."); + if (modelsNode.isArray()) { + for (JsonNode modelNode : modelsNode) { + modelDefs.add(convertJsonNodeToModelDef(modelNode)); + } + } else { + modelDefs.add(convertJsonNodeToModelDef(modelsNode)); + } + return modelDefs; + } + + private ModelDef convertJsonNodeToModelDef(JsonNode modelNode) { + String name = + checkNotNull( + modelNode.get(MODEL_NAME_KEY), + "Missing required field \"%s\" in `model`", + MODEL_NAME_KEY) + .asText(); + String model = + checkNotNull( + modelNode.get(MODEL_CLASS_NAME_KEY), + "Missing required field \"%s\" in `model`", + MODEL_CLASS_NAME_KEY) + .asText(); + Map properties = mapper.convertValue(modelNode, Map.class); + return new ModelDef(name, model, properties); + } } diff --git a/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java b/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java index c1c47904e..dca3e3154 100644 --- a/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java +++ b/flink-cdc-cli/src/test/java/org/apache/flink/cdc/cli/parser/YamlPipelineDefinitionParserTest.java @@ -20,6 +20,7 @@ package org.apache.flink.cdc.cli.parser; import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.event.SchemaChangeEventType; import org.apache.flink.cdc.common.pipeline.PipelineOptions; +import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.RouteDef; import org.apache.flink.cdc.composer.definition.SinkDef; @@ -39,6 +40,7 @@ import java.time.Duration; import java.time.ZoneId; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.Set; import static org.apache.flink.cdc.common.event.SchemaChangeEventType.ADD_COLUMN; @@ -344,6 +346,18 @@ class YamlPipelineDefinitionParserTest { null, "add new uniq_id for each row")), Collections.emptyList(), + Collections.singletonList( + new ModelDef( + "GET_EMBEDDING", + "OpenAIEmbeddingModel", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("model-name", "GET_EMBEDDING") + .put("class-name", "OpenAIEmbeddingModel") + .put("openai.model", "text-embedding-3-small") + .put("openai.host", "https://xxxx") + .put("openai.apikey", "abcd1234") + .build()))), Configuration.fromMap( ImmutableMap.builder() .put("name", "source-database-sync-pipe") @@ -397,7 +411,13 @@ class YamlPipelineDefinitionParserTest { + " name: source-database-sync-pipe\n" + " parallelism: 4\n" + " schema.change.behavior: evolve\n" - + " schema-operator.rpc-timeout: 1 h"; + + " schema-operator.rpc-timeout: 1 h\n" + + " model:\n" + + " - model-name: GET_EMBEDDING\n" + + " class-name: OpenAIEmbeddingModel\n" + + " openai.model: text-embedding-3-small\n" + + " openai.host: https://xxxx\n" + + " openai.apikey: abcd1234"; YamlPipelineDefinitionParser parser = new YamlPipelineDefinitionParser(); PipelineDef pipelineDef = parser.parse(pipelineDefText, new Configuration()); assertThat(pipelineDef).isEqualTo(fullDef); @@ -459,6 +479,18 @@ class YamlPipelineDefinitionParserTest { null, "add new uniq_id for each row")), Collections.emptyList(), + Collections.singletonList( + new ModelDef( + "GET_EMBEDDING", + "OpenAIEmbeddingModel", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("model-name", "GET_EMBEDDING") + .put("class-name", "OpenAIEmbeddingModel") + .put("openai.model", "text-embedding-3-small") + .put("openai.host", "https://xxxx") + .put("openai.apikey", "abcd1234") + .build()))), Configuration.fromMap( ImmutableMap.builder() .put("name", "source-database-sync-pipe") diff --git a/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml b/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml index b92e237d1..5a8cfb004 100644 --- a/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml +++ b/flink-cdc-cli/src/test/resources/definitions/pipeline-definition-full.yaml @@ -57,3 +57,9 @@ pipeline: parallelism: 4 schema.change.behavior: evolve schema-operator.rpc-timeout: 1 h + model: + model-name: GET_EMBEDDING + class-name: OpenAIEmbeddingModel + openai.model: text-embedding-3-small + openai.host: https://xxxx + openai.apikey: abcd1234 diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunction.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunction.java index a40878587..0133e4394 100644 --- a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunction.java +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunction.java @@ -30,9 +30,20 @@ public interface UserDefinedFunction { return null; } - /** This will be invoked every time when a UDF got created. */ + /** + * This will be invoked every time when a UDF got created. + * + *

this method is {@link Deprecated}, please use {@link #open(UserDefinedFunctionContext)} + * instead. + */ + @Deprecated default void open() throws Exception {} + /** This will be invoked every time when a UDF got created. */ + default void open(UserDefinedFunctionContext context) throws Exception { + open(); + } + /** This will be invoked before a UDF got destroyed. */ default void close() throws Exception {} } diff --git a/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunctionContext.java b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunctionContext.java new file mode 100644 index 000000000..eec73e457 --- /dev/null +++ b/flink-cdc-common/src/main/java/org/apache/flink/cdc/common/udf/UserDefinedFunctionContext.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.common.udf; + +import org.apache.flink.cdc.common.configuration.Configuration; + +/** Context for initialization of {@link UserDefinedFunction}. */ +public interface UserDefinedFunctionContext { + + Configuration configuration(); +} diff --git a/flink-cdc-composer/pom.xml b/flink-cdc-composer/pom.xml index 97c2e0864..5971f3b45 100644 --- a/flink-cdc-composer/pom.xml +++ b/flink-cdc-composer/pom.xml @@ -67,6 +67,12 @@ limitations under the License. ${project.version} test + + org.apache.flink + flink-cdc-pipeline-model + ${project.version} + test + diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java new file mode 100644 index 000000000..21cc6befa --- /dev/null +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/ModelDef.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.composer.definition; + +import java.util.Map; +import java.util.Objects; + +/** + * Common properties of model. + * + *

A transformation definition contains: + * + *

    + *
  • modelName: The name of function. + *
  • className: The model to transform data. + *
  • parameters: The parameters that used to configure the model. + *
+ */ +public class ModelDef { + + private final String modelName; + + private final String className; + + private final Map parameters; + + public ModelDef(String modelName, String className, Map parameters) { + this.modelName = modelName; + this.className = className; + this.parameters = parameters; + } + + public String getModelName() { + return modelName; + } + + public String getClassName() { + return className; + } + + public Map getParameters() { + return parameters; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ModelDef modelDef = (ModelDef) o; + return Objects.equals(modelName, modelDef.modelName) + && Objects.equals(className, modelDef.className) + && Objects.equals(parameters, modelDef.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(modelName, className, parameters); + } + + @Override + public String toString() { + return "ModelDef{" + + "name='" + + modelName + + '\'' + + ", model='" + + className + + '\'' + + ", parameters=" + + parameters + + '}'; + } +} diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/PipelineDef.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/PipelineDef.java index 6353c4e74..c81d45fd9 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/PipelineDef.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/definition/PipelineDef.java @@ -24,6 +24,7 @@ import org.apache.flink.cdc.composer.PipelineComposer; import org.apache.flink.cdc.composer.PipelineExecution; import java.time.ZoneId; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.TimeZone; @@ -55,6 +56,7 @@ public class PipelineDef { private final List routes; private final List transforms; private final List udfs; + private final List models; private final Configuration config; public PipelineDef( @@ -63,15 +65,27 @@ public class PipelineDef { List routes, List transforms, List udfs, + List models, Configuration config) { this.source = source; this.sink = sink; this.routes = routes; this.transforms = transforms; this.udfs = udfs; + this.models = models; this.config = evaluatePipelineTimeZone(config); } + public PipelineDef( + SourceDef source, + SinkDef sink, + List routes, + List transforms, + List udfs, + Configuration config) { + this(source, sink, routes, transforms, udfs, new ArrayList<>(), config); + } + public SourceDef getSource() { return source; } @@ -92,6 +106,10 @@ public class PipelineDef { return udfs; } + public List getModels() { + return models; + } + public Configuration getConfig() { return config; } @@ -109,6 +127,8 @@ public class PipelineDef { + transforms + ", udfs=" + udfs + + ", models=" + + models + ", config=" + config + '}'; @@ -128,12 +148,13 @@ public class PipelineDef { && Objects.equals(routes, that.routes) && Objects.equals(transforms, that.transforms) && Objects.equals(udfs, that.udfs) + && Objects.equals(models, that.models) && Objects.equals(config, that.config); } @Override public int hashCode() { - return Objects.hash(source, sink, routes, transforms, udfs, config); + return Objects.hash(source, sink, routes, transforms, udfs, models, config); } // ------------------------------------------------------------------------ diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java index 114035fe9..579eb9607 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/FlinkPipelineComposer.java @@ -107,7 +107,10 @@ public class FlinkPipelineComposer implements PipelineComposer { TransformTranslator transformTranslator = new TransformTranslator(); stream = transformTranslator.translatePreTransform( - stream, pipelineDef.getTransforms(), pipelineDef.getUdfs()); + stream, + pipelineDef.getTransforms(), + pipelineDef.getUdfs(), + pipelineDef.getModels()); // Schema operator SchemaOperatorTranslator schemaOperatorTranslator = @@ -124,8 +127,9 @@ public class FlinkPipelineComposer implements PipelineComposer { transformTranslator.translatePostTransform( stream, pipelineDef.getTransforms(), - pipelineDefConfig.get(PipelineOptions.PIPELINE_LOCAL_TIME_ZONE), - pipelineDef.getUdfs()); + pipelineDef.getConfig().get(PipelineOptions.PIPELINE_LOCAL_TIME_ZONE), + pipelineDef.getUdfs(), + pipelineDef.getModels()); // Build DataSink in advance as schema operator requires MetadataApplier DataSinkTranslator sinkTranslator = new DataSinkTranslator(); diff --git a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java index 3e7634147..d8f009721 100644 --- a/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java +++ b/flink-cdc-composer/src/main/java/org/apache/flink/cdc/composer/flink/translator/TransformTranslator.java @@ -17,8 +17,9 @@ package org.apache.flink.cdc.composer.flink.translator; -import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.cdc.common.event.Event; +import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.TransformDef; import org.apache.flink.cdc.composer.definition.UdfDef; import org.apache.flink.cdc.runtime.operators.transform.PostTransformOperator; @@ -26,7 +27,9 @@ import org.apache.flink.cdc.runtime.operators.transform.PreTransformOperator; import org.apache.flink.cdc.runtime.typeutils.EventTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; /** @@ -35,8 +38,15 @@ import java.util.stream.Collectors; */ public class TransformTranslator { + /** Package of built-in model. */ + public static final String PREFIX_CLASSPATH_BUILT_IN_MODEL = + "org.apache.flink.cdc.runtime.model."; + public DataStream translatePreTransform( - DataStream input, List transforms, List udfFunctions) { + DataStream input, + List transforms, + List udfFunctions, + List models) { if (transforms.isEmpty()) { return input; } @@ -52,10 +62,11 @@ public class TransformTranslator { transform.getPartitionKeys(), transform.getTableOptions()); } + + preTransformFunctionBuilder.addUdfFunctions( + udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList())); preTransformFunctionBuilder.addUdfFunctions( - udfFunctions.stream() - .map(udf -> Tuple2.of(udf.getName(), udf.getClasspath())) - .collect(Collectors.toList())); + models.stream().map(this::modelToUDFTuple).collect(Collectors.toList())); return input.transform( "Transform:Schema", new EventTypeInfo(), preTransformFunctionBuilder.build()); } @@ -64,7 +75,8 @@ public class TransformTranslator { DataStream input, List transforms, String timezone, - List udfFunctions) { + List udfFunctions, + List models) { if (transforms.isEmpty()) { return input; } @@ -84,10 +96,21 @@ public class TransformTranslator { } postTransformFunctionBuilder.addTimezone(timezone); postTransformFunctionBuilder.addUdfFunctions( - udfFunctions.stream() - .map(udf -> Tuple2.of(udf.getName(), udf.getClasspath())) - .collect(Collectors.toList())); + udfFunctions.stream().map(this::udfDefToUDFTuple).collect(Collectors.toList())); + postTransformFunctionBuilder.addUdfFunctions( + models.stream().map(this::modelToUDFTuple).collect(Collectors.toList())); return input.transform( "Transform:Data", new EventTypeInfo(), postTransformFunctionBuilder.build()); } + + private Tuple3> modelToUDFTuple(ModelDef model) { + return Tuple3.of( + model.getModelName(), + PREFIX_CLASSPATH_BUILT_IN_MODEL + model.getClassName(), + model.getParameters()); + } + + private Tuple3> udfDefToUDFTuple(UdfDef udf) { + return Tuple3.of(udf.getName(), udf.getClasspath(), new HashMap<>()); + } } diff --git a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java index c3b412dc7..609c62ad7 100644 --- a/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java +++ b/flink-cdc-composer/src/test/java/org/apache/flink/cdc/composer/flink/FlinkPipelineUdfITCase.java @@ -21,6 +21,7 @@ import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.pipeline.SchemaChangeBehavior; import org.apache.flink.cdc.composer.PipelineExecution; +import org.apache.flink.cdc.composer.definition.ModelDef; import org.apache.flink.cdc.composer.definition.PipelineDef; import org.apache.flink.cdc.composer.definition.SinkDef; import org.apache.flink.cdc.composer.definition.SourceDef; @@ -35,6 +36,8 @@ import org.apache.flink.cdc.connectors.values.source.ValuesDataSourceOptions; import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; import org.apache.flink.test.junit5.MiniClusterExtension; +import org.apache.flink.shaded.guava31.com.google.common.collect.ImmutableMap; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.RegisterExtension; @@ -44,8 +47,10 @@ import org.junit.jupiter.params.provider.MethodSource; import java.io.ByteArrayOutputStream; import java.io.PrintStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.stream.Stream; import static org.apache.flink.configuration.CoreOptions.ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL; @@ -822,6 +827,75 @@ public class FlinkPipelineUdfITCase { "DataChangeEvent{tableId=default_namespace.default_schema.table1, before=[2, , 4, Integer: 42, 2-42], after=[2, x, 4, Integer: 42, 2-42], op=UPDATE, meta=()}"); } + @ParameterizedTest + @MethodSource("testParams") + void testTransformWithModel(ValuesDataSink.SinkApi sinkApi) throws Exception { + FlinkPipelineComposer composer = FlinkPipelineComposer.ofMiniCluster(); + + // Setup value source + Configuration sourceConfig = new Configuration(); + sourceConfig.set( + ValuesDataSourceOptions.EVENT_SET_ID, + ValuesDataSourceHelper.EventSetId.TRANSFORM_TABLE); + SourceDef sourceDef = + new SourceDef(ValuesDataFactory.IDENTIFIER, "Value Source", sourceConfig); + + // Setup value sink + Configuration sinkConfig = new Configuration(); + sinkConfig.set(ValuesDataSinkOptions.MATERIALIZED_IN_MEMORY, true); + sinkConfig.set(ValuesDataSinkOptions.SINK_API, sinkApi); + SinkDef sinkDef = new SinkDef(ValuesDataFactory.IDENTIFIER, "Value Sink", sinkConfig); + + // Setup transform + TransformDef transformDef = + new TransformDef( + "default_namespace.default_schema.table1", + "*, CHAT(col1) AS emb", + null, + "col1", + null, + "key1=value1", + ""); + + // Setup pipeline + Configuration pipelineConfig = new Configuration(); + pipelineConfig.set(PipelineOptions.PIPELINE_PARALLELISM, 1); + pipelineConfig.set( + PipelineOptions.PIPELINE_SCHEMA_CHANGE_BEHAVIOR, SchemaChangeBehavior.EVOLVE); + PipelineDef pipelineDef = + new PipelineDef( + sourceDef, + sinkDef, + Collections.emptyList(), + Collections.singletonList(transformDef), + new ArrayList<>(), + Arrays.asList( + new ModelDef( + "CHAT", + "OpenAIChatModel", + new LinkedHashMap<>( + ImmutableMap.builder() + .put("openai.model", "gpt-4o-mini") + .put( + "openai.host", + "http://langchain4j.dev/demo/openai/v1") + .put("openai.apikey", "demo") + .build()))), + pipelineConfig); + + // Execute the pipeline + PipelineExecution execution = composer.compose(pipelineDef); + execution.execute(); + + // Check the order and content of all received events + String[] outputEvents = outCaptor.toString().trim().split("\n"); + assertThat(outputEvents) + .contains( + "CreateTableEvent{tableId=default_namespace.default_schema.table1, schema=columns={`col1` STRING,`col2` STRING,`emb` STRING}, primaryKeys=col1, options=({key1=value1})}") + // The result of transform by model is not fixed. + .hasSize(9); + } + private static Stream testParams() { return Stream.of( arguments(ValuesDataSink.SinkApi.SINK_FUNCTION, "java"), diff --git a/flink-cdc-pipeline-model/pom.xml b/flink-cdc-pipeline-model/pom.xml new file mode 100644 index 000000000..cf51c1cc3 --- /dev/null +++ b/flink-cdc-pipeline-model/pom.xml @@ -0,0 +1,81 @@ + + + + + flink-cdc-parent + org.apache.flink + ${revision} + + 4.0.0 + + flink-cdc-pipeline-model + + 0.23.0 + + + + + org.apache.flink + flink-cdc-common + ${project.version} + provided + + + org.apache.flink + flink-test-utils-junit + ${flink.version} + test + + + dev.langchain4j + langchain4j + ${langchain4j.version} + + + dev.langchain4j + langchain4j-open-ai + ${langchain4j.version} + + + com.theokanning.openai-gpt3-java + service + 0.12.0 + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + test-jar + + test-jar + + + + + + + + \ No newline at end of file diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java new file mode 100644 index 000000000..f56b76d5b --- /dev/null +++ b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/ModelOptions.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.model; + +import org.apache.flink.cdc.common.configuration.ConfigOption; +import org.apache.flink.cdc.common.configuration.ConfigOptions; + +/** Options of built-in model. */ +public class ModelOptions { + + // Options for Open AI Model. + public static final ConfigOption OPENAI_MODEL_NAME = + ConfigOptions.key("openai.model") + .stringType() + .noDefaultValue() + .withDescription("Name of model to be called."); + + public static final ConfigOption OPENAI_HOST = + ConfigOptions.key("openai.host") + .stringType() + .noDefaultValue() + .withDescription("Host of the Model server to be connected."); + + public static final ConfigOption OPENAI_API_KEY = + ConfigOptions.key("openai.apikey") + .stringType() + .noDefaultValue() + .withDescription("Api Key for verification of the Model server."); + + public static final ConfigOption OPENAI_CHAT_PROMPT = + ConfigOptions.key("openai.chat.prompt") + .stringType() + .noDefaultValue() + .withDescription("Prompt for chat using OpenAI."); +} diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java new file mode 100644 index 000000000..2fa147f50 --- /dev/null +++ b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIChatModel.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.model; + +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.common.udf.UserDefinedFunction; +import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; +import org.apache.flink.cdc.common.utils.Preconditions; + +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.openai.OpenAiChatModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; + +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY; +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_CHAT_PROMPT; +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST; +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME; + +/** + * A {@link UserDefinedFunction} that use Model defined by OpenAI to generate text, refer to docs}. + */ +public class OpenAIChatModel implements UserDefinedFunction { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAIChatModel.class); + + private OpenAiChatModel chatModel; + + private String modelName; + + private String host; + + private String prompt; + + public String eval(String input) { + return chat(input); + } + + private String chat(String input) { + if (input == null || input.trim().isEmpty()) { + LOG.warn("Empty or null input provided for embedding."); + return ""; + } + if (prompt != null) { + input = prompt + ": " + input; + } + return chatModel + .generate(Collections.singletonList(new UserMessage(input))) + .content() + .text(); + } + + @Override + public DataType getReturnType() { + return DataTypes.STRING(); + } + + @Override + public void open(UserDefinedFunctionContext userDefinedFunctionContext) { + Configuration modelOptions = userDefinedFunctionContext.configuration(); + this.modelName = modelOptions.get(OPENAI_MODEL_NAME); + Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty."); + this.host = modelOptions.get(OPENAI_HOST); + Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty."); + String apiKey = modelOptions.get(OPENAI_API_KEY); + Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty."); + this.prompt = modelOptions.get(OPENAI_CHAT_PROMPT); + LOG.info("Opening OpenAIChatModel " + modelName + " " + host); + this.chatModel = + OpenAiChatModel.builder().apiKey(apiKey).baseUrl(host).modelName(modelName).build(); + } + + @Override + public void close() { + LOG.info("Closed OpenAIChatModel " + modelName + " " + host); + } +} diff --git a/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java new file mode 100644 index 000000000..dbc29c307 --- /dev/null +++ b/flink-cdc-pipeline-model/src/main/java/org/apache/flink/cdc/runtime/model/OpenAIEmbeddingModel.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.model; + +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.data.ArrayData; +import org.apache.flink.cdc.common.data.GenericArrayData; +import org.apache.flink.cdc.common.types.DataType; +import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.common.udf.UserDefinedFunction; +import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; +import org.apache.flink.cdc.common.utils.Preconditions; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.openai.OpenAiEmbeddingModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_API_KEY; +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_HOST; +import static org.apache.flink.cdc.runtime.model.ModelOptions.OPENAI_MODEL_NAME; + +/** + * A {@link UserDefinedFunction} that use Model defined by OpenAI to generate vector data, refer to + * docs}. + */ +public class OpenAIEmbeddingModel implements UserDefinedFunction { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAIEmbeddingModel.class); + + private String modelName; + + private String host; + + private OpenAiEmbeddingModel embeddingModel; + + public ArrayData eval(String input) { + return getEmbedding(input); + } + + private ArrayData getEmbedding(String input) { + if (input == null || input.trim().isEmpty()) { + LOG.debug("Empty or null input provided for embedding."); + return new GenericArrayData(new Float[0]); + } + + TextSegment textSegment = new TextSegment(input, new Metadata()); + + List embeddings = + embeddingModel.embedAll(Collections.singletonList(textSegment)).content(); + + if (embeddings != null && !embeddings.isEmpty()) { + List embeddingList = embeddings.get(0).vectorAsList(); + Float[] embeddingArray = embeddingList.toArray(new Float[0]); + return new GenericArrayData(embeddingArray); + } else { + LOG.warn("No embedding results returned for input: {}", input); + return new GenericArrayData(new Float[0]); + } + } + + @Override + public DataType getReturnType() { + return DataTypes.ARRAY(DataTypes.FLOAT()); + } + + @Override + public void open(UserDefinedFunctionContext userDefinedFunctionContext) { + Configuration modelOptions = userDefinedFunctionContext.configuration(); + this.modelName = modelOptions.get(OPENAI_MODEL_NAME); + Preconditions.checkNotNull(modelName, OPENAI_MODEL_NAME.key() + " should not be empty."); + this.host = modelOptions.get(OPENAI_HOST); + Preconditions.checkNotNull(host, OPENAI_HOST.key() + " should not be empty."); + String apiKey = modelOptions.get(OPENAI_API_KEY); + Preconditions.checkNotNull(apiKey, OPENAI_API_KEY.key() + " should not be empty."); + LOG.info("Opening OpenAIEmbeddingModel " + modelName + " " + host); + this.embeddingModel = + OpenAiEmbeddingModel.builder() + .apiKey(apiKey) + .baseUrl(host) + .modelName(modelName) + .build(); + } + + @Override + public void close() { + LOG.info("Closed OpenAIEmbeddingModel " + modelName + " " + host); + } +} diff --git a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java new file mode 100644 index 000000000..bba2d8b25 --- /dev/null +++ b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIChatModel.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.model; + +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** A test for {@link OpenAIChatModel}. */ +public class TestOpenAIChatModel { + @Test + public void testEval() { + OpenAIChatModel openAIChatModel = new OpenAIChatModel(); + Configuration configuration = new Configuration(); + configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1"); + configuration.set(ModelOptions.OPENAI_API_KEY, "demo"); + configuration.set(ModelOptions.OPENAI_MODEL_NAME, "gpt-4o-mini"); + UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration; + openAIChatModel.open(userDefinedFunctionContext); + String response = openAIChatModel.eval("Who invented the electric light?"); + Assertions.assertFalse(response.isEmpty()); + } +} diff --git a/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java new file mode 100644 index 000000000..118fb5628 --- /dev/null +++ b/flink-cdc-pipeline-model/src/test/java/org/apache/flink/cdc/runtime/model/TestOpenAIEmbeddingModel.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.cdc.runtime.model; + +import org.apache.flink.cdc.common.configuration.Configuration; +import org.apache.flink.cdc.common.data.ArrayData; +import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** A test for {@link OpenAIEmbeddingModel}. */ +public class TestOpenAIEmbeddingModel { + + @Test + public void testEval() { + OpenAIEmbeddingModel openAIEmbeddingModel = new OpenAIEmbeddingModel(); + Configuration configuration = new Configuration(); + configuration.set(ModelOptions.OPENAI_HOST, "http://langchain4j.dev/demo/openai/v1"); + configuration.set(ModelOptions.OPENAI_API_KEY, "demo"); + configuration.set(ModelOptions.OPENAI_MODEL_NAME, "text-embedding-3-small"); + UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration; + openAIEmbeddingModel.open(userDefinedFunctionContext); + ArrayData arrayData = + openAIEmbeddingModel.eval("Flink CDC is a streaming data integration tool"); + Assertions.assertNotNull(arrayData); + } +} diff --git a/flink-cdc-runtime/pom.xml b/flink-cdc-runtime/pom.xml index c5dbc123f..6b1ea91a9 100644 --- a/flink-cdc-runtime/pom.xml +++ b/flink-cdc-runtime/pom.xml @@ -89,6 +89,12 @@ limitations under the License. ${project.version} test + + org.apache.flink + flink-cdc-pipeline-model + ${project.version} + test + \ No newline at end of file diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java index 8a607ffb5..dc6f07751 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PostTransformOperator.java @@ -18,6 +18,8 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.cdc.common.configuration.Configuration; import org.apache.flink.cdc.common.data.RecordData; import org.apache.flink.cdc.common.data.binary.BinaryRecordData; import org.apache.flink.cdc.common.event.CreateTableEvent; @@ -29,6 +31,7 @@ import org.apache.flink.cdc.common.event.TableId; import org.apache.flink.cdc.common.pipeline.PipelineOptions; import org.apache.flink.cdc.common.schema.Schema; import org.apache.flink.cdc.common.schema.Selectors; +import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext; import org.apache.flink.cdc.common.utils.SchemaUtils; import org.apache.flink.cdc.runtime.parser.TransformParser; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -70,7 +73,7 @@ public class PostTransformOperator extends AbstractStreamOperator /** keep the relationship of TableId and table information. */ private final Map postTransformChangeInfoMap; - private final List> udfFunctions; + private final List>> udfFunctions; private List udfDescriptors; private transient Map udfFunctionInstances; @@ -89,7 +92,8 @@ public class PostTransformOperator extends AbstractStreamOperator public static class Builder { private final List transformRules = new ArrayList<>(); private String timezone; - private final List> udfFunctions = new ArrayList<>(); + private final List>> udfFunctions = + new ArrayList<>(); public PostTransformOperator.Builder addTransform( String tableInclusions, @@ -125,7 +129,7 @@ public class PostTransformOperator extends AbstractStreamOperator } public PostTransformOperator.Builder addUdfFunctions( - List> udfFunctions) { + List>> udfFunctions) { this.udfFunctions.addAll(udfFunctions); return this; } @@ -138,7 +142,7 @@ public class PostTransformOperator extends AbstractStreamOperator private PostTransformOperator( List transformRules, String timezone, - List> udfFunctions) { + List>> udfFunctions) { this.transformRules = transformRules; this.timezone = timezone; this.postTransformChangeInfoMap = new ConcurrentHashMap<>(); @@ -158,10 +162,7 @@ public class PostTransformOperator extends AbstractStreamOperator super.setup(containingTask, config, output); udfDescriptors = udfFunctions.stream() - .map( - udf -> { - return new UserDefinedFunctionDescriptor(udf.f0, udf.f1); - }) + .map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1, udf.f2)) .collect(Collectors.toList()); } @@ -539,7 +540,12 @@ public class PostTransformOperator extends AbstractStreamOperator // into UserDefinedFunction interface, thus the provided UDF classes // might not be compatible with the interface definition in CDC common. Object udfInstance = udfFunctionInstances.get(udf.getName()); - udfInstance.getClass().getMethod("open").invoke(udfInstance); + UserDefinedFunctionContext userDefinedFunctionContext = + () -> Configuration.fromMap(udf.getParameters()); + udfInstance + .getClass() + .getMethod("open", UserDefinedFunctionContext.class) + .invoke(udfInstance, userDefinedFunctionContext); } else { // Do nothing, Flink-style UDF lifecycle hooks are not supported } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PreTransformOperator.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PreTransformOperator.java index b0fc8218d..538c28ddb 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PreTransformOperator.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/PreTransformOperator.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.cdc.common.data.binary.BinaryRecordData; import org.apache.flink.cdc.common.event.CreateTableEvent; import org.apache.flink.cdc.common.event.DataChangeEvent; @@ -69,7 +70,7 @@ public class PreTransformOperator extends AbstractStreamOperator private final Map preTransformChangeInfoMap; private final List> schemaMetadataTransformers; private transient ListState state; - private final List> udfFunctions; + private final List>> udfFunctions; private List udfDescriptors; private Map preTransformProcessorMap; private Map hasAsteriskMap; @@ -82,7 +83,8 @@ public class PreTransformOperator extends AbstractStreamOperator public static class Builder { private final List transformRules = new ArrayList<>(); - private final List> udfFunctions = new ArrayList<>(); + private final List>> udfFunctions = + new ArrayList<>(); public PreTransformOperator.Builder addTransform( String tableInclusions, @Nullable String projection, @Nullable String filter) { @@ -109,7 +111,7 @@ public class PreTransformOperator extends AbstractStreamOperator } public PreTransformOperator.Builder addUdfFunctions( - List> udfFunctions) { + List>> udfFunctions) { this.udfFunctions.addAll(udfFunctions); return this; } @@ -120,7 +122,8 @@ public class PreTransformOperator extends AbstractStreamOperator } private PreTransformOperator( - List transformRules, List> udfFunctions) { + List transformRules, + List>> udfFunctions) { this.transformRules = transformRules; this.preTransformChangeInfoMap = new ConcurrentHashMap<>(); this.preTransformProcessorMap = new ConcurrentHashMap<>(); @@ -137,7 +140,7 @@ public class PreTransformOperator extends AbstractStreamOperator super.setup(containingTask, config, output); this.udfDescriptors = this.udfFunctions.stream() - .map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1)) + .map(udf -> new UserDefinedFunctionDescriptor(udf.f0, udf.f1, udf.f2)) .collect(Collectors.toList()); // Initialize data fields in advance because they might be accessed in diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptor.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptor.java index e12e8417c..9acdad351 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptor.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptor.java @@ -24,6 +24,8 @@ import org.apache.flink.cdc.common.udf.UserDefinedFunction; import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @@ -38,9 +40,16 @@ public class UserDefinedFunctionDescriptor implements Serializable { private final String className; private final DataType returnTypeHint; private final boolean isCdcPipelineUdf; + private final Map parameters; public UserDefinedFunctionDescriptor(String name, String classpath) { + this(name, classpath, new HashMap<>()); + } + + public UserDefinedFunctionDescriptor( + String name, String classpath, Map parameters) { this.name = name; + this.parameters = parameters; this.classpath = classpath; this.className = classpath.substring(classpath.lastIndexOf('.') + 1); try { @@ -107,6 +116,10 @@ public class UserDefinedFunctionDescriptor implements Serializable { return className; } + public Map getParameters() { + return parameters; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -115,13 +128,19 @@ public class UserDefinedFunctionDescriptor implements Serializable { if (o == null || getClass() != o.getClass()) { return false; } - UserDefinedFunctionDescriptor context = (UserDefinedFunctionDescriptor) o; - return Objects.equals(name, context.name) && Objects.equals(classpath, context.classpath); + UserDefinedFunctionDescriptor that = (UserDefinedFunctionDescriptor) o; + return isCdcPipelineUdf == that.isCdcPipelineUdf + && Objects.equals(name, that.name) + && Objects.equals(classpath, that.classpath) + && Objects.equals(className, that.className) + && Objects.equals(returnTypeHint, that.returnTypeHint) + && Objects.equals(parameters, that.parameters); } @Override public int hashCode() { - return Objects.hash(name, classpath); + return Objects.hash( + name, classpath, className, returnTypeHint, isCdcPipelineUdf, parameters); } @Override @@ -133,6 +152,15 @@ public class UserDefinedFunctionDescriptor implements Serializable { + ", classpath='" + classpath + '\'' + + ", className='" + + className + + '\'' + + ", returnTypeHint=" + + returnTypeHint + + ", isCdcPipelineUdf=" + + isCdcPipelineUdf + + ", parameters=" + + parameters + '}'; } } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java index 547d9cd87..6d697652e 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/TransformParser.java @@ -36,6 +36,7 @@ import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rex.RexBuilder; @@ -123,10 +124,11 @@ public class TransformParser { if (udf.getReturnTypeHint() != null) { // This UDF has return type hint annotation returnTypeInference = - o -> - o.getTypeFactory() - .createSqlType( - convertCalciteType(udf.getReturnTypeHint())); + o -> { + RelDataTypeFactory typeFactory = o.getTypeFactory(); + DataType returnTypeHint = udf.getReturnTypeHint(); + return convertCalciteType(typeFactory, returnTypeHint); + }; } else { // Infer it from eval method return type returnTypeInference = o -> function.getReturnType(o.getTypeFactory()); diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java index 986337c8a..bb2c3503d 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformSqlOperatorTable.java @@ -310,4 +310,36 @@ public class TransformSqlOperatorTable extends ReflectiveSqlOperatorTable { // Cast Functions // -------------- public static final SqlFunction CAST = SqlStdOperatorTable.CAST; + + public static final SqlFunction AI_CHAT_PREDICT = + new SqlFunction( + "AI_CHAT_PREDICT", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.VARCHAR), + null, + OperandTypes.family( + SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + // Define the AI_EMBEDDING function + public static final SqlFunction GET_EMBEDDING = + new SqlFunction( + "GET_EMBEDDING", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.VARCHAR), + null, + OperandTypes.family( + SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + // Define the AI_LANGCHAIN_PREDICT function + public static final SqlFunction AI_LANGCHAIN_PREDICT = + new SqlFunction( + "AI_LANGCHAIN_PREDICT", + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.VARCHAR), + null, + OperandTypes.family( + SqlTypeFamily.STRING, SqlTypeFamily.STRING, SqlTypeFamily.STRING), + SqlFunctionCategory.USER_DEFINED_FUNCTION); } diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformTable.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformTable.java index a47146eb8..814ed74e4 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformTable.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/parser/metadata/TransformTable.java @@ -51,8 +51,7 @@ public class TransformTable extends AbstractTable { for (Column column : columns) { names.add(column.getName()); RelDataType sqlType = - relDataTypeFactory.createSqlType( - DataTypeConverter.convertCalciteType(column.getType())); + DataTypeConverter.convertCalciteType(relDataTypeFactory, column.getType()); types.add(sqlType); } return relDataTypeFactory.createStructType(Pair.zip(names, types)); diff --git a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/typeutils/DataTypeConverter.java b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/typeutils/DataTypeConverter.java index 78f7d448b..8da220293 100644 --- a/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/typeutils/DataTypeConverter.java +++ b/flink-cdc-runtime/src/main/java/org/apache/flink/cdc/runtime/typeutils/DataTypeConverter.java @@ -17,18 +17,27 @@ package org.apache.flink.cdc.runtime.typeutils; +import org.apache.flink.cdc.common.data.ArrayData; import org.apache.flink.cdc.common.data.DecimalData; +import org.apache.flink.cdc.common.data.GenericArrayData; +import org.apache.flink.cdc.common.data.GenericMapData; import org.apache.flink.cdc.common.data.LocalZonedTimestampData; +import org.apache.flink.cdc.common.data.MapData; import org.apache.flink.cdc.common.data.TimestampData; import org.apache.flink.cdc.common.data.binary.BinaryStringData; import org.apache.flink.cdc.common.schema.Column; +import org.apache.flink.cdc.common.types.ArrayType; import org.apache.flink.cdc.common.types.BinaryType; import org.apache.flink.cdc.common.types.DataType; import org.apache.flink.cdc.common.types.DataTypes; +import org.apache.flink.cdc.common.types.DecimalType; +import org.apache.flink.cdc.common.types.MapType; import org.apache.flink.cdc.common.types.RowType; +import org.apache.flink.cdc.common.types.TimestampType; import org.apache.flink.cdc.common.types.VarBinaryType; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.type.SqlTypeName; import java.math.BigDecimal; @@ -39,8 +48,11 @@ import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; /** A data type converter. */ public class DataTypeConverter { @@ -90,50 +102,67 @@ public class DataTypeConverter { case ROW: return Object.class; case ARRAY: + return ArrayData.class; case MAP: + return MapData.class; default: throw new UnsupportedOperationException("Unsupported type: " + dataType); } } - public static SqlTypeName convertCalciteType(DataType dataType) { + public static RelDataType convertCalciteType( + RelDataTypeFactory typeFactory, DataType dataType) { switch (dataType.getTypeRoot()) { case BOOLEAN: - return SqlTypeName.BOOLEAN; + return typeFactory.createSqlType(SqlTypeName.BOOLEAN); case TINYINT: - return SqlTypeName.TINYINT; + return typeFactory.createSqlType(SqlTypeName.TINYINT); case SMALLINT: - return SqlTypeName.SMALLINT; + return typeFactory.createSqlType(SqlTypeName.SMALLINT); case INTEGER: - return SqlTypeName.INTEGER; + return typeFactory.createSqlType(SqlTypeName.INTEGER); case BIGINT: - return SqlTypeName.BIGINT; + return typeFactory.createSqlType(SqlTypeName.BIGINT); case DATE: - return SqlTypeName.DATE; + return typeFactory.createSqlType(SqlTypeName.DATE); case TIME_WITHOUT_TIME_ZONE: - return SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE; + return typeFactory.createSqlType(SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE); case TIMESTAMP_WITHOUT_TIME_ZONE: - return SqlTypeName.TIMESTAMP; + return typeFactory.createSqlType(SqlTypeName.TIMESTAMP); case TIMESTAMP_WITH_LOCAL_TIME_ZONE: - return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + return typeFactory.createSqlType(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE); case FLOAT: - return SqlTypeName.FLOAT; + return typeFactory.createSqlType(SqlTypeName.FLOAT); case DOUBLE: - return SqlTypeName.DOUBLE; + return typeFactory.createSqlType(SqlTypeName.DOUBLE); case CHAR: - return SqlTypeName.CHAR; + return typeFactory.createSqlType(SqlTypeName.CHAR); case VARCHAR: - return SqlTypeName.VARCHAR; + return typeFactory.createSqlType(SqlTypeName.VARCHAR); case BINARY: - return SqlTypeName.BINARY; + return typeFactory.createSqlType(SqlTypeName.BINARY); case VARBINARY: - return SqlTypeName.VARBINARY; + return typeFactory.createSqlType(SqlTypeName.VARBINARY); case DECIMAL: - return SqlTypeName.DECIMAL; + return typeFactory.createSqlType(SqlTypeName.DECIMAL); case ROW: - return SqlTypeName.ROW; + List dataTypes = + ((RowType) dataType) + .getFieldTypes().stream() + .map((type) -> convertCalciteType(typeFactory, type)) + .collect(Collectors.toList()); + return typeFactory.createStructType( + dataTypes, ((RowType) dataType).getFieldNames()); case ARRAY: + DataType elementType = ((ArrayType) dataType).getElementType(); + return typeFactory.createArrayType( + convertCalciteType(typeFactory, elementType), -1); case MAP: + RelDataType keyType = + convertCalciteType(typeFactory, ((MapType) dataType).getKeyType()); + RelDataType valueType = + convertCalciteType(typeFactory, ((MapType) dataType).getValueType()); + return typeFactory.createMapType(keyType, valueType); default: throw new UnsupportedOperationException("Unsupported type: " + dataType); } @@ -173,9 +202,16 @@ public class DataTypeConverter { return DataTypes.VARBINARY(VarBinaryType.MAX_LENGTH); case DECIMAL: return DataTypes.DECIMAL(relDataType.getPrecision(), relDataType.getScale()); - case ROW: case ARRAY: + RelDataType componentType = relDataType.getComponentType(); + return DataTypes.ARRAY(convertCalciteRelDataTypeToDataType(componentType)); case MAP: + RelDataType keyType = relDataType.getKeyType(); + RelDataType valueType = relDataType.getValueType(); + return DataTypes.MAP( + convertCalciteRelDataTypeToDataType(keyType), + convertCalciteRelDataTypeToDataType(valueType)); + case ROW: default: throw new UnsupportedOperationException( "Unsupported type: " + relDataType.getSqlTypeName()); @@ -220,7 +256,9 @@ public class DataTypeConverter { case ROW: return value; case ARRAY: + return convertToArray(value, (ArrayType) dataType); case MAP: + return convertToMap(value, (MapType) dataType); default: throw new UnsupportedOperationException("Unsupported type: " + dataType); } @@ -264,7 +302,9 @@ public class DataTypeConverter { case ROW: return value; case ARRAY: + return convertToArrayOriginal(value, (ArrayType) dataType); case MAP: + return convertToMapOriginal(value, (MapType) dataType); default: throw new UnsupportedOperationException("Unsupported type: " + dataType); } @@ -378,6 +418,101 @@ public class DataTypeConverter { return toLocalTime(obj).toSecondOfDay() * 1000; } + private static Object convertToArray(Object obj, ArrayType arrayType) { + if (obj instanceof ArrayData) { + return obj; + } + if (obj instanceof List) { + List list = (List) obj; + GenericArrayData arrayData = new GenericArrayData(list.toArray()); + return arrayData; + } + if (obj.getClass().isArray()) { + return new GenericArrayData((Object[]) obj); + } + throw new IllegalArgumentException("Unable to convert to ArrayData: " + obj); + } + + private static Object convertToArrayOriginal(Object obj, ArrayType arrayType) { + if (obj instanceof ArrayData) { + ArrayData arrayData = (ArrayData) obj; + Object[] result = new Object[arrayData.size()]; + for (int i = 0; i < arrayData.size(); i++) { + result[i] = getArrayElement(arrayData, i, arrayType.getElementType()); + } + return result; + } + return obj; + } + + private static Object getArrayElement(ArrayData arrayData, int pos, DataType elementType) { + switch (elementType.getTypeRoot()) { + case BOOLEAN: + return arrayData.getBoolean(pos); + case TINYINT: + return arrayData.getByte(pos); + case SMALLINT: + return arrayData.getShort(pos); + case INTEGER: + return arrayData.getInt(pos); + case BIGINT: + return arrayData.getLong(pos); + case FLOAT: + return arrayData.getFloat(pos); + case DOUBLE: + return arrayData.getDouble(pos); + case CHAR: + case VARCHAR: + return arrayData.getString(pos); + case DECIMAL: + return arrayData.getDecimal( + pos, + ((DecimalType) elementType).getPrecision(), + ((DecimalType) elementType).getScale()); + case DATE: + return arrayData.getInt(pos); + case TIME_WITHOUT_TIME_ZONE: + return arrayData.getInt(pos); + case TIMESTAMP_WITHOUT_TIME_ZONE: + return arrayData.getTimestamp(pos, ((TimestampType) elementType).getPrecision()); + case ARRAY: + return convertToArrayOriginal(arrayData.getArray(pos), (ArrayType) elementType); + case MAP: + return convertToMapOriginal(arrayData.getMap(pos), (MapType) elementType); + default: + throw new UnsupportedOperationException( + "Unsupported array element type: " + elementType); + } + } + + private static Object convertToMap(Object obj, MapType mapType) { + if (obj instanceof MapData) { + return obj; + } + if (obj instanceof Map) { + Map javaMap = (Map) obj; + GenericMapData mapData = new GenericMapData(javaMap); + return mapData; + } + throw new IllegalArgumentException("Unable to convert to MapData: " + obj); + } + + private static Object convertToMapOriginal(Object obj, MapType mapType) { + if (obj instanceof MapData) { + MapData mapData = (MapData) obj; + Map result = new HashMap<>(); + ArrayData keyArray = mapData.keyArray(); + ArrayData valueArray = mapData.valueArray(); + for (int i = 0; i < mapData.size(); i++) { + Object key = getArrayElement(keyArray, i, mapType.getKeyType()); + Object value = getArrayElement(valueArray, i, mapType.getValueType()); + result.put(key, value); + } + return result; + } + return obj; + } + private static LocalTime toLocalTime(Object obj) { if (obj == null) { return null; diff --git a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java index 3f4bc366a..9d5c8eb20 100644 --- a/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java +++ b/flink-cdc-runtime/src/test/java/org/apache/flink/cdc/runtime/operators/transform/UserDefinedFunctionDescriptorTest.java @@ -20,8 +20,10 @@ package org.apache.flink.cdc.runtime.operators.transform; import org.apache.flink.cdc.common.types.DataType; import org.apache.flink.cdc.common.types.DataTypes; import org.apache.flink.cdc.common.udf.UserDefinedFunction; +import org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel; import org.apache.flink.table.functions.ScalarFunction; +import com.fasterxml.jackson.core.JsonProcessingException; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -48,7 +50,7 @@ public class UserDefinedFunctionDescriptorTest { public static class NotUDF {} @Test - void testUserDefinedFunctionDescriptor() { + void testUserDefinedFunctionDescriptor() throws JsonProcessingException { assertThat(new UserDefinedFunctionDescriptor("cdc_udf", CdcUdf.class.getName())) .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") @@ -93,5 +95,14 @@ public class UserDefinedFunctionDescriptorTest { "not_even_exist", "not.a.valid.class.path")) .isExactlyInstanceOf(IllegalArgumentException.class) .hasMessage("Failed to instantiate UDF not_even_exist@not.a.valid.class.path"); + String name = "GET_EMBEDDING"; + assertThat(new UserDefinedFunctionDescriptor(name, OpenAIEmbeddingModel.class.getName())) + .extracting("name", "className", "classpath", "returnTypeHint", "isCdcPipelineUdf") + .containsExactly( + "GET_EMBEDDING", + "OpenAIEmbeddingModel", + "org.apache.flink.cdc.runtime.model.OpenAIEmbeddingModel", + DataTypes.ARRAY(DataTypes.FLOAT()), + true); } } diff --git a/pom.xml b/pom.xml index 314d7b1c5..fdef5e72a 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,7 @@ limitations under the License. flink-cdc-runtime flink-cdc-e2e-tests flink-cdc-pipeline-udf-examples + flink-cdc-pipeline-model