mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-12-28 02:35:56 +08:00
Compare commits
27 Commits
main
...
5f6b515bb3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f6b515bb3 | ||
|
|
638a4fb77d | ||
|
|
4e2430e5d3 | ||
|
|
b2d6aa783b | ||
|
|
761ee32d19 | ||
|
|
c426b64b3d | ||
|
|
eaef9629a4 | ||
|
|
d236477531 | ||
|
|
34c7523f01 | ||
|
|
c893672635 | ||
|
|
bbfaf1fb95 | ||
|
|
adcf4712e6 | ||
|
|
969fdca9ef | ||
|
|
6708eed8a0 | ||
|
|
ad63c9e66f | ||
|
|
76e8199026 | ||
|
|
413fcde382 | ||
|
|
6e634b85cf | ||
|
|
a0d7d5a965 | ||
|
|
de10e102bd | ||
|
|
c61d6440f9 | ||
|
|
3a8924d7af | ||
|
|
95527d76ef | ||
|
|
7ec33793b7 | ||
|
|
1a6812182b | ||
|
|
5ba60433d7 | ||
|
|
480f248a3d |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -12,4 +12,5 @@ cmd.md
|
|||||||
.env
|
.env
|
||||||
/one-api
|
/one-api
|
||||||
temp
|
temp
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
/__debug_bin*
|
||||||
|
|||||||
@@ -44,4 +44,4 @@ COPY --from=builder2 /build/one-api /
|
|||||||
|
|
||||||
EXPOSE 3000
|
EXPOSE 3000
|
||||||
WORKDIR /data
|
WORKDIR /data
|
||||||
ENTRYPOINT ["/one-api"]
|
ENTRYPOINT ["/one-api"]
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ graph LR
|
|||||||
+ 例子:`NODE_TYPE=slave`
|
+ 例子:`NODE_TYPE=slave`
|
||||||
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
||||||
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
||||||
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
||||||
+例子:`CHANNEL_TEST_FREQUENCY=1440`
|
+例子:`CHANNEL_TEST_FREQUENCY=1440`
|
||||||
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
||||||
+ 例子:`POLLING_INTERVAL=5`
|
+ 例子:`POLLING_INTERVAL=5`
|
||||||
|
|||||||
@@ -164,3 +164,6 @@ var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
|
|||||||
|
|
||||||
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
|
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
|
||||||
var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.")
|
var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.")
|
||||||
|
|
||||||
|
// OpenrouterProviderSort is used to determine the order of the providers in the openrouter
|
||||||
|
var OpenrouterProviderSort = env.String("OPENROUTER_PROVIDER_SORT", "")
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package conv
|
package conv
|
||||||
|
|
||||||
func AsString(v any) string {
|
func AsString(v any) string {
|
||||||
str, _ := v.(string)
|
if str, ok := v.(string); ok {
|
||||||
return str
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
62
common/helper/audio.go
Normal file
62
common/helper/audio.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||||
|
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||||
|
if data == nil {
|
||||||
|
return "", errors.New("data is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.CreateTemp("", "*-"+filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(f, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudioTokens returns the number of tokens in an audio file.
|
||||||
|
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) {
|
||||||
|
filename, err := SaveTmpFile("audio", audio)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to save audio to temporary file")
|
||||||
|
}
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
duration, err := GetAudioDuration(ctx, filename)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
return duration * tokensPerSecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
|
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||||
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
|
c := exec.CommandContext(ctx, "/usr/bin/ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actually gpt-4-audio calculates tokens with 0.1s precision,
|
||||||
|
// while whisper calculates tokens with 1s precision
|
||||||
|
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||||
|
}
|
||||||
68
common/helper/audio_test.go
Normal file
68
common/helper/audio_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAudioDuration(t *testing.T) {
|
||||||
|
// skip if there is no ffmpeg installed
|
||||||
|
_, err := exec.LookPath("ffmpeg")
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("ffmpeg not installed, skipping test")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should return correct duration for a valid audio file", func(t *testing.T) {
|
||||||
|
tmpFile, err := os.CreateTemp("", "test_audio*.mp3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(tmpFile.Name())
|
||||||
|
|
||||||
|
// download test audio file
|
||||||
|
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(tmpFile, resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, tmpFile.Close())
|
||||||
|
|
||||||
|
duration, err := GetAudioDuration(context.Background(), tmpFile.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, duration, 3.904)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return an error for a non-existent file", func(t *testing.T) {
|
||||||
|
_, err := GetAudioDuration(context.Background(), "non_existent_file.mp3")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAudioTokens(t *testing.T) {
|
||||||
|
// skip if there is no ffmpeg installed
|
||||||
|
_, err := exec.LookPath("ffmpeg")
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("ffmpeg not installed, skipping test")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("should return correct tokens for a valid audio file", func(t *testing.T) {
|
||||||
|
// download test audio file
|
||||||
|
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
tokens, err := GetAudioTokens(context.Background(), resp.Body, 50)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tokens, 195.2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return an error for a non-existent file", func(t *testing.T) {
|
||||||
|
_, err := GetAudioTokens(context.Background(), nil, 1)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,6 +32,14 @@ func OpenBrowser(url string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RespondError sends a JSON response with a success status and an error message.
|
||||||
|
func RespondError(c *gin.Context, err error) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func GetIp() (ip string) {
|
func GetIp() (ip string) {
|
||||||
ips, err := net.InterfaceAddrs()
|
ips, err := net.InterfaceAddrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetTimestamp get current timestamp in seconds
|
||||||
func GetTimestamp() int64 {
|
func GetTimestamp() int64 {
|
||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,6 +106,8 @@ func testChannel(ctx context.Context, channel *model.Channel, request *relaymode
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err, nil
|
return "", err, nil
|
||||||
}
|
}
|
||||||
|
c.Set(ctxkey.ConvertedRequest, convertedRequest)
|
||||||
|
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err, nil
|
return "", err, nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/utils"
|
"github.com/songquanpeng/one-api/common/utils"
|
||||||
)
|
)
|
||||||
@@ -42,7 +43,7 @@ func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority b
|
|||||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "get random satisfied channel")
|
||||||
}
|
}
|
||||||
channel := Channel{}
|
channel := Channel{}
|
||||||
channel.Id = ability.ChannelId
|
channel.Id = ability.ChannelId
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return aiProxyLibraryRequest, nil
|
return aiProxyLibraryRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
|
|
||||||
// https://x.com/alexalbert__/status/1812921642143900036
|
// https://x.com/alexalbert__/status/1812921642143900036
|
||||||
// claude-3-5-sonnet can support 8k context
|
// claude-3-5-sonnet can support 8k context
|
||||||
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
|
if strings.HasPrefix(meta.ActualModelName, "claude-3-7-sonnet") {
|
||||||
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
|
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -47,10 +47,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(c, *request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ package anthropic
|
|||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
"claude-3-5-haiku-20241022",
|
|
||||||
"claude-3-5-haiku-latest",
|
"claude-3-5-haiku-latest",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
|
"claude-3-5-sonnet-latest",
|
||||||
"claude-3-5-sonnet-20240620",
|
"claude-3-5-sonnet-20240620",
|
||||||
"claude-3-5-sonnet-20241022",
|
"claude-3-5-sonnet-20241022",
|
||||||
"claude-3-5-sonnet-latest",
|
"claude-3-7-sonnet-latest",
|
||||||
|
"claude-3-7-sonnet-20250219",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,18 +2,21 @@ package anthropic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/image"
|
"github.com/songquanpeng/one-api/common/image"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@@ -36,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
// isModelSupportThinking is used to check if the model supports extended thinking
|
||||||
|
func isModelSupportThinking(model string) bool {
|
||||||
|
if strings.Contains(model, "claude-3-7-sonnet") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
|
||||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||||
|
|
||||||
for _, tool := range textRequest.Tools {
|
for _, tool := range textRequest.Tools {
|
||||||
@@ -61,7 +73,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
TopK: textRequest.TopK,
|
TopK: textRequest.TopK,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
Tools: claudeTools,
|
Tools: claudeTools,
|
||||||
|
Thinking: textRequest.Thinking,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isModelSupportThinking(textRequest.Model) &&
|
||||||
|
c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
|
||||||
|
claudeRequest.Thinking = &model.Thinking{
|
||||||
|
Type: "enabled",
|
||||||
|
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isModelSupportThinking(textRequest.Model) &&
|
||||||
|
claudeRequest.Thinking != nil {
|
||||||
|
if claudeRequest.MaxTokens <= 1024 {
|
||||||
|
return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking")
|
||||||
|
}
|
||||||
|
|
||||||
|
// top_p must be nil when using extended thinking
|
||||||
|
claudeRequest.TopP = nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(claudeTools) > 0 {
|
if len(claudeTools) > 0 {
|
||||||
claudeToolChoice := struct {
|
claudeToolChoice := struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -127,7 +159,9 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
var content Content
|
var content Content
|
||||||
if part.Type == model.ContentTypeText {
|
if part.Type == model.ContentTypeText {
|
||||||
content.Type = "text"
|
content.Type = "text"
|
||||||
content.Text = part.Text
|
if part.Text != nil {
|
||||||
|
content.Text = *part.Text
|
||||||
|
}
|
||||||
} else if part.Type == model.ContentTypeImageURL {
|
} else if part.Type == model.ContentTypeImageURL {
|
||||||
content.Type = "image"
|
content.Type = "image"
|
||||||
content.Source = &ImageSource{
|
content.Source = &ImageSource{
|
||||||
@@ -142,13 +176,14 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
claudeMessage.Content = contents
|
claudeMessage.Content = contents
|
||||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||||
}
|
}
|
||||||
return &claudeRequest
|
return &claudeRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||||
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||||
var response *Response
|
var response *Response
|
||||||
var responseText string
|
var responseText string
|
||||||
|
var reasoningText string
|
||||||
var stopReason string
|
var stopReason string
|
||||||
tools := make([]model.Tool, 0)
|
tools := make([]model.Tool, 0)
|
||||||
|
|
||||||
@@ -158,6 +193,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
case "content_block_start":
|
case "content_block_start":
|
||||||
if claudeResponse.ContentBlock != nil {
|
if claudeResponse.ContentBlock != nil {
|
||||||
responseText = claudeResponse.ContentBlock.Text
|
responseText = claudeResponse.ContentBlock.Text
|
||||||
|
if claudeResponse.ContentBlock.Thinking != nil {
|
||||||
|
reasoningText = *claudeResponse.ContentBlock.Thinking
|
||||||
|
}
|
||||||
|
|
||||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||||
tools = append(tools, model.Tool{
|
tools = append(tools, model.Tool{
|
||||||
Id: claudeResponse.ContentBlock.Id,
|
Id: claudeResponse.ContentBlock.Id,
|
||||||
@@ -172,6 +211,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
if claudeResponse.Delta != nil {
|
if claudeResponse.Delta != nil {
|
||||||
responseText = claudeResponse.Delta.Text
|
responseText = claudeResponse.Delta.Text
|
||||||
|
if claudeResponse.Delta.Thinking != nil {
|
||||||
|
reasoningText = *claudeResponse.Delta.Thinking
|
||||||
|
}
|
||||||
|
|
||||||
if claudeResponse.Delta.Type == "input_json_delta" {
|
if claudeResponse.Delta.Type == "input_json_delta" {
|
||||||
tools = append(tools, model.Tool{
|
tools = append(tools, model.Tool{
|
||||||
Function: model.Function{
|
Function: model.Function{
|
||||||
@@ -189,9 +232,20 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||||
stopReason = *claudeResponse.Delta.StopReason
|
stopReason = *claudeResponse.Delta.StopReason
|
||||||
}
|
}
|
||||||
|
case "thinking_delta":
|
||||||
|
if claudeResponse.Delta != nil && claudeResponse.Delta.Thinking != nil {
|
||||||
|
reasoningText = *claudeResponse.Delta.Thinking
|
||||||
|
}
|
||||||
|
case "ping",
|
||||||
|
"message_stop",
|
||||||
|
"content_block_stop":
|
||||||
|
default:
|
||||||
|
logger.SysErrorf("unknown stream response type %q", claudeResponse.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
var choice openai.ChatCompletionsStreamResponseChoice
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = responseText
|
choice.Delta.Content = responseText
|
||||||
|
choice.Delta.Reasoning = &reasoningText
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||||
choice.Delta.ToolCalls = tools
|
choice.Delta.ToolCalls = tools
|
||||||
@@ -209,11 +263,23 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
|
|
||||||
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||||
var responseText string
|
var responseText string
|
||||||
if len(claudeResponse.Content) > 0 {
|
var reasoningText string
|
||||||
responseText = claudeResponse.Content[0].Text
|
|
||||||
}
|
|
||||||
tools := make([]model.Tool, 0)
|
tools := make([]model.Tool, 0)
|
||||||
for _, v := range claudeResponse.Content {
|
for _, v := range claudeResponse.Content {
|
||||||
|
switch v.Type {
|
||||||
|
case "thinking":
|
||||||
|
if v.Thinking != nil {
|
||||||
|
reasoningText += *v.Thinking
|
||||||
|
} else {
|
||||||
|
logger.Errorf(context.Background(), "thinking is nil in response")
|
||||||
|
}
|
||||||
|
case "text":
|
||||||
|
responseText += v.Text
|
||||||
|
default:
|
||||||
|
logger.Warnf(context.Background(), "unknown response type %q", v.Type)
|
||||||
|
}
|
||||||
|
|
||||||
if v.Type == "tool_use" {
|
if v.Type == "tool_use" {
|
||||||
args, _ := json.Marshal(v.Input)
|
args, _ := json.Marshal(v.Input)
|
||||||
tools = append(tools, model.Tool{
|
tools = append(tools, model.Tool{
|
||||||
@@ -226,11 +292,13 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := openai.TextResponseChoice{
|
choice := openai.TextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: model.Message{
|
Message: model.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: responseText,
|
Content: responseText,
|
||||||
|
Reasoning: &reasoningText,
|
||||||
Name: nil,
|
Name: nil,
|
||||||
ToolCalls: tools,
|
ToolCalls: tools,
|
||||||
},
|
},
|
||||||
@@ -277,6 +345,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
data = strings.TrimPrefix(data, "data:")
|
data = strings.TrimPrefix(data, "data:")
|
||||||
data = strings.TrimSpace(data)
|
data = strings.TrimSpace(data)
|
||||||
|
|
||||||
|
logger.Debugf(c.Request.Context(), "stream <- %q\n", data)
|
||||||
|
|
||||||
var claudeResponse StreamResponse
|
var claudeResponse StreamResponse
|
||||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -344,6 +414,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debugf(c.Request.Context(), "response <- %s\n", string(responseBody))
|
||||||
|
|
||||||
var claudeResponse Response
|
var claudeResponse Response
|
||||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package anthropic
|
package anthropic
|
||||||
|
|
||||||
|
import "github.com/songquanpeng/one-api/relay/model"
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/messages_post
|
// https://docs.anthropic.com/claude/reference/messages_post
|
||||||
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
@@ -22,6 +24,9 @@ type Content struct {
|
|||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content,omitempty"`
|
||||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||||
|
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking
|
||||||
|
Thinking *string `json:"thinking,omitempty"`
|
||||||
|
Signature *string `json:"signature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@@ -54,6 +59,7 @@ type Request struct {
|
|||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
//Metadata `json:"metadata,omitempty"`
|
//Metadata `json:"metadata,omitempty"`
|
||||||
|
Thinking *model.Thinking `json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
@@ -84,6 +90,8 @@ type Delta struct {
|
|||||||
PartialJson string `json:"partial_json,omitempty"`
|
PartialJson string `json:"partial_json,omitempty"`
|
||||||
StopReason *string `json:"stop_reason"`
|
StopReason *string `json:"stop_reason"`
|
||||||
StopSequence *string `json:"stop_sequence"`
|
StopSequence *string `json:"stop_sequence"`
|
||||||
|
Thinking *string `json:"thinking,omitempty"`
|
||||||
|
Signature *string `json:"signature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamResponse struct {
|
type StreamResponse struct {
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeReq := anthropic.ConvertRequest(*request)
|
claudeReq, err := anthropic.ConvertRequest(c, *request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request")
|
||||||
|
}
|
||||||
|
|
||||||
c.Set(ctxkey.RequestModel, request.Model)
|
c.Set(ctxkey.RequestModel, request.Model)
|
||||||
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
||||||
return claudeReq, nil
|
return claudeReq, nil
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ var AwsModelIDMap = map[string]string{
|
|||||||
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
"claude-3-7-sonnet-latest": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsModelID(requestModel string) (string, error) {
|
func awsModelID(requestModel string) (string, error) {
|
||||||
@@ -47,13 +49,14 @@ func awsModelID(requestModel string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
|
||||||
awsReq := &bedrockruntime.InvokeModelInput{
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
ModelId: aws.String(awsModelId),
|
ModelId: aws.String(awsModelID),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
@@ -101,13 +104,14 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
|
|||||||
|
|
||||||
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
createdTime := helper.GetTimestamp()
|
createdTime := helper.GetTimestamp()
|
||||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
|
||||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||||
ModelId: aws.String(awsModelId),
|
ModelId: aws.String(awsModelID),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
import (
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
// Request is the request to AWS Claude
|
// Request is the request to AWS Claude
|
||||||
//
|
//
|
||||||
@@ -17,4 +20,5 @@ type Request struct {
|
|||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
Thinking *model.Thinking `json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
|
||||||
awsReq := &bedrockruntime.InvokeModelInput{
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
ModelId: aws.String(awsModelId),
|
ModelId: aws.String(awsModelID),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
@@ -140,13 +141,14 @@ func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
|
|||||||
|
|
||||||
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
createdTime := helper.GetTimestamp()
|
createdTime := helper.GetTimestamp()
|
||||||
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
|
||||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||||
ModelId: aws.String(awsModelId),
|
ModelId: aws.String(awsModelID),
|
||||||
Accept: aws.String("application/json"),
|
Accept: aws.String("application/json"),
|
||||||
ContentType: aws.String("application/json"),
|
ContentType: aws.String("application/json"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
75
relay/adaptor/aws/utils/consts.go
Normal file
75
relay/adaptor/aws/utils/consts.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CrossRegionInferences is a list of model IDs that support cross-region inference.
|
||||||
|
//
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||||
|
//
|
||||||
|
// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)})
|
||||||
|
var CrossRegionInferences = []string{
|
||||||
|
"us.amazon.nova-lite-v1:0",
|
||||||
|
"us.amazon.nova-micro-v1:0",
|
||||||
|
"us.amazon.nova-pro-v1:0",
|
||||||
|
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
"us.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"us.anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"us.meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-1-70b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-1-8b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-11b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-2-90b-instruct-v1:0",
|
||||||
|
"us.meta.llama3-3-70b-instruct-v1:0",
|
||||||
|
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"eu.amazon.nova-lite-v1:0",
|
||||||
|
"eu.amazon.nova-micro-v1:0",
|
||||||
|
"eu.amazon.nova-pro-v1:0",
|
||||||
|
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"eu.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"eu.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"eu.meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
"eu.meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
"apac.amazon.nova-lite-v1:0",
|
||||||
|
"apac.amazon.nova-micro-v1:0",
|
||||||
|
"apac.amazon.nova-pro-v1:0",
|
||||||
|
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"apac.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"apac.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID.
|
||||||
|
func ConvertModelID2CrossRegionProfile(model, region string) string {
|
||||||
|
var regionPrefix string
|
||||||
|
switch prefix := strings.Split(region, "-")[0]; prefix {
|
||||||
|
case "us", "eu":
|
||||||
|
regionPrefix = prefix
|
||||||
|
case "ap":
|
||||||
|
regionPrefix = "apac"
|
||||||
|
default:
|
||||||
|
// not supported, return original model
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
|
newModelID := regionPrefix + "." + model
|
||||||
|
if slices.Contains(CrossRegionInferences, newModelID) {
|
||||||
|
logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID)
|
||||||
|
return newModelID
|
||||||
|
}
|
||||||
|
|
||||||
|
// not found, return original model
|
||||||
|
return model
|
||||||
|
}
|
||||||
@@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,9 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||||
p, _ := textRequest.Prompt.(string)
|
|
||||||
return &Request{
|
return &Request{
|
||||||
Prompt: p,
|
Prompt: textRequest.Prompt,
|
||||||
MaxTokens: textRequest.MaxTokens,
|
MaxTokens: textRequest.MaxTokens,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
type Adaptor struct{}
|
type Adaptor struct{}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return convertedRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
defaultVersion := config.GeminiVersion
|
defaultVersion := config.GeminiVersion
|
||||||
if strings.Contains(meta.ActualModelName, "gemini-2.0") ||
|
if strings.Contains(meta.ActualModelName, "gemini-2") ||
|
||||||
strings.Contains(meta.ActualModelName, "gemini-1.5") {
|
strings.Contains(meta.ActualModelName, "gemini-1.5") {
|
||||||
defaultVersion = "v1beta"
|
defaultVersion = "v1beta"
|
||||||
}
|
}
|
||||||
@@ -66,7 +66,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ var ModelsSupportSystemInstruction = []string{
|
|||||||
// "gemini-1.5-pro-experimental",
|
// "gemini-1.5-pro-experimental",
|
||||||
"gemini-2.0-flash", "gemini-2.0-flash-exp",
|
"gemini-2.0-flash", "gemini-2.0-flash-exp",
|
||||||
"gemini-2.0-flash-thinking-exp-01-21",
|
"gemini-2.0-flash-thinking-exp-01-21",
|
||||||
|
"gemini-2.0-flash-lite",
|
||||||
|
// "gemini-2.0-flash-exp-image-generation",
|
||||||
|
"gemini-2.0-pro-exp-02-05",
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsModelSupportSystemInstruction check if the model support system instruction.
|
// IsModelSupportSystemInstruction check if the model support system instruction.
|
||||||
|
|||||||
@@ -8,19 +8,18 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/image"
|
"github.com/songquanpeng/one-api/common/image"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/geminiv2"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||||
@@ -61,9 +60,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
GenerationConfig: ChatGenerationConfig{
|
GenerationConfig: ChatGenerationConfig{
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
MaxOutputTokens: textRequest.MaxTokens,
|
MaxOutputTokens: textRequest.MaxTokens,
|
||||||
|
ResponseModalities: geminiv2.GetModelModalities(textRequest.Model),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if textRequest.ResponseFormat != nil {
|
if textRequest.ResponseFormat != nil {
|
||||||
@@ -106,9 +106,9 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
var parts []Part
|
var parts []Part
|
||||||
imageNum := 0
|
imageNum := 0
|
||||||
for _, part := range openaiContent {
|
for _, part := range openaiContent {
|
||||||
if part.Type == model.ContentTypeText {
|
if part.Type == model.ContentTypeText && part.Text != nil && *part.Text != "" {
|
||||||
parts = append(parts, Part{
|
parts = append(parts, Part{
|
||||||
Text: part.Text,
|
Text: *part.Text,
|
||||||
})
|
})
|
||||||
} else if part.Type == model.ContentTypeImageURL {
|
} else if part.Type == model.ContentTypeImageURL {
|
||||||
imageNum += 1
|
imageNum += 1
|
||||||
@@ -258,19 +258,52 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||||
} else {
|
} else {
|
||||||
|
// Handle text and image content
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
|
var contentItems []model.MessageContent
|
||||||
|
|
||||||
for _, part := range candidate.Content.Parts {
|
for _, part := range candidate.Content.Parts {
|
||||||
if i > 0 {
|
if part.Text != "" {
|
||||||
builder.WriteString("\n")
|
// For text parts
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(part.Text)
|
||||||
|
|
||||||
|
// Add to content items
|
||||||
|
contentItems = append(contentItems, model.MessageContent{
|
||||||
|
Type: model.ContentTypeText,
|
||||||
|
Text: &part.Text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if part.InlineData != nil && part.InlineData.MimeType != "" && part.InlineData.Data != "" {
|
||||||
|
// For inline image data
|
||||||
|
imageURL := &model.ImageURL{
|
||||||
|
// The data is already base64 encoded
|
||||||
|
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||||||
|
}
|
||||||
|
|
||||||
|
contentItems = append(contentItems, model.MessageContent{
|
||||||
|
Type: model.ContentTypeImageURL,
|
||||||
|
ImageURL: imageURL,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
builder.WriteString(part.Text)
|
|
||||||
}
|
}
|
||||||
choice.Message.Content = builder.String()
|
|
||||||
|
// If we have multiple content types, use structured content format
|
||||||
|
if len(contentItems) > 1 || (len(contentItems) == 1 && contentItems[0].Type != model.ContentTypeText) {
|
||||||
|
choice.Message.Content = contentItems
|
||||||
|
} else {
|
||||||
|
// Otherwise use the simple string content format
|
||||||
|
choice.Message.Content = builder.String()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
choice.Message.Content = ""
|
choice.Message.Content = ""
|
||||||
choice.FinishReason = candidate.FinishReason
|
choice.FinishReason = candidate.FinishReason
|
||||||
}
|
}
|
||||||
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
@@ -278,14 +311,78 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
|
|
||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||||
var choice openai.ChatCompletionsStreamResponseChoice
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
choice.Delta.Role = "assistant"
|
||||||
//choice.FinishReason = &constant.StopFinishReason
|
|
||||||
|
// Check if we have any candidates
|
||||||
|
if len(geminiResponse.Candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the first candidate
|
||||||
|
candidate := geminiResponse.Candidates[0]
|
||||||
|
|
||||||
|
// Check if there are parts in the content
|
||||||
|
if len(candidate.Content.Parts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different content types in the parts
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
// Handle text content
|
||||||
|
if part.Text != "" {
|
||||||
|
// Store as string for simple text responses
|
||||||
|
textContent := part.Text
|
||||||
|
choice.Delta.Content = textContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle image content
|
||||||
|
if part.InlineData != nil && part.InlineData.MimeType != "" && part.InlineData.Data != "" {
|
||||||
|
// Create a structured response for image content
|
||||||
|
imageUrl := fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data)
|
||||||
|
|
||||||
|
// If we already have text content, create a mixed content response
|
||||||
|
if strContent, ok := choice.Delta.Content.(string); ok && strContent != "" {
|
||||||
|
// Convert the existing text content and add the image
|
||||||
|
messageContents := []model.MessageContent{
|
||||||
|
{
|
||||||
|
Type: model.ContentTypeText,
|
||||||
|
Text: &strContent,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: model.ContentTypeImageURL,
|
||||||
|
ImageURL: &model.ImageURL{
|
||||||
|
Url: imageUrl,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
choice.Delta.Content = messageContents
|
||||||
|
} else {
|
||||||
|
// Only have image content
|
||||||
|
choice.Delta.Content = []model.MessageContent{
|
||||||
|
{
|
||||||
|
Type: model.ContentTypeImageURL,
|
||||||
|
ImageURL: &model.ImageURL{
|
||||||
|
Url: imageUrl,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle function calls (if present)
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
choice.Delta.ToolCalls = getToolCalls(&candidate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response
|
||||||
var response openai.ChatCompletionsStreamResponse
|
var response openai.ChatCompletionsStreamResponse
|
||||||
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||||
response.Created = helper.GetTimestamp()
|
response.Created = helper.GetTimestamp()
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Model = "gemini"
|
response.Model = "gemini"
|
||||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||||
|
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,17 +408,23 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
buffer := make([]byte, 1024*1024) // 1MB buffer
|
||||||
|
scanner.Buffer(buffer, len(buffer))
|
||||||
|
|
||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
data = strings.TrimSpace(data)
|
data = strings.TrimSpace(data)
|
||||||
|
|
||||||
if !strings.HasPrefix(data, "data: ") {
|
if !strings.HasPrefix(data, "data: ") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data = strings.TrimPrefix(data, "data: ")
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
data = strings.TrimSuffix(data, "\"")
|
data = strings.TrimSuffix(data, "\"")
|
||||||
|
|
||||||
|
fmt.Printf(">> gemini response: %s\n", data)
|
||||||
|
|
||||||
var geminiResponse ChatResponse
|
var geminiResponse ChatResponse
|
||||||
err := json.Unmarshal([]byte(data), &geminiResponse)
|
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -361,6 +464,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
@@ -6,6 +6,19 @@ type ChatRequest struct {
|
|||||||
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
|
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
|
||||||
Tools []ChatTools `json:"tools,omitempty"`
|
Tools []ChatTools `json:"tools,omitempty"`
|
||||||
SystemInstruction *ChatContent `json:"system_instruction,omitempty"`
|
SystemInstruction *ChatContent `json:"system_instruction,omitempty"`
|
||||||
|
ModelVersion string `json:"model_version,omitempty"`
|
||||||
|
UsageMetadata *UsageMetadata `json:"usage_metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||||
|
PromptTokensDetails []PromptTokensDetails `json:"promptTokensDetails,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PromptTokensDetails struct {
|
||||||
|
Modality string `json:"modality,omitempty"`
|
||||||
|
TokenCount int `json:"tokenCount,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
@@ -66,12 +79,13 @@ type ChatTools struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatGenerationConfig struct {
|
type ChatGenerationConfig struct {
|
||||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopP *float64 `json:"topP,omitempty"`
|
TopP *float64 `json:"topP,omitempty"`
|
||||||
TopK float64 `json:"topK,omitempty"`
|
TopK float64 `json:"topK,omitempty"`
|
||||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
StopSequences []string `json:"stopSequences,omitempty"`
|
StopSequences []string `json:"stopSequences,omitempty"`
|
||||||
|
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,39 @@
|
|||||||
package geminiv2
|
package geminiv2
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
// https://ai.google.dev/models/gemini
|
// https://ai.google.dev/models/gemini
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-pro", "gemini-1.0-pro",
|
"gemini-pro", "gemini-1.0-pro",
|
||||||
// "gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it",
|
"gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it",
|
||||||
|
"gemma-3-27b-it",
|
||||||
"gemini-1.5-flash", "gemini-1.5-flash-8b",
|
"gemini-1.5-flash", "gemini-1.5-flash-8b",
|
||||||
"gemini-1.5-pro", "gemini-1.5-pro-experimental",
|
"gemini-1.5-pro", "gemini-1.5-pro-experimental",
|
||||||
"text-embedding-004", "aqa",
|
"text-embedding-004", "aqa",
|
||||||
"gemini-2.0-flash", "gemini-2.0-flash-exp",
|
"gemini-2.0-flash", "gemini-2.0-flash-exp",
|
||||||
"gemini-2.0-flash-lite-preview-02-05",
|
"gemini-2.0-flash-lite-preview-02-05",
|
||||||
"gemini-2.0-flash-thinking-exp-01-21",
|
"gemini-2.0-flash-thinking-exp-01-21",
|
||||||
|
"gemini-2.0-flash-exp-image-generation",
|
||||||
"gemini-2.0-pro-exp-02-05",
|
"gemini-2.0-pro-exp-02-05",
|
||||||
|
"gemini-2.5-pro-exp-03-25",
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModalityText = "TEXT"
|
||||||
|
ModalityImage = "IMAGE"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetModelModalities returns the modalities of the model.
|
||||||
|
func GetModelModalities(model string) []string {
|
||||||
|
if strings.Contains(model, "-image-generation") {
|
||||||
|
return []string{ModalityText, ModalityImage}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Until 2025-03-26, the following models do not accept the responseModalities field
|
||||||
|
if model == "gemini-2.5-pro-exp-03-25" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{ModalityText}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +1,32 @@
|
|||||||
package groq
|
package groq
|
||||||
|
|
||||||
|
// ModelList is a list of models that can be used with Groq.
|
||||||
|
//
|
||||||
// https://console.groq.com/docs/models
|
// https://console.groq.com/docs/models
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
|
// Regular Models
|
||||||
|
"distil-whisper-large-v3-en",
|
||||||
"gemma2-9b-it",
|
"gemma2-9b-it",
|
||||||
"llama-3.1-70b-versatile",
|
"llama-3.3-70b-versatile",
|
||||||
"llama-3.1-8b-instant",
|
"llama-3.1-8b-instant",
|
||||||
"llama-3.2-11b-text-preview",
|
|
||||||
"llama-3.2-11b-vision-preview",
|
|
||||||
"llama-3.2-1b-preview",
|
|
||||||
"llama-3.2-3b-preview",
|
|
||||||
"llama-3.2-90b-text-preview",
|
|
||||||
"llama-3.2-90b-vision-preview",
|
|
||||||
"llama-guard-3-8b",
|
"llama-guard-3-8b",
|
||||||
"llama3-70b-8192",
|
"llama3-70b-8192",
|
||||||
"llama3-8b-8192",
|
"llama3-8b-8192",
|
||||||
"llama3-groq-70b-8192-tool-use-preview",
|
|
||||||
"llama3-groq-8b-8192-tool-use-preview",
|
|
||||||
"llava-v1.5-7b-4096-preview",
|
|
||||||
"mixtral-8x7b-32768",
|
"mixtral-8x7b-32768",
|
||||||
"distil-whisper-large-v3-en",
|
|
||||||
"whisper-large-v3",
|
"whisper-large-v3",
|
||||||
"whisper-large-v3-turbo",
|
"whisper-large-v3-turbo",
|
||||||
|
|
||||||
|
// Preview Models
|
||||||
|
"qwen-qwq-32b",
|
||||||
|
"mistral-saba-24b",
|
||||||
|
"qwen-2.5-coder-32b",
|
||||||
|
"qwen-2.5-32b",
|
||||||
|
"deepseek-r1-distill-qwen-32b",
|
||||||
"deepseek-r1-distill-llama-70b-specdec",
|
"deepseek-r1-distill-llama-70b-specdec",
|
||||||
"deepseek-r1-distill-llama-70b",
|
"deepseek-r1-distill-llama-70b",
|
||||||
|
"llama-3.2-1b-preview",
|
||||||
|
"llama-3.2-3b-preview",
|
||||||
|
"llama-3.2-11b-vision-preview",
|
||||||
|
"llama-3.2-90b-vision-preview",
|
||||||
|
"llama-3.3-70b-specdec",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ type Adaptor interface {
|
|||||||
GetRequestURL(meta *meta.Meta) (string, error)
|
GetRequestURL(meta *meta.Meta) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertImageRequest(request *model.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
for _, part := range openaiContent {
|
for _, part := range openaiContent {
|
||||||
switch part.Type {
|
switch part.Type {
|
||||||
case model.ContentTypeText:
|
case model.ContentTypeText:
|
||||||
contentText = part.Text
|
if part.Text != nil {
|
||||||
|
contentText = *part.Text
|
||||||
|
}
|
||||||
case model.ContentTypeImageURL:
|
case model.ContentTypeImageURL:
|
||||||
_, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
_, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||||
imageUrls = append(imageUrls, data)
|
imageUrls = append(imageUrls, data)
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/alibailian"
|
"github.com/songquanpeng/one-api/relay/adaptor/alibailian"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/baiduv2"
|
"github.com/songquanpeng/one-api/relay/adaptor/baiduv2"
|
||||||
@@ -16,6 +20,8 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/geminiv2"
|
"github.com/songquanpeng/one-api/relay/adaptor/geminiv2"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openrouter"
|
||||||
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
@@ -33,16 +39,24 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
|||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Azure:
|
case channeltype.Azure:
|
||||||
|
defaultVersion := meta.Config.APIVersion
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning?tabs=python#api--feature-support
|
||||||
|
if strings.HasPrefix(meta.ActualModelName, "o1") ||
|
||||||
|
strings.HasPrefix(meta.ActualModelName, "o3") {
|
||||||
|
defaultVersion = "2024-12-01-preview"
|
||||||
|
}
|
||||||
|
|
||||||
if meta.Mode == relaymode.ImagesGenerations {
|
if meta.Mode == relaymode.ImagesGenerations {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
|
||||||
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion)
|
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, defaultVersion)
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion)
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, defaultVersion)
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
model_ := meta.ActualModelName
|
model_ := meta.ActualModelName
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
@@ -85,28 +99,92 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
if request.Stream {
|
|
||||||
|
meta := meta.GetByContext(c)
|
||||||
|
switch meta.ChannelType {
|
||||||
|
case channeltype.OpenRouter:
|
||||||
|
includeReasoning := true
|
||||||
|
request.IncludeReasoning = &includeReasoning
|
||||||
|
if request.Provider == nil || request.Provider.Sort == "" &&
|
||||||
|
config.OpenrouterProviderSort != "" {
|
||||||
|
if request.Provider == nil {
|
||||||
|
request.Provider = &openrouter.RequestProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Provider.Sort = config.OpenrouterProviderSort
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream && !config.EnforceIncludeUsage {
|
||||||
|
logger.Warn(c.Request.Context(),
|
||||||
|
"please set ENFORCE_INCLUDE_USAGE=true to ensure accurate billing in stream mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnforceIncludeUsage && request.Stream {
|
||||||
// always return usage in stream mode
|
// always return usage in stream mode
|
||||||
if request.StreamOptions == nil {
|
if request.StreamOptions == nil {
|
||||||
request.StreamOptions = &model.StreamOptions{}
|
request.StreamOptions = &model.StreamOptions{}
|
||||||
}
|
}
|
||||||
request.StreamOptions.IncludeUsage = true
|
request.StreamOptions.IncludeUsage = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// o1/o1-mini/o1-preview do not support system prompt/max_tokens/temperature
|
||||||
|
if strings.HasPrefix(meta.ActualModelName, "o1") ||
|
||||||
|
strings.HasPrefix(meta.ActualModelName, "o3") {
|
||||||
|
temperature := float64(1)
|
||||||
|
request.Temperature = &temperature // Only the default (1) value is supported
|
||||||
|
|
||||||
|
request.MaxTokens = 0
|
||||||
|
request.Messages = func(raw []model.Message) (filtered []model.Message) {
|
||||||
|
for i := range raw {
|
||||||
|
if raw[i].Role != "system" {
|
||||||
|
filtered = append(filtered, raw[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}(request.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// web search do not support system prompt/max_tokens/temperature
|
||||||
|
if strings.HasPrefix(meta.ActualModelName, "gpt-4o-search") ||
|
||||||
|
strings.HasPrefix(meta.ActualModelName, "gpt-4o-mini-search") {
|
||||||
|
request.Temperature = nil
|
||||||
|
request.TopP = nil
|
||||||
|
request.PresencePenalty = nil
|
||||||
|
request.N = nil
|
||||||
|
request.FrequencyPenalty = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Stream && !config.EnforceIncludeUsage &&
|
||||||
|
(strings.HasPrefix(request.Model, "gpt-4o-audio") ||
|
||||||
|
strings.HasPrefix(request.Model, "gpt-4o-mini-audio")) {
|
||||||
|
// TODO: Since it is not clear how to implement billing in stream mode,
|
||||||
|
// it is temporarily not supported
|
||||||
|
return nil, errors.New("set ENFORCE_INCLUDE_USAGE=true to enable stream mode for gpt-4o-audio")
|
||||||
|
}
|
||||||
|
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context,
|
||||||
|
meta *meta.Meta,
|
||||||
|
requestBody io.Reader) (*http.Response, error) {
|
||||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
meta *meta.Meta) (usage *model.Usage,
|
||||||
|
err *model.ErrorWithStatusCode) {
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
|
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
|
||||||
@@ -121,10 +199,61 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
|||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.ImagesGenerations:
|
case relaymode.ImagesGenerations:
|
||||||
err, _ = ImageHandler(c, resp)
|
err, _ = ImageHandler(c, resp)
|
||||||
|
case relaymode.ImagesEdits:
|
||||||
|
err, _ = ImagesEditsHandler(c, resp)
|
||||||
default:
|
default:
|
||||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -------------------------------------
|
||||||
|
// calculate web-search tool cost
|
||||||
|
// -------------------------------------
|
||||||
|
if usage != nil {
|
||||||
|
searchContextSize := "medium"
|
||||||
|
var req *model.GeneralOpenAIRequest
|
||||||
|
if vi, ok := c.Get(ctxkey.ConvertedRequest); ok {
|
||||||
|
if req, ok = vi.(*model.GeneralOpenAIRequest); ok {
|
||||||
|
if req != nil &&
|
||||||
|
req.WebSearchOptions != nil &&
|
||||||
|
req.WebSearchOptions.SearchContextSize != nil {
|
||||||
|
searchContextSize = *req.WebSearchOptions.SearchContextSize
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(meta.ActualModelName, "gpt-4o-search"):
|
||||||
|
switch searchContextSize {
|
||||||
|
case "low":
|
||||||
|
usage.ToolsCost += int64(math.Ceil(30 / 1000 * ratio.QuotaPerUsd))
|
||||||
|
case "medium":
|
||||||
|
usage.ToolsCost += int64(math.Ceil(35 / 1000 * ratio.QuotaPerUsd))
|
||||||
|
case "high":
|
||||||
|
usage.ToolsCost += int64(math.Ceil(40 / 1000 * ratio.QuotaPerUsd))
|
||||||
|
default:
|
||||||
|
return nil, ErrorWrapper(
|
||||||
|
errors.Errorf("invalid search context size %q", searchContextSize),
|
||||||
|
"invalid search context size: "+searchContextSize,
|
||||||
|
http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(meta.ActualModelName, "gpt-4o-mini-search"):
|
||||||
|
switch searchContextSize {
|
||||||
|
case "low":
|
||||||
|
usage.ToolsCost += int64(math.Ceil(25 / 1000 * ratio.QuotaPerUsd))
|
||||||
|
case "medium":
|
||||||
|
usage.ToolsCost += int64(math.Ceil(27.5 / 1000 * ratio.QuotaPerUsd))
|
||||||
|
case "high":
|
||||||
|
usage.ToolsCost += int64(math.Ceil(30 / 1000 * ratio.QuotaPerUsd))
|
||||||
|
default:
|
||||||
|
return nil, ErrorWrapper(
|
||||||
|
errors.Errorf("invalid search context size %q", searchContextSize),
|
||||||
|
"invalid search context size: "+searchContextSize,
|
||||||
|
http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,11 +7,10 @@ var ModelList = []string{
|
|||||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4o", "gpt-4o-2024-05-13",
|
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "chatgpt-4o-latest",
|
||||||
"gpt-4o-2024-08-06",
|
|
||||||
"gpt-4o-2024-11-20",
|
|
||||||
"chatgpt-4o-latest",
|
|
||||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||||
|
"gpt-4o-mini-audio-preview", "gpt-4o-mini-audio-preview-2024-12-17",
|
||||||
|
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||||
@@ -24,4 +23,8 @@ var ModelList = []string{
|
|||||||
"o1", "o1-2024-12-17",
|
"o1", "o1-2024-12-17",
|
||||||
"o1-preview", "o1-preview-2024-09-12",
|
"o1-preview", "o1-preview-2024-09-12",
|
||||||
"o1-mini", "o1-mini-2024-09-12",
|
"o1-mini", "o1-mini-2024-09-12",
|
||||||
|
"o3-mini", "o3-mini-2025-01-31",
|
||||||
|
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
|
||||||
|
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat
|
||||||
|
"gpt-4o-search-preview", "gpt-4o-mini-search-preview",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,30 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ImagesEditsHandler just copy response body to client
|
||||||
|
//
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/createEdit
|
||||||
|
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
var imageResponse ImageResponse
|
var imageResponse ImageResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -5,15 +5,16 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/render"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/conv"
|
"github.com/songquanpeng/one-api/common/conv"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
)
|
)
|
||||||
@@ -24,128 +25,300 @@ const (
|
|||||||
dataPrefixLength = len(dataPrefix)
|
dataPrefixLength = len(dataPrefix)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StreamHandler processes streaming responses from OpenAI API
|
||||||
|
// It handles incremental content delivery and accumulates the final response text
|
||||||
|
// Returns error (if any), accumulated response text, and token usage information
|
||||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
||||||
|
// Initialize accumulators for the response
|
||||||
responseText := ""
|
responseText := ""
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
reasoningText := ""
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
var usage *model.Usage
|
var usage *model.Usage
|
||||||
|
|
||||||
|
// Set up scanner for reading the stream line by line
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
|
||||||
|
scanner.Buffer(buffer, len(buffer))
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
// Set response headers for SSE
|
||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
doneRendered := false
|
doneRendered := false
|
||||||
|
|
||||||
|
// Process each line from the stream
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := NormalizeDataLine(scanner.Text())
|
||||||
if len(data) < dataPrefixLength { // ignore blank line or wrong format
|
|
||||||
continue
|
// logger.Debugf(c.Request.Context(), "stream response: %s", data)
|
||||||
|
|
||||||
|
// Skip lines that don't match expected format
|
||||||
|
if len(data) < dataPrefixLength {
|
||||||
|
continue // Ignore blank line or wrong format
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify line starts with expected prefix
|
||||||
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for stream termination
|
||||||
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
doneRendered = true
|
doneRendered = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process based on relay mode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
var streamResponse ChatCompletionsStreamResponse
|
var streamResponse ChatCompletionsStreamResponse
|
||||||
|
|
||||||
|
// Parse the JSON response
|
||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.Errorf(c.Request.Context(), "unmarshalling stream data %q got %+v", data, err)
|
||||||
render.StringData(c, data) // if error happened, pass the data to client
|
render.StringData(c, data) // Pass raw data to client if parsing fails
|
||||||
continue // just ignore the error
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip empty choices (Azure specific behavior)
|
||||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
continue
|
||||||
continue // just ignore empty choice
|
|
||||||
}
|
}
|
||||||
render.StringData(c, data)
|
|
||||||
|
// Process each choice in the response
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
|
// Extract reasoning content from different possible fields
|
||||||
|
currentReasoningChunk := extractReasoningContent(&choice.Delta)
|
||||||
|
|
||||||
|
// Update accumulated reasoning text
|
||||||
|
if currentReasoningChunk != "" {
|
||||||
|
reasoningText += currentReasoningChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the reasoning content in the format requested by client
|
||||||
|
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
|
||||||
|
|
||||||
|
// Accumulate response content
|
||||||
responseText += conv.AsString(choice.Delta.Content)
|
responseText += conv.AsString(choice.Delta.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send the processed data to the client
|
||||||
|
render.StringData(c, data)
|
||||||
|
|
||||||
|
// Update usage information if available
|
||||||
if streamResponse.Usage != nil {
|
if streamResponse.Usage != nil {
|
||||||
usage = streamResponse.Usage
|
usage = streamResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
case relaymode.Completions:
|
case relaymode.Completions:
|
||||||
|
// Send the data immediately for Completions mode
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
|
|
||||||
var streamResponse CompletionsStreamResponse
|
var streamResponse CompletionsStreamResponse
|
||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accumulate text from all choices
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseText += choice.Text
|
responseText += choice.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for scanner errors
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
logger.SysError("error reading stream: " + err.Error())
|
logger.SysError("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure stream termination is sent to client
|
||||||
if !doneRendered {
|
if !doneRendered {
|
||||||
render.Done(c)
|
render.Done(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := resp.Body.Close()
|
// Clean up resources
|
||||||
if err != nil {
|
if err := resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, responseText, usage
|
// Return the complete response text (reasoning + content) and usage
|
||||||
|
return nil, reasoningText + responseText, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to extract reasoning content from message delta
|
||||||
|
func extractReasoningContent(delta *model.Message) string {
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
// Extract reasoning from different possible fields
|
||||||
|
if delta.Reasoning != nil {
|
||||||
|
content += *delta.Reasoning
|
||||||
|
delta.Reasoning = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if delta.ReasoningContent != nil {
|
||||||
|
content += *delta.ReasoningContent
|
||||||
|
delta.ReasoningContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler processes non-streaming responses from OpenAI API
|
||||||
|
// Returns error (if any) and token usage information
|
||||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
var textResponse SlimTextResponse
|
// Read the entire response body
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
// Close the original response body
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
// Parse the response JSON
|
||||||
|
var textResponse SlimTextResponse
|
||||||
|
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
|
||||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for API errors
|
||||||
if textResponse.Error.Type != "" {
|
if textResponse.Error.Type != "" {
|
||||||
return &model.ErrorWithStatusCode{
|
return &model.ErrorWithStatusCode{
|
||||||
Error: textResponse.Error,
|
Error: textResponse.Error,
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// Process reasoning content in each choice
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
for _, msg := range textResponse.Choices {
|
||||||
// So the HTTPClient will be confused by the response.
|
reasoningContent := processReasoningContent(&msg)
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
// Set reasoning in requested format if content exists
|
||||||
c.Writer.Header().Set(k, v[0])
|
if reasoningContent != "" {
|
||||||
|
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset response body for forwarding to client
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
|
||||||
|
|
||||||
|
// Forward all response headers (not just first value of each)
|
||||||
|
for k, values := range resp.Header {
|
||||||
|
for _, v := range values {
|
||||||
|
c.Writer.Header().Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set response status and copy body to client
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
if err != nil {
|
|
||||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
// Close the reset body
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
|
// Calculate token usage if not provided by API
|
||||||
|
calculateTokenUsage(&textResponse, promptTokens, modelName)
|
||||||
|
|
||||||
|
return nil, &textResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// processReasoningContent is a helper function to extract and process reasoning content from the message
|
||||||
|
func processReasoningContent(msg *TextResponseChoice) string {
|
||||||
|
var reasoningContent string
|
||||||
|
|
||||||
|
// Check different locations for reasoning content
|
||||||
|
switch {
|
||||||
|
case msg.Reasoning != nil:
|
||||||
|
reasoningContent = *msg.Reasoning
|
||||||
|
msg.Reasoning = nil
|
||||||
|
case msg.ReasoningContent != nil:
|
||||||
|
reasoningContent = *msg.ReasoningContent
|
||||||
|
msg.ReasoningContent = nil
|
||||||
|
case msg.Message.Reasoning != nil:
|
||||||
|
reasoningContent = *msg.Message.Reasoning
|
||||||
|
msg.Message.Reasoning = nil
|
||||||
|
case msg.Message.ReasoningContent != nil:
|
||||||
|
reasoningContent = *msg.Message.ReasoningContent
|
||||||
|
msg.Message.ReasoningContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return reasoningContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to calculate token usage
|
||||||
|
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
|
||||||
|
// Calculate tokens if not provided by the API
|
||||||
|
if response.Usage.TotalTokens == 0 ||
|
||||||
|
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
|
||||||
|
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range textResponse.Choices {
|
for _, choice := range response.Choices {
|
||||||
|
// Count content tokens
|
||||||
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
||||||
|
|
||||||
|
// Count reasoning tokens in all possible locations
|
||||||
|
if choice.Message.Reasoning != nil {
|
||||||
|
completionTokens += CountToken(*choice.Message.Reasoning)
|
||||||
|
}
|
||||||
|
if choice.Message.ReasoningContent != nil {
|
||||||
|
completionTokens += CountToken(*choice.Message.ReasoningContent)
|
||||||
|
}
|
||||||
|
if choice.Reasoning != nil {
|
||||||
|
completionTokens += CountToken(*choice.Reasoning)
|
||||||
|
}
|
||||||
|
if choice.ReasoningContent != nil {
|
||||||
|
completionTokens += CountToken(*choice.ReasoningContent)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
textResponse.Usage = model.Usage{
|
|
||||||
|
// Set usage values
|
||||||
|
response.Usage = model.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
||||||
}
|
}
|
||||||
|
} else if hasAudioTokens(response) {
|
||||||
|
// Handle audio tokens conversion
|
||||||
|
calculateAudioTokens(response, modelName)
|
||||||
}
|
}
|
||||||
return nil, &textResponse.Usage
|
}
|
||||||
|
|
||||||
|
// Helper function to check if response has audio tokens
|
||||||
|
func hasAudioTokens(response *SlimTextResponse) bool {
|
||||||
|
return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
|
||||||
|
(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to calculate audio token usage
|
||||||
|
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
|
||||||
|
// Convert audio tokens for prompt
|
||||||
|
if response.PromptTokensDetails != nil {
|
||||||
|
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
|
||||||
|
int(math.Ceil(
|
||||||
|
float64(response.PromptTokensDetails.AudioTokens)*
|
||||||
|
ratio.GetAudioPromptRatio(modelName),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert audio tokens for completion
|
||||||
|
if response.CompletionTokensDetails != nil {
|
||||||
|
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
|
||||||
|
int(math.Ceil(
|
||||||
|
float64(response.CompletionTokensDetails.AudioTokens)*
|
||||||
|
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate total tokens
|
||||||
|
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/model"
|
import (
|
||||||
|
"mime/multipart"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
type TextContent struct {
|
type TextContent struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
@@ -71,6 +75,24 @@ type TextToSpeechRequest struct {
|
|||||||
ResponseFormat string `json:"response_format"`
|
ResponseFormat string `json:"response_format"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AudioTranscriptionRequest struct {
|
||||||
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||||
|
Model string `form:"model" binding:"required"`
|
||||||
|
Language string `form:"language"`
|
||||||
|
Prompt string `form:"prompt"`
|
||||||
|
ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
|
||||||
|
Temperature float64 `form:"temperature"`
|
||||||
|
TimestampGranularity []string `form:"timestamp_granularity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioTranslationRequest struct {
|
||||||
|
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||||
|
Model string `form:"model" binding:"required"`
|
||||||
|
Prompt string `form:"prompt"`
|
||||||
|
ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"`
|
||||||
|
Temperature float64 `form:"temperature"`
|
||||||
|
}
|
||||||
|
|
||||||
type UsageOrResponseText struct {
|
type UsageOrResponseText struct {
|
||||||
*model.Usage
|
*model.Usage
|
||||||
ResponseText string
|
ResponseText string
|
||||||
@@ -110,12 +132,14 @@ type EmbeddingResponse struct {
|
|||||||
model.Usage `json:"usage"`
|
model.Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageData represents an image in the response
|
||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Url string `json:"url,omitempty"`
|
Url string `json:"url,omitempty"`
|
||||||
B64Json string `json:"b64_json,omitempty"`
|
B64Json string `json:"b64_json,omitempty"`
|
||||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageResponse represents the response structure for image generations
|
||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Data []ImageData `json:"data"`
|
Data []ImageData `json:"data"`
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/image"
|
"github.com/songquanpeng/one-api/common/image"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@@ -73,8 +77,10 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []model.Message, model string) int {
|
// CountTokenMessages counts the number of tokens in a list of messages.
|
||||||
tokenEncoder := getTokenEncoder(model)
|
func CountTokenMessages(ctx context.Context,
|
||||||
|
messages []model.Message, actualModel string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(actualModel)
|
||||||
// Reference:
|
// Reference:
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
@@ -82,47 +88,54 @@ func CountTokenMessages(messages []model.Message, model string) int {
|
|||||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
var tokensPerMessage int
|
var tokensPerMessage int
|
||||||
var tokensPerName int
|
var tokensPerName int
|
||||||
if model == "gpt-3.5-turbo-0301" {
|
if actualModel == "gpt-3.5-turbo-0301" {
|
||||||
tokensPerMessage = 4
|
tokensPerMessage = 4
|
||||||
tokensPerName = -1 // If there's a name, the role is omitted
|
tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
} else {
|
} else {
|
||||||
tokensPerMessage = 3
|
tokensPerMessage = 3
|
||||||
tokensPerName = 1
|
tokensPerName = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenNum := 0
|
tokenNum := 0
|
||||||
|
var totalAudioTokens float64
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
switch v := message.Content.(type) {
|
contents := message.ParseContent()
|
||||||
case string:
|
for _, content := range contents {
|
||||||
tokenNum += getTokenNum(tokenEncoder, v)
|
switch content.Type {
|
||||||
case []any:
|
case model.ContentTypeText:
|
||||||
for _, it := range v {
|
if content.Text != nil {
|
||||||
m := it.(map[string]any)
|
tokenNum += getTokenNum(tokenEncoder, *content.Text)
|
||||||
switch m["type"] {
|
}
|
||||||
case "text":
|
case model.ContentTypeImageURL:
|
||||||
if textValue, ok := m["text"]; ok {
|
imageTokens, err := countImageTokens(
|
||||||
if textString, ok := textValue.(string); ok {
|
content.ImageURL.Url,
|
||||||
tokenNum += getTokenNum(tokenEncoder, textString)
|
content.ImageURL.Detail,
|
||||||
}
|
actualModel)
|
||||||
}
|
if err != nil {
|
||||||
case "image_url":
|
logger.SysError("error counting image tokens: " + err.Error())
|
||||||
imageUrl, ok := m["image_url"].(map[string]any)
|
} else {
|
||||||
if ok {
|
tokenNum += imageTokens
|
||||||
url := imageUrl["url"].(string)
|
}
|
||||||
detail := ""
|
case model.ContentTypeInputAudio:
|
||||||
if imageUrl["detail"] != nil {
|
audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
|
||||||
detail = imageUrl["detail"].(string)
|
if err != nil {
|
||||||
}
|
logger.SysError("error decoding audio data: " + err.Error())
|
||||||
imageTokens, err := countImageTokens(url, detail, model)
|
}
|
||||||
if err != nil {
|
|
||||||
logger.SysError("error counting image tokens: " + err.Error())
|
audioTokens, err := helper.GetAudioTokens(ctx,
|
||||||
} else {
|
bytes.NewReader(audioData),
|
||||||
tokenNum += imageTokens
|
ratio.GetAudioPromptTokensPerSecond(actualModel))
|
||||||
}
|
if err != nil {
|
||||||
}
|
logger.SysError("error counting audio tokens: " + err.Error())
|
||||||
|
} else {
|
||||||
|
totalAudioTokens += audioTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenNum += int(math.Ceil(totalAudioTokens))
|
||||||
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if message.Name != nil {
|
if message.Name != nil {
|
||||||
tokenNum += tokensPerName
|
tokenNum += tokensPerName
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
@@ -21,3 +22,11 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
|
|||||||
StatusCode: statusCode,
|
StatusCode: statusCode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NormalizeDataLine(data string) string {
|
||||||
|
if strings.HasPrefix(data, "data:") {
|
||||||
|
content := strings.TrimLeft(data[len("data:"):], " ")
|
||||||
|
return "data: " + content
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|||||||
22
relay/adaptor/openrouter/model.go
Normal file
22
relay/adaptor/openrouter/model.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package openrouter
|
||||||
|
|
||||||
|
// RequestProvider customize how your requests are routed using the provider object
|
||||||
|
// in the request body for Chat Completions and Completions.
|
||||||
|
//
|
||||||
|
// https://openrouter.ai/docs/features/provider-routing
|
||||||
|
type RequestProvider struct {
|
||||||
|
// Order is list of provider names to try in order (e.g. ["Anthropic", "OpenAI"]). Default: empty
|
||||||
|
Order []string `json:"order,omitempty"`
|
||||||
|
// AllowFallbacks is whether to allow backup providers when the primary is unavailable. Default: true
|
||||||
|
AllowFallbacks bool `json:"allow_fallbacks,omitempty"`
|
||||||
|
// RequireParameters is only use providers that support all parameters in your request. Default: false
|
||||||
|
RequireParameters bool `json:"require_parameters,omitempty"`
|
||||||
|
// DataCollection is control whether to use providers that may store data ("allow" or "deny"). Default: "allow"
|
||||||
|
DataCollection string `json:"data_collection,omitempty" binding:"omitempty,oneof=allow deny"`
|
||||||
|
// Ignore is list of provider names to skip for this request. Default: empty
|
||||||
|
Ignore []string `json:"ignore,omitempty"`
|
||||||
|
// Quantizations is list of quantization levels to filter by (e.g. ["int4", "int8"]). Default: empty
|
||||||
|
Quantizations []string `json:"quantizations,omitempty"`
|
||||||
|
// Sort is sort providers by price or throughput (e.g. "price" or "throughput"). Default: empty
|
||||||
|
Sort string `json:"sort,omitempty" binding:"omitempty,oneof=price throughput latency"`
|
||||||
|
}
|
||||||
@@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,11 +25,17 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
Prompt: Prompt{
|
Prompt: Prompt{
|
||||||
Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
|
Messages: make([]ChatMessage, 0, len(textRequest.Messages)),
|
||||||
},
|
},
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
CandidateCount: textRequest.N,
|
TopP: textRequest.TopP,
|
||||||
TopP: textRequest.TopP,
|
TopK: textRequest.MaxTokens,
|
||||||
TopK: textRequest.MaxTokens,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if textRequest.N != nil {
|
||||||
|
palmRequest.CandidateCount = *textRequest.N
|
||||||
|
} else {
|
||||||
|
palmRequest.CandidateCount = 1
|
||||||
|
}
|
||||||
|
|
||||||
for _, message := range textRequest.Messages {
|
for _, message := range textRequest.Messages {
|
||||||
palmMessage := ChatMessage{
|
palmMessage := ChatMessage{
|
||||||
Content: message.StringContent(),
|
Content: message.StringContent(),
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.Errorf("not implement")
|
return nil, errors.Errorf("not implement")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return DrawImageRequest{
|
return DrawImageRequest{
|
||||||
Input: ImageInput{
|
Input: ImageInput{
|
||||||
Steps: 25,
|
Steps: 25,
|
||||||
|
|||||||
@@ -33,9 +33,16 @@ var ModelList = []string{
|
|||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
// language model
|
// language model
|
||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
|
"anthropic/claude-3.5-haiku",
|
||||||
|
"anthropic/claude-3.5-sonnet",
|
||||||
|
"anthropic/claude-3.7-sonnet",
|
||||||
|
"deepseek-ai/deepseek-r1",
|
||||||
"ibm-granite/granite-20b-code-instruct-8k",
|
"ibm-granite/granite-20b-code-instruct-8k",
|
||||||
"ibm-granite/granite-3.0-2b-instruct",
|
"ibm-granite/granite-3.0-2b-instruct",
|
||||||
"ibm-granite/granite-3.0-8b-instruct",
|
"ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-3.1-2b-instruct",
|
||||||
|
"ibm-granite/granite-3.1-8b-instruct",
|
||||||
|
"ibm-granite/granite-3.2-8b-instruct",
|
||||||
"ibm-granite/granite-8b-code-instruct-128k",
|
"ibm-granite/granite-8b-code-instruct-128k",
|
||||||
"meta/llama-2-13b",
|
"meta/llama-2-13b",
|
||||||
"meta/llama-2-13b-chat",
|
"meta/llama-2-13b-chat",
|
||||||
@@ -50,7 +57,6 @@ var ModelList = []string{
|
|||||||
"meta/meta-llama-3-8b-instruct",
|
"meta/meta-llama-3-8b-instruct",
|
||||||
"mistralai/mistral-7b-instruct-v0.2",
|
"mistralai/mistral-7b-instruct-v0.2",
|
||||||
"mistralai/mistral-7b-v0.1",
|
"mistralai/mistral-7b-v0.1",
|
||||||
"mistralai/mixtral-8x7b-instruct-v0.1",
|
|
||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
// video model
|
// video model
|
||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return convertedRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ var ModelList = []string{
|
|||||||
"claude-3-5-sonnet@20240620",
|
"claude-3-5-sonnet@20240620",
|
||||||
"claude-3-5-sonnet-v2@20241022",
|
"claude-3-5-sonnet-v2@20241022",
|
||||||
"claude-3-5-haiku@20241022",
|
"claude-3-5-haiku@20241022",
|
||||||
|
"claude-3-7-sonnet@20250219",
|
||||||
}
|
}
|
||||||
|
|
||||||
const anthropicVersion = "vertex-2023-10-16"
|
const anthropicVersion = "vertex-2023-10-16"
|
||||||
@@ -31,7 +32,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeReq := anthropic.ConvertRequest(*request)
|
claudeReq, err := anthropic.ConvertRequest(c, *request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request")
|
||||||
|
}
|
||||||
|
|
||||||
req := Request{
|
req := Request{
|
||||||
AnthropicVersion: anthropicVersion,
|
AnthropicVersion: anthropicVersion,
|
||||||
// Model: claudeReq.Model,
|
// Model: claudeReq.Model,
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,10 +41,15 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
|
|||||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
xunfeiRequest.Header.AppId = xunfeiAppId
|
||||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
|
||||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||||
xunfeiRequest.Payload.Message.Text = messages
|
xunfeiRequest.Payload.Message.Text = messages
|
||||||
|
|
||||||
|
if request.N != nil {
|
||||||
|
xunfeiRequest.Parameter.Chat.TopK = *request.N
|
||||||
|
} else {
|
||||||
|
xunfeiRequest.Parameter.Chat.TopK = 1
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
|
if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
|
||||||
functions := make([]model.Function, len(request.Tools))
|
functions := make([]model.Function, len(request.Tools))
|
||||||
for i, tool := range request.Tools {
|
for i, tool := range request.Tools {
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -8,18 +8,16 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
|
||||||
"github.com/songquanpeng/one-api/relay/constant/role"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
|
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||||
"github.com/songquanpeng/one-api/relay/controller/validator"
|
"github.com/songquanpeng/one-api/relay/controller/validator"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
@@ -45,10 +43,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
|||||||
return textRequest, nil
|
return textRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model)
|
||||||
case relaymode.Completions:
|
case relaymode.Completions:
|
||||||
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
case relaymode.Moderations:
|
case relaymode.Moderations:
|
||||||
@@ -94,19 +92,30 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
|
|||||||
return preConsumedQuota, nil
|
return preConsumedQuota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) {
|
func postConsumeQuota(ctx context.Context,
|
||||||
|
usage *relaymodel.Usage,
|
||||||
|
meta *meta.Meta,
|
||||||
|
textRequest *relaymodel.GeneralOpenAIRequest,
|
||||||
|
ratio float64,
|
||||||
|
preConsumedQuota int64,
|
||||||
|
modelRatio float64,
|
||||||
|
groupRatio float64,
|
||||||
|
systemPromptReset bool) (quota int64) {
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var quota int64
|
|
||||||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
|
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
|
// It appears that DeepSeek's official service automatically merges ReasoningTokens into CompletionTokens,
|
||||||
|
// but the behavior of third-party providers may differ, so for now we do not add them manually.
|
||||||
|
// completionTokens := usage.CompletionTokens + usage.CompletionTokensDetails.ReasoningTokens
|
||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
quota = int64(math.Ceil((float64(promptTokens)+float64(completionTokens)*completionRatio)*ratio)) + usage.ToolsCost
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
quota = 1
|
quota = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
totalTokens := promptTokens + completionTokens
|
totalTokens := promptTokens + completionTokens
|
||||||
if totalTokens == 0 {
|
if totalTokens == 0 {
|
||||||
// in this case, must be some error happened
|
// in this case, must be some error happened
|
||||||
@@ -122,7 +131,13 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("倍率:%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio)
|
|
||||||
|
var logContent string
|
||||||
|
if usage.ToolsCost == 0 {
|
||||||
|
logContent = fmt.Sprintf("倍率:%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio)
|
||||||
|
} else {
|
||||||
|
logContent = fmt.Sprintf("倍率:%.2f × %.2f × %.2f, tools cost %d", modelRatio, groupRatio, completionRatio, usage.ToolsCost)
|
||||||
|
}
|
||||||
model.RecordConsumeLog(ctx, &model.Log{
|
model.RecordConsumeLog(ctx, &model.Log{
|
||||||
UserId: meta.UserId,
|
UserId: meta.UserId,
|
||||||
ChannelId: meta.ChannelId,
|
ChannelId: meta.ChannelId,
|
||||||
@@ -138,6 +153,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
|||||||
})
|
})
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||||
|
|
||||||
|
return quota
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
channeltype.Ali,
|
channeltype.Ali,
|
||||||
channeltype.Replicate,
|
channeltype.Replicate,
|
||||||
channeltype.Baidu:
|
channeltype.Baidu:
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay"
|
"github.com/songquanpeng/one-api/relay"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
@@ -44,7 +45,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
// pre-consume quota
|
// pre-consume quota
|
||||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode)
|
||||||
meta.PromptTokens = promptTokens
|
meta.PromptTokens = promptTokens
|
||||||
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
||||||
if bizErr != nil {
|
if bizErr != nil {
|
||||||
@@ -104,6 +105,8 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
|
|||||||
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
|
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c.Set(ctxkey.ConvertedRequest, convertedRequest)
|
||||||
|
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
|
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
|
import "github.com/songquanpeng/one-api/relay/adaptor/openrouter"
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
|
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
|
||||||
@@ -23,49 +25,103 @@ type StreamOptions struct {
|
|||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
// https://platform.openai.com/docs/api-reference/chat/create
|
// https://platform.openai.com/docs/api-reference/chat/create
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Store *bool `json:"store,omitempty"`
|
Store *bool `json:"store,omitempty"`
|
||||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
Metadata any `json:"metadata,omitempty"`
|
||||||
Metadata any `json:"metadata,omitempty"`
|
// FrequencyPenalty is a number between -2.0 and 2.0 that penalizes
|
||||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
// new tokens based on their existing frequency in the text so far,
|
||||||
LogitBias any `json:"logit_bias,omitempty"`
|
// default is 0.
|
||||||
Logprobs *bool `json:"logprobs,omitempty"`
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty" binding:"omitempty,min=-2,max=2"`
|
||||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
LogitBias any `json:"logit_bias,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
Logprobs *bool `json:"logprobs,omitempty"`
|
||||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Modalities []string `json:"modalities,omitempty"`
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
Prediction any `json:"prediction,omitempty"`
|
// N is how many chat completion choices to generate for each input message,
|
||||||
Audio *Audio `json:"audio,omitempty"`
|
// default to 1.
|
||||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
N *int `json:"n,omitempty" binding:"omitempty,min=0"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
// ReasoningEffort constrains effort on reasoning for reasoning models, reasoning models only.
|
||||||
Seed float64 `json:"seed,omitempty"`
|
ReasoningEffort *string `json:"reasoning_effort,omitempty" binding:"omitempty,oneof=low medium high"`
|
||||||
ServiceTier *string `json:"service_tier,omitempty"`
|
// Modalities currently the model only programmatically allows modalities = [“text”, “audio”]
|
||||||
Stop any `json:"stop,omitempty"`
|
Modalities []string `json:"modalities,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Prediction any `json:"prediction,omitempty"`
|
||||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
Audio *Audio `json:"audio,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
// PresencePenalty is a number between -2.0 and 2.0 that penalizes
|
||||||
TopP *float64 `json:"top_p,omitempty"`
|
// new tokens based on whether they appear in the text so far, default is 0.
|
||||||
TopK int `json:"top_k,omitempty"`
|
PresencePenalty *float64 `json:"presence_penalty,omitempty" binding:"omitempty,min=-2,max=2"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty" binding:"omitempty,oneof=default auto"`
|
||||||
User string `json:"user,omitempty"`
|
Stop any `json:"stop,omitempty"`
|
||||||
FunctionCall any `json:"function_call,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Functions any `json:"functions,omitempty"`
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
|
Functions any `json:"functions,omitempty"`
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings/create
|
// https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
Dimensions int `json:"dimensions,omitempty"`
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
// https://platform.openai.com/docs/api-reference/images/create
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
Quality *string `json:"quality,omitempty"`
|
Quality *string `json:"quality,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
Style *string `json:"style,omitempty"`
|
Style *string `json:"style,omitempty"`
|
||||||
|
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||||
|
|
||||||
// Others
|
// Others
|
||||||
Instruction string `json:"instruction,omitempty"`
|
Instruction string `json:"instruction,omitempty"`
|
||||||
NumCtx int `json:"num_ctx,omitempty"`
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
|
// -------------------------------------
|
||||||
|
// Openrouter
|
||||||
|
// -------------------------------------
|
||||||
|
Provider *openrouter.RequestProvider `json:"provider,omitempty"`
|
||||||
|
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
|
||||||
|
// -------------------------------------
|
||||||
|
// Anthropic
|
||||||
|
// -------------------------------------
|
||||||
|
Thinking *Thinking `json:"thinking,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSearchOptions is the tool searches the web for relevant results to use in a response.
|
||||||
|
type WebSearchOptions struct {
|
||||||
|
// SearchContextSize is the high level guidance for the amount of context window space to use for the search,
|
||||||
|
// default is "medium".
|
||||||
|
SearchContextSize *string `json:"search_context_size,omitempty" binding:"omitempty,oneof=low medium high"`
|
||||||
|
UserLocation *UserLocation `json:"user_location,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserLocation is a struct that contains the location of the user.
|
||||||
|
type UserLocation struct {
|
||||||
|
// Approximate is the approximate location parameters for the search.
|
||||||
|
Approximate UserLocationApproximate `json:"approximate" binding:"required"`
|
||||||
|
// Type is the type of location approximation.
|
||||||
|
Type string `json:"type" binding:"required,oneof=approximate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserLocationApproximate is a struct that contains the approximate location of the user.
|
||||||
|
type UserLocationApproximate struct {
|
||||||
|
// City is the city of the user, e.g. San Francisco.
|
||||||
|
City *string `json:"city,omitempty"`
|
||||||
|
// Country is the country of the user, e.g. US.
|
||||||
|
Country *string `json:"country,omitempty"`
|
||||||
|
// Region is the region of the user, e.g. California.
|
||||||
|
Region *string `json:"region,omitempty"`
|
||||||
|
// Timezone is the IANA timezone of the user, e.g. America/Los_Angeles.
|
||||||
|
Timezone *string `json:"timezone,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking
|
||||||
|
type Thinking struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
BudgetTokens int `json:"budget_tokens" binding:"omitempty,min=1024"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
|||||||
@@ -1,12 +1,106 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReasoningFormat is the format of reasoning content,
|
||||||
|
// can be set by the reasoning_format parameter in the request url.
|
||||||
|
type ReasoningFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReasoningFormatUnspecified ReasoningFormat = ""
|
||||||
|
// ReasoningFormatReasoningContent is the reasoning format used by deepseek official API
|
||||||
|
ReasoningFormatReasoningContent ReasoningFormat = "reasoning_content"
|
||||||
|
// ReasoningFormatReasoning is the reasoning format used by openrouter
|
||||||
|
ReasoningFormatReasoning ReasoningFormat = "reasoning"
|
||||||
|
|
||||||
|
// ReasoningFormatThinkTag is the reasoning format used by 3rd party deepseek-r1 providers.
|
||||||
|
//
|
||||||
|
// Deprecated: I believe <think> is a very poor format, especially in stream mode, it is difficult to extract and convert.
|
||||||
|
// Considering that only a few deepseek-r1 third-party providers use this format, it has been decided to no longer support it.
|
||||||
|
// ReasoningFormatThinkTag ReasoningFormat = "think-tag"
|
||||||
|
|
||||||
|
// ReasoningFormatThinking is the reasoning format used by anthropic
|
||||||
|
ReasoningFormatThinking ReasoningFormat = "thinking"
|
||||||
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
Content any `json:"content,omitempty"`
|
// Content is a string or a list of objects
|
||||||
ReasoningContent any `json:"reasoning_content,omitempty"`
|
Content any `json:"content,omitempty"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
||||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
|
Audio *messageAudio `json:"audio,omitempty"`
|
||||||
|
Annotation []AnnotationItem `json:"annotation,omitempty"`
|
||||||
|
|
||||||
|
// -------------------------------------
|
||||||
|
// Deepseek 专有的一些字段
|
||||||
|
// https://api-docs.deepseek.com/api/create-chat-completion
|
||||||
|
// -------------------------------------
|
||||||
|
// Prefix forces the model to begin its answer with the supplied prefix in the assistant message.
|
||||||
|
// To enable this feature, set base_url to "https://api.deepseek.com/beta".
|
||||||
|
Prefix *bool `json:"prefix,omitempty"` // ReasoningContent is Used for the deepseek-reasoner model in the Chat
|
||||||
|
// Prefix Completion feature as the input for the CoT in the last assistant message.
|
||||||
|
// When using this feature, the prefix parameter must be set to true.
|
||||||
|
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||||
|
|
||||||
|
// -------------------------------------
|
||||||
|
// Openrouter
|
||||||
|
// -------------------------------------
|
||||||
|
Reasoning *string `json:"reasoning,omitempty"`
|
||||||
|
Refusal *bool `json:"refusal,omitempty"`
|
||||||
|
|
||||||
|
// -------------------------------------
|
||||||
|
// Anthropic
|
||||||
|
// -------------------------------------
|
||||||
|
Thinking *string `json:"thinking,omitempty"`
|
||||||
|
Signature *string `json:"signature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AnnotationItem struct {
|
||||||
|
Type string `json:"type" binding:"oneof=url_citation"`
|
||||||
|
UrlCitation UrlCitation `json:"url_citation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UrlCitation is a URL citation when using web search.
|
||||||
|
type UrlCitation struct {
|
||||||
|
// Endpoint is the index of the last character of the URL citation in the message.
|
||||||
|
EndIndex int `json:"end_index"`
|
||||||
|
// StartIndex is the index of the first character of the URL citation in the message.
|
||||||
|
StartIndex int `json:"start_index"`
|
||||||
|
// Title is the title of the web resource.
|
||||||
|
Title string `json:"title"`
|
||||||
|
// Url is the URL of the web resource.
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReasoningContent sets the reasoning content based on the format
|
||||||
|
func (m *Message) SetReasoningContent(format string, reasoningContent string) {
|
||||||
|
switch ReasoningFormat(strings.ToLower(strings.TrimSpace(format))) {
|
||||||
|
case ReasoningFormatReasoningContent:
|
||||||
|
m.ReasoningContent = &reasoningContent
|
||||||
|
// case ReasoningFormatThinkTag:
|
||||||
|
// m.Content = fmt.Sprintf("<think>%s</think>%s", reasoningContent, m.Content)
|
||||||
|
case ReasoningFormatThinking:
|
||||||
|
m.Thinking = &reasoningContent
|
||||||
|
case ReasoningFormatReasoning,
|
||||||
|
ReasoningFormatUnspecified:
|
||||||
|
m.Reasoning = &reasoningContent
|
||||||
|
default:
|
||||||
|
logger.Warnf(context.TODO(), "unknown reasoning format: %q", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageAudio struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Data string `json:"data,omitempty"`
|
||||||
|
ExpiredAt int `json:"expired_at,omitempty"`
|
||||||
|
Transcript string `json:"transcript,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Message) IsStringContent() bool {
|
func (m Message) IsStringContent() bool {
|
||||||
@@ -27,6 +121,7 @@ func (m Message) StringContent() string {
|
|||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if contentMap["type"] == ContentTypeText {
|
if contentMap["type"] == ContentTypeText {
|
||||||
if subStr, ok := contentMap["text"].(string); ok {
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
contentStr += subStr
|
contentStr += subStr
|
||||||
@@ -35,6 +130,7 @@ func (m Message) StringContent() string {
|
|||||||
}
|
}
|
||||||
return contentStr
|
return contentStr
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,10 +140,11 @@ func (m Message) ParseContent() []MessageContent {
|
|||||||
if ok {
|
if ok {
|
||||||
contentList = append(contentList, MessageContent{
|
contentList = append(contentList, MessageContent{
|
||||||
Type: ContentTypeText,
|
Type: ContentTypeText,
|
||||||
Text: content,
|
Text: &content,
|
||||||
})
|
})
|
||||||
return contentList
|
return contentList
|
||||||
}
|
}
|
||||||
|
|
||||||
anyList, ok := m.Content.([]any)
|
anyList, ok := m.Content.([]any)
|
||||||
if ok {
|
if ok {
|
||||||
for _, contentItem := range anyList {
|
for _, contentItem := range anyList {
|
||||||
@@ -60,7 +157,7 @@ func (m Message) ParseContent() []MessageContent {
|
|||||||
if subStr, ok := contentMap["text"].(string); ok {
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
contentList = append(contentList, MessageContent{
|
contentList = append(contentList, MessageContent{
|
||||||
Type: ContentTypeText,
|
Type: ContentTypeText,
|
||||||
Text: subStr,
|
Text: &subStr,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
case ContentTypeImageURL:
|
case ContentTypeImageURL:
|
||||||
@@ -72,8 +169,21 @@ func (m Message) ParseContent() []MessageContent {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
case ContentTypeInputAudio:
|
||||||
|
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||||
|
contentList = append(contentList, MessageContent{
|
||||||
|
Type: ContentTypeInputAudio,
|
||||||
|
InputAudio: &InputAudio{
|
||||||
|
Data: subObj["data"].(string),
|
||||||
|
Format: subObj["format"].(string),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
logger.Warnf(context.TODO(), "unknown content type: %s", contentMap["type"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return contentList
|
return contentList
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -85,7 +195,23 @@ type ImageURL struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type MessageContent struct {
|
type MessageContent struct {
|
||||||
Type string `json:"type,omitempty"`
|
// Type should be one of the following: text/input_audio
|
||||||
Text string `json:"text"`
|
Type string `json:"type,omitempty"`
|
||||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
Text *string `json:"text,omitempty"`
|
||||||
|
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||||
|
InputAudio *InputAudio `json:"input_audio,omitempty"`
|
||||||
|
// -------------------------------------
|
||||||
|
// Anthropic
|
||||||
|
// -------------------------------------
|
||||||
|
Thinking *string `json:"thinking,omitempty"`
|
||||||
|
Signature *string `json:"signature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputAudio struct {
|
||||||
|
// Data is the base64 encoded audio data
|
||||||
|
Data string `json:"data" binding:"required"`
|
||||||
|
// Format is the audio format, should be one of the
|
||||||
|
// following: mp3/mp4/mpeg/mpga/m4a/wav/webm/pcm16.
|
||||||
|
// When stream=true, format should be pcm16
|
||||||
|
Format string `json:"format"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,22 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
|
// Usage is the token usage information returned by OpenAI API.
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
// PromptTokensDetails may be empty for some models
|
||||||
|
PromptTokensDetails *usagePromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||||
|
// CompletionTokensDetails may be empty for some models
|
||||||
|
CompletionTokensDetails *usageCompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||||
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
|
||||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
// -------------------------------------
|
||||||
}
|
// Custom fields
|
||||||
|
// -------------------------------------
|
||||||
type CompletionTokensDetails struct {
|
// ToolsCost is the cost of using tools, in quota.
|
||||||
ReasoningTokens int `json:"reasoning_tokens"`
|
ToolsCost int64 `json:"tools_cost,omitempty"`
|
||||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
|
||||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
@@ -25,3 +30,20 @@ type ErrorWithStatusCode struct {
|
|||||||
Error
|
Error
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type usagePromptTokensDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens"`
|
||||||
|
AudioTokens int `json:"audio_tokens"`
|
||||||
|
// TextTokens could be zero for pure text chats
|
||||||
|
TextTokens int `json:"text_tokens"`
|
||||||
|
ImageTokens int `json:"image_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageCompletionTokensDetails struct {
|
||||||
|
ReasoningTokens int `json:"reasoning_tokens"`
|
||||||
|
AudioTokens int `json:"audio_tokens"`
|
||||||
|
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||||
|
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||||
|
// TextTokens could be zero for pure text chats
|
||||||
|
TextTokens int `json:"text_tokens"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ const (
|
|||||||
AudioTranslation
|
AudioTranslation
|
||||||
// Proxy is a special relay mode for proxying requests to custom upstream
|
// Proxy is a special relay mode for proxying requests to custom upstream
|
||||||
Proxy
|
Proxy
|
||||||
|
ImagesEdits
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ const LogsTable = () => {
|
|||||||
if (isAdminUser) {
|
if (isAdminUser) {
|
||||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||||
} else {
|
} else {
|
||||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/self?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
}
|
}
|
||||||
const res = await API.get(url);
|
const res = await API.get(url);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ const LogsTable = () => {
|
|||||||
if (isAdminUser) {
|
if (isAdminUser) {
|
||||||
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||||
} else {
|
} else {
|
||||||
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/self?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
}
|
}
|
||||||
const res = await API.get(url);
|
const res = await API.get(url);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
|
|||||||
Reference in New Issue
Block a user