diff --git a/src/main/java/com/unfbx/chatgpt/AzureOpenAiApi.java b/src/main/java/com/unfbx/chatgpt/AzureOpenAiApi.java new file mode 100644 index 0000000..b162bb5 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/AzureOpenAiApi.java @@ -0,0 +1,81 @@ +package com.unfbx.chatgpt; + +import com.unfbx.chatgpt.entity.chat.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; +import com.unfbx.chatgpt.entity.completions.Completion; +import com.unfbx.chatgpt.entity.completions.CompletionResponse; +import com.unfbx.chatgpt.entity.embeddings.Embedding; +import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse; +import io.reactivex.Single; +import retrofit2.http.Body; +import retrofit2.http.POST; + +/** + * Azure OpenAI api接口 + * api版本:2023-03-15-preview + * + * apiHost: + * https://${your-resource-name}.openai.azure.com/openai/deployments/${deployment-id}/ + * + * 文档: + * https://learn.microsoft.com/zh-cn/azure/cognitive-services/openai/reference + * swagger:https://github.com/Azure/azure-rest-api-specs/blob/main/specification/cognitiveservices/data-plane/AzureOpenAI/inference/stable/2022-12-01/inference.json + * + * @author skywalker + * @since 2023/5/7 17:22 + */ +public interface AzureOpenAiApi extends OpenAiApi { + + /** + * 与OpenAiApi接口保持一直,不额外增加参数传递api-version字段 + */ + String API_VERSION_QUERY_STRING = "?api-version=2023-03-15-preview"; + + /** + * 文本问答 + * Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. + * + * 注意: + * logprobs, best_of and echo parameters are not available on gpt-35-turbo model. + * azure版本api,在gpt-35-turbo model下,不支持传递logprobs, best_of, echo这3个参数,需要置为null + * + * 示例: + * Completion q = Completion.builder() + * .prompt("who are you?") + * .logprobs(null) + * .bestOf(null) + * .echo(null) + * .maxTokens(16) + * .build(); + * + * @param completion 问答参数 + * @return Single CompletionResponse + */ + @POST("completions" + API_VERSION_QUERY_STRING) + Single completions(@Body Completion completion); + + /** + * 文本向量计算 + * + * 注意: + * Too many inputs for model None. The max number of inputs is 1. We hope to increase the number of inputs per request soon. + * Azure版本api只支持传递一个input + * + * @param embedding 向量参数 + * @return Single EmbeddingResponse + */ + @POST("embeddings" + API_VERSION_QUERY_STRING) + Single embeddings(@Body Embedding embedding); + + /** + * 最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型 + * + * @param chatCompletion chat completion + * @return 返回答案 + */ + @Override + @POST("chat/completions" + API_VERSION_QUERY_STRING) + Single chatCompletion(@Body ChatCompletion chatCompletion); + + +} diff --git a/src/main/java/com/unfbx/chatgpt/OpenAiClient.java b/src/main/java/com/unfbx/chatgpt/OpenAiClient.java index f2194c3..938c278 100644 --- a/src/main/java/com/unfbx/chatgpt/OpenAiClient.java +++ b/src/main/java/com/unfbx/chatgpt/OpenAiClient.java @@ -146,12 +146,17 @@ private OpenAiClient(Builder builder) { .build(); } okHttpClient = builder.okHttpClient; + + if (Objects.isNull(builder.openAiApiClass)) { + builder.openAiApiClass = OpenAiApi.class; + } + this.openAiApi = new Retrofit.Builder() .baseUrl(apiHost) .client(okHttpClient) .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) .addConverterFactory(JacksonConverterFactory.create()) - .build().create(OpenAiApi.class); + .build().create(builder.openAiApiClass); } @@ -831,6 +836,11 @@ public static final class Builder { */ private OpenAiAuthInterceptor authInterceptor; + /** + * api接口类 + */ + private Class openAiApiClass; + public Builder() { } @@ -864,6 +874,11 @@ public Builder authInterceptor(OpenAiAuthInterceptor val) { return this; } + public Builder openAiApiClass(Class val) { + openAiApiClass = val; + return this; + } + public OpenAiClient build() { return new OpenAiClient(this); } diff --git a/src/main/java/com/unfbx/chatgpt/entity/completions/Completion.java b/src/main/java/com/unfbx/chatgpt/entity/completions/Completion.java index 77fb645..266ea45 100644 --- a/src/main/java/com/unfbx/chatgpt/entity/completions/Completion.java +++ b/src/main/java/com/unfbx/chatgpt/entity/completions/Completion.java @@ -76,7 +76,7 @@ public class Completion implements Serializable { private Integer logprobs; @Builder.Default - private boolean echo = false; + private Boolean echo = false; private List stop; diff --git a/src/main/java/com/unfbx/chatgpt/interceptor/AzureDynamicKeyOpenAiAuthInterceptor.java b/src/main/java/com/unfbx/chatgpt/interceptor/AzureDynamicKeyOpenAiAuthInterceptor.java new file mode 100644 index 0000000..276c734 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/interceptor/AzureDynamicKeyOpenAiAuthInterceptor.java @@ -0,0 +1,35 @@ +package com.unfbx.chatgpt.interceptor; + +import cn.hutool.http.ContentType; +import cn.hutool.http.Header; +import lombok.NoArgsConstructor; +import okhttp3.Request; + +/** + * azure另一种鉴权方式,通过api-key + * 文档: + * https://learn.microsoft.com/zh-cn/azure/cognitive-services/openai/reference + * + * @author skywalker + * @since 2023/5/7 17:22 + */ +@NoArgsConstructor +public class AzureDynamicKeyOpenAiAuthInterceptor extends DynamicKeyOpenAiAuthInterceptor { + + /** + * 默认的鉴权处理方法 + * + * @param key api key + * @param original 源请求体 + * @return 请求体 + */ + @Override + public Request auth(String key, Request original) { + Request request = original.newBuilder() + .header("api-key", key) + .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) + .method(original.method(), original.body()) + .build(); + return request; + } +} diff --git a/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiAuthInterceptor.java b/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiAuthInterceptor.java new file mode 100644 index 0000000..86008f5 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/interceptor/AzureOpenAiAuthInterceptor.java @@ -0,0 +1,30 @@ +package com.unfbx.chatgpt.interceptor; + +import cn.hutool.http.ContentType; +import cn.hutool.http.Header; +import lombok.NoArgsConstructor; +import okhttp3.Request; + +/** + * @author skywalker + * @since 2023/5/7 17:22 + */ +@NoArgsConstructor +public class AzureOpenAiAuthInterceptor extends DefaultOpenAiAuthInterceptor { + /** + * 默认的鉴权处理方法 + * + * @param key api key + * @param original 源请求体 + * @return 请求体 + */ + @Override + public Request auth(String key, Request original) { + Request request = original.newBuilder() + .header("api-key", key) + .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) + .method(original.method(), original.body()) + .build(); + return request; + } +}