Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2024-01-28 13:44:16 +00:00
113 changed files with 3608 additions and 3013 deletions

View File

@@ -0,0 +1,22 @@
package aiproxy
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@@ -0,0 +1,194 @@
package aiproxy
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
}
return &LibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: helper.GetUUID(),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
Id: helper.GetUUID(),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
Id: helper.GetUUID(),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: response.Model,
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var usage openai.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,32 @@
package aiproxy
type LibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type LibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type LibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type LibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []LibraryDocument `json:"documents"`
LibraryError
}
type LibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []LibraryDocument `json:"documents"`
}

View File

@@ -0,0 +1,22 @@
package ali
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

329
relay/channel/ali/main.go Normal file
View File

@@ -0,0 +1,329 @@
package ali
// import (
// "bufio"
// "encoding/json"
// "github.com/gin-gonic/gin"
// "io"
// "net/http"
// "one-api/common"
// "strings"
// )
// // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
// type AliMessage struct {
// Content string `json:"content"`
// Role string `json:"role"`
// }
// type AliInput struct {
// //Prompt string `json:"prompt"`
// Messages []AliMessage `json:"messages"`
// }
// type AliParameters struct {
// TopP float64 `json:"top_p,omitempty"`
// TopK int `json:"top_k,omitempty"`
// Seed uint64 `json:"seed,omitempty"`
// EnableSearch bool `json:"enable_search,omitempty"`
// }
// type AliChatRequest struct {
// Model string `json:"model"`
// Input AliInput `json:"input"`
// Parameters AliParameters `json:"parameters,omitempty"`
// }
// type AliEmbeddingRequest struct {
// Model string `json:"model"`
// Input struct {
// Texts []string `json:"texts"`
// } `json:"input"`
// Parameters *struct {
// TextType string `json:"text_type,omitempty"`
// } `json:"parameters,omitempty"`
// }
// type AliEmbedding struct {
// Embedding []float64 `json:"embedding"`
// TextIndex int `json:"text_index"`
// }
// type AliEmbeddingResponse struct {
// Output struct {
// Embeddings []AliEmbedding `json:"embeddings"`
// } `json:"output"`
// Usage AliUsage `json:"usage"`
// AliError
// }
// type AliError struct {
// Code string `json:"code"`
// Message string `json:"message"`
// RequestId string `json:"request_id"`
// }
// type AliUsage struct {
// InputTokens int `json:"input_tokens"`
// OutputTokens int `json:"output_tokens"`
// TotalTokens int `json:"total_tokens"`
// }
// type AliOutput struct {
// Text string `json:"text"`
// FinishReason string `json:"finish_reason"`
// }
// type AliChatResponse struct {
// Output AliOutput `json:"output"`
// Usage AliUsage `json:"usage"`
// AliError
// }
// func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
// messages := make([]AliMessage, 0, len(request.Messages))
// prompt := ""
// for i := 0; i < len(request.Messages); i++ {
// message := request.Messages[i]
// if message.Role == "system" {
// messages = append(messages, AliMessage{
// User: message.Content,
// Bot: "Okay",
// })
// continue
// } else {
// if i == len(request.Messages)-1 {
// prompt = message.Content
// break
// }
// messages = append(messages, AliMessage{
// User: message.Content,
// Bot: request.Messages[i+1].Content,
// })
// i++
// }
// }
// return &AliChatRequest{
// Model: request.Model,
// Input: AliInput{
// Prompt: prompt,
// History: messages,
// },
// //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// // TopP: request.TopP,
// // TopK: 50,
// // //Seed: 0,
// // //EnableSearch: false,
// //},
// }
// }
// func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
// return &AliEmbeddingRequest{
// Model: "text-embedding-v1",
// Input: struct {
// Texts []string `json:"texts"`
// }{
// Texts: request.ParseInput(),
// },
// }
// }
// func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var aliResponse AliEmbeddingResponse
// err := json.NewDecoder(resp.Body).Decode(&aliResponse)
// if err != nil {
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
// }
// err = resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// if aliResponse.Code != "" {
// return &OpenAIErrorWithStatusCode{
// OpenAIError: OpenAIError{
// Message: aliResponse.Message,
// Type: aliResponse.Code,
// Param: aliResponse.RequestId,
// Code: aliResponse.Code,
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
// fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
// jsonResponse, err := json.Marshal(fullTextResponse)
// if err != nil {
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// c.Writer.WriteHeader(resp.StatusCode)
// _, err = c.Writer.Write(jsonResponse)
// return nil, &fullTextResponse.Usage
// }
// func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
// openAIEmbeddingResponse := OpenAIEmbeddingResponse{
// Object: "list",
// Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
// Model: "text-embedding-v1",
// Usage: Usage{TotalTokens: response.Usage.TotalTokens},
// }
// for _, item := range response.Output.Embeddings {
// openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
// Object: `embedding`,
// Index: item.TextIndex,
// Embedding: item.Embedding,
// })
// }
// return &openAIEmbeddingResponse
// }
// func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
// choice := OpenAITextResponseChoice{
// Index: 0,
// Message: Message{
// Role: "assistant",
// Content: response.Output.Text,
// },
// FinishReason: response.Output.FinishReason,
// }
// fullTextResponse := OpenAITextResponse{
// Id: response.RequestId,
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: []OpenAITextResponseChoice{choice},
// Usage: Usage{
// PromptTokens: response.Usage.InputTokens,
// CompletionTokens: response.Usage.OutputTokens,
// TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
// },
// }
// return &fullTextResponse
// }
// func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
// var choice ChatCompletionsStreamResponseChoice
// choice.Delta.Content = aliResponse.Output.Text
// if aliResponse.Output.FinishReason != "null" {
// finishReason := aliResponse.Output.FinishReason
// choice.FinishReason = &finishReason
// }
// response := ChatCompletionsStreamResponse{
// Id: aliResponse.RequestId,
// Object: "chat.completion.chunk",
// Created: common.GetTimestamp(),
// Model: "ernie-bot",
// Choices: []ChatCompletionsStreamResponseChoice{choice},
// }
// return &response
// }
// func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var usage Usage
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// if atEOF && len(data) == 0 {
// return 0, nil, nil
// }
// if i := strings.Index(string(data), "\n"); i >= 0 {
// return i + 1, data[0:i], nil
// }
// if atEOF {
// return len(data), data, nil
// }
// return 0, nil, nil
// })
// dataChan := make(chan string)
// stopChan := make(chan bool)
// go func() {
// for scanner.Scan() {
// data := scanner.Text()
// if len(data) < 5 { // ignore blank line or wrong format
// continue
// }
// if data[:5] != "data:" {
// continue
// }
// data = data[5:]
// dataChan <- data
// }
// stopChan <- true
// }()
// setEventStreamHeaders(c)
// lastResponseText := ""
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// var aliResponse AliChatResponse
// err := json.Unmarshal([]byte(data), &aliResponse)
// if err != nil {
// common.SysError("error unmarshalling stream response: " + err.Error())
// return true
// }
// if aliResponse.Usage.OutputTokens != 0 {
// usage.PromptTokens = aliResponse.Usage.InputTokens
// usage.CompletionTokens = aliResponse.Usage.OutputTokens
// usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
// }
// response := streamResponseAli2OpenAI(&aliResponse)
// response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
// lastResponseText = aliResponse.Output.Text
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// common.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
// err := resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// return nil, &usage
// }
// func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var aliResponse AliChatResponse
// responseBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
// }
// err = resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// err = json.Unmarshal(responseBody, &aliResponse)
// if err != nil {
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
// }
// if aliResponse.Code != "" {
// return &OpenAIErrorWithStatusCode{
// OpenAIError: OpenAIError{
// Message: aliResponse.Message,
// Type: aliResponse.Code,
// Param: aliResponse.RequestId,
// Code: aliResponse.Code,
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
// fullTextResponse := responseAli2OpenAI(&aliResponse)
// jsonResponse, err := json.Marshal(fullTextResponse)
// if err != nil {
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// c.Writer.WriteHeader(resp.StatusCode)
// _, err = c.Writer.Write(jsonResponse)
// return nil, &fullTextResponse.Usage
// }

View File

@@ -0,0 +1,71 @@
package ali
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -0,0 +1,22 @@
package anthropic
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@@ -0,0 +1,204 @@
package anthropic
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"io"
"net/http"
"strings"
)
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
// messages, err := textRequest.TextMessages()
// if err != nil {
// log.Panicf("invalid message type: %T", textRequest.Messages)
// }
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse Response
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = model
completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
usage := openai.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,29 @@
package anthropic
type Metadata struct {
UserId string `json:"user_id"`
}
type Request struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type Response struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error Error `json:"error"`
}

View File

@@ -0,0 +1,22 @@
package baidu
// import (
// "github.com/gin-gonic/gin"
// "github.com/songquanpeng/one-api/relay/channel/openai"
// "net/http"
// )
// type Adaptor struct {
// }
// func (a *Adaptor) Auth(c *gin.Context) error {
// return nil
// }
// func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
// return nil, nil
// }
// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
// return nil, nil, nil
// }

359
relay/channel/baidu/main.go Normal file
View File

@@ -0,0 +1,359 @@
package baidu
// import (
// "bufio"
// "encoding/json"
// "errors"
// "fmt"
// "github.com/gin-gonic/gin"
// "io"
// "net/http"
// "one-api/common"
// "strings"
// "sync"
// "time"
// )
// // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
// type BaiduTokenResponse struct {
// ExpiresIn int `json:"expires_in"`
// AccessToken string `json:"access_token"`
// }
// type BaiduMessage struct {
// Role string `json:"role"`
// Content string `json:"content"`
// }
// type BaiduChatRequest struct {
// Messages []BaiduMessage `json:"messages"`
// Stream bool `json:"stream"`
// UserId string `json:"user_id,omitempty"`
// }
// type BaiduError struct {
// ErrorCode int `json:"error_code"`
// ErrorMsg string `json:"error_msg"`
// }
// type BaiduChatResponse struct {
// Id string `json:"id"`
// Object string `json:"object"`
// Created int64 `json:"created"`
// Result string `json:"result"`
// IsTruncated bool `json:"is_truncated"`
// NeedClearHistory bool `json:"need_clear_history"`
// Usage Usage `json:"usage"`
// BaiduError
// }
// type BaiduChatStreamResponse struct {
// BaiduChatResponse
// SentenceId int `json:"sentence_id"`
// IsEnd bool `json:"is_end"`
// }
// type BaiduEmbeddingRequest struct {
// Input []string `json:"input"`
// }
// type BaiduEmbeddingData struct {
// Object string `json:"object"`
// Embedding []float64 `json:"embedding"`
// Index int `json:"index"`
// }
// type BaiduEmbeddingResponse struct {
// Id string `json:"id"`
// Object string `json:"object"`
// Created int64 `json:"created"`
// Data []BaiduEmbeddingData `json:"data"`
// Usage Usage `json:"usage"`
// BaiduError
// }
// type BaiduAccessToken struct {
// AccessToken string `json:"access_token"`
// Error string `json:"error,omitempty"`
// ErrorDescription string `json:"error_description,omitempty"`
// ExpiresIn int64 `json:"expires_in,omitempty"`
// ExpiresAt time.Time `json:"-"`
// }
// var baiduTokenStore sync.Map
// func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
// messages := make([]BaiduMessage, 0, len(request.Messages))
// for _, message := range request.Messages {
// if message.Role == "system" {
// messages = append(messages, BaiduMessage{
// Role: "user",
// Content: message.Content,
// })
// messages = append(messages, BaiduMessage{
// Role: "assistant",
// Content: "Okay",
// })
// } else {
// messages = append(messages, BaiduMessage{
// Role: message.Role,
// Content: message.Content,
// })
// }
// }
// return &BaiduChatRequest{
// Messages: messages,
// Stream: request.Stream,
// }
// }
// func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
// choice := OpenAITextResponseChoice{
// Index: 0,
// Message: Message{
// Role: "assistant",
// Content: response.Result,
// },
// FinishReason: "stop",
// }
// fullTextResponse := OpenAITextResponse{
// Id: response.Id,
// Object: "chat.completion",
// Created: response.Created,
// Choices: []OpenAITextResponseChoice{choice},
// Usage: response.Usage,
// }
// return &fullTextResponse
// }
// func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
// var choice ChatCompletionsStreamResponseChoice
// choice.Delta.Content = baiduResponse.Result
// if baiduResponse.IsEnd {
// choice.FinishReason = &stopFinishReason
// }
// response := ChatCompletionsStreamResponse{
// Id: baiduResponse.Id,
// Object: "chat.completion.chunk",
// Created: baiduResponse.Created,
// Model: "ernie-bot",
// Choices: []ChatCompletionsStreamResponseChoice{choice},
// }
// return &response
// }
// func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
// return &BaiduEmbeddingRequest{
// Input: request.ParseInput(),
// }
// }
// func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
// openAIEmbeddingResponse := OpenAIEmbeddingResponse{
// Object: "list",
// Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
// Model: "baidu-embedding",
// Usage: response.Usage,
// }
// for _, item := range response.Data {
// openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
// Object: item.Object,
// Index: item.Index,
// Embedding: item.Embedding,
// })
// }
// return &openAIEmbeddingResponse
// }
// func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var usage Usage
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// if atEOF && len(data) == 0 {
// return 0, nil, nil
// }
// if i := strings.Index(string(data), "\n"); i >= 0 {
// return i + 1, data[0:i], nil
// }
// if atEOF {
// return len(data), data, nil
// }
// return 0, nil, nil
// })
// dataChan := make(chan string)
// stopChan := make(chan bool)
// go func() {
// for scanner.Scan() {
// data := scanner.Text()
// if len(data) < 6 { // ignore blank line or wrong format
// continue
// }
// data = data[6:]
// dataChan <- data
// }
// stopChan <- true
// }()
// setEventStreamHeaders(c)
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// var baiduResponse BaiduChatStreamResponse
// err := json.Unmarshal([]byte(data), &baiduResponse)
// if err != nil {
// common.SysError("error unmarshalling stream response: " + err.Error())
// return true
// }
// if baiduResponse.Usage.TotalTokens != 0 {
// usage.TotalTokens = baiduResponse.Usage.TotalTokens
// usage.PromptTokens = baiduResponse.Usage.PromptTokens
// usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
// }
// response := streamResponseBaidu2OpenAI(&baiduResponse)
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// common.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
// err := resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// return nil, &usage
// }
// func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var baiduResponse BaiduChatResponse
// responseBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
// }
// err = resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// err = json.Unmarshal(responseBody, &baiduResponse)
// if err != nil {
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
// }
// if baiduResponse.ErrorMsg != "" {
// return &OpenAIErrorWithStatusCode{
// OpenAIError: OpenAIError{
// Message: baiduResponse.ErrorMsg,
// Type: "baidu_error",
// Param: "",
// Code: baiduResponse.ErrorCode,
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
// fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
// jsonResponse, err := json.Marshal(fullTextResponse)
// if err != nil {
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// c.Writer.WriteHeader(resp.StatusCode)
// _, err = c.Writer.Write(jsonResponse)
// return nil, &fullTextResponse.Usage
// }
// func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var baiduResponse BaiduEmbeddingResponse
// responseBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
// }
// err = resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// err = json.Unmarshal(responseBody, &baiduResponse)
// if err != nil {
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
// }
// if baiduResponse.ErrorMsg != "" {
// return &OpenAIErrorWithStatusCode{
// OpenAIError: OpenAIError{
// Message: baiduResponse.ErrorMsg,
// Type: "baidu_error",
// Param: "",
// Code: baiduResponse.ErrorCode,
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
// fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
// jsonResponse, err := json.Marshal(fullTextResponse)
// if err != nil {
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// c.Writer.WriteHeader(resp.StatusCode)
// _, err = c.Writer.Write(jsonResponse)
// return nil, &fullTextResponse.Usage
// }
// func getBaiduAccessToken(apiKey string) (string, error) {
// if val, ok := baiduTokenStore.Load(apiKey); ok {
// var accessToken BaiduAccessToken
// if accessToken, ok = val.(BaiduAccessToken); ok {
// // soon this will expire
// if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
// go func() {
// _, _ = getBaiduAccessTokenHelper(apiKey)
// }()
// }
// return accessToken.AccessToken, nil
// }
// }
// accessToken, err := getBaiduAccessTokenHelper(apiKey)
// if err != nil {
// return "", err
// }
// if accessToken == nil {
// return "", errors.New("getBaiduAccessToken return a nil token")
// }
// return (*accessToken).AccessToken, nil
// }
// func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
// parts := strings.Split(apiKey, "|")
// if len(parts) != 2 {
// return nil, errors.New("invalid baidu apikey")
// }
// req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
// parts[0], parts[1]), nil)
// if err != nil {
// return nil, err
// }
// req.Header.Add("Content-Type", "application/json")
// req.Header.Add("Accept", "application/json")
// res, err := impatientHTTPClient.Do(req)
// if err != nil {
// return nil, err
// }
// defer res.Body.Close()
// var accessToken BaiduAccessToken
// err = json.NewDecoder(res.Body).Decode(&accessToken)
// if err != nil {
// return nil, err
// }
// if accessToken.Error != "" {
// return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
// }
// if accessToken.AccessToken == "" {
// return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
// }
// accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
// baiduTokenStore.Store(apiKey, accessToken)
// return &accessToken, nil
// }

View File

@@ -0,0 +1,50 @@
package baidu
// import (
// "github.com/songquanpeng/one-api/relay/channel/openai"
// "time"
// )
// type ChatResponse struct {
// Id string `json:"id"`
// Object string `json:"object"`
// Created int64 `json:"created"`
// Result string `json:"result"`
// IsTruncated bool `json:"is_truncated"`
// NeedClearHistory bool `json:"need_clear_history"`
// Usage openai.Usage `json:"usage"`
// Error
// }
// type ChatStreamResponse struct {
// ChatResponse
// SentenceId int `json:"sentence_id"`
// IsEnd bool `json:"is_end"`
// }
// type EmbeddingRequest struct {
// Input []string `json:"input"`
// }
// type EmbeddingData struct {
// Object string `json:"object"`
// Embedding []float64 `json:"embedding"`
// Index int `json:"index"`
// }
// type EmbeddingResponse struct {
// Id string `json:"id"`
// Object string `json:"object"`
// Created int64 `json:"created"`
// Data []EmbeddingData `json:"data"`
// Usage openai.Usage `json:"usage"`
// Error
// }
// type AccessToken struct {
// AccessToken string `json:"access_token"`
// Error string `json:"error,omitempty"`
// ErrorDescription string `json:"error_description,omitempty"`
// ExpiresIn int64 `json:"expires_in,omitempty"`
// ExpiresAt time.Time `json:"-"`
// }

View File

@@ -0,0 +1,22 @@
package google
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@@ -0,0 +1,312 @@
package google
import (
"bufio"
"context"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
GeminiVisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == openai.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == openai.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url)
if err != nil {
logger.Warn(context.TODO(),
fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err))
continue
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
logger.Info(context.TODO(),
fmt.Sprintf("send %d images to gemini-pro-vision", len(parts)))
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: openai.Message{
Role: "assistant",
Content: "",
},
FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = model
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
usage := openai.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,80 @@
package google
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []openai.Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}

View File

@@ -0,0 +1,175 @@
package google
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"io"
"net/http"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
TopP: textRequest.TopP,
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
} else {
palmMessage.Author = "1"
}
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
}
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: openai.Message{
Role: "assistant",
Content: candidate.Content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
logger.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
fullTextResponse.Model = model
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
usage := openai.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,15 @@
package channel
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor interface {
GetRequestURL() string
Auth(c *gin.Context) error
ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
DoRequest(request *openai.GeneralOpenAIRequest) error
DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
}

View File

@@ -0,0 +1,21 @@
package openai
import (
"github.com/gin-gonic/gin"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
return nil, nil, nil
}

View File

@@ -0,0 +1,6 @@
package openai
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)

View File

@@ -0,0 +1,145 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"io"
"net/http"
"strings"
)
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
case constant.RelayModeChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
}
case constant.RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
var textResponse SlimTextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
return nil, &textResponse.Usage
}

View File

@@ -0,0 +1,288 @@
package openai
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type WhisperJSONResponse struct {
Text string `json:"text,omitempty"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type UsageOrResponseText struct {
*Usage
ResponseText string
}
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type ErrorWithStatusCode struct {
Error
StatusCode int `json:"status_code"`
}
type SlimTextResponse struct {
Choices []TextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error Error `json:"error"`
}
type TextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type TextResponse struct {
Id string `json:"id"`
Model string `json:"model,omitempty"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []TextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type EmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}

View File

@@ -0,0 +1,254 @@
package openai
import (
"errors"
"fmt"
"github.com/pkoukk/tiktoken-go"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"math"
"strings"
)
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
logger.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
}
logger.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder
}
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
return defaultTokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if config.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
func CountTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
switch v := message.Content.(type) {
case string:
tokenNum += getTokenNum(tokenEncoder, v)
case []any:
for _, it := range v {
m := it.(map[string]any)
switch m["type"] {
case "text":
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
case "image_url":
imageUrl, ok := m["image_url"].(map[string]any)
if ok {
url := imageUrl["url"].(string)
detail := ""
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
}
}
}
}
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
// func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
// tokenEncoder := getTokenEncoder(model)
// // Reference:
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// // https://github.com/pkoukk/tiktoken-go/issues/6
// //
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
// var tokensPerMessage int
// var tokensPerName int
// if model == "gpt-3.5-turbo-0301" {
// tokensPerMessage = 4
// tokensPerName = -1 // If there's a name, the role is omitted
// } else {
// tokensPerMessage = 3
// tokensPerName = 1
// }
// tokenNum := 0
// for _, message := range messages {
// tokenNum += tokensPerMessage
// for _, cnt := range message.Content {
// switch cnt.Type {
// case OpenaiVisionMessageContentTypeText:
// tokenNum += getTokenNum(tokenEncoder, cnt.Text)
// case OpenaiVisionMessageContentTypeImageUrl:
// imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
// if err != nil {
// return 0, errors.Wrap(err, "failed to decode base64 image")
// }
// if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
// return 0, errors.Wrap(err, "failed to count vision image token")
// } else {
// tokenNum += imgtoken
// }
// }
// }
// tokenNum += getTokenNum(tokenEncoder, message.Role)
// if message.Name != nil {
// tokenNum += tokensPerName
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
// }
// }
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
// return tokenNum, nil
// }
const (
lowDetailCost = 85
highDetailCostPerTile = 170
additionalCost = 85
)
// https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string) (_ int, err error) {
var fetchSize = true
var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
// According to the official guide, "low" disable the high-res model,
// and only receive low-res 512px x 512px version of the image, indicating
// that image is treated as low-res when size is smaller than 512px x 512px,
// then we can assume that image size larger than 512px x 512px is treated
// as high-res. Then we have the following logic:
// if detail == "" || detail == "auto" {
// width, height, err = image.GetImageSize(url)
// if err != nil {
// return 0, err
// }
// fetchSize = false
// // not sure if this is correct
// if width > 512 || height > 512 {
// detail = "high"
// } else {
// detail = "low"
// }
// }
// However, in my test, it seems to be always the same as "high".
// The following image, which is 125x50, is still treated as high-res, taken
// 255 tokens in the response of non-stream chat completion api.
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
if detail == "" || detail == "auto" {
// assume by test, not sure if this is correct
detail = "high"
}
switch detail {
case "low":
return lowDetailCost, nil
case "high":
if fetchSize {
width, height, err = image.GetImageSize(url)
if err != nil {
return 0, err
}
}
if width > 2048 || height > 2048 { // max(width, height) > 2048
ratio := float64(2048) / math.Max(float64(width), float64(height))
width = int(float64(width) * ratio)
height = int(float64(height) * ratio)
}
if width > 768 && height > 768 { // min(width, height) > 768
ratio := float64(768) / math.Min(float64(width), float64(height))
width = int(float64(width) * ratio)
height = int(float64(height) * ratio)
}
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
result := numSquares*highDetailCostPerTile + additionalCost
return result, nil
default:
return 0, errors.New("invalid detail option")
}
}
func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return CountTokenText(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model)
}
return 0
}
func CountTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}

View File

@@ -0,0 +1,13 @@
package openai
func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
Error := Error{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return &ErrorWithStatusCode{
Error: Error,
StatusCode: statusCode,
}
}

View File

@@ -0,0 +1,22 @@
package tencent
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@@ -0,0 +1,234 @@
package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"io"
"net/http"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, Message{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &ChatRequest{
Timestamp: helper.GetTimestamp(),
Expired: helper.GetTimestamp() + 24*60*60,
QueryID: helper.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Object: "chat.completion",
Created: helper.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
choice := openai.TextResponseChoice{
Index: 0,
Message: openai.Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &constant.StopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
var TencentResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &openai.ErrorWithStatusCode{
Error: openai.Error{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse.Model = "hunyuan"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func GetSign(req ChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

View File

@@ -0,0 +1,63 @@
package tencent
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []Message `json:"messages"`
}
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type ChatResponse struct {
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage openai.Usage `json:"usage,omitempty"` // token 数量
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}

View File

@@ -0,0 +1,22 @@
package xunfei
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

View File

@@ -0,0 +1,306 @@
package xunfei
// import (
// "crypto/hmac"
// "crypto/sha256"
// "encoding/base64"
// "encoding/json"
// "fmt"
// "github.com/gin-gonic/gin"
// "github.com/gorilla/websocket"
// "io"
// "net/http"
// "net/url"
// "one-api/common"
// "strings"
// "time"
// )
// // https://console.xfyun.cn/services/cbm
// // https://www.xfyun.cn/doc/spark/Web.html
// type XunfeiMessage struct {
// Role string `json:"role"`
// Content string `json:"content"`
// }
// type XunfeiChatRequest struct {
// Header struct {
// AppId string `json:"app_id"`
// } `json:"header"`
// Parameter struct {
// Chat struct {
// Domain string `json:"domain,omitempty"`
// Temperature float64 `json:"temperature,omitempty"`
// TopK int `json:"top_k,omitempty"`
// MaxTokens int `json:"max_tokens,omitempty"`
// Auditing bool `json:"auditing,omitempty"`
// } `json:"chat"`
// } `json:"parameter"`
// Payload struct {
// Message struct {
// Text []XunfeiMessage `json:"text"`
// } `json:"message"`
// } `json:"payload"`
// }
// type XunfeiChatResponseTextItem struct {
// Content string `json:"content"`
// Role string `json:"role"`
// Index int `json:"index"`
// }
// type XunfeiChatResponse struct {
// Header struct {
// Code int `json:"code"`
// Message string `json:"message"`
// Sid string `json:"sid"`
// Status int `json:"status"`
// } `json:"header"`
// Payload struct {
// Choices struct {
// Status int `json:"status"`
// Seq int `json:"seq"`
// Text []XunfeiChatResponseTextItem `json:"text"`
// } `json:"choices"`
// Usage struct {
// //Text struct {
// // QuestionTokens string `json:"question_tokens"`
// // PromptTokens string `json:"prompt_tokens"`
// // CompletionTokens string `json:"completion_tokens"`
// // TotalTokens string `json:"total_tokens"`
// //} `json:"text"`
// Text Usage `json:"text"`
// } `json:"usage"`
// } `json:"payload"`
// }
// func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
// messages := make([]XunfeiMessage, 0, len(request.Messages))
// for _, message := range request.Messages {
// if message.Role == "system" {
// messages = append(messages, XunfeiMessage{
// Role: "user",
// Content: message.Content,
// })
// messages = append(messages, XunfeiMessage{
// Role: "assistant",
// Content: "Okay",
// })
// } else {
// messages = append(messages, XunfeiMessage{
// Role: message.Role,
// Content: message.Content,
// })
// }
// }
// xunfeiRequest := XunfeiChatRequest{}
// xunfeiRequest.Header.AppId = xunfeiAppId
// xunfeiRequest.Parameter.Chat.Domain = domain
// xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
// xunfeiRequest.Parameter.Chat.TopK = request.N
// xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
// xunfeiRequest.Payload.Message.Text = messages
// return &xunfeiRequest
// }
// func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
// if len(response.Payload.Choices.Text) == 0 {
// response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
// {
// Content: "",
// },
// }
// }
// choice := OpenAITextResponseChoice{
// Index: 0,
// Message: Message{
// Role: "assistant",
// Content: response.Payload.Choices.Text[0].Content,
// },
// FinishReason: stopFinishReason,
// }
// fullTextResponse := OpenAITextResponse{
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: []OpenAITextResponseChoice{choice},
// Usage: response.Payload.Usage.Text,
// }
// return &fullTextResponse
// }
// func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
// if len(xunfeiResponse.Payload.Choices.Text) == 0 {
// xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
// {
// Content: "",
// },
// }
// }
// var choice ChatCompletionsStreamResponseChoice
// choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
// if xunfeiResponse.Payload.Choices.Status == 2 {
// choice.FinishReason = &stopFinishReason
// }
// response := ChatCompletionsStreamResponse{
// Object: "chat.completion.chunk",
// Created: common.GetTimestamp(),
// Model: "SparkDesk",
// Choices: []ChatCompletionsStreamResponseChoice{choice},
// }
// return &response
// }
// func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
// HmacWithShaToBase64 := func(algorithm, data, key string) string {
// mac := hmac.New(sha256.New, []byte(key))
// mac.Write([]byte(data))
// encodeData := mac.Sum(nil)
// return base64.StdEncoding.EncodeToString(encodeData)
// }
// ul, err := url.Parse(hostUrl)
// if err != nil {
// fmt.Println(err)
// }
// date := time.Now().UTC().Format(time.RFC1123)
// signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
// sign := strings.Join(signString, "\n")
// sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
// authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
// "hmac-sha256", "host date request-line", sha)
// authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
// v := url.Values{}
// v.Add("host", ul.Host)
// v.Add("date", date)
// v.Add("authorization", authorization)
// callUrl := hostUrl + "?" + v.Encode()
// return callUrl
// }
// func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
// domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
// dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
// if err != nil {
// return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
// }
// setEventStreamHeaders(c)
// var usage Usage
// c.Stream(func(w io.Writer) bool {
// select {
// case xunfeiResponse := <-dataChan:
// usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
// usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
// usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
// response := streamResponseXunfei2OpenAI(&xunfeiResponse)
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// common.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
// return nil, &usage
// }
// func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
// domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
// dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
// if err != nil {
// return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
// }
// var usage Usage
// var content string
// var xunfeiResponse XunfeiChatResponse
// stop := false
// for !stop {
// select {
// case xunfeiResponse = <-dataChan:
// if len(xunfeiResponse.Payload.Choices.Text) == 0 {
// continue
// }
// content += xunfeiResponse.Payload.Choices.Text[0].Content
// usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
// usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
// usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
// case stop = <-stopChan:
// }
// }
// xunfeiResponse.Payload.Choices.Text[0].Content = content
// response := responseXunfei2OpenAI(&xunfeiResponse)
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// _, _ = c.Writer.Write(jsonResponse)
// return nil, &usage
// }
// func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
// d := websocket.Dialer{
// HandshakeTimeout: 5 * time.Second,
// }
// conn, resp, err := d.Dial(authUrl, nil)
// if err != nil || resp.StatusCode != 101 {
// return nil, nil, err
// }
// data := requestOpenAI2Xunfei(textRequest, appId, domain)
// err = conn.WriteJSON(data)
// if err != nil {
// return nil, nil, err
// }
// dataChan := make(chan XunfeiChatResponse)
// stopChan := make(chan bool)
// go func() {
// for {
// _, msg, err := conn.ReadMessage()
// if err != nil {
// common.SysError("error reading stream response: " + err.Error())
// break
// }
// var response XunfeiChatResponse
// err = json.Unmarshal(msg, &response)
// if err != nil {
// common.SysError("error unmarshalling stream response: " + err.Error())
// break
// }
// dataChan <- response
// if response.Payload.Choices.Status == 2 {
// err := conn.Close()
// if err != nil {
// common.SysError("error closing websocket connection: " + err.Error())
// }
// break
// }
// }
// stopChan <- true
// }()
// return dataChan, stopChan, nil
// }
// func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
// query := c.Request.URL.Query()
// apiVersion := query.Get("api-version")
// if apiVersion == "" {
// apiVersion = c.GetString("api_version")
// }
// if apiVersion == "" {
// apiVersion = "v1.1"
// common.SysLog("api_version not found, use default: " + apiVersion)
// }
// domain := "general"
// if apiVersion != "v1.1" {
// domain += strings.Split(apiVersion, ".")[0]
// }
// authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
// return domain, authUrl
// }

View File

@@ -0,0 +1,61 @@
package xunfei
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []Message `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type ChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type ChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []ChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text openai.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}

View File

@@ -0,0 +1,22 @@
package zhipu
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/openai"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Auth(c *gin.Context) error {
return nil
}
func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
return nil, nil, nil
}

301
relay/channel/zhipu/main.go Normal file
View File

@@ -0,0 +1,301 @@
package zhipu
// import (
// "bufio"
// "encoding/json"
// "github.com/gin-gonic/gin"
// "github.com/golang-jwt/jwt"
// "io"
// "net/http"
// "one-api/common"
// "strings"
// "sync"
// "time"
// )
// // https://open.bigmodel.cn/doc/api#chatglm_std
// // chatglm_std, chatglm_lite
// // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
// type ZhipuMessage struct {
// Role string `json:"role"`
// Content string `json:"content"`
// }
// type ZhipuRequest struct {
// Prompt []ZhipuMessage `json:"prompt"`
// Temperature float64 `json:"temperature,omitempty"`
// TopP float64 `json:"top_p,omitempty"`
// RequestId string `json:"request_id,omitempty"`
// Incremental bool `json:"incremental,omitempty"`
// }
// type ZhipuResponseData struct {
// TaskId string `json:"task_id"`
// RequestId string `json:"request_id"`
// TaskStatus string `json:"task_status"`
// Choices []ZhipuMessage `json:"choices"`
// Usage `json:"usage"`
// }
// type ZhipuResponse struct {
// Code int `json:"code"`
// Msg string `json:"msg"`
// Success bool `json:"success"`
// Data ZhipuResponseData `json:"data"`
// }
// type ZhipuStreamMetaResponse struct {
// RequestId string `json:"request_id"`
// TaskId string `json:"task_id"`
// TaskStatus string `json:"task_status"`
// Usage `json:"usage"`
// }
// type zhipuTokenData struct {
// Token string
// ExpiryTime time.Time
// }
// var zhipuTokens sync.Map
// var expSeconds int64 = 24 * 3600
// func getZhipuToken(apikey string) string {
// data, ok := zhipuTokens.Load(apikey)
// if ok {
// tokenData := data.(zhipuTokenData)
// if time.Now().Before(tokenData.ExpiryTime) {
// return tokenData.Token
// }
// }
// split := strings.Split(apikey, ".")
// if len(split) != 2 {
// common.SysError("invalid zhipu key: " + apikey)
// return ""
// }
// id := split[0]
// secret := split[1]
// expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
// expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
// timestamp := time.Now().UnixNano() / 1e6
// payload := jwt.MapClaims{
// "api_key": id,
// "exp": expMillis,
// "timestamp": timestamp,
// }
// token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
// token.Header["alg"] = "HS256"
// token.Header["sign_type"] = "SIGN"
// tokenString, err := token.SignedString([]byte(secret))
// if err != nil {
// return ""
// }
// zhipuTokens.Store(apikey, zhipuTokenData{
// Token: tokenString,
// ExpiryTime: expiryTime,
// })
// return tokenString
// }
// func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
// messages := make([]ZhipuMessage, 0, len(request.Messages))
// for _, message := range request.Messages {
// if message.Role == "system" {
// messages = append(messages, ZhipuMessage{
// Role: "system",
// Content: message.Content,
// })
// messages = append(messages, ZhipuMessage{
// Role: "user",
// Content: "Okay",
// })
// } else {
// messages = append(messages, ZhipuMessage{
// Role: message.Role,
// Content: message.Content,
// })
// }
// }
// return &ZhipuRequest{
// Prompt: messages,
// Temperature: request.Temperature,
// TopP: request.TopP,
// Incremental: false,
// }
// }
// func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
// fullTextResponse := OpenAITextResponse{
// Id: response.Data.TaskId,
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
// Usage: response.Data.Usage,
// }
// for i, choice := range response.Data.Choices {
// openaiChoice := OpenAITextResponseChoice{
// Index: i,
// Message: Message{
// Role: choice.Role,
// Content: strings.Trim(choice.Content, "\""),
// },
// FinishReason: "",
// }
// if i == len(response.Data.Choices)-1 {
// openaiChoice.FinishReason = "stop"
// }
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
// }
// return &fullTextResponse
// }
// func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
// var choice ChatCompletionsStreamResponseChoice
// choice.Delta.Content = zhipuResponse
// response := ChatCompletionsStreamResponse{
// Object: "chat.completion.chunk",
// Created: common.GetTimestamp(),
// Model: "chatglm",
// Choices: []ChatCompletionsStreamResponseChoice{choice},
// }
// return &response
// }
// func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
// var choice ChatCompletionsStreamResponseChoice
// choice.Delta.Content = ""
// choice.FinishReason = &stopFinishReason
// response := ChatCompletionsStreamResponse{
// Id: zhipuResponse.RequestId,
// Object: "chat.completion.chunk",
// Created: common.GetTimestamp(),
// Model: "chatglm",
// Choices: []ChatCompletionsStreamResponseChoice{choice},
// }
// return &response, &zhipuResponse.Usage
// }
// func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var usage *Usage
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// if atEOF && len(data) == 0 {
// return 0, nil, nil
// }
// if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
// return i + 2, data[0:i], nil
// }
// if atEOF {
// return len(data), data, nil
// }
// return 0, nil, nil
// })
// dataChan := make(chan string)
// metaChan := make(chan string)
// stopChan := make(chan bool)
// go func() {
// for scanner.Scan() {
// data := scanner.Text()
// lines := strings.Split(data, "\n")
// for i, line := range lines {
// if len(line) < 5 {
// continue
// }
// if line[:5] == "data:" {
// dataChan <- line[5:]
// if i != len(lines)-1 {
// dataChan <- "\n"
// }
// } else if line[:5] == "meta:" {
// metaChan <- line[5:]
// }
// }
// }
// stopChan <- true
// }()
// setEventStreamHeaders(c)
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// response := streamResponseZhipu2OpenAI(data)
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// common.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case data := <-metaChan:
// var zhipuResponse ZhipuStreamMetaResponse
// err := json.Unmarshal([]byte(data), &zhipuResponse)
// if err != nil {
// common.SysError("error unmarshalling stream response: " + err.Error())
// return true
// }
// response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// common.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// usage = zhipuUsage
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
// err := resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// return nil, usage
// }
// func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
// var zhipuResponse ZhipuResponse
// responseBody, err := io.ReadAll(resp.Body)
// if err != nil {
// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
// }
// err = resp.Body.Close()
// if err != nil {
// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
// }
// err = json.Unmarshal(responseBody, &zhipuResponse)
// if err != nil {
// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
// }
// if !zhipuResponse.Success {
// return &OpenAIErrorWithStatusCode{
// OpenAIError: OpenAIError{
// Message: zhipuResponse.Msg,
// Type: "zhipu_error",
// Param: "",
// Code: zhipuResponse.Code,
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
// fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
// jsonResponse, err := json.Marshal(fullTextResponse)
// if err != nil {
// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
// }
// c.Writer.Header().Set("Content-Type", "application/json")
// c.Writer.WriteHeader(resp.StatusCode)
// _, err = c.Writer.Write(jsonResponse)
// return nil, &fullTextResponse.Usage
// }

View File

@@ -0,0 +1,46 @@
package zhipu
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"time"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Prompt []Message `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []Message `json:"choices"`
openai.Usage `json:"usage"`
}
type Response struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ResponseData `json:"data"`
}
type StreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
openai.Usage `json:"usage"`
}
type tokenData struct {
Token string
ExpiryTime time.Time
}