diff --git a/pom.xml b/pom.xml index b8aa41e..526e9e4 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ com.unfbx chatgpt-java - 1.1.5 + 1.1.5-custom chatgpt-java OpenAI Java SDK, OpenAI Api for Java. ChatGPT Java SDK . https://chatgpt-java.unfbx.com @@ -116,7 +116,7 @@ com.knuddels jtokkit - 0.6.1 + 0.5.0 diff --git a/src/main/java/com/unfbx/chatgpt/WenXinApi.java b/src/main/java/com/unfbx/chatgpt/WenXinApi.java new file mode 100644 index 0000000..687f5b1 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/WenXinApi.java @@ -0,0 +1,18 @@ +package com.unfbx.chatgpt; + +import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletionResponse; +import io.reactivex.Single; +import retrofit2.http.Body; +import retrofit2.http.POST; + +public interface WenXinApi { + /** + * 最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型 + * + * @param chatCompletion chat completion + * @return 返回答案 + */ + @POST("rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions") + Single chatCompletion(@Body ChatCompletion chatCompletion); +} \ No newline at end of file diff --git a/src/main/java/com/unfbx/chatgpt/WenXinStreamClient.java b/src/main/java/com/unfbx/chatgpt/WenXinStreamClient.java new file mode 100644 index 0000000..22c1bec --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/WenXinStreamClient.java @@ -0,0 +1,250 @@ +package com.unfbx.chatgpt; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.http.ContentType; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.unfbx.chatgpt.constant.OpenAIConst; +import com.unfbx.chatgpt.constant.WenXinConst; +import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletion; +import com.unfbx.chatgpt.exception.BaseException; +import com.unfbx.chatgpt.exception.CommonError; +import com.unfbx.chatgpt.function.KeyRandomStrategy; +import com.unfbx.chatgpt.function.KeyStrategyFunction; +import com.unfbx.chatgpt.interceptor.*; +import com.unfbx.chatgpt.sse.ConsoleEventSourceListener; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; +import org.jetbrains.annotations.NotNull; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + + +/** + * 描述: open ai 客户端 + * + * @author https:www.unfbx.com + * 2023-02-28 + */ + +@Slf4j +public class WenXinStreamClient { + @Getter + @NotNull + private List apiKey; + /** + * 自定义api host使用builder的方式构造client + */ + @Getter + private String apiHost; + /** + * 自定义的okHttpClient + * 如果不自定义 ,就是用sdk默认的OkHttpClient实例 + */ + @Getter + private OkHttpClient okHttpClient; + + /** + * api key的获取策略 + */ + @Getter + private KeyStrategyFunction, String> keyStrategy; + + @Getter + private WenXinApi wenXinApi; + + /** + * 自定义鉴权处理拦截器
+ * 可以不设置,默认实现:DefaultOpenAiAuthInterceptor
+ * 如需自定义实现参考:DealKeyWithOpenAiAuthInterceptor + * + * @see DynamicKeyOpenAiAuthInterceptor + * @see DefaultOpenAiAuthInterceptor + */ + @Getter + private WenXinAuthInterceptor authInterceptor; + + /** + * 构造实例对象 + * + * @param builder + */ + private WenXinStreamClient(Builder builder) { + if (CollectionUtil.isEmpty(builder.apiKey)) { + throw new BaseException(CommonError.API_KEYS_NOT_NUL); + } + apiKey = builder.apiKey; + + if (StrUtil.isBlank(builder.apiHost)) { + builder.apiHost = WenXinConst.WenXin_HOST; + } + apiHost = builder.apiHost; + + if (Objects.isNull(builder.keyStrategy)) { + builder.keyStrategy = new KeyRandomStrategy(); + } + keyStrategy = builder.keyStrategy; + + if (Objects.isNull(builder.authInterceptor)) { + builder.authInterceptor = new DefaultWenXinAuthInterceptor(); + } + authInterceptor = builder.authInterceptor; + //设置apiKeys和key的获取策略 + authInterceptor.setApiKey(this.apiKey); + authInterceptor.setKeyStrategy(this.keyStrategy); + + if (Objects.isNull(builder.okHttpClient)) { + builder.okHttpClient = this.okHttpClient(); + } else { + //自定义的okhttpClient 需要增加api keys + builder.okHttpClient = builder.okHttpClient + .newBuilder() + .addInterceptor(authInterceptor) + .build(); + } + okHttpClient = builder.okHttpClient; + + this.wenXinApi = new Retrofit.Builder() + .baseUrl(apiHost) + .client(okHttpClient) + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .addConverterFactory(JacksonConverterFactory.create()) + .build().create(WenXinApi.class); + } + + /** + * 创建默认的OkHttpClient + */ + private OkHttpClient okHttpClient() { + if (Objects.isNull(this.authInterceptor)) { + this.authInterceptor = new DefaultWenXinAuthInterceptor(); + } + this.authInterceptor.setApiKey(this.apiKey); + this.authInterceptor.setKeyStrategy(this.keyStrategy); + return new OkHttpClient + .Builder() + .addInterceptor(this.authInterceptor) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); + } + + + /** + * 流式输出,最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型 + * + * @param chatCompletion 问答参数 + * @param eventSourceListener sse监听器 + * @see ConsoleEventSourceListener + */ + public void streamChatCompletion(ChatCompletion chatCompletion, EventSourceListener eventSourceListener) { + if (Objects.isNull(eventSourceListener)) { + log.error("参数异常:EventSourceListener不能为空,可以参考:com.unfbx.chatgpt.sse.ConsoleEventSourceListener"); + throw new BaseException(CommonError.PARAM_ERROR); + } + if (!chatCompletion.isStream()) { + chatCompletion.setStream(true); + } + try { + EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); + ObjectMapper mapper = new ObjectMapper(); + String requestBody = mapper.writeValueAsString(chatCompletion); + Request request = new Request.Builder() + .url(this.apiHost + "rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions") + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + //创建事件 + factory.newEventSource(request, eventSourceListener); + } catch (JsonProcessingException e) { + log.error("请求参数解析异常:{}", e); + e.printStackTrace(); + } catch (Exception e) { + log.error("请求参数解析异常:{}", e); + e.printStackTrace(); + } + } + + + /** + * 构造 + * + * @return Builder + */ + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private @NotNull List apiKey; + /** + * api请求地址,结尾处有斜杠 + * + * @see OpenAIConst + */ + private String apiHost; + + /** + * 自定义OkhttpClient + */ + private OkHttpClient okHttpClient; + + + /** + * api key的获取策略 + */ + private KeyStrategyFunction keyStrategy; + + /** + * 自定义鉴权拦截器 + */ + private WenXinAuthInterceptor authInterceptor; + + public Builder() { + } + + public Builder apiKey(@NotNull List val) { + apiKey = val; + return this; + } + + /** + * @param val api请求地址,结尾处有斜杠 + * @return Builder + * @see OpenAIConst + */ + public Builder apiHost(String val) { + apiHost = val; + return this; + } + + public Builder keyStrategy(KeyStrategyFunction val) { + keyStrategy = val; + return this; + } + + public Builder okHttpClient(OkHttpClient val) { + okHttpClient = val; + return this; + } + + public Builder authInterceptor(WenXinAuthInterceptor val) { + authInterceptor = val; + return this; + } + + public WenXinStreamClient build() { + return new WenXinStreamClient(this); + } + } +} diff --git a/src/main/java/com/unfbx/chatgpt/constant/WenXinConst.java b/src/main/java/com/unfbx/chatgpt/constant/WenXinConst.java new file mode 100644 index 0000000..629be00 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/constant/WenXinConst.java @@ -0,0 +1,12 @@ +package com.unfbx.chatgpt.constant; + +/** + * 描述: + * + * @author https:www.unfbx.com + * @since 2023-03-06 + */ +public class WenXinConst { + + public final static String WenXin_HOST = "https://aip.baidubce.com/"; +} diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/ChatCompletion.java b/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/ChatCompletion.java new file mode 100644 index 0000000..77ff7fc --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/ChatCompletion.java @@ -0,0 +1,67 @@ +package com.unfbx.chatgpt.entity.chat.wenxin; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.unfbx.chatgpt.entity.chat.Functions; +import lombok.*; +import lombok.extern.slf4j.Slf4j; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +/** + * 文心一言模型参数 + */ +@Data +@Builder +@Slf4j +@JsonInclude(JsonInclude.Include.NON_NULL) +@AllArgsConstructor +public class ChatCompletion implements Serializable { + /** + * 使用什么取样温度,0到2之间。较高的值(如0.8)将使输出更加随机,而较低的值(如0.2)将使输出更加集中和确定。 + *

+ * We generally recommend altering this or but not both.top_p + */ + @Builder.Default + private double temperature = 0.2; + + /** + * 使用温度采样的替代方法称为核心采样,其中模型考虑具有top_p概率质量的令牌的结果。因此,0.1 意味着只考虑包含前 10% 概率质量的代币。 + *

+ * 我们通常建议更改此设置,但不要同时更改两者。temperature + */ + @JsonProperty("top_p") + @Builder.Default + private Double topP = 0.8d; + + /** + * 通过对已生成的token增加惩罚,减少重复生成的现象。说明: + * (1)值越大表示惩罚越大 + * (2)默认1.0,取值范围:[1.0, 2.0] + */ + @JsonProperty("penalty_score") + @Builder.Default + private Double penalty_score = 1d; + + /** + * 是否流式输出. + * default:false + * + * @see com.unfbx.chatgpt.OpenAiStreamClient + */ + @Builder.Default + private boolean stream = false; + + /** + * 用户唯一值,确保接口不被重复调用 + */ + private String user_id; + + /** + * 问题描述 + */ + @NonNull + private List messages; +} diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/ChatCompletionResponse.java b/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/ChatCompletionResponse.java new file mode 100644 index 0000000..2df6ce2 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/ChatCompletionResponse.java @@ -0,0 +1,29 @@ +package com.unfbx.chatgpt.entity.chat.wenxin; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.unfbx.chatgpt.entity.chat.ChatChoice; +import com.unfbx.chatgpt.entity.common.Usage; +import lombok.Data; + +import java.io.Serializable; +import java.util.List; + +/** + * 描述: chat答案类 + * + * @author https:www.unfbx.com + * 2023-03-02 + */ +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class ChatCompletionResponse implements Serializable { + private String id; + private String object; + private long created; + private long sentence_id; + private Boolean is_end; + private Boolean is_truncated; + private String result; + private Boolean need_clear_history; + private int ban_round; +} diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/Message.java b/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/Message.java new file mode 100644 index 0000000..3c069b7 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/wenxin/Message.java @@ -0,0 +1,106 @@ + +package com.unfbx.chatgpt.entity.chat.wenxin; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.unfbx.chatgpt.entity.chat.FunctionCall; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.Getter; + +import java.io.Serializable; + +/** + * 描述: + * + * @author https:www.unfbx.com + * @since 2023-03-02 + */ +@Data +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public class Message implements Serializable { + + /** + * 目前支持四个中角色参考官网,进行情景输入: + */ + private String role; + + private String content; + + private String name; + + public static Builder builder() { + return new Builder(); + } + + /** + * 构造函数 + * + * @param role 角色 + * @param content 描述主题信息 + * @param name name + * @param functionCall functionCall + */ + public Message(String role, String content, String name, FunctionCall functionCall) { + this.role = role; + this.content = content; + this.name = name; + } + + public Message() { + } + + private Message(Builder builder) { + setRole(builder.role); + setContent(builder.content); + setName(builder.name); + } + + + @Getter + @AllArgsConstructor + public enum Role { + + //SYSTEM("system"), + USER("user"), + ASSISTANT("assistant"), + //FUNCTION("function"), + ; + private final String name; + } + + public static final class Builder { + private String role; + private String content; + private String name; + + public Builder() { + } + + public Builder role(Role role) { + this.role = role.getName(); + return this; + } + + public Builder role(String role) { + this.role = role; + return this; + } + + public Builder content(String content) { + this.content = content; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Message build() { + return new Message(this); + } + } +} diff --git a/src/main/java/com/unfbx/chatgpt/interceptor/DefaultWenXinAuthInterceptor.java b/src/main/java/com/unfbx/chatgpt/interceptor/DefaultWenXinAuthInterceptor.java new file mode 100644 index 0000000..bb9ced6 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/interceptor/DefaultWenXinAuthInterceptor.java @@ -0,0 +1,65 @@ +package com.unfbx.chatgpt.interceptor; + +import lombok.extern.slf4j.Slf4j; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * 描述:请求增加header apikey + * + * @author https:www.unfbx.com + * @since 2023-03-23 + */ +@Slf4j +public class DefaultWenXinAuthInterceptor extends WenXinAuthInterceptor { + /** + * 请求头处理 + */ + public DefaultWenXinAuthInterceptor() { + super.setWarringConfig(null); + } + + /** + * 构造方法 + * + * @param warringConfig 所有的key都失效后的告警参数配置 + */ + public DefaultWenXinAuthInterceptor(Map warringConfig) { + super.setWarringConfig(warringConfig); + } + + /** + * 拦截器鉴权 + * + * @param chain Chain + * @return Response对象 + * @throws IOException io异常 + */ + @Override + public Response intercept(Chain chain) throws IOException { + Request original = chain.request(); + return chain.proceed(auth(super.getKey(), original)); + } + + /** + * key失效或者禁用后的处理逻辑 + * 默认不处理 + * + * @param apiKey 返回新的api keys集合 + * @return 新的apiKey集合 + */ + @Override + protected List onErrorDealApiKeys(String apiKey) { + return super.getApiKey(); + } + + @Override + protected void noHaveActiveKeyWarring() { + log.error("--------> [告警] 没有可用的key!!!"); + return; + } +} diff --git a/src/main/java/com/unfbx/chatgpt/interceptor/WenXinAuthInterceptor.java b/src/main/java/com/unfbx/chatgpt/interceptor/WenXinAuthInterceptor.java new file mode 100644 index 0000000..1c200c5 --- /dev/null +++ b/src/main/java/com/unfbx/chatgpt/interceptor/WenXinAuthInterceptor.java @@ -0,0 +1,91 @@ +package com.unfbx.chatgpt.interceptor; + + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.http.ContentType; +import cn.hutool.http.Header; +import com.unfbx.chatgpt.exception.BaseException; +import com.unfbx.chatgpt.exception.CommonError; +import com.unfbx.chatgpt.function.KeyStrategyFunction; +import lombok.Getter; +import lombok.Setter; +import okhttp3.HttpUrl; +import okhttp3.Interceptor; +import okhttp3.Request; +import retrofit2.http.Url; + +import java.net.URL; +import java.util.List; +import java.util.Map; + +public abstract class WenXinAuthInterceptor implements Interceptor { + + + /** + * key 集合 + */ + @Getter + @Setter + private List apiKey; + /** + * 自定义的key的使用策略 + */ + @Getter + @Setter + private KeyStrategyFunction, String> keyStrategy; + + /** + * 预警触发参数配置,配置参数实现飞书、钉钉、企业微信、邮箱预警等功能 + */ + @Getter + @Setter + private Map warringConfig; + + /** + * 自定义apiKeys的处理逻辑 + * + * @param errorKey 错误的key + * @return 返回值是新的apiKeys + */ + protected abstract List onErrorDealApiKeys(String errorKey); + + /** + * 所有的key都失效后,自定义预警配置 + * 可以通过warringConfig配置参数实现飞书、钉钉、企业微信、邮箱预警等 + */ + protected abstract void noHaveActiveKeyWarring(); + + + /** + * 获取请求key + * + * @return key + */ + public final String getKey() { + if (CollectionUtil.isEmpty(apiKey)) { + this.noHaveActiveKeyWarring(); + throw new BaseException(CommonError.NO_ACTIVE_API_KEYS); + } + return keyStrategy.apply(apiKey); + } + + /** + * 默认的鉴权处理方法 + * + * @param key api key + * @param original 源请求体 + * @return 请求体 + */ + public Request auth(String key, Request original) { + String url = original.url().url().toString(); + url = url + "?access_token="+key; + //url.newBuilder().addQueryParameter("access_token",original.url().pa).build(); + + return original.newBuilder() + .url(url) + //.header(Header.AUTHORIZATION.getValue(), "Bearer " + key) + .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) + .method(original.method(), original.body()) + .build(); + } +} diff --git a/src/main/java/com/unfbx/chatgpt/utils/TikTokensUtil.java b/src/main/java/com/unfbx/chatgpt/utils/TikTokensUtil.java index a74e999..e6d7c06 100644 --- a/src/main/java/com/unfbx/chatgpt/utils/TikTokensUtil.java +++ b/src/main/java/com/unfbx/chatgpt/utils/TikTokensUtil.java @@ -271,4 +271,4 @@ public static ModelType getModelTypeByName(String name) { log.warn("[{}]模型不存在或者暂不支持计算tokens", name); return null; } -} +} \ No newline at end of file diff --git a/src/test/java/com/unfbx/chatgpt/WenXinStream.java b/src/test/java/com/unfbx/chatgpt/WenXinStream.java new file mode 100644 index 0000000..69303ba --- /dev/null +++ b/src/test/java/com/unfbx/chatgpt/WenXinStream.java @@ -0,0 +1,49 @@ +package com.unfbx.chatgpt; + +import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletion; +import com.unfbx.chatgpt.entity.chat.wenxin.Message; +import com.unfbx.chatgpt.interceptor.OpenAILogger; +import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor; +import com.unfbx.chatgpt.sse.ConsoleEventSourceListener; +import okhttp3.OkHttpClient; +import okhttp3.logging.HttpLoggingInterceptor; + +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class WenXinStream { + public static void main(String[] args) { + //国内访问需要做代理,国外服务器不需要 + //Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)); + HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); + httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.BODY); + OkHttpClient okHttpClient = new OkHttpClient + .Builder() + //.proxy(proxy)//自定义代理 + .addInterceptor(httpLoggingInterceptor)//自定义日志输出 + .addInterceptor(new OpenAiResponseInterceptor())//自定义返回值拦截 + .connectTimeout(10, TimeUnit.SECONDS)//自定义超时时间 + .writeTimeout(30, TimeUnit.SECONDS)//自定义超时时间 + .readTimeout(30, TimeUnit.SECONDS)//自定义超时时间 + .build(); + + WenXinStreamClient client = WenXinStreamClient.builder() + .apiKey(Arrays.asList("24.ffc3e99e9fd85a7d275a43b6eb71651e.2592000.1702114689.xxx")) + .okHttpClient(okHttpClient) + //自己做了代理就传代理地址,没有可不不传 +// .apiHost("https://自己代理的服务器地址/") + .build(); + + ConsoleEventSourceListener eventSourceListener = new ConsoleEventSourceListener(); + Message message = Message.builder().role(Message.Role.USER).content("你好").build(); + ChatCompletion chatCompletion = ChatCompletion.builder().messages(Arrays.asList(message)).build(); + client.streamChatCompletion(chatCompletion,eventSourceListener); + CountDownLatch countDownLatch = new CountDownLatch(1); + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } +}