feat: add scholarai channel

This commit is contained in:
linzhaoming 2024-06-12 20:35:14 +08:00
parent 58591b8c2a
commit 318a086e10
12 changed files with 7857 additions and 1510 deletions

View File

@ -0,0 +1 @@
v0.0.1

View File

@ -231,6 +231,7 @@ const (
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeScholarAI = 36
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@ -272,4 +273,5 @@ var ChannelBaseURLs = []string{
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"https://api.scholarai.io", //36
}

View File

@ -3,11 +3,12 @@ package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/relay/common"
"one-api/service"
"github.com/gin-gonic/gin"
)
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {

View File

@ -0,0 +1,60 @@
package scholarai
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("X-ScholarAI-API-Key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2ScholarAI(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = scholarAIStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = scholarAIHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -0,0 +1,9 @@
package scholarai
var ModelList = []string{
"scholarai",
"gpt-4o",
"gpt-4-turbo",
}
var ChannelName = "scholarai"

View File

@ -0,0 +1,28 @@
package scholarai
type ScholarAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ScholarAIChatRequest struct {
Model string `json:"model,omitempty"`
Messages []ScholarAIMessage `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ScholarAITextResponseChoice struct {
Index int `json:"index"`
ScholarAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type ScholarAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Logprods string `json:"logprods"`
Choices []ScholarAITextResponseChoice `json:"choices"`
SystemFingerprint string `json:"system_fingerprint"`
}

View File

@ -0,0 +1,141 @@
package scholarai
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func requestOpenAI2ScholarAI(request dto.GeneralOpenAIRequest) *ScholarAIChatRequest {
var combinedMessage strings.Builder
for _, message := range request.Messages {
combinedMessage.WriteString(fmt.Sprintf("%s: %s\n", message.Role, message.StringContent()))
}
scholarAIMessage := ScholarAIMessage{
Role: "user",
Content: combinedMessage.String(),
}
return &ScholarAIChatRequest{
Model: request.Model,
Messages: []ScholarAIMessage{scholarAIMessage},
Stream: request.Stream,
}
}
func responseScholarAI2OpenAI(response *ScholarAITextResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: response.Object,
Created: response.Created,
}
for _, choice := range response.Choices {
content, _ := json.Marshal(choice.ScholarAIMessage.Content)
c := dto.OpenAITextResponseChoice{
Index: choice.Index,
FinishReason: choice.FinishReason,
Message: dto.Message{
Content: content,
Role: choice.ScholarAIMessage.Role,
},
}
fullTextResponse.Choices = append(fullTextResponse.Choices, c)
}
return &fullTextResponse
}
func scholarAIHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var ScholarAITextResponse ScholarAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &ScholarAITextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := responseScholarAI2OpenAI(&ScholarAITextResponse)
completionTokens, _ := service.CountTokenText(ScholarAITextResponse.Choices[0].ScholarAIMessage.Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func scholarAIStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}

View File

@ -20,7 +20,7 @@ const (
APITypePerplexity
APITypeAws
APITypeCohere
APITypeScholarAI
APITypeDummy // this one is only for count, do not add any channel after this
)
@ -57,6 +57,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeAws
case common.ChannelTypeCohere:
apiType = APITypeCohere
case common.ChannelTypeScholarAI:
apiType = APITypeScholarAI
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@ -12,6 +12,7 @@ import (
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/scholarai"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
@ -51,6 +52,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &aws.Adaptor{}
case constant.APITypeCohere:
return &cohere.Adaptor{}
case constant.APITypeScholarAI:
return &scholarai.Adaptor{}
}
return nil
}

6331
web/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -97,6 +97,7 @@ export const CHANNEL_OPTIONS = [
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },
{ key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
{ key: 36, text: 'ScholarAI', value: 36, color: 'green', label: 'ScholarAI' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
{
key: 22,

File diff suppressed because it is too large Load Diff