merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-08-01 15:12:54 +08:00
18 changed files with 184 additions and 89 deletions

View File

@@ -222,9 +222,11 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
if resp != nil {
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
}

View File

@@ -1,6 +1,7 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -53,7 +53,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)

View File

@@ -83,13 +83,28 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
continue
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
}
}
content.Parts = parts

View File

@@ -3,14 +3,18 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
type OllamaEmbeddingRequest struct {
@@ -21,6 +25,3 @@ type OllamaEmbeddingRequest struct {
type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
}
//type OllamaOptions struct {
//}

View File

@@ -28,14 +28,18 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Stop, _ = request.Stop.([]string)
}
return &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
}

View File

@@ -22,7 +22,7 @@ import (
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
containStreamUsage := false
responseId := ""
var responseId string
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
@@ -86,7 +86,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
if err == nil {
if lastStreamResponse.Usage != nil && service.ValidUsage(lastStreamResponse.Usage) {
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
containStreamUsage = true
usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
@@ -109,14 +115,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -133,14 +134,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
} else {
for _, streamResponse := range streamResponses {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {