Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai 2024-07-10 06:25:26 +00:00
commit 65022b0e3e
2 changed files with 27 additions and 4 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
@ -28,14 +29,32 @@ func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta a.meta = meta
} }
// WorkerAI cannot be used across accounts with AIGateWay
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
func (a *Adaptor) isAIGateWay(baseURL string) bool {
return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
isAIGateWay := a.isAIGateWay(meta.BaseURL)
var urlPrefix string
if isAIGateWay {
urlPrefix = meta.BaseURL
} else {
urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
}
switch meta.Mode { switch meta.Mode {
case relaymode.ChatCompletions: case relaymode.ChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
case relaymode.Embeddings: case relaymode.Embeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
default: default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil if isAIGateWay {
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
}
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
} }
} }

View File

@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format if len(data) < dataPrefixLength { // ignore blank line or wrong format
@ -41,6 +42,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
} }
if strings.HasPrefix(data[dataPrefixLength:], done) { if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data) render.StringData(c, data)
doneRendered = true
continue continue
} }
switch relayMode { switch relayMode {
@ -81,7 +83,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
if !doneRendered {
render.Done(c) render.Done(c)
}
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {