mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: first response time support gemini and claude
This commit is contained in:
parent
c834289f2c
commit
f2654692e8
@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -9,8 +9,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
@ -246,7 +248,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
var usage *dto.Usage
|
var usage *dto.Usage
|
||||||
usage = &dto.Usage{}
|
usage = &dto.Usage{}
|
||||||
@ -278,10 +280,15 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
|
isFirst := true
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
// some implementations may add \r at the end of data
|
// some implementations may add \r at the end of data
|
||||||
data = strings.TrimSuffix(data, "\r")
|
data = strings.TrimSuffix(data, "\r")
|
||||||
var claudeResponse ClaudeResponse
|
var claudeResponse ClaudeResponse
|
||||||
@ -302,7 +309,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
if claudeResponse.Type == "message_start" {
|
if claudeResponse.Type == "message_start" {
|
||||||
// message_start, 获取usage
|
// message_start, 获取usage
|
||||||
responseId = claudeResponse.Message.Id
|
responseId = claudeResponse.Message.Id
|
||||||
modelName = claudeResponse.Message.Model
|
info.UpstreamModelName = claudeResponse.Message.Model
|
||||||
usage.PromptTokens = claudeUsage.InputTokens
|
usage.PromptTokens = claudeUsage.InputTokens
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
responseText += claudeResponse.Delta.Text
|
responseText += claudeResponse.Delta.Text
|
||||||
@ -316,7 +323,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
//response.Id = responseId
|
//response.Id = responseId
|
||||||
response.Id = responseId
|
response.Id = responseId
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
response.Model = modelName
|
response.Model = info.UpstreamModelName
|
||||||
|
|
||||||
jsonStr, err := json.Marshal(response)
|
jsonStr, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -335,13 +342,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
}
|
}
|
||||||
if usage.CompletionTokens == 0 {
|
if usage.CompletionTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
@ -20,27 +20,27 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
|
|||||||
|
|
||||||
// 定义一个映射,存储模型名称和对应的版本
|
// 定义一个映射,存储模型名称和对应的版本
|
||||||
var modelVersionMap = map[string]string{
|
var modelVersionMap = map[string]string{
|
||||||
"gemini-1.5-pro-latest": "v1beta",
|
"gemini-1.5-pro-latest": "v1beta",
|
||||||
"gemini-1.5-flash-latest": "v1beta",
|
"gemini-1.5-flash-latest": "v1beta",
|
||||||
"gemini-ultra": "v1beta",
|
"gemini-ultra": "v1beta",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
version, beta := modelVersionMap[info.UpstreamModelName]
|
||||||
if !beta {
|
if !beta {
|
||||||
if info.ApiVersion != "" {
|
if info.ApiVersion != "" {
|
||||||
version = info.ApiVersion
|
version = info.ApiVersion
|
||||||
} else {
|
} else {
|
||||||
version = "v1"
|
version = "v1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = geminiChatStreamHandler(c, resp)
|
err, responseText = geminiChatStreamHandler(c, resp, info)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -160,7 +161,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
@ -190,10 +191,15 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
|
isFirst := true
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
|
if isFirst {
|
||||||
|
isFirst = false
|
||||||
|
info.FirstResponseTime = time.Now()
|
||||||
|
}
|
||||||
// this is used to prevent annoying \ related format bug
|
// this is used to prevent annoying \ related format bug
|
||||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||||
type dummyStruct struct {
|
type dummyStruct struct {
|
||||||
|
@ -38,24 +38,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
// firstResponseTime = time.Now() - 1 second
|
||||||
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||||
|
|
||||||
info := &RelayInfo{
|
info := &RelayInfo{
|
||||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: c.GetString("base_url"),
|
BaseUrl: c.GetString("base_url"),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
ChannelType: channelType,
|
ChannelType: channelType,
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
TokenId: tokenId,
|
TokenId: tokenId,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Group: group,
|
Group: group,
|
||||||
TokenUnlimited: tokenUnlimited,
|
TokenUnlimited: tokenUnlimited,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
ApiType: apiType,
|
FirstResponseTime: startTime.Add(-time.Second),
|
||||||
ApiVersion: c.GetString("api_version"),
|
ApiType: apiType,
|
||||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
ApiVersion: c.GetString("api_version"),
|
||||||
Organization: c.GetString("channel_organization"),
|
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||||
|
Organization: c.GetString("channel_organization"),
|
||||||
}
|
}
|
||||||
if info.BaseUrl == "" {
|
if info.BaseUrl == "" {
|
||||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||||
|
Loading…
Reference in New Issue
Block a user