|
|
|
@ -15,12 +15,12 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
package com.alibaba.cloud.ai.tongyi.chat;
|
|
|
|
|
|
|
|
|
|
import java.util.HashSet;
|
|
|
|
|
import java.util.List;
|
|
|
|
|
import java.util.Objects;
|
|
|
|
|
import java.util.Set;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import com.alibaba.cloud.ai.tongyi.exception.TongYiException;
|
|
|
|
|
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
|
|
|
|
|
import com.alibaba.dashscope.aigc.generation.Generation;
|
|
|
|
@ -41,11 +41,11 @@ import org.slf4j.LoggerFactory;
|
|
|
|
|
import reactor.core.publisher.Flux;
|
|
|
|
|
import reactor.core.scheduler.Schedulers;
|
|
|
|
|
|
|
|
|
|
import org.springframework.ai.chat.ChatClient;
|
|
|
|
|
import org.springframework.ai.chat.ChatResponse;
|
|
|
|
|
import org.springframework.ai.chat.StreamingChatClient;
|
|
|
|
|
import org.springframework.ai.chat.messages.Message;
|
|
|
|
|
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
|
|
|
|
import org.springframework.ai.chat.model.ChatModel;
|
|
|
|
|
import org.springframework.ai.chat.model.ChatResponse;
|
|
|
|
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
|
|
|
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
|
|
import org.springframework.ai.model.ModelOptionsUtils;
|
|
|
|
@ -54,14 +54,15 @@ import org.springframework.ai.model.function.FunctionCallbackContext;
|
|
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
|
import org.springframework.util.CollectionUtils;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* {@link ChatClient} and {@link StreamingChatClient} implementation for {@literal Alibaba DashScope}
|
|
|
|
|
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal Alibaba DashScope}
|
|
|
|
|
* backed by {@link Generation}.
|
|
|
|
|
*
|
|
|
|
|
* @author yuluo
|
|
|
|
|
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
|
|
|
|
* @since 2023.0.0.0-RC1
|
|
|
|
|
* @see ChatClient
|
|
|
|
|
* @see ChatModel
|
|
|
|
|
* @see com.alibaba.dashscope.aigc.generation
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
@ -70,7 +71,7 @@ public class TongYiChatClient extends
|
|
|
|
|
com.alibaba.dashscope.common.Message,
|
|
|
|
|
ConversationParam,
|
|
|
|
|
GenerationResult>
|
|
|
|
|
implements ChatClient, StreamingChatClient {
|
|
|
|
|
implements ChatModel, StreamingChatModel {
|
|
|
|
|
|
|
|
|
|
private static final Logger logger = LoggerFactory.getLogger(TongYiChatClient.class);
|
|
|
|
|
|
|
|
|
@ -157,13 +158,13 @@ public class TongYiChatClient extends
|
|
|
|
|
|
|
|
|
|
msgManager.add(chatCompletions);
|
|
|
|
|
|
|
|
|
|
List<org.springframework.ai.chat.Generation> generations =
|
|
|
|
|
List<org.springframework.ai.chat.model.Generation> generations =
|
|
|
|
|
chatCompletions
|
|
|
|
|
.getOutput()
|
|
|
|
|
.getChoices()
|
|
|
|
|
.stream()
|
|
|
|
|
.map(choice ->
|
|
|
|
|
new org.springframework.ai.chat.Generation(
|
|
|
|
|
new org.springframework.ai.chat.model.Generation(
|
|
|
|
|
choice
|
|
|
|
|
.getMessage()
|
|
|
|
|
.getContent()
|
|
|
|
@ -201,7 +202,7 @@ public class TongYiChatClient extends
|
|
|
|
|
.getMessage()
|
|
|
|
|
.getContent())
|
|
|
|
|
.map(content -> {
|
|
|
|
|
var gen = new org.springframework.ai.chat.Generation(content)
|
|
|
|
|
var gen = new org.springframework.ai.chat.model.Generation(content)
|
|
|
|
|
.withGenerationMetadata(generateChoiceMetadata(
|
|
|
|
|
message.getOutput()
|
|
|
|
|
.getChoices()
|
|
|
|
@ -449,6 +450,20 @@ public class TongYiChatClient extends
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected Flux<GenerationResult> doChatCompletionStream(ConversationParam request) {
|
|
|
|
|
final Flowable<GenerationResult> genRes;
|
|
|
|
|
try {
|
|
|
|
|
genRes = generation.streamCall(request);
|
|
|
|
|
}
|
|
|
|
|
catch (NoApiKeyException | InputRequiredException e) {
|
|
|
|
|
logger.warn("TongYi chat client: " + e.getMessage());
|
|
|
|
|
throw new TongYiException(e.getMessage());
|
|
|
|
|
}
|
|
|
|
|
return Flux.from(genRes);
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected boolean isToolFunctionCall(GenerationResult response) {
|
|
|
|
|
|
|
|
|
|