mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-10 10:33:41 +08:00
Merge branch 'main' into pr/Laisky/23
This commit is contained in:
@@ -2,11 +2,12 @@ package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
)
|
||||
|
||||
var HTTPClient *http.Client
|
||||
|
||||
@@ -1,16 +1,30 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/songquanpeng/one-api/common/env"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if SessionSecret == "" {
|
||||
fmt.Println("SESSION_SECRET not set, using random secret")
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
panic(fmt.Sprintf("failed to generate random secret: %v", err))
|
||||
}
|
||||
|
||||
SessionSecret = base64.StdEncoding.EncodeToString(key)
|
||||
}
|
||||
}
|
||||
|
||||
var SystemName = "One API"
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var Footer = ""
|
||||
@@ -23,7 +37,7 @@ var DisplayTokenStatEnabled = true
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
var SessionSecret = os.Getenv("SESSION_SECRET")
|
||||
|
||||
var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
@@ -112,6 +126,7 @@ var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
|
||||
var IdleTimeout = env.Int("IDLE_TIMEOUT", 30) // unit is second
|
||||
|
||||
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
|
||||
|
||||
@@ -3,9 +3,12 @@ package ctxkey
|
||||
const (
|
||||
Config = "config"
|
||||
Id = "id"
|
||||
RequestId = "X-Oneapi-Request-Id"
|
||||
Username = "username"
|
||||
Role = "role"
|
||||
Status = "status"
|
||||
ChannelModel = "channel_model"
|
||||
ChannelRatio = "channel_ratio"
|
||||
Channel = "channel"
|
||||
ChannelId = "channel_id"
|
||||
SpecificChannelId = "specific_channel_id"
|
||||
@@ -15,10 +18,12 @@ const (
|
||||
Group = "group"
|
||||
ModelMapping = "model_mapping"
|
||||
ChannelName = "channel_name"
|
||||
ContentType = "content_type"
|
||||
TokenId = "token_id"
|
||||
TokenName = "token_name"
|
||||
BaseURL = "base_url"
|
||||
AvailableModels = "available_models"
|
||||
KeyRequestBody = "key_request_body"
|
||||
SystemPrompt = "system_prompt"
|
||||
Meta = "meta"
|
||||
)
|
||||
|
||||
@@ -4,41 +4,50 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
)
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(ctxkey.KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
func GetRequestBody(c *gin.Context) (requestBody []byte, err error) {
|
||||
if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil {
|
||||
return requestBodyCache.([]byte), nil
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
requestBody, err = io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "read request body failed")
|
||||
}
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(ctxkey.KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
requestBody, err := GetRequestBody(c)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "get request body failed")
|
||||
}
|
||||
|
||||
// check v should be a pointer
|
||||
if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr {
|
||||
return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
err = json.Unmarshal(requestBody, v)
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
err = c.ShouldBind(&v)
|
||||
err = c.ShouldBind(v)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "unmarshal request body failed")
|
||||
}
|
||||
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
|
||||
63
common/helper/audio.go
Normal file
63
common/helper/audio.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||
if data == nil {
|
||||
return "", errors.New("data is nil")
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "*-"+filename)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, data)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
|
||||
// GetAudioTokens returns the number of tokens in an audio file.
|
||||
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) {
|
||||
filename, err := SaveTmpFile("audio", audio)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to save audio to temporary file")
|
||||
}
|
||||
defer os.Remove(filename)
|
||||
|
||||
duration, err := GetAudioDuration(ctx, filename)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio tokens")
|
||||
}
|
||||
|
||||
return int(math.Ceil(duration)) * tokensPerSecond, nil
|
||||
}
|
||||
|
||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||
c := exec.CommandContext(ctx, "/usr/bin/ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
|
||||
// Actually gpt-4-audio calculates tokens with 0.1s precision,
|
||||
// while whisper calculates tokens with 1s precision
|
||||
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||
}
|
||||
55
common/helper/audio_test.go
Normal file
55
common/helper/audio_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAudioDuration(t *testing.T) {
|
||||
t.Run("should return correct duration for a valid audio file", func(t *testing.T) {
|
||||
tmpFile, err := os.CreateTemp("", "test_audio*.mp3")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// download test audio file
|
||||
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, err = io.Copy(tmpFile, resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tmpFile.Close())
|
||||
|
||||
duration, err := GetAudioDuration(context.Background(), tmpFile.Name())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, duration, 3.904)
|
||||
})
|
||||
|
||||
t.Run("should return an error for a non-existent file", func(t *testing.T) {
|
||||
_, err := GetAudioDuration(context.Background(), "non_existent_file.mp3")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAudioTokens(t *testing.T) {
|
||||
t.Run("should return correct tokens for a valid audio file", func(t *testing.T) {
|
||||
// download test audio file
|
||||
resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a")
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
tokens, err := GetAudioTokens(context.Background(), resp.Body, 50)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tokens, 200)
|
||||
})
|
||||
|
||||
t.Run("should return an error for a non-existent file", func(t *testing.T) {
|
||||
_, err := GetAudioTokens(context.Background(), nil, 1)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"html/template"
|
||||
"log"
|
||||
"net"
|
||||
@@ -11,6 +9,9 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
)
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetTimestamp get current timestamp in seconds
|
||||
func GetTimestamp() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
@@ -3,11 +3,12 @@ package common
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,15 +28,15 @@ func printHelp() {
|
||||
func Init() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
fmt.Println(Version)
|
||||
os.Exit(0)
|
||||
}
|
||||
// if *PrintVersion {
|
||||
// fmt.Println(Version)
|
||||
// os.Exit(0)
|
||||
// }
|
||||
|
||||
if *PrintHelp {
|
||||
printHelp()
|
||||
os.Exit(0)
|
||||
}
|
||||
// if *PrintHelp {
|
||||
// printHelp()
|
||||
// os.Exit(0)
|
||||
// }
|
||||
|
||||
if os.Getenv("SESSION_SECRET") != "" {
|
||||
if os.Getenv("SESSION_SECRET") == "random_string" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"net"
|
||||
"net/smtp"
|
||||
@@ -18,7 +19,7 @@ func shouldAuth() bool {
|
||||
|
||||
func SendEmail(subject string, receiver string, content string) error {
|
||||
if receiver == "" {
|
||||
return fmt.Errorf("receiver is empty")
|
||||
return errors.Errorf("receiver is empty")
|
||||
}
|
||||
if config.SMTPFrom == "" { // for compatibility
|
||||
config.SMTPFrom = config.SMTPAccount
|
||||
@@ -57,7 +58,7 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
var err error
|
||||
if config.SMTPPort == 465 {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
InsecureSkipVerify: false,
|
||||
ServerName: config.SMTPServer,
|
||||
}
|
||||
conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
)
|
||||
|
||||
@@ -18,5 +18,5 @@ func Notify(by string, title string, description string, content string) error {
|
||||
if by == ByMessagePusher {
|
||||
return SendMessage(title, description, content)
|
||||
}
|
||||
return fmt.Errorf("unknown notify method: %s", by)
|
||||
return errors.Errorf("unknown notify method: %s", by)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package message
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetUUID() string {
|
||||
|
||||
Reference in New Issue
Block a user