Upgrade spring ai version to 1.0.0-M1 (#3759)

pull/3762/head
zhangqian9158 8 months ago committed by GitHub
parent c21921d522
commit 0d81c19f6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -27,7 +27,7 @@
<rocketmq.version>5.1.4</rocketmq.version>
<!-- Spring AI -->
<spring.ai.version>0.8.1</spring.ai.version>
<spring.ai.version>1.0.0-M1</spring.ai.version>
<dashscope-sdk-java.version>2.14.0</dashscope-sdk-java.version>
<!-- scheduling -->

@ -25,9 +25,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@ -47,15 +47,15 @@ public class TongYiSimpleServiceImpl extends AbstractTongYiServiceImpl {
private static final Logger logger = LoggerFactory.getLogger(TongYiService.class);
private final ChatClient chatClient;
private final ChatModel chatModel;
private final StreamingChatClient streamingChatClient;
private final StreamingChatModel streamingChatModel;
@Autowired
public TongYiSimpleServiceImpl(ChatClient chatClient, StreamingChatClient streamingChatClient) {
public TongYiSimpleServiceImpl(ChatModel chatModel, StreamingChatModel streamingChatModel) {
this.chatClient = chatClient;
this.streamingChatClient = streamingChatClient;
this.chatModel = chatModel;
this.streamingChatModel = streamingChatModel;
}
@Override
@ -63,7 +63,7 @@ public class TongYiSimpleServiceImpl extends AbstractTongYiServiceImpl {
Prompt prompt = new Prompt(new UserMessage(message));
return chatClient.call(prompt).getResult().getOutput().getContent();
return chatModel.call(prompt).getResult().getOutput().getContent();
}
@Override
@ -71,7 +71,7 @@ public class TongYiSimpleServiceImpl extends AbstractTongYiServiceImpl {
StringBuilder fullContent = new StringBuilder();
streamingChatClient.stream(new Prompt(message))
streamingChatModel.stream(new Prompt(message))
.flatMap(chatResponse -> Flux.fromIterable(chatResponse.getResults()))
.map(content -> content.getOutput().getContent())
.doOnNext(fullContent::append)

@ -22,7 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.beans.factory.annotation.Autowired;
@ -42,12 +42,12 @@ public class TongYiImagesServiceImpl extends AbstractTongYiServiceImpl {
private static final Logger logger = LoggerFactory.getLogger(TongYiService.class);
private final ImageClient imageClient;
private final ImageModel imageModel;
@Autowired
public TongYiImagesServiceImpl(ImageClient client) {
public TongYiImagesServiceImpl(ImageModel imageModel) {
this.imageClient = client;
this.imageModel = imageModel;
}
@Override
@ -55,7 +55,7 @@ public class TongYiImagesServiceImpl extends AbstractTongYiServiceImpl {
var prompt = new ImagePrompt(imgPrompt);
return imageClient.call(prompt);
return imageModel.call(prompt);
}
}

@ -25,8 +25,8 @@ import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.parser.BeanOutputParser;
@ -47,10 +47,10 @@ public class TongYiOutputParseServiceImpl extends AbstractTongYiServiceImpl {
private static final Logger logger = LoggerFactory.getLogger(TongYiService.class);
private final ChatClient chatClient;
private final ChatModel chatModel;
public TongYiOutputParseServiceImpl(ChatClient chatClient) {
this.chatClient = chatClient;
public TongYiOutputParseServiceImpl(ChatModel chatModel) {
this.chatModel = chatModel;
}
@Override
@ -66,7 +66,7 @@ public class TongYiOutputParseServiceImpl extends AbstractTongYiServiceImpl {
""";
PromptTemplate promptTemplate = new PromptTemplate(userMessage, Map.of("actor", actor, "format", format));
Prompt prompt = promptTemplate.create();
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
// {@link BeanOutputParser#getFormat}
// simple solve.

@ -24,8 +24,8 @@ import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.beans.factory.annotation.Value;
@ -47,13 +47,13 @@ public class TongYiPromptTemplateServiceImpl extends AbstractTongYiServiceImpl {
private static final Logger logger = LoggerFactory.getLogger(TongYiService.class);
private final ChatClient chatClient;
private final ChatModel chatModel;
@Value("classpath:/prompts/joke-prompt.st")
private Resource jokeResource;
public TongYiPromptTemplateServiceImpl(ChatClient chatClient) {
this.chatClient = chatClient;
public TongYiPromptTemplateServiceImpl(ChatModel chatModel) {
this.chatModel = chatModel;
}
@Override
@ -62,6 +62,6 @@ public class TongYiPromptTemplateServiceImpl extends AbstractTongYiServiceImpl {
PromptTemplate promptTemplate = new PromptTemplate(jokeResource);
Prompt prompt = promptTemplate.create(Map.of("adjective", adjective, "topic", topic));
return chatClient.call(prompt).getResult().getOutput();
return chatModel.call(prompt).getResult().getOutput();
}
}

@ -25,9 +25,9 @@ import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.beans.factory.annotation.Value;
@ -46,10 +46,10 @@ public class TongYiRolesServiceImpl extends AbstractTongYiServiceImpl {
private static final Logger logger = LoggerFactory.getLogger(TongYiService.class);
private final ChatClient chatClient;
private final ChatModel chatModel;
public TongYiRolesServiceImpl(ChatClient chatClient) {
this.chatClient = chatClient;
public TongYiRolesServiceImpl(ChatModel chatModel) {
this.chatModel = chatModel;
}
@Value("classpath:/prompts/assistant-message.st")
@ -74,6 +74,6 @@ public class TongYiRolesServiceImpl extends AbstractTongYiServiceImpl {
Prompt prompt = new Prompt(List.of(systemPromptTemplateMessage, userMessage));
return chatClient.call(prompt).getResult().getOutput();
return chatModel.call(prompt).getResult().getOutput();
}
}

@ -26,8 +26,8 @@ import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.beans.factory.annotation.Value;
@ -48,10 +48,10 @@ public class TongYiStuffServiceImpl extends AbstractTongYiServiceImpl {
private static final Logger logger = LoggerFactory.getLogger(TongYiService.class);
private final ChatClient chatClient;
private final ChatModel chatModel;
public TongYiStuffServiceImpl(ChatClient chatClient) {
this.chatClient = chatClient;
public TongYiStuffServiceImpl(ChatModel chatModel) {
this.chatModel = chatModel;
}
@Value("classpath:/docs/wikipedia-curling.md")
@ -76,7 +76,7 @@ public class TongYiStuffServiceImpl extends AbstractTongYiServiceImpl {
}
Prompt prompt = promptTemplate.create(map);
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
return new Completion(generation.getOutput().getContent());
}
}

@ -20,7 +20,7 @@ import java.util.List;
import com.alibaba.cloud.ai.example.tongyi.service.AbstractTongYiServiceImpl;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.stereotype.Service;
/**
@ -31,17 +31,17 @@ import org.springframework.stereotype.Service;
@Service
public class TongYiTextEmbeddingServiceImpl extends AbstractTongYiServiceImpl {
private final EmbeddingClient embeddingClient;
private final EmbeddingModel embeddingModel;
public TongYiTextEmbeddingServiceImpl(EmbeddingClient embeddingClient) {
public TongYiTextEmbeddingServiceImpl(EmbeddingModel embeddingModel) {
this.embeddingClient = embeddingClient;
this.embeddingModel = embeddingModel;
}
@Override
public List<Double> textEmbedding(String text) {
return embeddingClient.embed(text);
return embeddingModel.embed(text);
}
}

@ -20,11 +20,11 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
@ -48,13 +48,13 @@ public class RAGService {
@Value("${topk:10}")
private int topK;
private final ChatClient client;
private final ChatModel chatModel;
private final VectorStore store;
public RAGService(ChatClient client, VectorStore store) {
public RAGService(ChatModel chatModel, VectorStore store) {
this.client = client;
this.chatModel = chatModel;
this.store = store;
}
@ -67,7 +67,7 @@ public class RAGService {
UserMessage userMessage = new UserMessage(message);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
return response.getResult();
}

@ -18,7 +18,7 @@ package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import java.nio.ByteBuffer;
import org.springframework.ai.model.ModelClient;
import org.springframework.ai.model.Model;
/**
* @author yuluo
@ -27,7 +27,7 @@ import org.springframework.ai.model.ModelClient;
*/
@FunctionalInterface
public interface SpeechClient extends ModelClient<SpeechPrompt, SpeechResponse> {
public interface SpeechClient extends Model<SpeechPrompt, SpeechResponse> {
/**
* Generates spoken audio from the provided text message.

@ -20,7 +20,7 @@ import java.nio.ByteBuffer;
import reactor.core.publisher.Flux;
import org.springframework.ai.model.StreamingModelClient;
import org.springframework.ai.model.StreamingModel;
/**
* @author yuluo
@ -29,7 +29,7 @@ import org.springframework.ai.model.StreamingModelClient;
*/
@FunctionalInterface
public interface SpeechStreamClient extends StreamingModelClient<SpeechPrompt, SpeechResponse> {
public interface SpeechStreamClient extends StreamingModel<SpeechPrompt, SpeechResponse> {
/**
* Generates a stream of audio bytes from the provided text message.

@ -34,7 +34,7 @@ import com.alibaba.dashscope.audio.asr.transcription.TranscriptionQueryParam;
import com.alibaba.dashscope.audio.asr.transcription.TranscriptionResult;
import com.alibaba.dashscope.audio.asr.transcription.TranscriptionTaskResult;
import org.springframework.ai.model.ModelClient;
import org.springframework.ai.model.Model;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
@ -47,7 +47,7 @@ import org.springframework.util.Assert;
*/
public class TongYiAudioTranscriptionClient
implements ModelClient<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
/**
* TongYi models options.

@ -15,12 +15,12 @@
*/
package com.alibaba.cloud.ai.tongyi.chat;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import com.alibaba.cloud.ai.tongyi.exception.TongYiException;
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
import com.alibaba.dashscope.aigc.generation.Generation;
@ -41,11 +41,11 @@ import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
@ -54,14 +54,15 @@ import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.CollectionUtils;
/**
* {@link ChatClient} and {@link StreamingChatClient} implementation for {@literal Alibaba DashScope}
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal Alibaba DashScope}
* backed by {@link Generation}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.0.0-RC1
* @see ChatClient
* @see ChatModel
* @see com.alibaba.dashscope.aigc.generation
*/
@ -70,7 +71,7 @@ public class TongYiChatClient extends
com.alibaba.dashscope.common.Message,
ConversationParam,
GenerationResult>
implements ChatClient, StreamingChatClient {
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(TongYiChatClient.class);
@ -157,13 +158,13 @@ public class TongYiChatClient extends
msgManager.add(chatCompletions);
List<org.springframework.ai.chat.Generation> generations =
List<org.springframework.ai.chat.model.Generation> generations =
chatCompletions
.getOutput()
.getChoices()
.stream()
.map(choice ->
new org.springframework.ai.chat.Generation(
new org.springframework.ai.chat.model.Generation(
choice
.getMessage()
.getContent()
@ -201,7 +202,7 @@ public class TongYiChatClient extends
.getMessage()
.getContent())
.map(content -> {
var gen = new org.springframework.ai.chat.Generation(content)
var gen = new org.springframework.ai.chat.model.Generation(content)
.withGenerationMetadata(generateChoiceMetadata(
message.getOutput()
.getChoices()
@ -449,6 +450,20 @@ public class TongYiChatClient extends
return result;
}
@Override
protected Flux<GenerationResult> doChatCompletionStream(ConversationParam request) {
final Flowable<GenerationResult> genRes;
try {
genRes = generation.streamCall(request);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes);
}
@Override
protected boolean isToolFunctionCall(GenerationResult response) {

@ -32,7 +32,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
@ -47,7 +47,7 @@ import org.springframework.util.Assert;
* {@see TextEmbeddingClient}
*/
public class TongYiTextEmbeddingClient extends AbstractEmbeddingClient {
public class TongYiTextEmbeddingClient extends AbstractEmbeddingModel {
private final Logger logger = LoggerFactory.getLogger(TongYiTextEmbeddingClient.class);

@ -30,8 +30,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
@ -48,7 +48,7 @@ import static com.alibaba.cloud.ai.tongyi.metadata.TongYiImagesResponseMetadata.
* @since 2023.0.0.0-RC1
*/
public class TongYiImagesClient implements ImageClient {
public class TongYiImagesClient implements ImageModel {
private final Logger logger = LoggerFactory.getLogger(TongYiImagesClient.class);

@ -16,6 +16,8 @@
package com.alibaba.cloud.ai.tongyi.metadata;
import java.util.HashMap;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
@ -23,6 +25,8 @@ import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
/**
* {@link ChatResponseMetadata} implementation for {@literal Alibaba DashScope}.
*
@ -31,7 +35,7 @@ import org.springframework.util.Assert;
* @since 2023.0.0.0-RC1
*/
public class TongYiAiChatResponseMetadata implements ChatResponseMetadata {
public class TongYiAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";

@ -16,6 +16,7 @@
package com.alibaba.cloud.ai.tongyi.metadata;
import java.util.HashMap;
import java.util.Objects;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
@ -31,7 +32,7 @@ import org.springframework.util.Assert;
* @since 2023.0.0.0-RC1
*/
public class TongYiImagesResponseMetadata implements ImageResponseMetadata {
public class TongYiImagesResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private final Long created;
@ -74,6 +75,7 @@ public class TongYiImagesResponseMetadata implements ImageResponseMetadata {
this.usage = usage;
}
@Override
public Long getCreated() {
return created;
}
@ -94,7 +96,7 @@ public class TongYiImagesResponseMetadata implements ImageResponseMetadata {
this.metrics = metrics;
}
@Override
public Long created() {
return this.created;
}

@ -15,6 +15,7 @@
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import java.util.HashMap;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisUsage;
@ -26,13 +27,15 @@ import org.springframework.ai.model.ResponseMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.0.0-RC1
*/
public class TongYiAudioSpeechResponseMetadata implements ResponseMetadata {
public class TongYiAudioSpeechResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
private SpeechSynthesisUsage usage;

@ -16,6 +16,8 @@
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import java.util.HashMap;
import javax.annotation.Nullable;
import com.alibaba.dashscope.audio.asr.transcription.TranscriptionResult;
@ -26,13 +28,15 @@ import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.util.Assert;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.0.0
*/
public class TongYiAudioTranscriptionResponseMetadata implements ResponseMetadata {
public class TongYiAudioTranscriptionResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
@Nullable
private RateLimit rateLimit;

Loading…
Cancel
Save