Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2024-10-29 01:15:54 +00:00
19 changed files with 112 additions and 37 deletions

View File

@@ -6,4 +6,5 @@ var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
}

View File

@@ -4,18 +4,19 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -27,6 +28,11 @@ const (
VisionMaxImageNum = 16
)
var mimeTypeMap = map[string]string{
"json_object": "application/json",
"text": "text/plain",
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
@@ -55,6 +61,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.ResponseFormat != nil {
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
}
if textRequest.ResponseFormat.JsonSchema != nil {
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
}
}
if textRequest.Tools != nil {
functions := make([]model.Function, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {

View File

@@ -65,10 +65,12 @@ type ChatTools struct {
}
type ChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}

View File

@@ -4,14 +4,21 @@ package groq
var ModelList = []string{
"gemma-7b-it",
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma2-9b-it",
"llama-3.1-405b-reasoning",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"llama-3.2-11b-text-preview",
"llama-3.2-11b-vision-preview",
"llama-3.2-1b-preview",
"llama-3.2-3b-preview",
"llama-3.2-90b-text-preview",
"llama-guard-3-8b",
"llama3-70b-8192",
"llama3-8b-8192",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"llava-v1.5-7b-4096-preview",
"mixtral-8x7b-32768",
"distil-whisper-large-v3-en",
"whisper-large-v3",
}

View File

@@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
if request.Stream {
// always return usage in stream mode
if request.StreamOptions == nil {
request.StreamOptions = &model.StreamOptions{}
}
request.StreamOptions.IncludeUsage = true
}
return request, nil
}

View File

@@ -7,5 +7,6 @@ var ModelList = []string{
"SparkDesk-v3.1",
"SparkDesk-v3.1-128K",
"SparkDesk-v3.5",
"SparkDesk-v3.5-32K",
"SparkDesk-v4.0",
}

View File

@@ -292,6 +292,8 @@ func apiVersion2domain(apiVersion string) string {
return "pro-128k"
case "v3.5":
return "generalv3.5"
case "v3.5-32K":
return "max-32k"
case "v4.0":
return "4.0Ultra"
}
@@ -303,7 +305,10 @@ func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (strin
domain := apiVersion2domain(apiVersion)
switch apiVersion {
case "v3.1-128K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/pro-128k", apiVersion), apiKey, apiSecret)
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
break
case "v3.5-32K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
break
default:
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)