diff --git a/common/constants.go b/common/constants.go index cbb7861..3f0cfe8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -211,6 +211,7 @@ const ( ChannelTypeGemini = 24 ChannelTypeMoonshot = 25 ChannelTypeZhipu_v4 = 26 + ChannelTypePerplexity = 27 ) var ChannelBaseURLs = []string{ @@ -240,5 +241,5 @@ var ChannelBaseURLs = []string{ "https://hunyuan.cloud.tencent.com", //23 "https://generativelanguage.googleapis.com", //24 "https://api.moonshot.cn", //25 - "https://open.bigmodel.cn", //26 + "https://api.perplexity.ai", //26 } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go new file mode 100644 index 0000000..4722bb7 --- /dev/null +++ b/relay/channel/perplexity/adaptor.go @@ -0,0 +1,63 @@ +package perplexity + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" + "one-api/service" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + if request.TopP >= 1 { + request.TopP = 0.99 + } + return requestOpenAI2Perplexity(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + var responseText string + err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/perplexity/constants.go b/relay/channel/perplexity/constants.go new file mode 100644 index 0000000..dc15541 --- /dev/null +++ b/relay/channel/perplexity/constants.go @@ -0,0 +1,7 @@ +package perplexity + +var ModelList = []string{ + "sonar-small-chat", "sonar-small-online", "sonar-medium-chat", "sonar-medium-online", "mistral-7b-instruct", "mixtral-8x7b-instruct", +} + +var ChannelName = "perplexity" diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go new file mode 100644 index 0000000..9772aea --- /dev/null +++ b/relay/channel/perplexity/relay-perplexity.go @@ -0,0 +1,21 @@ +package perplexity + +import "one-api/dto" + +func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { + messages := make([]dto.Message, 0, len(request.Messages)) + for _, message := range request.Messages { + messages = append(messages, dto.Message{ + Role: message.Role, + Content: message.Content, + }) + } + return &dto.GeneralOpenAIRequest{ + Model: request.Model, + Stream: request.Stream, + Messages: messages, + Temperature: request.Temperature, + TopP: request.TopP, + MaxTokens: request.MaxTokens, + } +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 51b62c1..79d06da 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -16,6 +16,7 @@ const ( APITypeTencent APITypeGemini APITypeZhipu_v4 + APITypePerplexity APITypeDummy // this one is only for count, do not add any channel after this ) @@ -43,6 +44,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeGemini case common.ChannelTypeZhipu_v4: apiType = APITypeZhipu_v4 + case common.ChannelTypePerplexity: + apiType = APITypePerplexity } return apiType }