feat: support llm chat on replicate

This commit is contained in:
Laisky.Cai
2024-12-19 03:14:32 +00:00
parent 4c86e7c506
commit 502cf3315d
9 changed files with 376 additions and 28 deletions

View File

@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"slices"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -80,8 +81,57 @@ func convertImageRemixRequest(c *gin.Context) (any, error) {
return rawReq.toFluxRemixRequest()
}
// ConvertRequest converts the request to the format that the target API expects.
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
if !request.Stream {
// TODO: support non-stream mode
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
}
// Build the prompt from OpenAI messages
var promptBuilder strings.Builder
for _, message := range request.Messages {
switch msgCnt := message.Content.(type) {
case string:
promptBuilder.WriteString(message.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(msgCnt)
promptBuilder.WriteString("\n")
default:
}
}
replicateRequest := ReplicateChatRequest{
Input: ChatInput{
Prompt: promptBuilder.String(),
MaxTokens: request.MaxTokens,
Temperature: 1.0,
TopP: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
},
}
// Map optional fields
if request.Temperature != nil {
replicateRequest.Input.Temperature = *request.Temperature
}
if request.TopP != nil {
replicateRequest.Input.TopP = *request.TopP
}
if request.PresencePenalty != nil {
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
}
if request.FrequencyPenalty != nil {
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
}
if request.MaxTokens > 0 {
replicateRequest.Input.MaxTokens = request.MaxTokens
} else if request.MaxTokens == 0 {
replicateRequest.Input.MaxTokens = 500
}
return replicateRequest, nil
}
func (a *Adaptor) Init(meta *meta.Meta) {
@@ -103,7 +153,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
logger.Info(c, "send image request to replicate")
logger.Info(c, "send request to replicate")
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
@@ -112,6 +162,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
case relaymode.ImagesGenerations,
relaymode.ImagesEdits:
err, usage = ImageHandler(c, resp)
case relaymode.ChatCompletions:
err, usage = ChatHandler(c, resp)
default:
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
}