[FLINK-37209][transform] Add built-in Qwen model for transform.
parent
36da15a500
commit
1f1f13eb73
@ -0,0 +1,71 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-cdc-pipeline-model</artifactId>
|
||||
<version>${revision}</version>
|
||||
</parent>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<artifactId>flink-cdc-pipeline-model-openai</artifactId>
|
||||
|
||||
<name>flink-cdc-pipeline-model-openai</name>
|
||||
|
||||
<properties>
|
||||
<langchain4j.version>0.23.0</langchain4j.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||
<artifactId>service</artifactId>
|
||||
<version>0.12.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>test-jar</id>
|
||||
<goals>
|
||||
<goal>test-jar</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
@ -0,0 +1,66 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-cdc-pipeline-model</artifactId>
|
||||
<version>${revision}</version>
|
||||
</parent>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<artifactId>flink-cdc-pipeline-model-qwen</artifactId>
|
||||
|
||||
<name>flink-cdc-pipeline-model-openai</name>
|
||||
|
||||
<properties>
|
||||
<langchain4j.version>0.23.0</langchain4j.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>test-jar</id>
|
||||
<goals>
|
||||
<goal>test-jar</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.flink.cdc.runtime.model;
|
||||
|
||||
import org.apache.flink.cdc.common.configuration.ConfigOption;
|
||||
import org.apache.flink.cdc.common.configuration.ConfigOptions;
|
||||
|
||||
/** Options of built-in qwen model. */
|
||||
public class ModelOptions {
|
||||
|
||||
// Options for Qwen Model.
|
||||
public static final ConfigOption<String> QWEN_MODEL_NAME =
|
||||
ConfigOptions.key("qwen.model")
|
||||
.stringType()
|
||||
.noDefaultValue()
|
||||
.withDescription("Name of model to be called.");
|
||||
|
||||
public static final ConfigOption<String> QWEN_API_KEY =
|
||||
ConfigOptions.key("qwen.apikey")
|
||||
.stringType()
|
||||
.noDefaultValue()
|
||||
.withDescription("Api Key for verification of the Model server.");
|
||||
|
||||
public static final ConfigOption<String> QWEN_CHAT_PROMPT =
|
||||
ConfigOptions.key("qwen.chat.prompt")
|
||||
.stringType()
|
||||
.noDefaultValue()
|
||||
.withDescription("Prompt for chat using OpenAI.");
|
||||
}
|
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.flink.cdc.runtime.model;
|
||||
|
||||
import org.apache.flink.cdc.common.configuration.Configuration;
|
||||
import org.apache.flink.cdc.common.types.DataType;
|
||||
import org.apache.flink.cdc.common.types.DataTypes;
|
||||
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
|
||||
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
|
||||
import org.apache.flink.cdc.common.utils.Preconditions;
|
||||
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.apache.flink.cdc.runtime.model.ModelOptions.QWEN_API_KEY;
|
||||
import static org.apache.flink.cdc.runtime.model.ModelOptions.QWEN_CHAT_PROMPT;
|
||||
import static org.apache.flink.cdc.runtime.model.ModelOptions.QWEN_MODEL_NAME;
|
||||
|
||||
/**
|
||||
* A {@link UserDefinedFunction} that use Model defined by Qwen to generate text, refer to <a
|
||||
* href="https://docs.langchain4j.dev/integrations/language-models/dashscope">docs</a>}.
|
||||
*/
|
||||
public class QwenChatModel implements UserDefinedFunction {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(QwenChatModel.class);
|
||||
|
||||
private dev.langchain4j.model.dashscope.QwenChatModel chatModel;
|
||||
|
||||
private String modelName;
|
||||
|
||||
private String prompt;
|
||||
|
||||
public String eval(String input) {
|
||||
return chat(input);
|
||||
}
|
||||
|
||||
private String chat(String input) {
|
||||
if (input == null || input.trim().isEmpty()) {
|
||||
LOG.warn("Empty or null input provided for embedding.");
|
||||
return "";
|
||||
}
|
||||
if (prompt != null) {
|
||||
input = prompt + ": " + input;
|
||||
}
|
||||
return chatModel
|
||||
.generate(Collections.singletonList(new UserMessage(input)))
|
||||
.content()
|
||||
.text();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getReturnType() {
|
||||
return DataTypes.STRING();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(UserDefinedFunctionContext userDefinedFunctionContext) {
|
||||
Configuration modelOptions = userDefinedFunctionContext.configuration();
|
||||
this.modelName = modelOptions.get(QWEN_MODEL_NAME);
|
||||
Preconditions.checkNotNull(modelName, QWEN_MODEL_NAME.key() + " should not be empty.");
|
||||
String apiKey = modelOptions.get(QWEN_API_KEY);
|
||||
Preconditions.checkNotNull(apiKey, QWEN_API_KEY.key() + " should not be empty.");
|
||||
this.prompt = modelOptions.get(QWEN_CHAT_PROMPT);
|
||||
this.chatModel =
|
||||
dev.langchain4j.model.dashscope.QwenChatModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "QwenChatModel{"
|
||||
+ "chatModel="
|
||||
+ chatModel
|
||||
+ ", modelName='"
|
||||
+ modelName
|
||||
+ '\''
|
||||
+ ", prompt='"
|
||||
+ prompt
|
||||
+ '\''
|
||||
+ '}';
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
LOG.info("Closed OpenAIChatModel " + modelName);
|
||||
}
|
||||
}
|
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.flink.cdc.runtime.model;
|
||||
|
||||
import org.apache.flink.cdc.common.configuration.Configuration;
|
||||
import org.apache.flink.cdc.common.data.ArrayData;
|
||||
import org.apache.flink.cdc.common.data.GenericArrayData;
|
||||
import org.apache.flink.cdc.common.types.DataType;
|
||||
import org.apache.flink.cdc.common.types.DataTypes;
|
||||
import org.apache.flink.cdc.common.udf.UserDefinedFunction;
|
||||
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
|
||||
import org.apache.flink.cdc.common.utils.Preconditions;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.flink.cdc.runtime.model.ModelOptions.QWEN_API_KEY;
|
||||
import static org.apache.flink.cdc.runtime.model.ModelOptions.QWEN_MODEL_NAME;
|
||||
|
||||
/**
|
||||
* A {@link UserDefinedFunction} that use Model defined by Qwen to generate vector data, refer to <a
|
||||
* href="https://docs.langchain4j.dev/integrations/embedding-models/dashscope">docs</a>}.
|
||||
*/
|
||||
public class QwenEmbeddingModel implements UserDefinedFunction {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(QwenEmbeddingModel.class);
|
||||
|
||||
private String modelName;
|
||||
|
||||
private dev.langchain4j.model.dashscope.QwenEmbeddingModel embeddingModel;
|
||||
|
||||
public ArrayData eval(String input) {
|
||||
return getEmbedding(input);
|
||||
}
|
||||
|
||||
private ArrayData getEmbedding(String input) {
|
||||
if (input == null || input.trim().isEmpty()) {
|
||||
LOG.debug("Empty or null input provided for embedding.");
|
||||
return new GenericArrayData(new Float[0]);
|
||||
}
|
||||
|
||||
TextSegment textSegment = new TextSegment(input, new Metadata());
|
||||
|
||||
List<Embedding> embeddings =
|
||||
embeddingModel.embedAll(Collections.singletonList(textSegment)).content();
|
||||
|
||||
if (embeddings != null && !embeddings.isEmpty()) {
|
||||
List<Float> embeddingList = embeddings.get(0).vectorAsList();
|
||||
Float[] embeddingArray = embeddingList.toArray(new Float[0]);
|
||||
return new GenericArrayData(embeddingArray);
|
||||
} else {
|
||||
LOG.warn("No embedding results returned for input: {}", input);
|
||||
return new GenericArrayData(new Float[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getReturnType() {
|
||||
return DataTypes.ARRAY(DataTypes.FLOAT());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void open(UserDefinedFunctionContext userDefinedFunctionContext) {
|
||||
Configuration modelOptions = userDefinedFunctionContext.configuration();
|
||||
this.modelName = modelOptions.get(QWEN_MODEL_NAME);
|
||||
Preconditions.checkNotNull(modelName, QWEN_MODEL_NAME.key() + " should not be empty.");
|
||||
String apiKey = modelOptions.get(QWEN_API_KEY);
|
||||
Preconditions.checkNotNull(apiKey, QWEN_API_KEY.key() + " should not be empty.");
|
||||
LOG.info("Opening QwenEmbeddingModel " + modelName);
|
||||
this.embeddingModel =
|
||||
dev.langchain4j.model.dashscope.QwenEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
LOG.info("Closed OpenAIEmbeddingModel " + modelName);
|
||||
}
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.flink.cdc.runtime.model;
|
||||
|
||||
import org.apache.flink.cdc.common.configuration.Configuration;
|
||||
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
|
||||
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
/** A test for {@link QwenChatModel}. */
|
||||
public class TestQwenChatModel {
|
||||
@Test
|
||||
@Disabled("For manual test as there is a limit for quota.")
|
||||
public void testEval() {
|
||||
QwenChatModel qwenChatModel = new QwenChatModel();
|
||||
Configuration configuration = new Configuration();
|
||||
configuration.set(ModelOptions.QWEN_API_KEY, "Your_API_KEY");
|
||||
configuration.set(ModelOptions.QWEN_MODEL_NAME, "qwen-plus");
|
||||
UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration;
|
||||
qwenChatModel.open(userDefinedFunctionContext);
|
||||
String response = qwenChatModel.eval("Who invented the electric light?");
|
||||
Assertions.assertFalse(response.isEmpty());
|
||||
}
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.flink.cdc.runtime.model;
|
||||
|
||||
import org.apache.flink.cdc.common.configuration.Configuration;
|
||||
import org.apache.flink.cdc.common.data.ArrayData;
|
||||
import org.apache.flink.cdc.common.udf.UserDefinedFunctionContext;
|
||||
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
/** A test for {@link QwenEmbeddingModel}. */
|
||||
public class TestQwenEmbeddingModel {
|
||||
|
||||
@Test
|
||||
@Disabled("For manual test as there is a limit for quota.")
|
||||
public void testEval() {
|
||||
QwenEmbeddingModel qwenEmbeddingModel = new QwenEmbeddingModel();
|
||||
Configuration configuration = new Configuration();
|
||||
configuration.set(ModelOptions.QWEN_API_KEY, "Your_API_KEY");
|
||||
configuration.set(ModelOptions.QWEN_MODEL_NAME, "text-embedding-v1");
|
||||
UserDefinedFunctionContext userDefinedFunctionContext = () -> configuration;
|
||||
qwenEmbeddingModel.open(userDefinedFunctionContext);
|
||||
ArrayData arrayData =
|
||||
qwenEmbeddingModel.eval("Flink CDC is a streaming data integration tool");
|
||||
Assertions.assertNotNull(arrayData);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue