diff --git a/controller/channel-test.go b/controller/channel-test.go index ef54a19..8f596f9 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/bytedance/gopkg/util/gopool" "io" "math" "net/http" @@ -24,6 +23,7 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" ) @@ -37,9 +37,15 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) + baseUrl := channel.GetBaseURL() + chatPath := "/v1/chat/completions" + if strings.HasPrefix(baseUrl, "https://models.inference.ai.azure.com") { + chatPath = "/chat/completions" + } + common.SysLog(fmt.Sprintf("testing channel %d with model %s path %s", channel.Id, testModel, chatPath)) c.Request = &http.Request{ Method: "POST", - URL: &url.URL{Path: "/v1/chat/completions"}, + URL: &url.URL{Path: chatPath}, Body: nil, Header: make(http.Header), } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 6090b45..566175a 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "mime/multipart" "net/http" @@ -19,6 +18,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -47,6 +48,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return minimax.GetRequestURL(info) case common.ChannelTypeCustom: url := info.BaseUrl + if strings.HasPrefix(url, "https://models.inference.ai.azure.com") { + url = strings.TrimPrefix(url, "/v1") + if info.RelayMode == constant.RelayModeCompletions { + url = fmt.Sprintf("%s/%s", url, "chat/completions") + } else if info.RelayMode == constant.RelayModeEmbeddings { + url = fmt.Sprintf("%s/%s", url, "embeddings") + } + } url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil default: diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 6006bc6..044a2b6 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -46,6 +46,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeChatCompletions } else if strings.HasPrefix(path, "/v1/completions") { relayMode = RelayModeCompletions + } else if strings.HasPrefix(path, "/chat/completions") { + relayMode = RelayModeCompletions } else if strings.HasPrefix(path, "/v1/embeddings") { relayMode = RelayModeEmbeddings } else if strings.HasSuffix(path, "embeddings") {