mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
Merge remote-tracking branch 'origin/upstream/main'
This commit is contained in:
commit
65022b0e3e
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
@ -28,14 +29,32 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
a.meta = meta
|
a.meta = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WorkerAI cannot be used across accounts with AIGateWay
|
||||||
|
// https://developers.cloudflare.com/ai-gateway/providers/workersai/#openai-compatible-endpoints
|
||||||
|
// https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/workers-ai
|
||||||
|
func (a *Adaptor) isAIGateWay(baseURL string) bool {
|
||||||
|
return strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") && strings.HasSuffix(baseURL, "/workers-ai")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
|
isAIGateWay := a.isAIGateWay(meta.BaseURL)
|
||||||
|
var urlPrefix string
|
||||||
|
if isAIGateWay {
|
||||||
|
urlPrefix = meta.BaseURL
|
||||||
|
} else {
|
||||||
|
urlPrefix = fmt.Sprintf("%s/client/v4/accounts/%s/ai", meta.BaseURL, meta.Config.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil
|
return fmt.Sprintf("%s/v1/chat/completions", urlPrefix), nil
|
||||||
case relaymode.Embeddings:
|
case relaymode.Embeddings:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil
|
return fmt.Sprintf("%s/v1/embeddings", urlPrefix), nil
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil
|
if isAIGateWay {
|
||||||
|
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
|
|
||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
doneRendered := false
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
if len(data) < dataPrefixLength { // ignore blank line or wrong format
|
if len(data) < dataPrefixLength { // ignore blank line or wrong format
|
||||||
@ -41,6 +42,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
}
|
}
|
||||||
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
|
doneRendered = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
@ -81,7 +83,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
logger.SysError("error reading stream: " + err.Error())
|
logger.SysError("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
render.Done(c)
|
if !doneRendered {
|
||||||
|
render.Done(c)
|
||||||
|
}
|
||||||
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user