feat: add tool choice configuration and update steam handling in Gemini

update ToolConfig to use pointer type in ChatRequest
This commit is contained in:
mxdlzg 2024-12-23 12:58:57 +08:00
parent a858292b54
commit ba50a137ea
2 changed files with 50 additions and 9 deletions

View File

@ -34,6 +34,12 @@ var mimeTypeMap = map[string]string{
"text": "text/plain", "text": "text/plain",
} }
var toolChoiceTypeMap = map[string]string{
"none": "NONE",
"auto": "AUTO",
"required": "ANY",
}
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{ geminiRequest := ChatRequest{
@ -92,6 +98,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
}, },
} }
} }
if textRequest.ToolChoice != nil {
geminiRequest.ToolConfig = &ToolConfig{
FunctionCallingConfig: FunctionCallingConfig{
Mode: "auto",
},
}
switch mode := textRequest.ToolChoice.(type) {
case string:
geminiRequest.ToolConfig.FunctionCallingConfig.Mode = toolChoiceTypeMap[mode]
case map[string]interface{}:
geminiRequest.ToolConfig.FunctionCallingConfig.Mode = "ANY"
if fn, ok := mode["function"].(map[string]interface{}); ok {
if name, ok := fn["name"].(string); ok {
geminiRequest.ToolConfig.FunctionCallingConfig.AllowedFunctionNames = []string{name}
}
}
}
}
shouldAddDummyModelMessage := false shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
content := ChatContent{ content := ChatContent{
@ -186,10 +210,16 @@ func (g *ChatResponse) GetResponseText() string {
if g == nil { if g == nil {
return "" return ""
} }
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { var builder strings.Builder
return g.Candidates[0].Content.Parts[0].Text for _, candidate := range g.Candidates {
for idx, part := range candidate.Content.Parts {
if idx > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
} }
return "" return builder.String()
} }
type ChatCandidate struct { type ChatCandidate struct {
@ -252,8 +282,8 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
choice.Message.ToolCalls = getToolCalls(&candidate) choice.Message.ToolCalls = getToolCalls(&candidate)
} else { } else {
var builder strings.Builder var builder strings.Builder
for _, part := range candidate.Content.Parts { for idx, part := range candidate.Content.Parts {
if i > 0 { if idx > 0 {
builder.WriteString("\n") builder.WriteString("\n")
} }
builder.WriteString(part.Text) builder.WriteString(part.Text)

View File

@ -1,10 +1,12 @@
package gemini package gemini
type ChatRequest struct { type ChatRequest struct {
Contents []ChatContent `json:"contents"` Contents []ChatContent `json:"contents"`
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` SystemInstruction *ChatContent `json:"system_instruction,omitempty"`
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
Tools []ChatTools `json:"tools,omitempty"` GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
Tools []ChatTools `json:"tools,omitempty"`
ToolConfig *ToolConfig `json:"tool_config,omitempty"`
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
@ -74,3 +76,12 @@ type ChatGenerationConfig struct {
CandidateCount int `json:"candidateCount,omitempty"` CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"` StopSequences []string `json:"stopSequences,omitempty"`
} }
type FunctionCallingConfig struct {
Mode string `json:"mode,omitempty"`
AllowedFunctionNames []string `json:"allowed_function_names,omitempty"`
}
type ToolConfig struct {
FunctionCallingConfig FunctionCallingConfig `json:"function_calling_config"`
}